feat: sniff add skip-src-address and skip-dst-address

This commit is contained in:
wwqgtxx 2024-08-27 20:33:43 +08:00
parent 3e2c9ce821
commit 8483178524
6 changed files with 200 additions and 119 deletions

View File

@ -57,6 +57,10 @@ func (set *IpCidrSet) Merge() error {
return nil
}
func (set *IpCidrSet) IsEmpty() bool {
return set == nil || len(set.rr) == 0
}
func (set *IpCidrSet) Foreach(f func(prefix netip.Prefix) bool) {
for _, r := range set.rr {
for _, prefix := range r.Prefixes() {

View File

@ -2,7 +2,6 @@ package sniffer
import (
"errors"
"fmt"
"net"
"net/netip"
"time"
@ -20,19 +19,29 @@ var (
ErrNoClue = errors.New("not enough information for making a decision")
)
var Dispatcher *SnifferDispatcher
type SnifferDispatcher struct {
type Dispatcher struct {
enable bool
sniffers map[sniffer.Sniffer]SnifferConfig
forceDomain []C.Rule
skipSrcAddress []C.Rule
skipDstAddress []C.Rule
skipDomain []C.Rule
skipList *lru.LruCache[string, uint8]
skipList *lru.LruCache[netip.AddrPort, uint8]
forceDnsMapping bool
parsePureIp bool
}
func (sd *SnifferDispatcher) shouldOverride(metadata *C.Metadata) bool {
func (sd *Dispatcher) shouldOverride(metadata *C.Metadata) bool {
for _, rule := range sd.skipDstAddress {
if ok, _ := rule.Match(&C.Metadata{DstIP: metadata.DstIP}); ok {
return false
}
}
for _, rule := range sd.skipSrcAddress {
if ok, _ := rule.Match(&C.Metadata{DstIP: metadata.SrcIP}); ok {
return false
}
}
if metadata.Host == "" && sd.parsePureIp {
return true
}
@ -47,10 +56,9 @@ func (sd *SnifferDispatcher) shouldOverride(metadata *C.Metadata) bool {
return false
}
func (sd *SnifferDispatcher) UDPSniff(packet C.PacketAdapter) bool {
func (sd *Dispatcher) UDPSniff(packet C.PacketAdapter) bool {
metadata := packet.Metadata()
if sd.shouldOverride(packet.Metadata()) {
if sd.shouldOverride(metadata) {
for sniffer, config := range sd.sniffers {
if sniffer.SupportNetwork() == C.UDP || sniffer.SupportNetwork() == C.ALLNet {
inWhitelist := sniffer.SupportPort(metadata.DstPort)
@ -73,7 +81,7 @@ func (sd *SnifferDispatcher) UDPSniff(packet C.PacketAdapter) bool {
}
// TCPSniff returns true if the connection is sniffed to have a domain
func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata) bool {
func (sd *Dispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata) bool {
if sd.shouldOverride(metadata) {
inWhitelist := false
overrideDest := false
@ -91,17 +99,19 @@ func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata
return false
}
dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort)
dst := metadata.AddrPort()
if count, ok := sd.skipList.Get(dst); ok && count > 5 {
log.Debugln("[Sniffer] Skip sniffing[%s] due to multiple failures", dst)
return false
}
if host, err := sd.sniffDomain(conn, metadata); err != nil {
host, err := sd.sniffDomain(conn, metadata)
if err != nil {
sd.cacheSniffFailed(metadata)
log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%d] to [%s:%d]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort)
return false
} else {
}
for _, rule := range sd.skipDomain {
if ok, _ := rule.Match(&C.Metadata{Host: host}); ok {
log.Debugln("[Sniffer] Skip sni[%s]", host)
@ -114,11 +124,10 @@ func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata
sd.replaceDomain(metadata, host, overrideDest)
return true
}
}
return false
}
func (sd *SnifferDispatcher) replaceDomain(metadata *C.Metadata, host string, overrideDest bool) {
func (sd *Dispatcher) replaceDomain(metadata *C.Metadata, host string, overrideDest bool) {
metadata.SniffHost = host
if overrideDest {
log.Debugln("[Sniffer] Sniff %s [%s]-->[%s] success, replace domain [%s]-->[%s]",
@ -131,11 +140,11 @@ func (sd *SnifferDispatcher) replaceDomain(metadata *C.Metadata, host string, ov
metadata.DNSMode = C.DNSNormal
}
func (sd *SnifferDispatcher) Enable() bool {
return sd.enable
func (sd *Dispatcher) Enable() bool {
return sd != nil && sd.enable
}
func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metadata) (string, error) {
func (sd *Dispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metadata) (string, error) {
for s := range sd.sniffers {
if s.SupportNetwork() == C.TCP {
_ = conn.SetReadDeadline(time.Now().Add(1 * time.Second))
@ -178,8 +187,8 @@ func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metad
return "", ErrorSniffFailed
}
func (sd *SnifferDispatcher) cacheSniffFailed(metadata *C.Metadata) {
dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort)
func (sd *Dispatcher) cacheSniffFailed(metadata *C.Metadata) {
dst := metadata.AddrPort()
sd.skipList.Compute(dst, func(oldValue uint8, loaded bool) (newValue uint8, delete bool) {
if oldValue <= 5 {
oldValue++
@ -188,32 +197,35 @@ func (sd *SnifferDispatcher) cacheSniffFailed(metadata *C.Metadata) {
})
}
func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) {
dispatcher := SnifferDispatcher{
enable: false,
}
return &dispatcher, nil
type Config struct {
Enable bool
Sniffers map[sniffer.Type]SnifferConfig
ForceDomain []C.Rule
SkipSrcAddress []C.Rule
SkipDstAddress []C.Rule
SkipDomain []C.Rule
ForceDnsMapping bool
ParsePureIp bool
}
func NewSnifferDispatcher(snifferConfig map[sniffer.Type]SnifferConfig,
forceDomain []C.Rule, skipDomain []C.Rule,
forceDnsMapping bool, parsePureIp bool) (*SnifferDispatcher, error) {
dispatcher := SnifferDispatcher{
enable: true,
forceDomain: forceDomain,
skipDomain: skipDomain,
skipList: lru.New(lru.WithSize[string, uint8](128), lru.WithAge[string, uint8](600)),
forceDnsMapping: forceDnsMapping,
parsePureIp: parsePureIp,
sniffers: make(map[sniffer.Sniffer]SnifferConfig, 0),
func NewDispatcher(snifferConfig *Config) (*Dispatcher, error) {
dispatcher := Dispatcher{
enable: snifferConfig.Enable,
forceDomain: snifferConfig.ForceDomain,
skipSrcAddress: snifferConfig.SkipSrcAddress,
skipDstAddress: snifferConfig.SkipDstAddress,
skipDomain: snifferConfig.SkipDomain,
skipList: lru.New(lru.WithSize[netip.AddrPort, uint8](128), lru.WithAge[netip.AddrPort, uint8](600)),
forceDnsMapping: snifferConfig.ForceDnsMapping,
parsePureIp: snifferConfig.ParsePureIp,
sniffers: make(map[sniffer.Sniffer]SnifferConfig, len(snifferConfig.Sniffers)),
}
for snifferName, config := range snifferConfig {
for snifferName, config := range snifferConfig.Sniffers {
s, err := NewSniffer(snifferName, config)
if err != nil {
log.Errorln("Sniffer name[%s] is error", snifferName)
return &SnifferDispatcher{enable: false}, err
return &Dispatcher{enable: false}, err
}
dispatcher.sniffers[s] = config
}

View File

@ -25,7 +25,7 @@ import (
"github.com/metacubex/mihomo/component/geodata"
P "github.com/metacubex/mihomo/component/process"
"github.com/metacubex/mihomo/component/resolver"
SNIFF "github.com/metacubex/mihomo/component/sniffer"
"github.com/metacubex/mihomo/component/sniffer"
tlsC "github.com/metacubex/mihomo/component/tls"
"github.com/metacubex/mihomo/component/trie"
"github.com/metacubex/mihomo/component/updater"
@ -161,16 +161,6 @@ type Profile struct {
StoreFakeIP bool
}
// Sniffer config
type Sniffer struct {
Enable bool
Sniffers map[snifferTypes.Type]SNIFF.SnifferConfig
ForceDomain []C.Rule
SkipDomain []C.Rule
ForceDnsMapping bool
ParsePureIp bool
}
// TLS config
type TLS struct {
Certificate string
@ -196,7 +186,7 @@ type Config struct {
Providers map[string]providerTypes.ProxyProvider
RuleProviders map[string]providerTypes.RuleProvider
Tunnels []LC.Tunnel
Sniffer *Sniffer
Sniffer *sniffer.Config
TLS *TLS
}
@ -331,10 +321,13 @@ type RawSniffer struct {
OverrideDest bool `yaml:"override-destination" json:"override-destination"`
Sniffing []string `yaml:"sniffing" json:"sniffing"`
ForceDomain []string `yaml:"force-domain" json:"force-domain"`
SkipSrcAddress []string `yaml:"skip-src-address" json:"skip-src-address"`
SkipDstAddress []string `yaml:"skip-dst-address" json:"skip-dst-address"`
SkipDomain []string `yaml:"skip-domain" json:"skip-domain"`
Ports []string `yaml:"port-whitelist" json:"port-whitelist"`
ForceDnsMapping bool `yaml:"force-dns-mapping" json:"force-dns-mapping"`
ParsePureIp bool `yaml:"parse-pure-ip" json:"parse-pure-ip"`
Sniff map[string]RawSniffingConfig `yaml:"sniff" json:"sniff"`
}
@ -1477,7 +1470,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rul
var rule C.Rule
if len(cfg.Fallback) != 0 {
if cfg.FallbackFilter.GeoIP {
rule, err = RC.NewGEOIP(cfg.FallbackFilter.GeoIPCode, "", false, true)
rule, err = RC.NewGEOIP(cfg.FallbackFilter.GeoIPCode, "dns.fallback-filter.geoip", false, true)
if err != nil {
return nil, fmt.Errorf("load GeoIP dns fallback filter error, %w", err)
}
@ -1507,7 +1500,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rul
}
}
rule = RP.NewDomainSet(domainTrie.NewDomainSet(), "dns.fallback-filter.domain")
dnsCfg.FallbackIPFilter = append(dnsCfg.FallbackIPFilter, rule)
dnsCfg.FallbackDomainFilter = append(dnsCfg.FallbackDomainFilter, rule)
}
if len(cfg.FallbackFilter.GeoSite) > 0 {
log.Warnln("replace fallback-filter.geosite with nameserver-policy, it will be removed in the future")
@ -1516,7 +1509,7 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie[resolver.HostValue], rul
if err != nil {
return nil, fmt.Errorf("DNS FallbackGeosite[%d] format error: %w", idx, err)
}
dnsCfg.FallbackIPFilter = append(dnsCfg.FallbackIPFilter, rule)
dnsCfg.FallbackDomainFilter = append(dnsCfg.FallbackDomainFilter, rule)
}
}
}
@ -1618,13 +1611,13 @@ func parseTuicServer(rawTuic RawTuicServer, general *General) error {
return nil
}
func parseSniffer(snifferRaw RawSniffer, ruleProviders map[string]providerTypes.RuleProvider) (*Sniffer, error) {
sniffer := &Sniffer{
func parseSniffer(snifferRaw RawSniffer, ruleProviders map[string]providerTypes.RuleProvider) (*sniffer.Config, error) {
snifferConfig := &sniffer.Config{
Enable: snifferRaw.Enable,
ForceDnsMapping: snifferRaw.ForceDnsMapping,
ParsePureIp: snifferRaw.ParsePureIp,
}
loadSniffer := make(map[snifferTypes.Type]SNIFF.SnifferConfig)
loadSniffer := make(map[snifferTypes.Type]sniffer.SnifferConfig)
if len(snifferRaw.Sniff) != 0 {
for sniffType, sniffConfig := range snifferRaw.Sniff {
@ -1640,7 +1633,7 @@ func parseSniffer(snifferRaw RawSniffer, ruleProviders map[string]providerTypes.
for _, snifferType := range snifferTypes.List {
if snifferType.String() == strings.ToUpper(sniffType) {
find = true
loadSniffer[snifferType] = SNIFF.SnifferConfig{
loadSniffer[snifferType] = sniffer.SnifferConfig{
Ports: ports,
OverrideDest: overrideDest,
}
@ -1652,7 +1645,7 @@ func parseSniffer(snifferRaw RawSniffer, ruleProviders map[string]providerTypes.
}
}
} else {
if sniffer.Enable && len(snifferRaw.Sniffing) != 0 {
if snifferConfig.Enable && len(snifferRaw.Sniffing) != 0 {
// Deprecated: Use Sniff instead
log.Warnln("Deprecated: Use Sniff instead")
}
@ -1666,7 +1659,7 @@ func parseSniffer(snifferRaw RawSniffer, ruleProviders map[string]providerTypes.
for _, snifferType := range snifferTypes.List {
if snifferType.String() == strings.ToUpper(snifferName) {
find = true
loadSniffer[snifferType] = SNIFF.SnifferConfig{
loadSniffer[snifferType] = sniffer.SnifferConfig{
Ports: globalPorts,
OverrideDest: snifferRaw.OverrideDest,
}
@ -1679,21 +1672,80 @@ func parseSniffer(snifferRaw RawSniffer, ruleProviders map[string]providerTypes.
}
}
sniffer.Sniffers = loadSniffer
snifferConfig.Sniffers = loadSniffer
forceDomain, err := parseDomain(snifferRaw.ForceDomain, nil, "sniffer.force-domain", ruleProviders)
if err != nil {
return nil, fmt.Errorf("error in force-domain, error:%w", err)
}
sniffer.ForceDomain = forceDomain
snifferConfig.ForceDomain = forceDomain
skipSrcAddress, err := parseIPCIDR(snifferRaw.SkipSrcAddress, nil, "sniffer.skip-src-address", ruleProviders)
if err != nil {
return nil, fmt.Errorf("error in skip-src-address, error:%w", err)
}
snifferConfig.SkipSrcAddress = skipSrcAddress
skipDstAddress, err := parseIPCIDR(snifferRaw.SkipDstAddress, nil, "sniffer.skip-src-address", ruleProviders)
if err != nil {
return nil, fmt.Errorf("error in skip-dst-address, error:%w", err)
}
snifferConfig.SkipDstAddress = skipDstAddress
skipDomain, err := parseDomain(snifferRaw.SkipDomain, nil, "sniffer.skip-domain", ruleProviders)
if err != nil {
return nil, fmt.Errorf("error in skip-domain, error:%w", err)
}
sniffer.SkipDomain = skipDomain
snifferConfig.SkipDomain = skipDomain
return sniffer, nil
return snifferConfig, nil
}
func parseIPCIDR(addresses []string, cidrSet *cidr.IpCidrSet, adapterName string, ruleProviders map[string]providerTypes.RuleProvider) (ipRules []C.Rule, err error) {
var rule C.Rule
for _, ipcidr := range addresses {
ipcidrLower := strings.ToLower(ipcidr)
if strings.Contains(ipcidrLower, "geoip:") {
subkeys := strings.Split(ipcidr, ":")
subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",")
for _, country := range subkeys {
rule, err = RC.NewGEOIP(country, adapterName, false, false)
if err != nil {
return nil, err
}
ipRules = append(ipRules, rule)
}
} else if strings.Contains(ipcidrLower, "rule-set:") {
subkeys := strings.Split(ipcidr, ":")
subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",")
for _, domainSetName := range subkeys {
rule, err = parseIPRuleSet(domainSetName, adapterName, ruleProviders)
if err != nil {
return nil, err
}
ipRules = append(ipRules, rule)
}
} else {
if cidrSet == nil {
cidrSet = cidr.NewIpCidrSet()
}
err = cidrSet.AddIpCidrForString(ipcidr)
if err != nil {
return nil, err
}
}
}
if !cidrSet.IsEmpty() {
err = cidrSet.Merge()
if err != nil {
return nil, err
}
rule = RP.NewIpCidrSet(cidrSet, adapterName)
ipRules = append(ipRules, rule)
}
return
}
func parseDomain(domains []string, domainTrie *trie.DomainTrie[struct{}], adapterName string, ruleProviders map[string]providerTypes.RuleProvider) (domainRules []C.Rule, err error) {
@ -1739,6 +1791,21 @@ func parseDomain(domains []string, domainTrie *trie.DomainTrie[struct{}], adapte
return
}
func parseIPRuleSet(domainSetName string, adapterName string, ruleProviders map[string]providerTypes.RuleProvider) (C.Rule, error) {
if rp, ok := ruleProviders[domainSetName]; !ok {
return nil, fmt.Errorf("not found rule-set: %s", domainSetName)
} else {
switch rp.Behavior() {
case providerTypes.Domain:
return nil, fmt.Errorf("rule provider type error, except ipcidr,actual %s", rp.Behavior())
case providerTypes.Classical:
log.Warnln("%s provider is %s, only matching it contain ip rule", rp.Name(), rp.Behavior())
default:
}
}
return RP.NewRuleSet(domainSetName, adapterName, true)
}
func parseDomainRuleSet(domainSetName string, adapterName string, ruleProviders map[string]providerTypes.RuleProvider) (C.Rule, error) {
if rp, ok := ruleProviders[domainSetName]; !ok {
return nil, fmt.Errorf("not found rule-set: %s", domainSetName)

View File

@ -190,6 +190,10 @@ sniffer:
override-destination: true
force-domain:
- +.v2ex.com
# skip-src-address: # 对于来源ip跳过嗅探
# - 192.168.0.3/32
# skip-dst-address: # 对于目标ip跳过嗅探
# - 192.168.0.3/32
## 对嗅探结果进行跳过
# skip-domain:
# - Mijia Cloud

View File

@ -21,7 +21,7 @@ import (
"github.com/metacubex/mihomo/component/profile"
"github.com/metacubex/mihomo/component/profile/cachefile"
"github.com/metacubex/mihomo/component/resolver"
SNI "github.com/metacubex/mihomo/component/sniffer"
"github.com/metacubex/mihomo/component/sniffer"
tlsC "github.com/metacubex/mihomo/component/tls"
"github.com/metacubex/mihomo/component/trie"
"github.com/metacubex/mihomo/component/updater"
@ -361,25 +361,17 @@ func hcCompatibleProvider(proxyProviders map[string]provider.ProxyProvider) {
}
func updateSniffer(sniffer *config.Sniffer) {
if sniffer.Enable {
dispatcher, err := SNI.NewSnifferDispatcher(
sniffer.Sniffers, sniffer.ForceDomain, sniffer.SkipDomain,
sniffer.ForceDnsMapping, sniffer.ParsePureIp,
)
func updateSniffer(snifferConfig *sniffer.Config) {
dispatcher, err := sniffer.NewDispatcher(snifferConfig)
if err != nil {
log.Warnln("initial sniffer failed, err:%v", err)
}
tunnel.UpdateSniffer(dispatcher)
if snifferConfig.Enable {
log.Infoln("Sniffer is loaded and working")
} else {
dispatcher, err := SNI.NewCloseSnifferDispatcher()
if err != nil {
log.Warnln("initial sniffer failed, err:%v", err)
}
tunnel.UpdateSniffer(dispatcher)
log.Infoln("Sniffer is closed")
}
}

View File

@ -39,7 +39,6 @@ var (
proxies = make(map[string]C.Proxy)
providers map[string]provider.ProxyProvider
ruleProviders map[string]provider.RuleProvider
sniffingEnable = false
configMux sync.RWMutex
// Outbound Rule
@ -52,6 +51,9 @@ var (
fakeIPRange netip.Prefix
snifferDispatcher *sniffer.Dispatcher
sniffingEnable = false
ruleUpdateCallback = utils.NewCallback[provider.RuleProvider]()
)
@ -115,7 +117,7 @@ func FakeIPRange() netip.Prefix {
}
func SetSniffing(b bool) {
if sniffer.Dispatcher.Enable() {
if snifferDispatcher.Enable() {
configMux.Lock()
sniffingEnable = b
configMux.Unlock()
@ -208,9 +210,9 @@ func UpdateListeners(newListeners map[string]C.InboundListener) {
listeners = newListeners
}
func UpdateSniffer(dispatcher *sniffer.SnifferDispatcher) {
func UpdateSniffer(dispatcher *sniffer.Dispatcher) {
configMux.Lock()
sniffer.Dispatcher = dispatcher
snifferDispatcher = dispatcher
sniffingEnable = dispatcher.Enable()
configMux.Unlock()
}
@ -347,8 +349,8 @@ func handleUDPConn(packet C.PacketAdapter) {
return
}
if sniffer.Dispatcher.Enable() && sniffingEnable {
sniffer.Dispatcher.UDPSniff(packet)
if sniffingEnable && snifferDispatcher.Enable() {
snifferDispatcher.UDPSniff(packet)
}
// local resolve UDP dns
@ -456,10 +458,10 @@ func handleTCPConn(connCtx C.ConnContext) {
conn := connCtx.Conn()
conn.ResetPeeked() // reset before sniffer
if sniffer.Dispatcher.Enable() && sniffingEnable {
if sniffingEnable && snifferDispatcher.Enable() {
// Try to sniff a domain when `preHandleMetadata` failed, this is usually
// caused by a "Fake DNS record missing" error when enhanced-mode is fake-ip.
if sniffer.Dispatcher.TCPSniff(conn, metadata) {
if snifferDispatcher.TCPSniff(conn, metadata) {
// we now have a domain name
preHandleFailed = false
}