fix: wireguard not working in CMFA

This commit is contained in:
wwqgtxx 2024-08-13 13:33:24 +08:00
parent c17d7c0281
commit 5bf22422d9
3 changed files with 14 additions and 15 deletions

View File

@ -127,10 +127,6 @@ func GetTcpConcurrent() bool {
} }
func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) { func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) {
if features.CMFA && DefaultSocketHook != nil {
return dialContextHooked(ctx, network, destination, port)
}
var address string var address string
if IP4PEnable { if IP4PEnable {
destination, port = lookupIP4P(destination, port) destination, port = lookupIP4P(destination, port)
@ -149,6 +145,14 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
} }
dialer := netDialer.(*net.Dialer) dialer := netDialer.(*net.Dialer)
if opt.mpTcp {
setMultiPathTCP(dialer)
}
if features.CMFA && DefaultSocketHook != nil { // ignore interfaceName, routingMark and tfo in CMFA
return dialContextHooked(ctx, dialer, network, address)
}
if opt.interfaceName != "" { if opt.interfaceName != "" {
bind := bindIfaceToDialer bind := bindIfaceToDialer
if opt.fallbackBind { if opt.fallbackBind {
@ -161,9 +165,6 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
if opt.routingMark != 0 { if opt.routingMark != 0 {
bindMarkToDialer(opt.routingMark, dialer, network, destination) bindMarkToDialer(opt.routingMark, dialer, network, destination)
} }
if opt.mpTcp {
setMultiPathTCP(dialer)
}
if opt.tfo && !DisableTFO { if opt.tfo && !DisableTFO {
return dialTFO(ctx, *dialer, network, address) return dialTFO(ctx, *dialer, network, address)
} }

View File

@ -5,7 +5,6 @@ package dialer
import ( import (
"context" "context"
"net" "net"
"net/netip"
"syscall" "syscall"
) )
@ -13,12 +12,12 @@ type SocketControl func(network, address string, conn syscall.RawConn) error
var DefaultSocketHook SocketControl var DefaultSocketHook SocketControl
func dialContextHooked(ctx context.Context, network string, destination netip.Addr, port string) (net.Conn, error) { func dialContextHooked(ctx context.Context, dialer *net.Dialer, network string, address string) (net.Conn, error) {
dialer := &net.Dialer{ addControlToDialer(dialer, func(ctx context.Context, network, address string, c syscall.RawConn) error {
Control: DefaultSocketHook, return DefaultSocketHook(network, address, c)
} })
conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(destination.String(), port)) conn, err := dialer.DialContext(ctx, network, address)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -5,7 +5,6 @@ package dialer
import ( import (
"context" "context"
"net" "net"
"net/netip"
"syscall" "syscall"
) )
@ -13,7 +12,7 @@ type SocketControl func(network, address string, conn syscall.RawConn) error
var DefaultSocketHook SocketControl var DefaultSocketHook SocketControl
func dialContextHooked(ctx context.Context, network string, destination netip.Addr, port string) (net.Conn, error) { func dialContextHooked(ctx context.Context, dialer *net.Dialer, network string, address string) (net.Conn, error) {
return nil, nil return nil, nil
} }