chore: reduce the performance overhead of not enabling LoopBackDetector

This commit is contained in:
wwqgtxx 2024-11-05 09:29:01 +08:00
parent 69454b030e
commit d4478dbfa2
2 changed files with 30 additions and 14 deletions

View File

@ -3,18 +3,12 @@ package outbound
import ( import (
"context" "context"
"errors" "errors"
"os"
"strconv"
"github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/loopback" "github.com/metacubex/mihomo/component/loopback"
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/constant/features"
) )
var DisableLoopBackDetector, _ = strconv.ParseBool(os.Getenv("DISABLE_LOOPBACK_DETECTOR"))
type Direct struct { type Direct struct {
*Base *Base
loopBack *loopback.Detector loopBack *loopback.Detector
@ -27,10 +21,8 @@ type DirectOption struct {
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
if !features.CMFA && !DisableLoopBackDetector { if err := d.loopBack.CheckConn(metadata); err != nil {
if err := d.loopBack.CheckConn(metadata); err != nil { return nil, err
return nil, err
}
} }
opts = append(opts, dialer.WithResolver(resolver.DirectHostResolver)) opts = append(opts, dialer.WithResolver(resolver.DirectHostResolver))
c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), d.Base.DialOptions(opts...)...) c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), d.Base.DialOptions(opts...)...)
@ -42,10 +34,8 @@ func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
if !features.CMFA && !DisableLoopBackDetector { if err := d.loopBack.CheckPacketConn(metadata); err != nil {
if err := d.loopBack.CheckPacketConn(metadata); err != nil { return nil, err
return nil, err
}
} }
// net.UDPConn.WriteTo only working with *net.UDPAddr, so we need a net.UDPAddr // net.UDPConn.WriteTo only working with *net.UDPAddr, so we need a net.UDPAddr
if !metadata.Resolved() { if !metadata.Resolved() {

View File

@ -4,14 +4,25 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"os"
"strconv"
"github.com/metacubex/mihomo/common/callback" "github.com/metacubex/mihomo/common/callback"
"github.com/metacubex/mihomo/component/iface" "github.com/metacubex/mihomo/component/iface"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/constant/features"
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
) )
var disableLoopBackDetector, _ = strconv.ParseBool(os.Getenv("DISABLE_LOOPBACK_DETECTOR"))
func init() {
if features.CMFA {
disableLoopBackDetector = true
}
}
var ErrReject = errors.New("reject loopback connection") var ErrReject = errors.New("reject loopback connection")
type Detector struct { type Detector struct {
@ -20,6 +31,9 @@ type Detector struct {
} }
func NewDetector() *Detector { func NewDetector() *Detector {
if disableLoopBackDetector {
return nil
}
return &Detector{ return &Detector{
connMap: xsync.NewMapOf[netip.AddrPort, struct{}](), connMap: xsync.NewMapOf[netip.AddrPort, struct{}](),
packetConnMap: xsync.NewMapOf[uint16, struct{}](), packetConnMap: xsync.NewMapOf[uint16, struct{}](),
@ -27,6 +41,9 @@ func NewDetector() *Detector {
} }
func (l *Detector) NewConn(conn C.Conn) C.Conn { func (l *Detector) NewConn(conn C.Conn) C.Conn {
if l == nil || l.connMap == nil {
return conn
}
metadata := C.Metadata{} metadata := C.Metadata{}
if metadata.SetRemoteAddr(conn.LocalAddr()) != nil { if metadata.SetRemoteAddr(conn.LocalAddr()) != nil {
return conn return conn
@ -42,6 +59,9 @@ func (l *Detector) NewConn(conn C.Conn) C.Conn {
} }
func (l *Detector) NewPacketConn(conn C.PacketConn) C.PacketConn { func (l *Detector) NewPacketConn(conn C.PacketConn) C.PacketConn {
if l == nil || l.packetConnMap == nil {
return conn
}
metadata := C.Metadata{} metadata := C.Metadata{}
if metadata.SetRemoteAddr(conn.LocalAddr()) != nil { if metadata.SetRemoteAddr(conn.LocalAddr()) != nil {
return conn return conn
@ -58,6 +78,9 @@ func (l *Detector) NewPacketConn(conn C.PacketConn) C.PacketConn {
} }
func (l *Detector) CheckConn(metadata *C.Metadata) error { func (l *Detector) CheckConn(metadata *C.Metadata) error {
if l == nil || l.connMap == nil {
return nil
}
connAddr := metadata.SourceAddrPort() connAddr := metadata.SourceAddrPort()
if !connAddr.IsValid() { if !connAddr.IsValid() {
return nil return nil
@ -69,6 +92,9 @@ func (l *Detector) CheckConn(metadata *C.Metadata) error {
} }
func (l *Detector) CheckPacketConn(metadata *C.Metadata) error { func (l *Detector) CheckPacketConn(metadata *C.Metadata) error {
if l == nil || l.packetConnMap == nil {
return nil
}
connAddr := metadata.SourceAddrPort() connAddr := metadata.SourceAddrPort()
if !connAddr.IsValid() { if !connAddr.IsValid() {
return nil return nil