From 5bf22422d9653933a6dc481c6f1021f743b9772f Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Tue, 13 Aug 2024 13:33:24 +0800 Subject: [PATCH] fix: wireguard not working in CMFA --- component/dialer/dialer.go | 15 ++++++++------- component/dialer/patch_android.go | 11 +++++------ component/dialer/patch_common.go | 3 +-- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index 54a1aa6ac..ba95c31b8 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -127,10 +127,6 @@ func GetTcpConcurrent() bool { } func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) { - if features.CMFA && DefaultSocketHook != nil { - return dialContextHooked(ctx, network, destination, port) - } - var address string if IP4PEnable { destination, port = lookupIP4P(destination, port) @@ -149,6 +145,14 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po } dialer := netDialer.(*net.Dialer) + if opt.mpTcp { + setMultiPathTCP(dialer) + } + + if features.CMFA && DefaultSocketHook != nil { // ignore interfaceName, routingMark and tfo in CMFA + return dialContextHooked(ctx, dialer, network, address) + } + if opt.interfaceName != "" { bind := bindIfaceToDialer if opt.fallbackBind { @@ -161,9 +165,6 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po if opt.routingMark != 0 { bindMarkToDialer(opt.routingMark, dialer, network, destination) } - if opt.mpTcp { - setMultiPathTCP(dialer) - } if opt.tfo && !DisableTFO { return dialTFO(ctx, *dialer, network, address) } diff --git a/component/dialer/patch_android.go b/component/dialer/patch_android.go index 7c33a6c0c..079b9772a 100644 --- a/component/dialer/patch_android.go +++ b/component/dialer/patch_android.go @@ -5,7 +5,6 @@ package dialer import ( "context" "net" - "net/netip" "syscall" ) @@ -13,12 +12,12 @@ type SocketControl func(network, address string, conn syscall.RawConn) error var DefaultSocketHook SocketControl -func dialContextHooked(ctx context.Context, network string, destination netip.Addr, port string) (net.Conn, error) { - dialer := &net.Dialer{ - Control: DefaultSocketHook, - } +func dialContextHooked(ctx context.Context, dialer *net.Dialer, network string, address string) (net.Conn, error) { + addControlToDialer(dialer, func(ctx context.Context, network, address string, c syscall.RawConn) error { + return DefaultSocketHook(network, address, c) + }) - conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(destination.String(), port)) + conn, err := dialer.DialContext(ctx, network, address) if err != nil { return nil, err } diff --git a/component/dialer/patch_common.go b/component/dialer/patch_common.go index bad0ef488..2c96fe60b 100644 --- a/component/dialer/patch_common.go +++ b/component/dialer/patch_common.go @@ -5,7 +5,6 @@ package dialer import ( "context" "net" - "net/netip" "syscall" ) @@ -13,7 +12,7 @@ type SocketControl func(network, address string, conn syscall.RawConn) error var DefaultSocketHook SocketControl -func dialContextHooked(ctx context.Context, network string, destination netip.Addr, port string) (net.Conn, error) { +func dialContextHooked(ctx context.Context, dialer *net.Dialer, network string, address string) (net.Conn, error) { return nil, nil }