Feature: support domain in fallback filter (#964)

This commit is contained in:
Melvin 2020-09-28 22:17:10 +08:00 committed by GitHub
parent e09931dcf7
commit a6444bb449
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 19 deletions

View File

@ -69,6 +69,7 @@ type DNS struct {
type FallbackFilter struct {
GeoIP bool `yaml:"geoip"`
IPCIDR []*net.IPNet `yaml:"ipcidr"`
Domain []string `yaml:"domain"`
}
// Experimental config
@ -103,6 +104,7 @@ type RawDNS struct {
type RawFallbackFilter struct {
GeoIP bool `yaml:"geoip"`
IPCIDR []string `yaml:"ipcidr"`
Domain []string `yaml:"domain"`
}
type RawConfig struct {
@ -561,6 +563,7 @@ func parseDNS(cfg RawDNS, hosts *trie.DomainTrie) (*DNS, error) {
if fallbackip, err := parseFallbackIPCIDR(cfg.FallbackFilter.IPCIDR); err == nil {
dnsCfg.FallbackFilter.IPCIDR = fallbackip
}
dnsCfg.FallbackFilter.Domain = cfg.FallbackFilter.Domain
if cfg.UseHosts {
dnsCfg.Hosts = hosts

View File

@ -4,9 +4,10 @@ import (
"net"
"github.com/Dreamacro/clash/component/mmdb"
"github.com/Dreamacro/clash/component/trie"
)
type fallbackFilter interface {
type fallbackIPFilter interface {
Match(net.IP) bool
}
@ -24,3 +25,22 @@ type ipnetFilter struct {
func (inf *ipnetFilter) Match(ip net.IP) bool {
return inf.ipnet.Contains(ip)
}
type fallbackDomainFilter interface {
Match(domain string) bool
}
type domainFilter struct {
tree *trie.DomainTrie
}
func NewDomainFilter(domains []string) *domainFilter {
df := domainFilter{tree: trie.New()}
for _, domain := range domains {
df.tree.Insert(domain, "")
}
return &df
}
func (df *domainFilter) Match(domain string) bool {
return df.tree.Search(domain) != nil
}

View File

@ -7,6 +7,7 @@ import (
"fmt"
"math/rand"
"net"
"strings"
"time"
"github.com/Dreamacro/clash/common/cache"
@ -34,13 +35,14 @@ type result struct {
}
type Resolver struct {
ipv6 bool
hosts *trie.DomainTrie
main []dnsClient
fallback []dnsClient
fallbackFilters []fallbackFilter
group singleflight.Group
lruCache *cache.LruCache
ipv6 bool
hosts *trie.DomainTrie
main []dnsClient
fallback []dnsClient
fallbackDomainFilters []fallbackDomainFilter
fallbackIPFilters []fallbackIPFilter
group singleflight.Group
lruCache *cache.LruCache
}
// ResolveIP request with TypeA and TypeAAAA, priority return TypeA
@ -78,8 +80,8 @@ func (r *Resolver) ResolveIPv6(host string) (ip net.IP, err error) {
return r.resolveIP(host, D.TypeAAAA)
}
func (r *Resolver) shouldFallback(ip net.IP) bool {
for _, filter := range r.fallbackFilters {
func (r *Resolver) shouldIPFallback(ip net.IP) bool {
for _, filter := range r.fallbackIPFilters {
if filter.Match(ip) {
return true
}
@ -126,7 +128,7 @@ func (r *Resolver) exchangeWithoutCache(m *D.Msg) (msg *D.Msg, err error) {
isIPReq := isIPRequest(q)
if isIPReq {
return r.fallbackExchange(m)
return r.ipExchange(m)
}
return r.batchExchange(r.main, m)
@ -170,19 +172,49 @@ func (r *Resolver) batchExchange(clients []dnsClient, m *D.Msg) (msg *D.Msg, err
return
}
func (r *Resolver) fallbackExchange(m *D.Msg) (msg *D.Msg, err error) {
func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool {
if r.fallback == nil || len(r.fallbackDomainFilters) == 0 {
return false
}
domain := r.msgToDomain(m)
if domain == "" {
return false
}
for _, df := range r.fallbackDomainFilters {
if df.Match(domain) {
return true
}
}
return false
}
func (r *Resolver) ipExchange(m *D.Msg) (msg *D.Msg, err error) {
onlyFallback := r.shouldOnlyQueryFallback(m)
if onlyFallback {
res := <-r.asyncExchange(r.fallback, m)
return res.Msg, res.Error
}
msgCh := r.asyncExchange(r.main, m)
if r.fallback == nil {
if r.fallback == nil { // directly return if no fallback servers are available
res := <-msgCh
msg, err = res.Msg, res.Error
return
}
fallbackMsg := r.asyncExchange(r.fallback, m)
res := <-msgCh
if res.Error == nil {
if ips := r.msgToIP(res.Msg); len(ips) != 0 {
if !r.shouldFallback(ips[0]) {
msg = res.Msg
if !r.shouldIPFallback(ips[0]) {
msg = res.Msg // no need to wait for fallback result
err = res.Error
return msg, err
}
@ -240,6 +272,14 @@ func (r *Resolver) msgToIP(msg *D.Msg) []net.IP {
return ips
}
func (r *Resolver) msgToDomain(msg *D.Msg) string {
if len(msg.Question) > 0 {
return strings.TrimRight(msg.Question[0].Name, ".")
}
return ""
}
func (r *Resolver) asyncExchange(client []dnsClient, msg *D.Msg) <-chan *result {
ch := make(chan *result, 1)
go func() {
@ -257,6 +297,7 @@ type NameServer struct {
type FallbackFilter struct {
GeoIP bool
IPCIDR []*net.IPNet
Domain []string
}
type Config struct {
@ -286,14 +327,19 @@ func NewResolver(config Config) *Resolver {
r.fallback = transform(config.Fallback, defaultResolver)
}
fallbackFilters := []fallbackFilter{}
fallbackIPFilters := []fallbackIPFilter{}
if config.FallbackFilter.GeoIP {
fallbackFilters = append(fallbackFilters, &geoipFilter{})
fallbackIPFilters = append(fallbackIPFilters, &geoipFilter{})
}
for _, ipnet := range config.FallbackFilter.IPCIDR {
fallbackFilters = append(fallbackFilters, &ipnetFilter{ipnet: ipnet})
fallbackIPFilters = append(fallbackIPFilters, &ipnetFilter{ipnet: ipnet})
}
r.fallbackIPFilters = fallbackIPFilters
if len(config.FallbackFilter.Domain) != 0 {
fallbackDomainFilters := []fallbackDomainFilter{NewDomainFilter(config.FallbackFilter.Domain)}
r.fallbackDomainFilters = fallbackDomainFilters
}
r.fallbackFilters = fallbackFilters
return r
}

View File

@ -118,6 +118,7 @@ func updateDNS(c *config.DNS) {
FallbackFilter: dns.FallbackFilter{
GeoIP: c.FallbackFilter.GeoIP,
IPCIDR: c.FallbackFilter.IPCIDR,
Domain: c.FallbackFilter.Domain,
},
Default: c.DefaultNameserver,
}