chore: packet deadline support CreateReadWaiter interface

This commit is contained in:
wwqgtxx 2023-05-20 11:44:11 +08:00
parent 2b1e69153b
commit b047ca0294
3 changed files with 171 additions and 70 deletions

View File

@ -3,6 +3,7 @@ package deadline
import ( import (
"net" "net"
"os" "os"
"runtime"
"time" "time"
"github.com/Dreamacro/clash/common/atomic" "github.com/Dreamacro/clash/common/atomic"
@ -13,8 +14,6 @@ type readResult struct {
data []byte data []byte
addr net.Addr addr net.Addr
err error err error
enhanceReadResult
singReadResult
} }
type NetPacketConn struct { type NetPacketConn struct {
@ -23,14 +22,14 @@ type NetPacketConn struct {
pipeDeadline pipeDeadline pipeDeadline pipeDeadline
disablePipe atomic.Bool disablePipe atomic.Bool
inRead atomic.Bool inRead atomic.Bool
resultCh chan *readResult resultCh chan any
} }
func NewNetPacketConn(pc net.PacketConn) net.PacketConn { func NewNetPacketConn(pc net.PacketConn) net.PacketConn {
npc := &NetPacketConn{ npc := &NetPacketConn{
PacketConn: pc, PacketConn: pc,
pipeDeadline: makePipeDeadline(), pipeDeadline: makePipeDeadline(),
resultCh: make(chan *readResult, 1), resultCh: make(chan any, 1),
} }
npc.resultCh <- nil npc.resultCh <- nil
if enhancePC, isEnhance := pc.(packet.EnhancePacketConn); isEnhance { if enhancePC, isEnhance := pc.(packet.EnhancePacketConn); isEnhance {
@ -65,20 +64,28 @@ func NewNetPacketConn(pc net.PacketConn) net.PacketConn {
} }
func (c *NetPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { func (c *NetPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
select { FOR:
case result := <-c.resultCh: for {
if result != nil { select {
n = copy(p, result.data) case result := <-c.resultCh:
addr = result.addr if result != nil {
err = result.err if result, ok := result.(*readResult); ok {
c.resultCh <- nil // finish cache read n = copy(p, result.data)
return addr = result.addr
} else { err = result.err
c.resultCh <- nil c.resultCh <- nil // finish cache read
break return
}
c.resultCh <- result // another type of read
runtime.Gosched() // allowing other goroutines to run
continue FOR
} else {
c.resultCh <- nil
break FOR
}
case <-c.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded
} }
case <-c.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded
} }
if c.disablePipe.Load() { if c.disablePipe.Load() {
@ -100,11 +107,11 @@ func (c *NetPacketConn) pipeReadFrom(size int) {
buffer := make([]byte, size) buffer := make([]byte, size)
n, addr, err := c.PacketConn.ReadFrom(buffer) n, addr, err := c.PacketConn.ReadFrom(buffer)
buffer = buffer[:n] buffer = buffer[:n]
c.resultCh <- &readResult{ result := &readResult{}
data: buffer, result.data = buffer
addr: addr, result.addr = addr
err: err, result.err = err
} c.resultCh <- result
} }
func (c *NetPacketConn) SetReadDeadline(t time.Time) error { func (c *NetPacketConn) SetReadDeadline(t time.Time) error {

View File

@ -3,6 +3,7 @@ package deadline
import ( import (
"net" "net"
"os" "os"
"runtime"
"github.com/Dreamacro/clash/common/net/packet" "github.com/Dreamacro/clash/common/net/packet"
) )
@ -19,7 +20,10 @@ func NewEnhancePacketConn(pc packet.EnhancePacketConn) packet.EnhancePacketConn
} }
type enhanceReadResult struct { type enhanceReadResult struct {
put func() data []byte
put func()
addr net.Addr
err error
} }
type enhancePacketConn struct { type enhancePacketConn struct {
@ -28,21 +32,29 @@ type enhancePacketConn struct {
} }
func (c *enhancePacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { func (c *enhancePacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) {
select { FOR:
case result := <-c.netPacketConn.resultCh: for {
if result != nil { select {
data = result.data case result := <-c.netPacketConn.resultCh:
put = result.put if result != nil {
addr = result.addr if result, ok := result.(*enhanceReadResult); ok {
err = result.err data = result.data
c.netPacketConn.resultCh <- nil // finish cache read put = result.put
return addr = result.addr
} else { err = result.err
c.netPacketConn.resultCh <- nil c.netPacketConn.resultCh <- nil // finish cache read
break return
}
c.netPacketConn.resultCh <- result // another type of read
runtime.Gosched() // allowing other goroutines to run
continue FOR
} else {
c.netPacketConn.resultCh <- nil
break FOR
}
case <-c.netPacketConn.pipeDeadline.wait():
return nil, nil, nil, os.ErrDeadlineExceeded
} }
case <-c.netPacketConn.pipeDeadline.wait():
return nil, nil, nil, os.ErrDeadlineExceeded
} }
if c.netPacketConn.disablePipe.Load() { if c.netPacketConn.disablePipe.Load() {
@ -62,12 +74,10 @@ func (c *enhancePacketConn) WaitReadFrom() (data []byte, put func(), addr net.Ad
func (c *enhancePacketConn) pipeWaitReadFrom() { func (c *enhancePacketConn) pipeWaitReadFrom() {
data, put, addr, err := c.enhancePacketConn.WaitReadFrom() data, put, addr, err := c.enhancePacketConn.WaitReadFrom()
c.netPacketConn.resultCh <- &readResult{ result := &enhanceReadResult{}
data: data, result.data = data
enhanceReadResult: enhanceReadResult{ result.put = put
put: put, result.addr = addr
}, result.err = err
addr: addr, c.netPacketConn.resultCh <- result
err: err,
}
} }

View File

@ -2,10 +2,13 @@ package deadline
import ( import (
"os" "os"
"runtime"
"github.com/Dreamacro/clash/common/net/packet" "github.com/Dreamacro/clash/common/net/packet"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
) )
type SingPacketConn struct { type SingPacketConn struct {
@ -33,6 +36,7 @@ var _ packet.EnhanceSingPacketConn = (*EnhanceSingPacketConn)(nil)
type singReadResult struct { type singReadResult struct {
buffer *buf.Buffer buffer *buf.Buffer
destination M.Socksaddr destination M.Socksaddr
err error
} }
type singPacketConn struct { type singPacketConn struct {
@ -41,26 +45,34 @@ type singPacketConn struct {
} }
func (c *singPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { func (c *singPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
select { FOR:
case result := <-c.netPacketConn.resultCh: for {
if result != nil { select {
destination = result.destination case result := <-c.netPacketConn.resultCh:
err = result.err if result != nil {
buffer.Resize(result.buffer.Start(), 0) if result, ok := result.(*singReadResult); ok {
n := copy(buffer.FreeBytes(), result.buffer.Bytes()) destination = result.destination
buffer.Truncate(n) err = result.err
result.buffer.Advance(n) buffer.Resize(result.buffer.Start(), 0)
if result.buffer.IsEmpty() { n := copy(buffer.FreeBytes(), result.buffer.Bytes())
result.buffer.Release() buffer.Truncate(n)
result.buffer.Advance(n)
if result.buffer.IsEmpty() {
result.buffer.Release()
}
c.netPacketConn.resultCh <- nil // finish cache read
return
}
c.netPacketConn.resultCh <- result // another type of read
runtime.Gosched() // allowing other goroutines to run
continue FOR
} else {
c.netPacketConn.resultCh <- nil
break FOR
} }
c.netPacketConn.resultCh <- nil // finish cache read case <-c.netPacketConn.pipeDeadline.wait():
return return M.Socksaddr{}, os.ErrDeadlineExceeded
} else {
c.netPacketConn.resultCh <- nil
break
} }
case <-c.netPacketConn.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
} }
if c.netPacketConn.disablePipe.Load() { if c.netPacketConn.disablePipe.Load() {
@ -82,15 +94,87 @@ func (c *singPacketConn) pipeReadPacket(bufLen int, bufStart int) {
buffer := buf.NewSize(bufLen) buffer := buf.NewSize(bufLen)
buffer.Advance(bufStart) buffer.Advance(bufStart)
destination, err := c.singPacketConn.ReadPacket(buffer) destination, err := c.singPacketConn.ReadPacket(buffer)
c.netPacketConn.resultCh <- &readResult{ result := &singReadResult{}
singReadResult: singReadResult{ result.destination = destination
buffer: buffer, result.err = err
destination: destination, c.netPacketConn.resultCh <- result
},
err: err,
}
} }
func (c *singPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { func (c *singPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return c.singPacketConn.WritePacket(buffer, destination) return c.singPacketConn.WritePacket(buffer, destination)
} }
func (c *singPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
prw, isReadWaiter := bufio.CreatePacketReadWaiter(c.singPacketConn)
if isReadWaiter {
return &singPacketReadWaiter{
netPacketConn: c.netPacketConn,
packetReadWaiter: prw,
}, true
}
return nil, false
}
var _ N.PacketReadWaiter = (*singPacketReadWaiter)(nil)
type singPacketReadWaiter struct {
netPacketConn *NetPacketConn
packetReadWaiter N.PacketReadWaiter
}
type singWaitReadResult singReadResult
func (c *singPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
c.packetReadWaiter.InitializeReadWaiter(newBuffer)
}
func (c *singPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) {
FOR:
for {
select {
case result := <-c.netPacketConn.resultCh:
if result != nil {
if result, ok := result.(*singWaitReadResult); ok {
destination = result.destination
err = result.err
c.netPacketConn.resultCh <- nil // finish cache read
return
}
c.netPacketConn.resultCh <- result // another type of read
runtime.Gosched() // allowing other goroutines to run
continue FOR
} else {
c.netPacketConn.resultCh <- nil
break FOR
}
case <-c.netPacketConn.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
}
}
if c.netPacketConn.disablePipe.Load() {
return c.packetReadWaiter.WaitReadPacket()
} else if c.netPacketConn.deadline.Load().IsZero() {
c.netPacketConn.inRead.Store(true)
defer c.netPacketConn.inRead.Store(false)
destination, err = c.packetReadWaiter.WaitReadPacket()
return
}
<-c.netPacketConn.resultCh
go c.pipeWaitReadPacket()
return c.WaitReadPacket()
}
func (c *singPacketReadWaiter) pipeWaitReadPacket() {
destination, err := c.packetReadWaiter.WaitReadPacket()
result := &singWaitReadResult{}
result.destination = destination
result.err = err
c.netPacketConn.resultCh <- result
}
func (c *singPacketReadWaiter) Upstream() any {
return c.packetReadWaiter
}