From 4f8a5a5f54ef082dfe02d5db4179e82292ee61d4 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Sat, 27 Jul 2024 10:36:11 +0800 Subject: [PATCH] feat: add `mrs` format ipcidr ruleset --- component/cidr/ipcidr_set_bin.go | 77 ++++++++++++++++++++++++++++ component/trie/domain_set_bin.go | 40 +++++---------- constant/provider/interface.go | 13 +++++ docs/config.yaml | 2 +- rules/provider/classical_strategy.go | 5 ++ rules/provider/domain_strategy.go | 13 +++-- rules/provider/ipcidr_strategy.go | 28 ++++++++++ rules/provider/mrs_converter.go | 33 ++++++++++++ rules/provider/mrs_reader.go | 72 ++++++++++++++++++++++++++ rules/provider/provider.go | 16 ++---- 10 files changed, 255 insertions(+), 44 deletions(-) create mode 100644 component/cidr/ipcidr_set_bin.go create mode 100644 rules/provider/mrs_reader.go diff --git a/component/cidr/ipcidr_set_bin.go b/component/cidr/ipcidr_set_bin.go new file mode 100644 index 000000000..f6a034885 --- /dev/null +++ b/component/cidr/ipcidr_set_bin.go @@ -0,0 +1,77 @@ +package cidr + +import ( + "encoding/binary" + "errors" + "io" + "net/netip" + + "go4.org/netipx" +) + +func (ss *IpCidrSet) WriteBin(w io.Writer) (err error) { + // version + _, err = w.Write([]byte{1}) + if err != nil { + return err + } + + // rr + err = binary.Write(w, binary.BigEndian, int64(len(ss.rr))) + if err != nil { + return err + } + for _, r := range ss.rr { + err = binary.Write(w, binary.BigEndian, r.From().As16()) + if err != nil { + return err + } + err = binary.Write(w, binary.BigEndian, r.To().As16()) + if err != nil { + return err + } + } + + return nil +} + +func ReadIpCidrSet(r io.Reader) (ss *IpCidrSet, err error) { + // version + version := make([]byte, 1) + _, err = io.ReadFull(r, version) + if err != nil { + return nil, err + } + if version[0] != 1 { + return nil, errors.New("version is invalid") + } + + ss = NewIpCidrSet() + var length int64 + + // rr + err = binary.Read(r, binary.BigEndian, &length) + if err != nil { + return nil, err + } + if length < 1 { + return nil, errors.New("length is invalid") + } + ss.rr = make([]netipx.IPRange, length) + for i := int64(0); i < length; i++ { + var a16 [16]byte + err = binary.Read(r, binary.BigEndian, &a16) + if err != nil { + return nil, err + } + from := netip.AddrFrom16(a16).Unmap() + err = binary.Read(r, binary.BigEndian, &a16) + if err != nil { + return nil, err + } + to := netip.AddrFrom16(a16).Unmap() + ss.rr[i] = netipx.IPRangeFrom(from, to) + } + + return ss, nil +} diff --git a/component/trie/domain_set_bin.go b/component/trie/domain_set_bin.go index e32d4e1a3..27d15802e 100644 --- a/component/trie/domain_set_bin.go +++ b/component/trie/domain_set_bin.go @@ -6,19 +6,13 @@ import ( "io" ) -func (ss *DomainSet) WriteBin(w io.Writer, count int64) (err error) { +func (ss *DomainSet) WriteBin(w io.Writer) (err error) { // version _, err = w.Write([]byte{1}) if err != nil { return err } - // count - err = binary.Write(w, binary.BigEndian, count) - if err != nil { - return err - } - // leaves err = binary.Write(w, binary.BigEndian, int64(len(ss.leaves))) if err != nil { @@ -56,21 +50,15 @@ func (ss *DomainSet) WriteBin(w io.Writer, count int64) (err error) { return nil } -func ReadDomainSetBin(r io.Reader) (ds *DomainSet, count int64, err error) { +func ReadDomainSetBin(r io.Reader) (ds *DomainSet, err error) { // version version := make([]byte, 1) _, err = io.ReadFull(r, version) if err != nil { - return nil, 0, err + return nil, err } if version[0] != 1 { - return nil, 0, errors.New("version is invalid") - } - - // count - err = binary.Read(r, binary.BigEndian, &count) - if err != nil { - return nil, 0, err + return nil, errors.New("version is invalid") } ds = &DomainSet{} @@ -79,49 +67,49 @@ func ReadDomainSetBin(r io.Reader) (ds *DomainSet, count int64, err error) { // leaves err = binary.Read(r, binary.BigEndian, &length) if err != nil { - return nil, 0, err + return nil, err } if length < 1 { - return nil, 0, errors.New("length is invalid") + return nil, errors.New("length is invalid") } ds.leaves = make([]uint64, length) for i := int64(0); i < length; i++ { err = binary.Read(r, binary.BigEndian, &ds.leaves[i]) if err != nil { - return nil, 0, err + return nil, err } } // labelBitmap err = binary.Read(r, binary.BigEndian, &length) if err != nil { - return nil, 0, err + return nil, err } if length < 1 { - return nil, 0, errors.New("length is invalid") + return nil, errors.New("length is invalid") } ds.labelBitmap = make([]uint64, length) for i := int64(0); i < length; i++ { err = binary.Read(r, binary.BigEndian, &ds.labelBitmap[i]) if err != nil { - return nil, 0, err + return nil, err } } // labels err = binary.Read(r, binary.BigEndian, &length) if err != nil { - return nil, 0, err + return nil, err } if length < 1 { - return nil, 0, errors.New("length is invalid") + return nil, errors.New("length is invalid") } ds.labels = make([]byte, length) _, err = io.ReadFull(r, ds.labels) if err != nil { - return nil, 0, err + return nil, err } ds.init() - return ds, count, nil + return ds, nil } diff --git a/constant/provider/interface.go b/constant/provider/interface.go index c86e61633..bd6b6e947 100644 --- a/constant/provider/interface.go +++ b/constant/provider/interface.go @@ -112,6 +112,19 @@ func (rt RuleBehavior) String() string { } } +func (rt RuleBehavior) Byte() byte { + switch rt { + case Domain: + return 0 + case IPCIDR: + return 1 + case Classical: + return 2 + default: + return 255 + } +} + func ParseBehavior(s string) (behavior RuleBehavior, err error) { switch s { case "domain": diff --git a/docs/config.yaml b/docs/config.yaml index 2d3343cf1..6e29f1642 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -942,7 +942,7 @@ rule-providers: interval: 259200 path: /path/to/save/file.yaml type: file - rule3: # mrs类型ruleset,目前仅支持domain,可以通过“mihomo convert-ruleset domain yaml XXX.yaml XXX.mrs”转换得到 + rule3: # mrs类型ruleset,目前仅支持domain和ipcidr,可以通过“mihomo convert-ruleset domain yaml XXX.yaml XXX.mrs”转换得到 type: http url: "url" format: mrs diff --git a/rules/provider/classical_strategy.go b/rules/provider/classical_strategy.go index 8353ebce4..205a8e599 100644 --- a/rules/provider/classical_strategy.go +++ b/rules/provider/classical_strategy.go @@ -5,6 +5,7 @@ import ( "strings" C "github.com/metacubex/mihomo/constant" + P "github.com/metacubex/mihomo/constant/provider" "github.com/metacubex/mihomo/log" ) @@ -16,6 +17,10 @@ type classicalStrategy struct { parse func(tp, payload, target string, params []string) (parsed C.Rule, parseErr error) } +func (c *classicalStrategy) Behavior() P.RuleBehavior { + return P.Classical +} + func (c *classicalStrategy) Match(metadata *C.Metadata) bool { for _, rule := range c.rules { if m, _ := rule.Match(metadata); m { diff --git a/rules/provider/domain_strategy.go b/rules/provider/domain_strategy.go index 0104fdf90..462d37dcf 100644 --- a/rules/provider/domain_strategy.go +++ b/rules/provider/domain_strategy.go @@ -6,6 +6,7 @@ import ( "github.com/metacubex/mihomo/component/trie" C "github.com/metacubex/mihomo/constant" + P "github.com/metacubex/mihomo/constant/provider" "github.com/metacubex/mihomo/log" ) @@ -15,6 +16,10 @@ type domainStrategy struct { domainSet *trie.DomainSet } +func (d *domainStrategy) Behavior() P.RuleBehavior { + return P.Domain +} + func (d *domainStrategy) ShouldFindProcess() bool { return false } @@ -51,12 +56,12 @@ func (d *domainStrategy) FinishInsert() { d.domainTrie = nil } -func (d *domainStrategy) FromMrs(r io.Reader) error { - domainSet, count, err := trie.ReadDomainSetBin(r) +func (d *domainStrategy) FromMrs(r io.Reader, count int) error { + domainSet, err := trie.ReadDomainSetBin(r) if err != nil { return err } - d.count = int(count) + d.count = count d.domainSet = domainSet return nil } @@ -65,7 +70,7 @@ func (d *domainStrategy) WriteMrs(w io.Writer) error { if d.domainSet == nil { return errors.New("nil domainSet") } - return d.domainSet.WriteBin(w, int64(d.count)) + return d.domainSet.WriteBin(w) } var _ mrsRuleStrategy = (*domainStrategy)(nil) diff --git a/rules/provider/ipcidr_strategy.go b/rules/provider/ipcidr_strategy.go index d0545c7cc..87cf7a2d8 100644 --- a/rules/provider/ipcidr_strategy.go +++ b/rules/provider/ipcidr_strategy.go @@ -1,8 +1,12 @@ package provider import ( + "errors" + "io" + "github.com/metacubex/mihomo/component/cidr" C "github.com/metacubex/mihomo/constant" + P "github.com/metacubex/mihomo/constant/provider" "github.com/metacubex/mihomo/log" "go4.org/netipx" @@ -15,6 +19,10 @@ type ipcidrStrategy struct { //trie *trie.IpCidrTrie } +func (i *ipcidrStrategy) Behavior() P.RuleBehavior { + return P.IPCIDR +} + func (i *ipcidrStrategy) ShouldFindProcess() bool { return false } @@ -54,6 +62,26 @@ func (i *ipcidrStrategy) FinishInsert() { i.cidrSet.Merge() } +func (i *ipcidrStrategy) FromMrs(r io.Reader, count int) error { + cidrSet, err := cidr.ReadIpCidrSet(r) + if err != nil { + return err + } + i.count = count + i.cidrSet = cidrSet + if i.count > 0 { + i.shouldResolveIP = true + } + return nil +} + +func (i *ipcidrStrategy) WriteMrs(w io.Writer) error { + if i.cidrSet == nil { + return errors.New("nil cidrSet") + } + return i.cidrSet.WriteBin(w) +} + func (i *ipcidrStrategy) ToIpCidr() *netipx.IPSet { return i.cidrSet.ToIPSet() } diff --git a/rules/provider/mrs_converter.go b/rules/provider/mrs_converter.go index 3b93b4a4b..c8f63fdfe 100644 --- a/rules/provider/mrs_converter.go +++ b/rules/provider/mrs_converter.go @@ -1,6 +1,7 @@ package provider import ( + "encoding/binary" "io" "os" @@ -27,6 +28,38 @@ func ConvertToMrs(buf []byte, behavior P.RuleBehavior, format P.RuleFormat, w io err = zstdErr } }() + + // header + _, err = encoder.Write(MrsMagicBytes[:]) + if err != nil { + return err + } + + // behavior + _behavior := []byte{behavior.Byte()} + _, err = encoder.Write(_behavior[:]) + if err != nil { + return err + } + + // count + count := int64(_strategy.Count()) + err = binary.Write(encoder, binary.BigEndian, count) + if err != nil { + return err + } + + // extra (reserved for future using) + var extra []byte + err = binary.Write(encoder, binary.BigEndian, int64(len(extra))) + if err != nil { + return err + } + _, err = encoder.Write(extra) + if err != nil { + return err + } + return _strategy.WriteMrs(encoder) } else { return ErrInvalidFormat diff --git a/rules/provider/mrs_reader.go b/rules/provider/mrs_reader.go new file mode 100644 index 000000000..66f62127c --- /dev/null +++ b/rules/provider/mrs_reader.go @@ -0,0 +1,72 @@ +package provider + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/klauspost/compress/zstd" +) + +var MrsMagicBytes = [4]byte{'M', 'R', 'S', 1} // MRSv1 + +func rulesMrsParse(buf []byte, strategy ruleStrategy) (ruleStrategy, error) { + if _strategy, ok := strategy.(mrsRuleStrategy); ok { + reader, err := zstd.NewReader(bytes.NewReader(buf)) + if err != nil { + return nil, err + } + defer reader.Close() + + // header + var header [4]byte + _, err = io.ReadFull(reader, header[:]) + if err != nil { + return nil, err + } + if header != MrsMagicBytes { + return nil, fmt.Errorf("invalid MrsMagic bytes") + } + + // behavior + var _behavior [1]byte + _, err = io.ReadFull(reader, _behavior[:]) + if err != nil { + return nil, err + } + if _behavior[0] != strategy.Behavior().Byte() { + return nil, fmt.Errorf("invalid behavior") + } + + // count + var count int64 + err = binary.Read(reader, binary.BigEndian, &count) + if err != nil { + return nil, err + } + + // extra (reserved for future using) + var length int64 + err = binary.Read(reader, binary.BigEndian, &length) + if err != nil { + return nil, err + } + if length < 0 { + return nil, errors.New("length is invalid") + } + if length > 0 { + extra := make([]byte, length) + _, err = io.ReadFull(reader, extra) + if err != nil { + return nil, err + } + } + + err = _strategy.FromMrs(reader, int(count)) + return strategy, err + } else { + return nil, ErrInvalidFormat + } +} diff --git a/rules/provider/provider.go b/rules/provider/provider.go index a4d8883df..8c5d7f940 100644 --- a/rules/provider/provider.go +++ b/rules/provider/provider.go @@ -14,7 +14,6 @@ import ( C "github.com/metacubex/mihomo/constant" P "github.com/metacubex/mihomo/constant/provider" - "github.com/klauspost/compress/zstd" "gopkg.in/yaml.v3" ) @@ -45,6 +44,7 @@ type RulePayload struct { } type ruleStrategy interface { + Behavior() P.RuleBehavior Match(metadata *C.Metadata) bool Count() int ShouldResolveIP() bool @@ -56,7 +56,7 @@ type ruleStrategy interface { type mrsRuleStrategy interface { ruleStrategy - FromMrs(r io.Reader) error + FromMrs(r io.Reader, count int) error WriteMrs(w io.Writer) error } @@ -165,17 +165,7 @@ var ErrInvalidFormat = errors.New("invalid format") func rulesParse(buf []byte, strategy ruleStrategy, format P.RuleFormat) (ruleStrategy, error) { strategy.Reset() if format == P.MrsRule { - if _strategy, ok := strategy.(mrsRuleStrategy); ok { - reader, err := zstd.NewReader(bytes.NewReader(buf)) - if err != nil { - return nil, err - } - defer reader.Close() - err = _strategy.FromMrs(reader) - return strategy, err - } else { - return nil, ErrInvalidFormat - } + return rulesMrsParse(buf, strategy) } schema := &RulePayload{}