diff --git a/common/singleflight/singleflight.go b/common/singleflight/singleflight.go new file mode 100644 index 000000000..e31c4eb6f --- /dev/null +++ b/common/singleflight/singleflight.go @@ -0,0 +1,224 @@ +// copy and modify from "golang.org/x/sync/singleflight" + +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package singleflight provides a duplicate function call suppression +// mechanism. +package singleflight + +import ( + "bytes" + "errors" + "fmt" + "runtime" + "runtime/debug" + "sync" +) + +// errGoexit indicates the runtime.Goexit was called in +// the user given function. +var errGoexit = errors.New("runtime.Goexit was called") + +// A panicError is an arbitrary value recovered from a panic +// with the stack trace during the execution of given function. +type panicError struct { + value interface{} + stack []byte +} + +// Error implements error interface. +func (p *panicError) Error() string { + return fmt.Sprintf("%v\n\n%s", p.value, p.stack) +} + +func (p *panicError) Unwrap() error { + err, ok := p.value.(error) + if !ok { + return nil + } + + return err +} + +func newPanicError(v interface{}) error { + stack := debug.Stack() + + // The first line of the stack trace is of the form "goroutine N [status]:" + // but by the time the panic reaches Do the goroutine may no longer exist + // and its status will have changed. Trim out the misleading line. + if line := bytes.IndexByte(stack[:], '\n'); line >= 0 { + stack = stack[line+1:] + } + return &panicError{value: v, stack: stack} +} + +// call is an in-flight or completed singleflight.Do call +type call[T any] struct { + wg sync.WaitGroup + + // These fields are written once before the WaitGroup is done + // and are only read after the WaitGroup is done. + val T + err error + + // These fields are read and written with the singleflight + // mutex held before the WaitGroup is done, and are read but + // not written after the WaitGroup is done. + dups int + chans []chan<- Result[T] +} + +// Group represents a class of work and forms a namespace in +// which units of work can be executed with duplicate suppression. +type Group[T any] struct { + mu sync.Mutex // protects m + m map[string]*call[T] // lazily initialized + + StoreResult bool +} + +// Result holds the results of Do, so they can be passed +// on a channel. +type Result[T any] struct { + Val T + Err error + Shared bool +} + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +// The return value shared indicates whether v was given to multiple callers. +func (g *Group[T]) Do(key string, fn func() (T, error)) (v T, err error, shared bool) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call[T]) + } + if c, ok := g.m[key]; ok { + c.dups++ + g.mu.Unlock() + c.wg.Wait() + + if e, ok := c.err.(*panicError); ok { + panic(e) + } else if c.err == errGoexit { + runtime.Goexit() + } + return c.val, c.err, true + } + c := new(call[T]) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + g.doCall(c, key, fn) + return c.val, c.err, c.dups > 0 +} + +// DoChan is like Do but returns a channel that will receive the +// results when they are ready. +// +// The returned channel will not be closed. +func (g *Group[T]) DoChan(key string, fn func() (T, error)) <-chan Result[T] { + ch := make(chan Result[T], 1) + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call[T]) + } + if c, ok := g.m[key]; ok { + c.dups++ + c.chans = append(c.chans, ch) + g.mu.Unlock() + return ch + } + c := &call[T]{chans: []chan<- Result[T]{ch}} + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + go g.doCall(c, key, fn) + + return ch +} + +// doCall handles the single call for a key. +func (g *Group[T]) doCall(c *call[T], key string, fn func() (T, error)) { + normalReturn := false + recovered := false + + // use double-defer to distinguish panic from runtime.Goexit, + // more details see https://golang.org/cl/134395 + defer func() { + // the given function invoked runtime.Goexit + if !normalReturn && !recovered { + c.err = errGoexit + } + + g.mu.Lock() + defer g.mu.Unlock() + c.wg.Done() + if g.m[key] == c && !g.StoreResult { + delete(g.m, key) + } + + if e, ok := c.err.(*panicError); ok { + // In order to prevent the waiting channels from being blocked forever, + // needs to ensure that this panic cannot be recovered. + if len(c.chans) > 0 { + go panic(e) + select {} // Keep this goroutine around so that it will appear in the crash dump. + } else { + panic(e) + } + } else if c.err == errGoexit { + // Already in the process of goexit, no need to call again + } else { + // Normal return + for _, ch := range c.chans { + ch <- Result[T]{c.val, c.err, c.dups > 0} + } + } + }() + + func() { + defer func() { + if !normalReturn { + // Ideally, we would wait to take a stack trace until we've determined + // whether this is a panic or a runtime.Goexit. + // + // Unfortunately, the only way we can distinguish the two is to see + // whether the recover stopped the goroutine from terminating, and by + // the time we know that, the part of the stack trace relevant to the + // panic has been discarded. + if r := recover(); r != nil { + c.err = newPanicError(r) + } + } + }() + + c.val, c.err = fn() + normalReturn = true + }() + + if !normalReturn { + recovered = true + } +} + +// Forget tells the singleflight to forget about a key. Future calls +// to Do for this key will call the function rather than waiting for +// an earlier call to complete. +func (g *Group[T]) Forget(key string) { + g.mu.Lock() + delete(g.m, key) + g.mu.Unlock() +} + +func (g *Group[T]) Reset() { + g.mu.Lock() + g.m = nil + g.mu.Unlock() +} diff --git a/component/geodata/attr.go b/component/geodata/attr.go index a9742aca9..2fd41ad6c 100644 --- a/component/geodata/attr.go +++ b/component/geodata/attr.go @@ -7,7 +7,7 @@ import ( ) type AttributeList struct { - matcher []AttributeMatcher + matcher []BooleanMatcher } func (al *AttributeList) Match(domain *router.Domain) bool { @@ -23,6 +23,14 @@ func (al *AttributeList) IsEmpty() bool { return len(al.matcher) == 0 } +func (al *AttributeList) String() string { + matcher := make([]string, len(al.matcher)) + for i, match := range al.matcher { + matcher[i] = string(match) + } + return strings.Join(matcher, ",") +} + func parseAttrs(attrs []string) *AttributeList { al := new(AttributeList) for _, attr := range attrs { diff --git a/component/geodata/router/condition.go b/component/geodata/router/condition.go index 5261d2fee..fb47e4a40 100644 --- a/component/geodata/router/condition.go +++ b/component/geodata/router/condition.go @@ -33,12 +33,13 @@ func domainToMatcher(domain *Domain) (strmatcher.Matcher, error) { type DomainMatcher interface { ApplyDomain(string) bool + Count() int } type succinctDomainMatcher struct { set *trie.DomainSet otherMatchers []strmatcher.Matcher - not bool + count int } func (m *succinctDomainMatcher) ApplyDomain(domain string) bool { @@ -51,16 +52,17 @@ func (m *succinctDomainMatcher) ApplyDomain(domain string) bool { } } } - if m.not { - isMatched = !isMatched - } return isMatched } -func NewSuccinctMatcherGroup(domains []*Domain, not bool) (DomainMatcher, error) { +func (m *succinctDomainMatcher) Count() int { + return m.count +} + +func NewSuccinctMatcherGroup(domains []*Domain) (DomainMatcher, error) { t := trie.New[struct{}]() m := &succinctDomainMatcher{ - not: not, + count: len(domains), } for _, d := range domains { switch d.Type { @@ -90,10 +92,10 @@ func NewSuccinctMatcherGroup(domains []*Domain, not bool) (DomainMatcher, error) type v2rayDomainMatcher struct { matchers strmatcher.IndexMatcher - not bool + count int } -func NewMphMatcherGroup(domains []*Domain, not bool) (DomainMatcher, error) { +func NewMphMatcherGroup(domains []*Domain) (DomainMatcher, error) { g := strmatcher.NewMphMatcherGroup() for _, d := range domains { matcherType, f := matcherTypeMap[d.Type] @@ -108,119 +110,80 @@ func NewMphMatcherGroup(domains []*Domain, not bool) (DomainMatcher, error) { g.Build() return &v2rayDomainMatcher{ matchers: g, - not: not, + count: len(domains), }, nil } func (m *v2rayDomainMatcher) ApplyDomain(domain string) bool { - isMatched := len(m.matchers.Match(strings.ToLower(domain))) > 0 - if m.not { - isMatched = !isMatched - } - return isMatched + return len(m.matchers.Match(strings.ToLower(domain))) > 0 } -type GeoIPMatcher struct { - countryCode string - reverseMatch bool - cidrSet *cidr.IpCidrSet +func (m *v2rayDomainMatcher) Count() int { + return m.count } -func (m *GeoIPMatcher) Init(cidrs []*CIDR) error { - for _, cidr := range cidrs { - addr, ok := netip.AddrFromSlice(cidr.Ip) - if !ok { - return fmt.Errorf("error when loading GeoIP: invalid IP: %s", cidr.Ip) - } - err := m.cidrSet.AddIpCidr(netip.PrefixFrom(addr, int(cidr.Prefix))) - if err != nil { - return fmt.Errorf("error when loading GeoIP: %w", err) - } - } - return m.cidrSet.Merge() +type notDomainMatcher struct { + DomainMatcher } -func (m *GeoIPMatcher) SetReverseMatch(isReverseMatch bool) { - m.reverseMatch = isReverseMatch +func (m notDomainMatcher) ApplyDomain(domain string) bool { + return !m.DomainMatcher.ApplyDomain(domain) +} + +func NewNotDomainMatcherGroup(matcher DomainMatcher) DomainMatcher { + return notDomainMatcher{matcher} +} + +type IPMatcher interface { + Match(ip netip.Addr) bool + Count() int +} + +type geoIPMatcher struct { + cidrSet *cidr.IpCidrSet + count int } // Match returns true if the given ip is included by the GeoIP. -func (m *GeoIPMatcher) Match(ip netip.Addr) bool { - match := m.cidrSet.IsContain(ip) - if m.reverseMatch { - return !match +func (m *geoIPMatcher) Match(ip netip.Addr) bool { + return m.cidrSet.IsContain(ip) +} + +func (m *geoIPMatcher) Count() int { + return m.count +} + +func NewGeoIPMatcher(cidrList []*CIDR) (IPMatcher, error) { + m := &geoIPMatcher{ + cidrSet: cidr.NewIpCidrSet(), + count: len(cidrList), } - return match -} - -// GeoIPMatcherContainer is a container for GeoIPMatchers. It keeps unique copies of GeoIPMatcher by country code. -type GeoIPMatcherContainer struct { - matchers []*GeoIPMatcher -} - -// Add adds a new GeoIP set into the container. -// If the country code of GeoIP is not empty, GeoIPMatcherContainer will try to find an existing one, instead of adding a new one. -func (c *GeoIPMatcherContainer) Add(geoip *GeoIP) (*GeoIPMatcher, error) { - if len(geoip.CountryCode) > 0 { - for _, m := range c.matchers { - if m.countryCode == geoip.CountryCode && m.reverseMatch == geoip.ReverseMatch { - return m, nil - } + for _, cidr := range cidrList { + addr, ok := netip.AddrFromSlice(cidr.Ip) + if !ok { + return nil, fmt.Errorf("error when loading GeoIP: invalid IP: %s", cidr.Ip) + } + err := m.cidrSet.AddIpCidr(netip.PrefixFrom(addr, int(cidr.Prefix))) + if err != nil { + return nil, fmt.Errorf("error when loading GeoIP: %w", err) } } - - m := &GeoIPMatcher{ - countryCode: geoip.CountryCode, - reverseMatch: geoip.ReverseMatch, - cidrSet: cidr.NewIpCidrSet(), - } - if err := m.Init(geoip.Cidr); err != nil { - return nil, err - } - if len(geoip.CountryCode) > 0 { - c.matchers = append(c.matchers, m) - } - return m, nil -} - -var globalGeoIPContainer GeoIPMatcherContainer - -type MultiGeoIPMatcher struct { - matchers []*GeoIPMatcher -} - -func NewGeoIPMatcher(geoip *GeoIP) (*GeoIPMatcher, error) { - matcher, err := globalGeoIPContainer.Add(geoip) + err := m.cidrSet.Merge() if err != nil { return nil, err } - return matcher, nil + return m, nil } -func (m *MultiGeoIPMatcher) ApplyIp(ip netip.Addr) bool { - for _, matcher := range m.matchers { - if matcher.Match(ip) { - return true - } - } - - return false +type notIPMatcher struct { + IPMatcher } -func NewMultiGeoIPMatcher(geoips []*GeoIP) (*MultiGeoIPMatcher, error) { - var matchers []*GeoIPMatcher - for _, geoip := range geoips { - matcher, err := globalGeoIPContainer.Add(geoip) - if err != nil { - return nil, err - } - matchers = append(matchers, matcher) - } - - matcher := &MultiGeoIPMatcher{ - matchers: matchers, - } - - return matcher, nil +func (m notIPMatcher) Match(ip netip.Addr) bool { + return !m.IPMatcher.Match(ip) +} + +func NewNotIpMatcherGroup(matcher IPMatcher) IPMatcher { + return notIPMatcher{matcher} } diff --git a/component/geodata/utils.go b/component/geodata/utils.go index 981d7eba4..a16e255e1 100644 --- a/component/geodata/utils.go +++ b/component/geodata/utils.go @@ -5,8 +5,7 @@ import ( "fmt" "strings" - "golang.org/x/sync/singleflight" - + "github.com/metacubex/mihomo/common/singleflight" "github.com/metacubex/mihomo/component/geodata/router" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/log" @@ -71,21 +70,22 @@ func SetSiteMatcher(newMatcher string) { func Verify(name string) error { switch name { case C.GeositeName: - _, _, err := LoadGeoSiteMatcher("CN") + _, err := LoadGeoSiteMatcher("CN") return err case C.GeoipName: - _, _, err := LoadGeoIPMatcher("CN") + _, err := LoadGeoIPMatcher("CN") return err default: return fmt.Errorf("not support name") } } -var loadGeoSiteMatcherSF = singleflight.Group{} +var loadGeoSiteMatcherListSF = singleflight.Group[[]*router.Domain]{StoreResult: true} +var loadGeoSiteMatcherSF = singleflight.Group[router.DomainMatcher]{StoreResult: true} -func LoadGeoSiteMatcher(countryCode string) (router.DomainMatcher, int, error) { +func LoadGeoSiteMatcher(countryCode string) (router.DomainMatcher, error) { if countryCode == "" { - return nil, 0, fmt.Errorf("country code could not be empty") + return nil, fmt.Errorf("country code could not be empty") } not := false @@ -97,73 +97,84 @@ func LoadGeoSiteMatcher(countryCode string) (router.DomainMatcher, int, error) { parts := strings.Split(countryCode, "@") if len(parts) == 0 { - return nil, 0, errors.New("empty rule") + return nil, errors.New("empty rule") } listName := strings.TrimSpace(parts[0]) attrVal := parts[1:] + attrs := parseAttrs(attrVal) if listName == "" { - return nil, 0, fmt.Errorf("empty listname in rule: %s", countryCode) + return nil, fmt.Errorf("empty listname in rule: %s", countryCode) } - v, err, shared := loadGeoSiteMatcherSF.Do(listName, func() (interface{}, error) { - geoLoader, err := GetGeoDataLoader(geoLoaderName) + matcherName := listName + if !attrs.IsEmpty() { + matcherName += "@" + attrs.String() + } + matcher, err, shared := loadGeoSiteMatcherSF.Do(matcherName, func() (router.DomainMatcher, error) { + log.Infoln("Load GeoSite rule: %s", matcherName) + domains, err, shared := loadGeoSiteMatcherListSF.Do(listName, func() ([]*router.Domain, error) { + geoLoader, err := GetGeoDataLoader(geoLoaderName) + if err != nil { + return nil, err + } + return geoLoader.LoadGeoSite(listName) + }) if err != nil { + if !shared { + loadGeoSiteMatcherListSF.Forget(listName) // don't store the error result + } return nil, err } - return geoLoader.LoadGeoSite(listName) + + if attrs.IsEmpty() { + if strings.Contains(countryCode, "@") { + log.Warnln("empty attribute list: %s", countryCode) + } + } else { + filteredDomains := make([]*router.Domain, 0, len(domains)) + hasAttrMatched := false + for _, domain := range domains { + if attrs.Match(domain) { + hasAttrMatched = true + filteredDomains = append(filteredDomains, domain) + } + } + if !hasAttrMatched { + log.Warnln("attribute match no rule: geosite: %s", countryCode) + } + domains = filteredDomains + } + + /** + linear: linear algorithm + matcher, err := router.NewDomainMatcher(domains) + mph:minimal perfect hash algorithm + */ + if geoSiteMatcher == "mph" { + return router.NewMphMatcherGroup(domains) + } else { + return router.NewSuccinctMatcherGroup(domains) + } }) if err != nil { if !shared { - loadGeoSiteMatcherSF.Forget(listName) // don't store the error result + loadGeoSiteMatcherSF.Forget(matcherName) // don't store the error result } - return nil, 0, err + return nil, err } - domains := v.([]*router.Domain) - - attrs := parseAttrs(attrVal) - if attrs.IsEmpty() { - if strings.Contains(countryCode, "@") { - log.Warnln("empty attribute list: %s", countryCode) - } - } else { - filteredDomains := make([]*router.Domain, 0, len(domains)) - hasAttrMatched := false - for _, domain := range domains { - if attrs.Match(domain) { - hasAttrMatched = true - filteredDomains = append(filteredDomains, domain) - } - } - if !hasAttrMatched { - log.Warnln("attribute match no rule: geosite: %s", countryCode) - } - domains = filteredDomains + if not { + matcher = router.NewNotDomainMatcherGroup(matcher) } - /** - linear: linear algorithm - matcher, err := router.NewDomainMatcher(domains) - mph:minimal perfect hash algorithm - */ - var matcher router.DomainMatcher - if geoSiteMatcher == "mph" { - matcher, err = router.NewMphMatcherGroup(domains, not) - } else { - matcher, err = router.NewSuccinctMatcherGroup(domains, not) - } - if err != nil { - return nil, 0, err - } - - return matcher, len(domains), nil + return matcher, nil } -var loadGeoIPMatcherSF = singleflight.Group{} +var loadGeoIPMatcherSF = singleflight.Group[router.IPMatcher]{StoreResult: true} -func LoadGeoIPMatcher(country string) (*router.GeoIPMatcher, int, error) { +func LoadGeoIPMatcher(country string) (router.IPMatcher, error) { if len(country) == 0 { - return nil, 0, fmt.Errorf("country code could not be empty") + return nil, fmt.Errorf("country code could not be empty") } not := false @@ -173,35 +184,33 @@ func LoadGeoIPMatcher(country string) (*router.GeoIPMatcher, int, error) { } country = strings.ToLower(country) - v, err, shared := loadGeoIPMatcherSF.Do(country, func() (interface{}, error) { + matcher, err, shared := loadGeoIPMatcherSF.Do(country, func() (router.IPMatcher, error) { + log.Infoln("Load GeoIP rule: %s", country) geoLoader, err := GetGeoDataLoader(geoLoaderName) if err != nil { return nil, err } - return geoLoader.LoadGeoIP(country) + cidrList, err := geoLoader.LoadGeoIP(country) + if err != nil { + return nil, err + } + return router.NewGeoIPMatcher(cidrList) }) if err != nil { if !shared { loadGeoIPMatcherSF.Forget(country) // don't store the error result + log.Warnln("Load GeoIP rule: %s", country) } - return nil, 0, err + return nil, err } - records := v.([]*router.CIDR) - - geoIP := &router.GeoIP{ - CountryCode: country, - Cidr: records, - ReverseMatch: not, + if not { + matcher = router.NewNotIpMatcherGroup(matcher) } - - matcher, err := router.NewGeoIPMatcher(geoIP) - if err != nil { - return nil, 0, err - } - return matcher, len(records), nil + return matcher, nil } func ClearCache() { - loadGeoSiteMatcherSF = singleflight.Group{} - loadGeoIPMatcherSF = singleflight.Group{} + loadGeoSiteMatcherListSF.Reset() + loadGeoSiteMatcherSF.Reset() + loadGeoIPMatcherSF.Reset() } diff --git a/component/updater/update_geo.go b/component/updater/update_geo.go index b07cd3158..4d16c1287 100644 --- a/component/updater/update_geo.go +++ b/component/updater/update_geo.go @@ -137,7 +137,7 @@ func getUpdateTime() (err error, time time.Time) { return nil, fileInfo.ModTime() } -func RegisterGeoUpdater(onSuccess func()) { +func RegisterGeoUpdater() { if C.GeoUpdateInterval <= 0 { log.Errorln("[GEO] Invalid update interval: %d", C.GeoUpdateInterval) return @@ -159,8 +159,6 @@ func RegisterGeoUpdater(onSuccess func()) { if err := UpdateGeoDatabases(); err != nil { log.Errorln("[GEO] Failed to update GEO database: %s", err.Error()) return - } else { - onSuccess() } } @@ -168,8 +166,6 @@ func RegisterGeoUpdater(onSuccess func()) { log.Infoln("[GEO] updating database every %d hours", C.GeoUpdateInterval) if err := UpdateGeoDatabases(); err != nil { log.Errorln("[GEO] Failed to update GEO database: %s", err.Error()) - } else { - onSuccess() } } }() diff --git a/dns/resolver.go b/dns/resolver.go index 7b9aafd0e..3cc7a41e3 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -8,6 +8,7 @@ import ( "github.com/metacubex/mihomo/common/arc" "github.com/metacubex/mihomo/common/lru" + "github.com/metacubex/mihomo/common/singleflight" "github.com/metacubex/mihomo/component/fakeip" "github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/trie" @@ -18,7 +19,6 @@ import ( D "github.com/miekg/dns" "github.com/samber/lo" "golang.org/x/exp/maps" - "golang.org/x/sync/singleflight" ) type dnsClient interface { @@ -44,7 +44,7 @@ type Resolver struct { fallback []dnsClient fallbackDomainFilters []C.Rule fallbackIPFilters []C.Rule - group singleflight.Group + group singleflight.Group[*D.Msg] cache dnsCache policy []dnsPolicy proxyServer []dnsClient @@ -169,19 +169,20 @@ func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.M retryNum := 0 retryMax := 3 - fn := func() (result any, err error) { + fn := func() (result *D.Msg, err error) { ctx, cancel := context.WithTimeout(context.Background(), resolver.DefaultDNSTimeout) // reset timeout in singleflight defer cancel() cache := false defer func() { if err != nil { - result = retryNum + result = &D.Msg{} + result.Opcode = retryNum retryNum++ return } - msg := result.(*D.Msg) + msg := result if cache { // OPT RRs MUST NOT be cached, forwarded, or stored in or loaded from master files. @@ -208,7 +209,7 @@ func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.M ch := r.group.DoChan(q.String(), fn) - var result singleflight.Result + var result singleflight.Result[*D.Msg] select { case result = <-ch: @@ -221,7 +222,7 @@ func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.M go func() { // start a retrying monitor in background result := <-ch ret, err, shared := result.Val, result.Err, result.Shared - if err != nil && !shared && ret.(int) < retryMax { // retry + if err != nil && !shared && ret.Opcode < retryMax { // retry r.group.DoChan(q.String(), fn) } }() @@ -230,12 +231,12 @@ func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.M } ret, err, shared := result.Val, result.Err, result.Shared - if err != nil && !shared && ret.(int) < retryMax { // retry + if err != nil && !shared && ret.Opcode < retryMax { // retry r.group.DoChan(q.String(), fn) } if err == nil { - msg = ret.(*D.Msg) + msg = ret if shared { msg = msg.Copy() } diff --git a/go.mod b/go.mod index 90c252798..bc9fccba9 100644 --- a/go.mod +++ b/go.mod @@ -53,7 +53,6 @@ require ( golang.org/x/crypto v0.26.0 golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 golang.org/x/net v0.28.0 - golang.org/x/sync v0.8.0 golang.org/x/sys v0.23.0 google.golang.org/protobuf v1.34.2 gopkg.in/yaml.v3 v3.0.1 @@ -108,6 +107,7 @@ require ( gitlab.com/yawning/bsaes.git v0.0.0-20190805113838-0a714cd429ec // indirect go.uber.org/mock v0.4.0 // indirect golang.org/x/mod v0.19.0 // indirect + golang.org/x/sync v0.8.0 // indirect golang.org/x/text v0.17.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.23.0 // indirect diff --git a/hub/route/configs.go b/hub/route/configs.go index a4dcaa525..c7b340c60 100644 --- a/hub/route/configs.go +++ b/hub/route/configs.go @@ -408,17 +408,5 @@ func updateGeoDatabases(w http.ResponseWriter, r *http.Request) { return } - cfg, err := executor.ParseWithPath(C.Path.Config()) - if err != nil { - log.Errorln("[GEO] update GEO databases failed: %v", err) - render.Status(r, http.StatusInternalServerError) - render.JSON(w, r, newError("Error parsing configuration")) - return - } - - log.Warnln("[GEO] update GEO databases success, applying config") - - executor.ApplyConfig(cfg, false) - render.NoContent(w, r) } diff --git a/main.go b/main.go index cd903ce6d..06a04ca17 100644 --- a/main.go +++ b/main.go @@ -120,17 +120,7 @@ func main() { } if C.GeoAutoUpdate { - updater.RegisterGeoUpdater(func() { - cfg, err := executor.ParseWithPath(C.Path.Config()) - if err != nil { - log.Errorln("[GEO] update GEO databases failed: %v", err) - return - } - - log.Warnln("[GEO] update GEO databases success, applying config") - - executor.ApplyConfig(cfg, false) - }) + updater.RegisterGeoUpdater() } defer executor.Shutdown() diff --git a/rules/common/geoip.go b/rules/common/geoip.go index b50680a47..839253212 100644 --- a/rules/common/geoip.go +++ b/rules/common/geoip.go @@ -1,6 +1,7 @@ package common import ( + "errors" "fmt" "strings" @@ -14,12 +15,11 @@ import ( type GEOIP struct { *Base - country string - adapter string - noResolveIP bool - isSourceIP bool - geoIPMatcher *router.GeoIPMatcher - recodeSize int + country string + adapter string + noResolveIP bool + isSourceIP bool + geodata bool } var _ C.Rule = (*GEOIP)(nil) @@ -78,7 +78,11 @@ func (g *GEOIP) Match(metadata *C.Metadata) (bool, string) { return false, g.adapter } - match := g.geoIPMatcher.Match(ip) + matcher, err := g.GetIPMatcher() + if err != nil { + return false, "" + } + match := matcher.Match(ip) if match && !g.isSourceIP { metadata.DstGeoIP = append(metadata.DstGeoIP, g.country) } @@ -101,12 +105,22 @@ func (g *GEOIP) GetCountry() string { return g.country } -func (g *GEOIP) GetIPMatcher() *router.GeoIPMatcher { - return g.geoIPMatcher +func (g *GEOIP) GetIPMatcher() (router.IPMatcher, error) { + if g.geodata { + geoIPMatcher, err := geodata.LoadGeoIPMatcher(g.country) + if err != nil { + return nil, fmt.Errorf("[GeoIP] %w", err) + } + return geoIPMatcher, nil + } + return nil, errors.New("geoip country not set") } func (g *GEOIP) GetRecodeSize() int { - return g.recodeSize + if matcher, err := g.GetIPMatcher(); err == nil { + return matcher.Count() + } + return 0 } func NewGEOIP(country string, adapter string, isSrc, noResolveIP bool) (*GEOIP, error) { @@ -116,31 +130,23 @@ func NewGEOIP(country string, adapter string, isSrc, noResolveIP bool) (*GEOIP, } country = strings.ToLower(country) + geoip := &GEOIP{ + Base: &Base{}, + country: country, + adapter: adapter, + noResolveIP: noResolveIP, + isSourceIP: isSrc, + } if !C.GeodataMode || country == "lan" { - geoip := &GEOIP{ - Base: &Base{}, - country: country, - adapter: adapter, - noResolveIP: noResolveIP, - isSourceIP: isSrc, - } return geoip, nil } - geoIPMatcher, size, err := geodata.LoadGeoIPMatcher(country) + geoip.geodata = true + geoIPMatcher, err := geoip.GetIPMatcher() // test load if err != nil { - return nil, fmt.Errorf("[GeoIP] %w", err) + return nil, err } - log.Infoln("Start initial GeoIP rule %s => %s, records: %d", country, adapter, size) - geoip := &GEOIP{ - Base: &Base{}, - country: country, - adapter: adapter, - noResolveIP: noResolveIP, - isSourceIP: isSrc, - geoIPMatcher: geoIPMatcher, - recodeSize: size, - } + log.Infoln("Finished initial GeoIP rule %s => %s, records: %d", country, adapter, geoIPMatcher.Count()) return geoip, nil } diff --git a/rules/common/geosite.go b/rules/common/geosite.go index 1e3c1ab5a..a728e9917 100644 --- a/rules/common/geosite.go +++ b/rules/common/geosite.go @@ -15,7 +15,6 @@ type GEOSITE struct { *Base country string adapter string - matcher router.DomainMatcher recodeSize int } @@ -28,7 +27,11 @@ func (gs *GEOSITE) Match(metadata *C.Metadata) (bool, string) { if len(domain) == 0 { return false, "" } - return gs.matcher.ApplyDomain(domain), gs.adapter + matcher, err := gs.GetDomainMatcher() + if err != nil { + return false, "" + } + return matcher.ApplyDomain(domain), gs.adapter } func (gs *GEOSITE) Adapter() string { @@ -39,12 +42,19 @@ func (gs *GEOSITE) Payload() string { return gs.country } -func (gs *GEOSITE) GetDomainMatcher() router.DomainMatcher { - return gs.matcher +func (gs *GEOSITE) GetDomainMatcher() (router.DomainMatcher, error) { + matcher, err := geodata.LoadGeoSiteMatcher(gs.country) + if err != nil { + return nil, fmt.Errorf("load GeoSite data error, %w", err) + } + return matcher, nil } func (gs *GEOSITE) GetRecodeSize() int { - return gs.recodeSize + if matcher, err := gs.GetDomainMatcher(); err == nil { + return matcher.Count() + } + return 0 } func NewGEOSITE(country string, adapter string) (*GEOSITE, error) { @@ -53,21 +63,19 @@ func NewGEOSITE(country string, adapter string) (*GEOSITE, error) { return nil, err } - matcher, size, err := geodata.LoadGeoSiteMatcher(country) - if err != nil { - return nil, fmt.Errorf("load GeoSite data error, %s", err.Error()) - } - - log.Infoln("Start initial GeoSite rule %s => %s, records: %d", country, adapter, size) - geoSite := &GEOSITE{ - Base: &Base{}, - country: country, - adapter: adapter, - matcher: matcher, - recodeSize: size, + Base: &Base{}, + country: country, + adapter: adapter, } + matcher, err := geoSite.GetDomainMatcher() // test load + if err != nil { + return nil, err + } + + log.Infoln("Finished initial GeoSite rule %s => %s, records: %d", country, adapter, matcher.Count()) + return geoSite, nil }