diff --git a/component/sniffer/dispatcher.go b/component/sniffer/dispatcher.go index bf0b1bb35..fa1c68276 100644 --- a/component/sniffer/dispatcher.go +++ b/component/sniffer/dispatcher.go @@ -26,14 +26,14 @@ var ( var Dispatcher *SnifferDispatcher type SnifferDispatcher struct { - enable bool - sniffers map[sniffer.Sniffer]SnifferConfig - forceDomain *trie.DomainSet - skipSNI *trie.DomainSet - skipList *cache.LruCache[string, uint8] - rwMux sync.RWMutex - forceDnsMapping bool - parsePureIp bool + enable bool + sniffers map[sniffer.Sniffer]SnifferConfig + forceDomain *trie.DomainSet + skipSNI *trie.DomainSet + skipList *cache.LruCache[string, uint8] + rwMux sync.RWMutex + forceDnsMapping bool + parsePureIp bool } func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata) { diff --git a/component/trie/domain.go b/component/trie/domain.go index 86e5245aa..3decbb025 100644 --- a/component/trie/domain.go +++ b/component/trie/domain.go @@ -25,7 +25,7 @@ func ValidAndSplitDomain(domain string) ([]string, bool) { if domain != "" && domain[len(domain)-1] == '.' { return nil, false } - domain=strings.ToLower(domain) + domain = strings.ToLower(domain) parts := strings.Split(domain, domainStep) if len(parts) == 1 { if parts[0] == "" { @@ -126,6 +126,9 @@ func (t *DomainTrie[T]) Optimize() { func (t *DomainTrie[T]) Foreach(print func(domain string, data T)) { for key, data := range t.root.getChildren() { recursion([]string{key}, data, print) + if data != nil && data.inited { + print(joinDomain([]string{key}), data.data) + } } } diff --git a/component/trie/domain_set_test.go b/component/trie/domain_set_test.go index 090bd495d..c4160f6c3 100644 --- a/component/trie/domain_set_test.go +++ b/component/trie/domain_set_test.go @@ -15,6 +15,9 @@ func TestDomainSet(t *testing.T) { "www.google.com", "test.a.net", "test.a.oc", + "Mijia Cloud", + ".qq.com", + "+.cn", } for _, domain := range domainSet { @@ -22,8 +25,13 @@ func TestDomainSet(t *testing.T) { } set := tree.NewDomainSet() assert.NotNil(t, set) + assert.True(t, set.Has("test.cn")) + assert.True(t, set.Has("cn")) + assert.True(t, set.Has("Mijia Cloud")) assert.True(t, set.Has("test.a.net")) + assert.True(t, set.Has("www.qq.com")) assert.True(t, set.Has("google.com")) + assert.False(t, set.Has("qq.com")) assert.False(t, set.Has("www.baidu.com")) }