From 75cd72385ab67b6d60388c790e8f0ff3b539cbc4 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Thu, 11 May 2023 13:47:51 +0800 Subject: [PATCH] chore: decrease direct udp read memory used for no-windows platform --- adapter/outbound/base.go | 18 ++++++++-- adapter/outbound/direct.go | 8 +---- common/net/packet.go | 68 ++++++++++++++++++++++++++++++++++++ common/net/packet_posix.go | 64 +++++++++++++++++++++++++++++++++ common/net/packet_windows.go | 15 ++++++++ constant/adapters.go | 2 +- tunnel/connection.go | 10 +++--- tunnel/statistic/tracker.go | 10 ++++++ 8 files changed, 178 insertions(+), 17 deletions(-) create mode 100644 common/net/packet.go create mode 100644 common/net/packet_posix.go create mode 100644 common/net/packet_windows.go diff --git a/adapter/outbound/base.go b/adapter/outbound/base.go index 367638b8c..e4a553b92 100644 --- a/adapter/outbound/base.go +++ b/adapter/outbound/base.go @@ -220,7 +220,7 @@ func NewConn(c net.Conn, a C.ProxyAdapter) C.Conn { } type packetConn struct { - net.PacketConn + N.EnhancePacketConn chain C.Chain adapterName string connID string @@ -242,15 +242,27 @@ func (c *packetConn) AppendToChains(a C.ProxyAdapter) { } func (c *packetConn) LocalAddr() net.Addr { - lAddr := c.PacketConn.LocalAddr() + lAddr := c.EnhancePacketConn.LocalAddr() return N.NewCustomAddr(c.adapterName, c.connID, lAddr) // make quic-go's connMultiplexer happy } +func (c *packetConn) Upstream() any { + return c.EnhancePacketConn +} + +func (c *packetConn) WriterReplaceable() bool { + return true +} + +func (c *packetConn) ReaderReplaceable() bool { + return true +} + func newPacketConn(pc net.PacketConn, a C.ProxyAdapter) C.PacketConn { if _, ok := pc.(syscall.Conn); !ok { // exclusion system conn like *net.UDPConn pc = N.NewDeadlinePacketConn(pc) // most conn from outbound can't handle readDeadline correctly } - return &packetConn{pc, []string{a.Name()}, a.Name(), utils.NewUUIDV4().String(), parseRemoteDestination(a.Addr())} + return &packetConn{N.NewEnhancePacketConn(pc), []string{a.Name()}, a.Name(), utils.NewUUIDV4().String(), parseRemoteDestination(a.Addr())} } func parseRemoteDestination(addr string) string { diff --git a/adapter/outbound/direct.go b/adapter/outbound/direct.go index eae37d7aa..94b59cd0d 100644 --- a/adapter/outbound/direct.go +++ b/adapter/outbound/direct.go @@ -3,8 +3,6 @@ package outbound import ( "context" "errors" - "net" - "github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" @@ -39,11 +37,7 @@ func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, if err != nil { return nil, err } - return newPacketConn(&directPacketConn{pc}, d), nil -} - -type directPacketConn struct { - net.PacketConn + return newPacketConn(pc, d), nil } func NewDirect() *Direct { diff --git a/common/net/packet.go b/common/net/packet.go new file mode 100644 index 000000000..30f1104a7 --- /dev/null +++ b/common/net/packet.go @@ -0,0 +1,68 @@ +package net + +import ( + "net" + + "github.com/Dreamacro/clash/common/pool" +) + +type EnhancePacketConn interface { + net.PacketConn + WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) + Upstream() any +} + +func NewEnhancePacketConn(pc net.PacketConn) EnhancePacketConn { + if udpConn, isUDPConn := pc.(*net.UDPConn); isUDPConn { + return &enhanceUDPConn{UDPConn: udpConn} + } + return &enhancePacketConn{PacketConn: pc} +} + +type enhancePacketConn struct { + net.PacketConn +} + +func (c *enhancePacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { + return waitReadFrom(c.PacketConn) +} + +func (c *enhancePacketConn) Upstream() any { + return c.PacketConn +} + +func (c *enhancePacketConn) WriterReplaceable() bool { + return true +} + +func (c *enhancePacketConn) ReaderReplaceable() bool { + return true +} + +func (c *enhanceUDPConn) Upstream() any { + return c.UDPConn +} + +func (c *enhanceUDPConn) WriterReplaceable() bool { + return true +} + +func (c *enhanceUDPConn) ReaderReplaceable() bool { + return true +} + +func waitReadFrom(pc net.PacketConn) (data []byte, put func(), addr net.Addr, err error) { + readBuf := pool.Get(pool.UDPBufferSize) + put = func() { + _ = pool.Put(readBuf) + } + var readN int + readN, addr, err = pc.ReadFrom(readBuf) + if readN > 0 { + data = readBuf[:readN] + } else { + put() + put = nil + } + return +} diff --git a/common/net/packet_posix.go b/common/net/packet_posix.go new file mode 100644 index 000000000..18c72a1c9 --- /dev/null +++ b/common/net/packet_posix.go @@ -0,0 +1,64 @@ +//go:build !windows + +package net + +import ( + "io" + "net" + "strconv" + "syscall" + + "github.com/Dreamacro/clash/common/pool" +) + +type enhanceUDPConn struct { + *net.UDPConn + rawConn syscall.RawConn +} + +func (c *enhanceUDPConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { + if c.rawConn == nil { + c.rawConn, _ = c.UDPConn.SyscallConn() + } + var readErr error + err = c.rawConn.Read(func(fd uintptr) (done bool) { + readBuf := pool.Get(pool.UDPBufferSize) + put = func() { + _ = pool.Put(readBuf) + } + var readFrom syscall.Sockaddr + var readN int + readN, _, _, readFrom, readErr = syscall.Recvmsg(int(fd), readBuf, nil, 0) + if readN > 0 { + data = readBuf[:readN] + } else { + put() + put = nil + } + if readErr == syscall.EAGAIN { + return false + } + if readFrom != nil { + switch from := readFrom.(type) { + case *syscall.SockaddrInet4: + ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 4 bytes + addr = &net.UDPAddr{IP: ip[:], Port: from.Port} + case *syscall.SockaddrInet6: + ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 16 bytes + addr = &net.UDPAddr{IP: ip[:], Port: from.Port, Zone: strconv.FormatInt(int64(from.ZoneId), 10)} + } + } + if readN == 0 { + readErr = io.EOF + } + return true + }) + if err != nil { + return + } + if readErr != nil { + err = readErr + return + } + return +} diff --git a/common/net/packet_windows.go b/common/net/packet_windows.go new file mode 100644 index 000000000..a5bf75aaf --- /dev/null +++ b/common/net/packet_windows.go @@ -0,0 +1,15 @@ +//go:build windows + +package net + +import ( + "net" +) + +type enhanceUDPConn struct { + *net.UDPConn +} + +func (c *enhanceUDPConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { + return waitReadFrom(c.UDPConn) +} diff --git a/constant/adapters.go b/constant/adapters.go index 2a2c68c1e..73877decb 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -81,7 +81,7 @@ type Conn interface { } type PacketConn interface { - net.PacketConn + N.EnhancePacketConn Connection // Deprecate WriteWithMetadata because of remote resolve DNS cause TURN failed // WriteWithMetadata(p []byte, metadata *Metadata) (n int, err error) diff --git a/tunnel/connection.go b/tunnel/connection.go index c64a52664..c95e33f25 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -7,7 +7,6 @@ import ( "time" N "github.com/Dreamacro/clash/common/net" - "github.com/Dreamacro/clash/common/pool" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/log" ) @@ -27,18 +26,16 @@ func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata return nil } -func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, oAddr, fAddr netip.Addr) { - buf := pool.Get(pool.UDPBufferSize) +func handleUDPToLocal(packet C.UDPPacket, pc N.EnhancePacketConn, key string, oAddr, fAddr netip.Addr) { defer func() { _ = pc.Close() closeAllLocalCoon(key) natTable.Delete(key) - _ = pool.Put(buf) }() for { _ = pc.SetReadDeadline(time.Now().Add(udpTimeout)) - n, from, err := pc.ReadFrom(buf) + data, put, from, err := pc.WaitReadFrom() if err != nil { return } @@ -54,7 +51,8 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, oAddr, } } - _, err = packet.WriteBack(buf[:n], fromUDPAddr) + _, err = packet.WriteBack(data, fromUDPAddr) + put() if err != nil { return } diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index 170cbc993..685b5e903 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -186,6 +186,16 @@ func (ut *udpTracker) ReadFrom(b []byte) (int, net.Addr, error) { return n, addr, err } +func (ut *udpTracker) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { + data, put, addr, err = ut.PacketConn.WaitReadFrom() + download := int64(len(data)) + if ut.pushToManager { + ut.manager.PushDownloaded(download) + } + ut.DownloadTotal.Add(download) + return +} + func (ut *udpTracker) WriteTo(b []byte, addr net.Addr) (int, error) { n, err := ut.PacketConn.WriteTo(b, addr) upload := int64(n)