diff --git a/adapter/outbound/direct.go b/adapter/outbound/direct.go index b9b9fefce..7def7b208 100644 --- a/adapter/outbound/direct.go +++ b/adapter/outbound/direct.go @@ -3,6 +3,7 @@ package outbound import ( "context" "errors" + "fmt" "net/netip" N "github.com/metacubex/mihomo/common/net" @@ -13,6 +14,7 @@ import ( type Direct struct { *Base + loopBack *loopBackDetector } type DirectOption struct { @@ -22,17 +24,23 @@ type DirectOption struct { // DialContext implements C.ProxyAdapter func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { + if d.loopBack.CheckConn(metadata.SourceAddrPort()) { + return nil, fmt.Errorf("reject loopback connection to: %s", metadata.RemoteAddress()) + } opts = append(opts, dialer.WithResolver(resolver.DefaultResolver)) c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), d.Base.DialOptions(opts...)...) if err != nil { return nil, err } N.TCPKeepAlive(c) - return NewConn(c, d), nil + return d.loopBack.NewConn(NewConn(c, d)), nil } // ListenPacketContext implements C.ProxyAdapter func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { + if d.loopBack.CheckPacketConn(metadata.SourceAddrPort()) { + return nil, fmt.Errorf("reject loopback connection to: %s", metadata.RemoteAddress()) + } // net.UDPConn.WriteTo only working with *net.UDPAddr, so we need a net.UDPAddr if !metadata.Resolved() { ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, resolver.DefaultResolver) @@ -45,7 +53,7 @@ func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, if err != nil { return nil, err } - return newPacketConn(pc, d), nil + return d.loopBack.NewPacketConn(newPacketConn(pc, d)), nil } func NewDirectWithOption(option DirectOption) *Direct { @@ -60,6 +68,7 @@ func NewDirectWithOption(option DirectOption) *Direct { rmark: option.RoutingMark, prefer: C.NewDNSPrefer(option.IPVersion), }, + loopBack: newLoopBackDetector(), } } @@ -71,6 +80,7 @@ func NewDirect() *Direct { udp: true, prefer: C.DualStack, }, + loopBack: newLoopBackDetector(), } } @@ -82,5 +92,6 @@ func NewCompatible() *Direct { udp: true, prefer: C.DualStack, }, + loopBack: newLoopBackDetector(), } } diff --git a/adapter/outbound/direct_loopback_detect.go b/adapter/outbound/direct_loopback_detect.go new file mode 100644 index 000000000..410d5a2fc --- /dev/null +++ b/adapter/outbound/direct_loopback_detect.go @@ -0,0 +1,68 @@ +package outbound + +import ( + "net/netip" + + "github.com/metacubex/mihomo/common/callback" + C "github.com/metacubex/mihomo/constant" + + "github.com/puzpuzpuz/xsync/v3" +) + +type loopBackDetector struct { + connMap *xsync.MapOf[netip.AddrPort, struct{}] + packetConnMap *xsync.MapOf[netip.AddrPort, struct{}] +} + +func newLoopBackDetector() *loopBackDetector { + return &loopBackDetector{ + connMap: xsync.NewMapOf[netip.AddrPort, struct{}](), + packetConnMap: xsync.NewMapOf[netip.AddrPort, struct{}](), + } +} + +func (l *loopBackDetector) NewConn(conn C.Conn) C.Conn { + metadata := C.Metadata{} + if metadata.SetRemoteAddr(conn.LocalAddr()) != nil { + return conn + } + connAddr := metadata.AddrPort() + if !connAddr.IsValid() { + return conn + } + l.connMap.Store(connAddr, struct{}{}) + return callback.NewCloseCallbackConn(conn, func() { + l.connMap.Delete(connAddr) + }) +} + +func (l *loopBackDetector) NewPacketConn(conn C.PacketConn) C.PacketConn { + metadata := C.Metadata{} + if metadata.SetRemoteAddr(conn.LocalAddr()) != nil { + return conn + } + connAddr := metadata.AddrPort() + if !connAddr.IsValid() { + return conn + } + l.packetConnMap.Store(connAddr, struct{}{}) + return callback.NewCloseCallbackPacketConn(conn, func() { + l.packetConnMap.Delete(connAddr) + }) +} + +func (l *loopBackDetector) CheckConn(connAddr netip.AddrPort) bool { + if !connAddr.IsValid() { + return false + } + _, ok := l.connMap.Load(connAddr) + return ok +} + +func (l *loopBackDetector) CheckPacketConn(connAddr netip.AddrPort) bool { + if !connAddr.IsValid() { + return false + } + _, ok := l.packetConnMap.Load(connAddr) + return ok +} diff --git a/common/callback/close_callback.go b/common/callback/close_callback.go new file mode 100644 index 000000000..630ee5d7f --- /dev/null +++ b/common/callback/close_callback.go @@ -0,0 +1,61 @@ +package callback + +import ( + "sync" + + C "github.com/metacubex/mihomo/constant" +) + +type closeCallbackConn struct { + C.Conn + closeFunc func() + closeOnce sync.Once +} + +func (w *closeCallbackConn) Close() error { + w.closeOnce.Do(w.closeFunc) + return w.Conn.Close() +} + +func (w *closeCallbackConn) ReaderReplaceable() bool { + return true +} + +func (w *closeCallbackConn) WriterReplaceable() bool { + return true +} + +func (w *closeCallbackConn) Upstream() any { + return w.Conn +} + +func NewCloseCallbackConn(conn C.Conn, callback func()) C.Conn { + return &closeCallbackConn{Conn: conn, closeFunc: callback} +} + +type closeCallbackPacketConn struct { + C.PacketConn + closeFunc func() + closeOnce sync.Once +} + +func (w *closeCallbackPacketConn) Close() error { + w.closeOnce.Do(w.closeFunc) + return w.PacketConn.Close() +} + +func (w *closeCallbackPacketConn) ReaderReplaceable() bool { + return true +} + +func (w *closeCallbackPacketConn) WriterReplaceable() bool { + return true +} + +func (w *closeCallbackPacketConn) Upstream() any { + return w.PacketConn +} + +func NewCloseCallbackPacketConn(conn C.PacketConn, callback func()) C.PacketConn { + return &closeCallbackPacketConn{PacketConn: conn, closeFunc: callback} +} diff --git a/constant/metadata.go b/constant/metadata.go index 09a2f152e..3c7129093 100644 --- a/constant/metadata.go +++ b/constant/metadata.go @@ -148,8 +148,8 @@ type Metadata struct { SpecialRules string `json:"specialRules"` RemoteDst string `json:"remoteDestination"` - RawSrcAddr net.Addr `json:"-"` - RawDstAddr net.Addr `json:"-"` + RawSrcAddr net.Addr `json:"-"` + RawDstAddr net.Addr `json:"-"` // Only domain rule SniffHost string `json:"sniffHost"` } @@ -162,6 +162,10 @@ func (m *Metadata) SourceAddress() string { return net.JoinHostPort(m.SrcIP.String(), strconv.FormatUint(uint64(m.SrcPort), 10)) } +func (m *Metadata) SourceAddrPort() netip.AddrPort { + return netip.AddrPortFrom(m.SrcIP.Unmap(), m.SrcPort) +} + func (m *Metadata) SourceDetail() string { if m.Type == INNER { return fmt.Sprintf("%s", MihomoName)