From 0c384b1e4225709d1dcec61db37dc8af0388bfa3 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Wed, 7 Feb 2024 21:07:41 +0800 Subject: [PATCH] fix: tproxy start error --- listener/tproxy/setsockopt_linux.go | 18 +++++++++++++----- listener/tproxy/udp_linux.go | 20 +++++++++++++++++--- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/listener/tproxy/setsockopt_linux.go b/listener/tproxy/setsockopt_linux.go index b83b28a40..9189f1152 100644 --- a/listener/tproxy/setsockopt_linux.go +++ b/listener/tproxy/setsockopt_linux.go @@ -36,13 +36,21 @@ func setsockopt(rc syscall.RawConn, addr string) error { } if err == nil { - err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVTOS, 1) - } - - if err == nil { - err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, syscall.IPV6_RECVTCLASS, 1) + _ = setDSCPsockopt(fd, isIPv6) } }) return err } + +func setDSCPsockopt(fd uintptr, isIPv6 bool) (err error) { + if err == nil { + err = syscall.SetsockoptInt(int(fd), syscall.SOL_IP, syscall.IP_RECVTOS, 1) + } + + if err == nil && isIPv6 { + err = syscall.SetsockoptInt(int(fd), syscall.SOL_IPV6, syscall.IPV6_RECVTCLASS, 1) + } + + return +} diff --git a/listener/tproxy/udp_linux.go b/listener/tproxy/udp_linux.go index 02b513793..c96d4cc73 100644 --- a/listener/tproxy/udp_linux.go +++ b/listener/tproxy/udp_linux.go @@ -104,7 +104,14 @@ func getOrigDst(oob []byte) (netip.AddrPort, error) { } // retrieve the destination address from the SCM. - sa, err := unix.ParseOrigDstAddr(&scms[1]) + var sa unix.Sockaddr + for i := range scms { + sa, err = unix.ParseOrigDstAddr(&scms[i]) + if err == nil { + break + } + } + if err != nil { return netip.AddrPort{}, fmt.Errorf("retrieve destination: %w", err) } @@ -123,12 +130,19 @@ func getOrigDst(oob []byte) (netip.AddrPort, error) { return rAddr, nil } -func getDSCP (oob []byte) (uint8, error) { +func getDSCP(oob []byte) (uint8, error) { scms, err := unix.ParseSocketControlMessage(oob) if err != nil { return 0, fmt.Errorf("parse control message: %w", err) } - dscp, err := parseDSCP(&scms[0]) + var dscp uint8 + for i := range scms { + dscp, err = parseDSCP(&scms[i]) + if err == nil { + break + } + } + if err != nil { return 0, fmt.Errorf("retrieve DSCP: %w", err) }