From a6eb11ce18ef2e456ca23d512061cfbe452be1d5 Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Wed, 6 Apr 2022 04:25:53 +0800 Subject: [PATCH] Refactor: DomainTrie use generics --- component/fakeip/pool.go | 4 ++-- component/fakeip/pool_test.go | 4 ++-- component/resolver/resolver.go | 14 ++++++++------ component/trie/domain.go | 24 ++++++++++++------------ component/trie/domain_test.go | 28 ++++++++++++++-------------- component/trie/node.go | 31 ++++++++++++++----------------- dns/filters.go | 6 +++--- dns/middleware.go | 13 +++++++------ dns/policy.go | 30 ++++++++++++++++++++++++++++++ dns/resolver.go | 14 ++++++++------ 10 files changed, 100 insertions(+), 68 deletions(-) create mode 100644 dns/policy.go diff --git a/component/fakeip/pool.go b/component/fakeip/pool.go index a55e5463e..afc1691b4 100644 --- a/component/fakeip/pool.go +++ b/component/fakeip/pool.go @@ -28,7 +28,7 @@ type Pool struct { broadcast uint32 offset uint32 mux sync.Mutex - host *trie.DomainTrie + host *trie.DomainTrie[bool] ipnet *net.IPNet store store } @@ -138,7 +138,7 @@ func uintToIP(v uint32) net.IP { type Options struct { IPNet *net.IPNet - Host *trie.DomainTrie + Host *trie.DomainTrie[bool] // Size sets the maximum number of entries in memory // and does not work if Persistence is true diff --git a/component/fakeip/pool_test.go b/component/fakeip/pool_test.go index 86e80a2dc..b4add98c0 100644 --- a/component/fakeip/pool_test.go +++ b/component/fakeip/pool_test.go @@ -100,8 +100,8 @@ func TestPool_CycleUsed(t *testing.T) { func TestPool_Skip(t *testing.T) { _, ipnet, _ := net.ParseCIDR("192.168.0.1/29") - tree := trie.New() - tree.Insert("example.com", tree) + tree := trie.New[bool]() + tree.Insert("example.com", true) pools, tempfile, err := createPools(Options{ IPNet: ipnet, Size: 10, diff --git a/component/resolver/resolver.go b/component/resolver/resolver.go index e1100a318..3c8ba384c 100644 --- a/component/resolver/resolver.go +++ b/component/resolver/resolver.go @@ -5,6 +5,7 @@ import ( "errors" "math/rand" "net" + "net/netip" "strings" "time" @@ -23,7 +24,7 @@ var ( DisableIPv6 = true // DefaultHosts aim to resolve hosts - DefaultHosts = trie.New() + DefaultHosts = trie.New[netip.Addr]() // DefaultDNSTimeout defined the default dns request timeout DefaultDNSTimeout = time.Second * 5 @@ -48,8 +49,8 @@ func ResolveIPv4(host string) (net.IP, error) { func ResolveIPv4WithResolver(host string, r Resolver) (net.IP, error) { if node := DefaultHosts.Search(host); node != nil { - if ip := node.Data.(net.IP).To4(); ip != nil { - return ip, nil + if ip := node.Data; ip.Is4() { + return ip.AsSlice(), nil } } @@ -92,8 +93,8 @@ func ResolveIPv6WithResolver(host string, r Resolver) (net.IP, error) { } if node := DefaultHosts.Search(host); node != nil { - if ip := node.Data.(net.IP).To16(); ip != nil { - return ip, nil + if ip := node.Data; ip.Is6() { + return ip.AsSlice(), nil } } @@ -128,7 +129,8 @@ func ResolveIPv6WithResolver(host string, r Resolver) (net.IP, error) { // ResolveIPWithResolver same as ResolveIP, but with a resolver func ResolveIPWithResolver(host string, r Resolver) (net.IP, error) { if node := DefaultHosts.Search(host); node != nil { - return node.Data.(net.IP), nil + ip := node.Data + return ip.Unmap().AsSlice(), nil } if r != nil { diff --git a/component/trie/domain.go b/component/trie/domain.go index 8915eda31..16dd9ae96 100644 --- a/component/trie/domain.go +++ b/component/trie/domain.go @@ -17,8 +17,8 @@ var ErrInvalidDomain = errors.New("invalid domain") // DomainTrie contains the main logic for adding and searching nodes for domain segments. // support wildcard domain (e.g *.google.com) -type DomainTrie struct { - root *Node +type DomainTrie[T comparable] struct { + root *Node[T] } func ValidAndSplitDomain(domain string) ([]string, bool) { @@ -51,7 +51,7 @@ func ValidAndSplitDomain(domain string) ([]string, bool) { // 3. subdomain.*.example.com // 4. .example.com // 5. +.example.com -func (t *DomainTrie) Insert(domain string, data any) error { +func (t *DomainTrie[T]) Insert(domain string, data T) error { parts, valid := ValidAndSplitDomain(domain) if !valid { return ErrInvalidDomain @@ -68,13 +68,13 @@ func (t *DomainTrie) Insert(domain string, data any) error { return nil } -func (t *DomainTrie) insert(parts []string, data any) { +func (t *DomainTrie[T]) insert(parts []string, data T) { node := t.root // reverse storage domain part to save space for i := len(parts) - 1; i >= 0; i-- { part := parts[i] if !node.hasChild(part) { - node.addChild(part, newNode(nil)) + node.addChild(part, newNode(getZero[T]())) } node = node.getChild(part) @@ -88,7 +88,7 @@ func (t *DomainTrie) insert(parts []string, data any) { // 1. static part // 2. wildcard domain // 2. dot wildcard domain -func (t *DomainTrie) Search(domain string) *Node { +func (t *DomainTrie[T]) Search(domain string) *Node[T] { parts, valid := ValidAndSplitDomain(domain) if !valid || parts[0] == "" { return nil @@ -96,26 +96,26 @@ func (t *DomainTrie) Search(domain string) *Node { n := t.search(t.root, parts) - if n == nil || n.Data == nil { + if n == nil || n.Data == getZero[T]() { return nil } return n } -func (t *DomainTrie) search(node *Node, parts []string) *Node { +func (t *DomainTrie[T]) search(node *Node[T], parts []string) *Node[T] { if len(parts) == 0 { return node } if c := node.getChild(parts[len(parts)-1]); c != nil { - if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != nil { + if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != getZero[T]() { return n } } if c := node.getChild(wildcard); c != nil { - if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != nil { + if n := t.search(c, parts[:len(parts)-1]); n != nil && n.Data != getZero[T]() { return n } } @@ -124,6 +124,6 @@ func (t *DomainTrie) search(node *Node, parts []string) *Node { } // New returns a new, empty Trie. -func New() *DomainTrie { - return &DomainTrie{root: newNode(nil)} +func New[T comparable]() *DomainTrie[T] { + return &DomainTrie[T]{root: newNode[T](getZero[T]())} } diff --git a/component/trie/domain_test.go b/component/trie/domain_test.go index 4322699ae..ced44d035 100644 --- a/component/trie/domain_test.go +++ b/component/trie/domain_test.go @@ -1,16 +1,16 @@ package trie import ( - "net" + "net/netip" "testing" "github.com/stretchr/testify/assert" ) -var localIP = net.IP{127, 0, 0, 1} +var localIP = netip.AddrFrom4([4]byte{127, 0, 0, 1}) func TestTrie_Basic(t *testing.T) { - tree := New() + tree := New[netip.Addr]() domains := []string{ "example.com", "google.com", @@ -23,7 +23,7 @@ func TestTrie_Basic(t *testing.T) { node := tree.Search("example.com") assert.NotNil(t, node) - assert.True(t, node.Data.(net.IP).Equal(localIP)) + assert.True(t, node.Data == localIP) assert.NotNil(t, tree.Insert("", localIP)) assert.Nil(t, tree.Search("")) assert.NotNil(t, tree.Search("localhost")) @@ -31,7 +31,7 @@ func TestTrie_Basic(t *testing.T) { } func TestTrie_Wildcard(t *testing.T) { - tree := New() + tree := New[netip.Addr]() domains := []string{ "*.example.com", "sub.*.example.com", @@ -64,7 +64,7 @@ func TestTrie_Wildcard(t *testing.T) { } func TestTrie_Priority(t *testing.T) { - tree := New() + tree := New[int]() domains := []string{ ".dev", "example.dev", @@ -79,18 +79,18 @@ func TestTrie_Priority(t *testing.T) { } for idx, domain := range domains { - tree.Insert(domain, idx) + tree.Insert(domain, idx+1) } - assertFn("test.dev", 0) - assertFn("foo.bar.dev", 0) - assertFn("example.dev", 1) - assertFn("foo.example.dev", 2) - assertFn("test.example.dev", 3) + assertFn("test.dev", 1) + assertFn("foo.bar.dev", 1) + assertFn("example.dev", 2) + assertFn("foo.example.dev", 3) + assertFn("test.example.dev", 4) } func TestTrie_Boundary(t *testing.T) { - tree := New() + tree := New[netip.Addr]() tree.Insert("*.dev", localIP) assert.NotNil(t, tree.Insert(".", localIP)) @@ -99,7 +99,7 @@ func TestTrie_Boundary(t *testing.T) { } func TestTrie_WildcardBoundary(t *testing.T) { - tree := New() + tree := New[netip.Addr]() tree.Insert("+.*", localIP) tree.Insert("stun.*.*.*", localIP) diff --git a/component/trie/node.go b/component/trie/node.go index 05f1d8406..1545d8805 100644 --- a/component/trie/node.go +++ b/component/trie/node.go @@ -1,34 +1,31 @@ package trie // Node is the trie's node -type Node struct { - children map[string]*Node - Data any +type Node[T comparable] struct { + children map[string]*Node[T] + Data T } -func (n *Node) getChild(s string) *Node { - if n.children == nil { - return nil - } - +func (n *Node[T]) getChild(s string) *Node[T] { return n.children[s] } -func (n *Node) hasChild(s string) bool { +func (n *Node[T]) hasChild(s string) bool { return n.getChild(s) != nil } -func (n *Node) addChild(s string, child *Node) { - if n.children == nil { - n.children = map[string]*Node{} - } - +func (n *Node[T]) addChild(s string, child *Node[T]) { n.children[s] = child } -func newNode(data any) *Node { - return &Node{ +func newNode[T comparable](data T) *Node[T] { + return &Node[T]{ Data: data, - children: nil, + children: map[string]*Node[T]{}, } } + +func getZero[T comparable]() T { + var result T + return result +} diff --git a/dns/filters.go b/dns/filters.go index a756d8728..c268d98a4 100644 --- a/dns/filters.go +++ b/dns/filters.go @@ -70,13 +70,13 @@ type fallbackDomainFilter interface { } type domainFilter struct { - tree *trie.DomainTrie + tree *trie.DomainTrie[bool] } func NewDomainFilter(domains []string) *domainFilter { - df := domainFilter{tree: trie.New()} + df := domainFilter{tree: trie.New[bool]()} for _, domain := range domains { - df.tree.Insert(domain, "") + df.tree.Insert(domain, true) } return &df } diff --git a/dns/middleware.go b/dns/middleware.go index 5958fe930..7259df66e 100644 --- a/dns/middleware.go +++ b/dns/middleware.go @@ -2,6 +2,7 @@ package dns import ( "net" + "net/netip" "strings" "time" @@ -20,7 +21,7 @@ type ( middleware func(next handler) handler ) -func withHosts(hosts *trie.DomainTrie) middleware { +func withHosts(hosts *trie.DomainTrie[netip.Addr]) middleware { return func(next handler) handler { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { q := r.Question[0] @@ -34,19 +35,19 @@ func withHosts(hosts *trie.DomainTrie) middleware { return next(ctx, r) } - ip := record.Data.(net.IP) + ip := record.Data msg := r.Copy() - if v4 := ip.To4(); v4 != nil && q.Qtype == D.TypeA { + if ip.Is4() && q.Qtype == D.TypeA { rr := &D.A{} rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL} - rr.A = v4 + rr.A = ip.AsSlice() msg.Answer = []D.RR{rr} - } else if v6 := ip.To16(); v6 != nil && q.Qtype == D.TypeAAAA { + } else if ip.Is6() && q.Qtype == D.TypeAAAA { rr := &D.AAAA{} rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: dnsDefaultTTL} - rr.AAAA = v6 + rr.AAAA = ip.AsSlice() msg.Answer = []D.RR{rr} } else { diff --git a/dns/policy.go b/dns/policy.go new file mode 100644 index 000000000..a8b423e16 --- /dev/null +++ b/dns/policy.go @@ -0,0 +1,30 @@ +package dns + +type Policy struct { + data []dnsClient +} + +func (p *Policy) GetData() []dnsClient { + return p.data +} + +func (p *Policy) Compare(p2 *Policy) int { + if p2 == nil { + return 1 + } + l1 := len(p.data) + l2 := len(p2.data) + if l1 == l2 { + return 0 + } + if l1 > l2 { + return 1 + } + return -1 +} + +func NewPolicy(data []dnsClient) *Policy { + return &Policy{ + data: data, + } +} diff --git a/dns/resolver.go b/dns/resolver.go index 86c3d2d52..c39885cbd 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -6,6 +6,7 @@ import ( "fmt" "math/rand" "net" + "net/netip" "strings" "time" @@ -33,14 +34,14 @@ type result struct { type Resolver struct { ipv6 bool - hosts *trie.DomainTrie + hosts *trie.DomainTrie[netip.Addr] main []dnsClient fallback []dnsClient fallbackDomainFilters []fallbackDomainFilter fallbackIPFilters []fallbackIPFilter group singleflight.Group lruCache *cache.LruCache[string, *D.Msg] - policy *trie.DomainTrie + policy *trie.DomainTrie[*Policy] proxyServer []dnsClient } @@ -194,7 +195,8 @@ func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient { return nil } - return record.Data.([]dnsClient) + p := record.Data + return p.GetData() } func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool { @@ -329,7 +331,7 @@ type Config struct { EnhancedMode C.DNSMode FallbackFilter FallbackFilter Pool *fakeip.Pool - Hosts *trie.DomainTrie + Hosts *trie.DomainTrie[netip.Addr] Policy map[string]NameServer } @@ -355,9 +357,9 @@ func NewResolver(config Config) *Resolver { } if len(config.Policy) != 0 { - r.policy = trie.New() + r.policy = trie.New[*Policy]() for domain, nameserver := range config.Policy { - r.policy.Insert(domain, transform([]NameServer{nameserver}, defaultResolver)) + r.policy.Insert(domain, NewPolicy(transform([]NameServer{nameserver}, defaultResolver))) } }