From bee1bddceb4261f5bf694eee21bb0d2eb227d3a7 Mon Sep 17 00:00:00 2001 From: Skyxim Date: Thu, 21 Apr 2022 07:06:08 -0700 Subject: [PATCH] feat: add sniffer port whitelist, when empty will add all ports --- common/utils/range.go | 44 +++++++++++++++++++++++++++++++++ component/sniffer/dispatcher.go | 22 +++++++++++++++-- config/config.go | 37 +++++++++++++++++++++++++-- go.mod | 3 ++- hub/executor/executor.go | 2 +- rule/common/port.go | 33 ++++++++----------------- 6 files changed, 112 insertions(+), 29 deletions(-) create mode 100644 common/utils/range.go diff --git a/common/utils/range.go b/common/utils/range.go new file mode 100644 index 000000000..c569d6a2c --- /dev/null +++ b/common/utils/range.go @@ -0,0 +1,44 @@ +package utils + +import ( + "golang.org/x/exp/constraints" +) + +type Range[T constraints.Ordered] struct { + start T + end T +} + +func NewRange[T constraints.Ordered](start, end T) *Range[T] { + if start > end { + return &Range[T]{ + start: end, + end: start, + } + } + + return &Range[T]{ + start: start, + end: end, + } +} + +func (r *Range[T]) Contains(t T) bool { + return t >= r.start && t <= r.end +} + +func (r *Range[T]) LeftContains(t T) bool { + return t >= r.start && t < r.end +} + +func (r *Range[T]) RightContains(t T) bool { + return t > r.start && t <= r.end +} + +func (r *Range[T]) Start() T { + return r.start +} + +func (r *Range[T]) End() T { + return r.end +} diff --git a/component/sniffer/dispatcher.go b/component/sniffer/dispatcher.go index 545a141c8..e3658d0eb 100644 --- a/component/sniffer/dispatcher.go +++ b/component/sniffer/dispatcher.go @@ -2,11 +2,14 @@ package sniffer import ( "errors" - "github.com/Dreamacro/clash/component/trie" "net" "net/netip" + "strconv" + + "github.com/Dreamacro/clash/component/trie" CN "github.com/Dreamacro/clash/common/net" + "github.com/Dreamacro/clash/common/utils" "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/log" @@ -26,6 +29,7 @@ type SnifferDispatcher struct { foreDomain *trie.DomainTrie[bool] skipSNI *trie.DomainTrie[bool] + portRanges *[]utils.Range[uint16] } func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) { @@ -35,6 +39,18 @@ func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) { } if metadata.Host == "" || sd.foreDomain.Search(metadata.Host) != nil { + port, err := strconv.ParseUint(metadata.DstPort, 10, 16) + if err != nil { + log.Debugln("[Sniffer] Dst port is error") + return + } + + for _, portRange := range *sd.portRanges { + if !portRange.Contains(uint16(port)) { + return + } + } + if host, err := sd.sniffDomain(bufConn, metadata); err != nil { log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%s] to [%s:%s]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort) return @@ -102,11 +118,13 @@ func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) { return &dispatcher, nil } -func NewSnifferDispatcher(needSniffer []C.SnifferType, forceDomain *trie.DomainTrie[bool], skipSNI *trie.DomainTrie[bool]) (*SnifferDispatcher, error) { +func NewSnifferDispatcher(needSniffer []C.SnifferType, forceDomain *trie.DomainTrie[bool], + skipSNI *trie.DomainTrie[bool], ports *[]utils.Range[uint16]) (*SnifferDispatcher, error) { dispatcher := SnifferDispatcher{ enable: true, foreDomain: forceDomain, skipSNI: skipSNI, + portRanges: ports, } for _, snifferName := range needSniffer { diff --git a/config/config.go b/config/config.go index a30c8c254..4bd00c91a 100644 --- a/config/config.go +++ b/config/config.go @@ -4,16 +4,19 @@ import ( "container/list" "errors" "fmt" - R "github.com/Dreamacro/clash/rule" - RP "github.com/Dreamacro/clash/rule/provider" "net" "net/netip" "net/url" "os" "runtime" + "strconv" "strings" "time" + "github.com/Dreamacro/clash/common/utils" + R "github.com/Dreamacro/clash/rule" + RP "github.com/Dreamacro/clash/rule/provider" + "github.com/Dreamacro/clash/adapter" "github.com/Dreamacro/clash/adapter/outbound" "github.com/Dreamacro/clash/adapter/outboundgroup" @@ -127,6 +130,7 @@ type Sniffer struct { Reverses *trie.DomainTrie[bool] ForceDomain *trie.DomainTrie[bool] SkipSNI *trie.DomainTrie[bool] + Ports *[]utils.Range[uint16] } // Experimental config @@ -224,6 +228,7 @@ type SnifferRaw struct { Reverse []string `yaml:"reverses" json:"reverses"` ForceDomain []string `yaml:"force-domain" json:"force-domain"` SkipSNI []string `yaml:"skip-sni" json:"skip-sni"` + Ports []string `yaml:"port-whitelist" json:"port-whitelist"` } // Parse config @@ -298,6 +303,7 @@ func UnmarshalRawConfig(buf []byte) (*RawConfig, error) { Reverse: []string{}, ForceDomain: []string{}, SkipSNI: []string{}, + Ports: []string{}, }, Profile: Profile{ StoreSelected: true, @@ -914,6 +920,33 @@ func parseSniffer(snifferRaw SnifferRaw) (*Sniffer, error) { Force: snifferRaw.Force, } + ports := []utils.Range[uint16]{} + if len(snifferRaw.Ports) == 0 { + ports = append(ports, *utils.NewRange[uint16](0, 65535)) + } else { + for _, portRange := range snifferRaw.Ports { + portRaws := strings.Split(portRange, "-") + if len(portRaws) > 1 { + p, err := strconv.ParseUint(portRaws[0], 10, 16) + if err != nil { + return nil, fmt.Errorf("%s format error", portRange) + } + + start := uint16(p) + + p, err = strconv.ParseUint(portRaws[0], 10, 16) + if err != nil { + return nil, fmt.Errorf("%s format error", portRange) + } + + end := uint16(p) + ports = append(ports, *utils.NewRange(start, end)) + } + } + } + + sniffer.Ports = &ports + loadSniffer := make(map[C.SnifferType]struct{}) for _, snifferName := range snifferRaw.Sniffing { diff --git a/go.mod b/go.mod index 769dbe661..360bf21b7 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,8 @@ require ( go.uber.org/atomic v1.9.0 go.uber.org/automaxprocs v1.5.1 golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 - golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2 + golang.org/x/exp v0.0.0-20220414153411-bcd21879b8fd + golang.org/x/net v0.0.0-20220412020605-290c469a71a5 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sys v0.0.0-20220412211240-33da011f77ad golang.org/x/time v0.0.0-20220411224347-583f2d630306 diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 09d72ccf3..e70aac3db 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -222,7 +222,7 @@ func updateTun(tun *config.Tun, dns *config.DNS) { func updateSniffer(sniffer *config.Sniffer) { if sniffer.Enable { - dispatcher, err := SNI.NewSnifferDispatcher(sniffer.Sniffers, sniffer.ForceDomain, sniffer.SkipSNI) + dispatcher, err := SNI.NewSnifferDispatcher(sniffer.Sniffers, sniffer.ForceDomain, sniffer.SkipSNI, sniffer.Ports) if err != nil { log.Warnln("initial sniffer failed, err:%v", err) } diff --git a/rule/common/port.go b/rule/common/port.go index 0e46649bc..06fde6c27 100644 --- a/rule/common/port.go +++ b/rule/common/port.go @@ -5,20 +5,16 @@ import ( "strconv" "strings" + "github.com/Dreamacro/clash/common/utils" C "github.com/Dreamacro/clash/constant" ) -type portReal struct { - portStart int - portEnd int -} - type Port struct { *Base adapter string port string isSource bool - portList []portReal + portList []utils.Range[uint16] } func (p *Port) RuleType() C.RuleType { @@ -45,17 +41,13 @@ func (p *Port) Payload() string { func (p *Port) matchPortReal(portRef string) bool { port, _ := strconv.Atoi(portRef) - var rs bool + for _, pr := range p.portList { - if pr.portEnd == -1 { - rs = port == pr.portStart - } else { - rs = port >= pr.portStart && port <= pr.portEnd - } - if rs { + if pr.Contains(uint16(port)) { return true } } + return false } @@ -65,7 +57,7 @@ func NewPort(port string, adapter string, isSource bool) (*Port, error) { return nil, fmt.Errorf("%s, too many ports to use, maximum support 28 ports", errPayload.Error()) } - var portList []portReal + var portRange []utils.Range[uint16] for _, p := range ports { if p == "" { continue @@ -84,23 +76,18 @@ func NewPort(port string, adapter string, isSource bool) (*Port, error) { switch subPortsLen { case 1: - portList = append(portList, portReal{int(portStart), -1}) + portRange = append(portRange, *utils.NewRange(uint16(portStart), uint16(portStart))) case 2: portEnd, err := strconv.ParseUint(strings.Trim(subPorts[1], "[ ]"), 10, 16) if err != nil { return nil, errPayload } - shouldReverse := portStart > portEnd - if shouldReverse { - portList = append(portList, portReal{int(portEnd), int(portStart)}) - } else { - portList = append(portList, portReal{int(portStart), int(portEnd)}) - } + portRange = append(portRange, *utils.NewRange(uint16(portStart), uint16(portEnd))) } } - if len(portList) == 0 { + if len(portRange) == 0 { return nil, errPayload } @@ -109,7 +96,7 @@ func NewPort(port string, adapter string, isSource bool) (*Port, error) { adapter: adapter, port: port, isSource: isSource, - portList: portList, + portList: portRange, }, nil }