From b047ca02942c8e6d154f400331725f0ee805b43d Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Sat, 20 May 2023 11:44:11 +0800 Subject: [PATCH] chore: packet deadline support CreateReadWaiter interface --- common/net/deadline/packet.go | 51 +++++----- common/net/deadline/packet_enhance.go | 56 ++++++----- common/net/deadline/packet_sing.go | 134 +++++++++++++++++++++----- 3 files changed, 171 insertions(+), 70 deletions(-) diff --git a/common/net/deadline/packet.go b/common/net/deadline/packet.go index f68aadaf9..bcf2db9d8 100644 --- a/common/net/deadline/packet.go +++ b/common/net/deadline/packet.go @@ -3,6 +3,7 @@ package deadline import ( "net" "os" + "runtime" "time" "github.com/Dreamacro/clash/common/atomic" @@ -13,8 +14,6 @@ type readResult struct { data []byte addr net.Addr err error - enhanceReadResult - singReadResult } type NetPacketConn struct { @@ -23,14 +22,14 @@ type NetPacketConn struct { pipeDeadline pipeDeadline disablePipe atomic.Bool inRead atomic.Bool - resultCh chan *readResult + resultCh chan any } func NewNetPacketConn(pc net.PacketConn) net.PacketConn { npc := &NetPacketConn{ PacketConn: pc, pipeDeadline: makePipeDeadline(), - resultCh: make(chan *readResult, 1), + resultCh: make(chan any, 1), } npc.resultCh <- nil if enhancePC, isEnhance := pc.(packet.EnhancePacketConn); isEnhance { @@ -65,20 +64,28 @@ func NewNetPacketConn(pc net.PacketConn) net.PacketConn { } func (c *NetPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - select { - case result := <-c.resultCh: - if result != nil { - n = copy(p, result.data) - addr = result.addr - err = result.err - c.resultCh <- nil // finish cache read - return - } else { - c.resultCh <- nil - break +FOR: + for { + select { + case result := <-c.resultCh: + if result != nil { + if result, ok := result.(*readResult); ok { + n = copy(p, result.data) + addr = result.addr + err = result.err + c.resultCh <- nil // finish cache read + return + } + c.resultCh <- result // another type of read + runtime.Gosched() // allowing other goroutines to run + continue FOR + } else { + c.resultCh <- nil + break FOR + } + case <-c.pipeDeadline.wait(): + return 0, nil, os.ErrDeadlineExceeded } - case <-c.pipeDeadline.wait(): - return 0, nil, os.ErrDeadlineExceeded } if c.disablePipe.Load() { @@ -100,11 +107,11 @@ func (c *NetPacketConn) pipeReadFrom(size int) { buffer := make([]byte, size) n, addr, err := c.PacketConn.ReadFrom(buffer) buffer = buffer[:n] - c.resultCh <- &readResult{ - data: buffer, - addr: addr, - err: err, - } + result := &readResult{} + result.data = buffer + result.addr = addr + result.err = err + c.resultCh <- result } func (c *NetPacketConn) SetReadDeadline(t time.Time) error { diff --git a/common/net/deadline/packet_enhance.go b/common/net/deadline/packet_enhance.go index 589e1447e..5b7d767f0 100644 --- a/common/net/deadline/packet_enhance.go +++ b/common/net/deadline/packet_enhance.go @@ -3,6 +3,7 @@ package deadline import ( "net" "os" + "runtime" "github.com/Dreamacro/clash/common/net/packet" ) @@ -19,7 +20,10 @@ func NewEnhancePacketConn(pc packet.EnhancePacketConn) packet.EnhancePacketConn } type enhanceReadResult struct { - put func() + data []byte + put func() + addr net.Addr + err error } type enhancePacketConn struct { @@ -28,21 +32,29 @@ type enhancePacketConn struct { } func (c *enhancePacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { - select { - case result := <-c.netPacketConn.resultCh: - if result != nil { - data = result.data - put = result.put - addr = result.addr - err = result.err - c.netPacketConn.resultCh <- nil // finish cache read - return - } else { - c.netPacketConn.resultCh <- nil - break +FOR: + for { + select { + case result := <-c.netPacketConn.resultCh: + if result != nil { + if result, ok := result.(*enhanceReadResult); ok { + data = result.data + put = result.put + addr = result.addr + err = result.err + c.netPacketConn.resultCh <- nil // finish cache read + return + } + c.netPacketConn.resultCh <- result // another type of read + runtime.Gosched() // allowing other goroutines to run + continue FOR + } else { + c.netPacketConn.resultCh <- nil + break FOR + } + case <-c.netPacketConn.pipeDeadline.wait(): + return nil, nil, nil, os.ErrDeadlineExceeded } - case <-c.netPacketConn.pipeDeadline.wait(): - return nil, nil, nil, os.ErrDeadlineExceeded } if c.netPacketConn.disablePipe.Load() { @@ -62,12 +74,10 @@ func (c *enhancePacketConn) WaitReadFrom() (data []byte, put func(), addr net.Ad func (c *enhancePacketConn) pipeWaitReadFrom() { data, put, addr, err := c.enhancePacketConn.WaitReadFrom() - c.netPacketConn.resultCh <- &readResult{ - data: data, - enhanceReadResult: enhanceReadResult{ - put: put, - }, - addr: addr, - err: err, - } + result := &enhanceReadResult{} + result.data = data + result.put = put + result.addr = addr + result.err = err + c.netPacketConn.resultCh <- result } diff --git a/common/net/deadline/packet_sing.go b/common/net/deadline/packet_sing.go index f69022ab1..a3da34f40 100644 --- a/common/net/deadline/packet_sing.go +++ b/common/net/deadline/packet_sing.go @@ -2,10 +2,13 @@ package deadline import ( "os" + "runtime" "github.com/Dreamacro/clash/common/net/packet" "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) type SingPacketConn struct { @@ -33,6 +36,7 @@ var _ packet.EnhanceSingPacketConn = (*EnhanceSingPacketConn)(nil) type singReadResult struct { buffer *buf.Buffer destination M.Socksaddr + err error } type singPacketConn struct { @@ -41,26 +45,34 @@ type singPacketConn struct { } func (c *singPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { - select { - case result := <-c.netPacketConn.resultCh: - if result != nil { - destination = result.destination - err = result.err - buffer.Resize(result.buffer.Start(), 0) - n := copy(buffer.FreeBytes(), result.buffer.Bytes()) - buffer.Truncate(n) - result.buffer.Advance(n) - if result.buffer.IsEmpty() { - result.buffer.Release() +FOR: + for { + select { + case result := <-c.netPacketConn.resultCh: + if result != nil { + if result, ok := result.(*singReadResult); ok { + destination = result.destination + err = result.err + buffer.Resize(result.buffer.Start(), 0) + n := copy(buffer.FreeBytes(), result.buffer.Bytes()) + buffer.Truncate(n) + result.buffer.Advance(n) + if result.buffer.IsEmpty() { + result.buffer.Release() + } + c.netPacketConn.resultCh <- nil // finish cache read + return + } + c.netPacketConn.resultCh <- result // another type of read + runtime.Gosched() // allowing other goroutines to run + continue FOR + } else { + c.netPacketConn.resultCh <- nil + break FOR } - c.netPacketConn.resultCh <- nil // finish cache read - return - } else { - c.netPacketConn.resultCh <- nil - break + case <-c.netPacketConn.pipeDeadline.wait(): + return M.Socksaddr{}, os.ErrDeadlineExceeded } - case <-c.netPacketConn.pipeDeadline.wait(): - return M.Socksaddr{}, os.ErrDeadlineExceeded } if c.netPacketConn.disablePipe.Load() { @@ -82,15 +94,87 @@ func (c *singPacketConn) pipeReadPacket(bufLen int, bufStart int) { buffer := buf.NewSize(bufLen) buffer.Advance(bufStart) destination, err := c.singPacketConn.ReadPacket(buffer) - c.netPacketConn.resultCh <- &readResult{ - singReadResult: singReadResult{ - buffer: buffer, - destination: destination, - }, - err: err, - } + result := &singReadResult{} + result.destination = destination + result.err = err + c.netPacketConn.resultCh <- result } func (c *singPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { return c.singPacketConn.WritePacket(buffer, destination) } + +func (c *singPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) { + prw, isReadWaiter := bufio.CreatePacketReadWaiter(c.singPacketConn) + if isReadWaiter { + return &singPacketReadWaiter{ + netPacketConn: c.netPacketConn, + packetReadWaiter: prw, + }, true + } + return nil, false +} + +var _ N.PacketReadWaiter = (*singPacketReadWaiter)(nil) + +type singPacketReadWaiter struct { + netPacketConn *NetPacketConn + packetReadWaiter N.PacketReadWaiter +} + +type singWaitReadResult singReadResult + +func (c *singPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) { + c.packetReadWaiter.InitializeReadWaiter(newBuffer) +} + +func (c *singPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) { +FOR: + for { + select { + case result := <-c.netPacketConn.resultCh: + if result != nil { + if result, ok := result.(*singWaitReadResult); ok { + destination = result.destination + err = result.err + c.netPacketConn.resultCh <- nil // finish cache read + return + } + c.netPacketConn.resultCh <- result // another type of read + runtime.Gosched() // allowing other goroutines to run + continue FOR + } else { + c.netPacketConn.resultCh <- nil + break FOR + } + case <-c.netPacketConn.pipeDeadline.wait(): + return M.Socksaddr{}, os.ErrDeadlineExceeded + } + } + + if c.netPacketConn.disablePipe.Load() { + return c.packetReadWaiter.WaitReadPacket() + } else if c.netPacketConn.deadline.Load().IsZero() { + c.netPacketConn.inRead.Store(true) + defer c.netPacketConn.inRead.Store(false) + destination, err = c.packetReadWaiter.WaitReadPacket() + return + } + + <-c.netPacketConn.resultCh + go c.pipeWaitReadPacket() + + return c.WaitReadPacket() +} + +func (c *singPacketReadWaiter) pipeWaitReadPacket() { + destination, err := c.packetReadWaiter.WaitReadPacket() + result := &singWaitReadResult{} + result.destination = destination + result.err = err + c.netPacketConn.resultCh <- result +} + +func (c *singPacketReadWaiter) Upstream() any { + return c.packetReadWaiter +}