chore: ensures packets can be sent without blocking the tunnel

This commit is contained in:
wwqgtxx 2024-09-26 11:21:07 +08:00
parent 5812a7bdeb
commit 4fa15c6334
4 changed files with 218 additions and 147 deletions

View File

@ -10,47 +10,30 @@ import (
)
type Table struct {
mapping *xsync.MapOf[string, *Entry]
lockMap *xsync.MapOf[string, *sync.Cond]
mapping *xsync.MapOf[string, *entry]
}
type Entry struct {
PacketConn C.PacketConn
WriteBackProxy C.WriteBackProxy
type entry struct {
PacketSender C.PacketSender
LocalUDPConnMap *xsync.MapOf[string, *net.UDPConn]
LocalLockMap *xsync.MapOf[string, *sync.Cond]
}
func (t *Table) Set(key string, e C.PacketConn, w C.WriteBackProxy) {
t.mapping.Store(key, &Entry{
PacketConn: e,
WriteBackProxy: w,
LocalUDPConnMap: xsync.NewMapOf[string, *net.UDPConn](),
LocalLockMap: xsync.NewMapOf[string, *sync.Cond](),
func (t *Table) GetOrCreate(key string, maker func() C.PacketSender) (C.PacketSender, bool) {
item, loaded := t.mapping.LoadOrCompute(key, func() *entry {
return &entry{
PacketSender: maker(),
LocalUDPConnMap: xsync.NewMapOf[string, *net.UDPConn](),
LocalLockMap: xsync.NewMapOf[string, *sync.Cond](),
}
})
}
func (t *Table) Get(key string) (C.PacketConn, C.WriteBackProxy) {
entry, exist := t.getEntry(key)
if !exist {
return nil, nil
}
return entry.PacketConn, entry.WriteBackProxy
}
func (t *Table) GetOrCreateLock(key string) (*sync.Cond, bool) {
item, loaded := t.lockMap.LoadOrCompute(key, makeLock)
return item, loaded
return item.PacketSender, loaded
}
func (t *Table) Delete(key string) {
t.mapping.Delete(key)
}
func (t *Table) DeleteLock(lockKey string) {
t.lockMap.Delete(lockKey)
}
func (t *Table) GetForLocalConn(lAddr, rAddr string) *net.UDPConn {
entry, exist := t.getEntry(lAddr)
if !exist {
@ -105,7 +88,7 @@ func (t *Table) DeleteLockForLocalConn(lAddr, key string) {
entry.LocalLockMap.Delete(key)
}
func (t *Table) getEntry(key string) (*Entry, bool) {
func (t *Table) getEntry(key string) (*entry, bool) {
return t.mapping.Load(key)
}
@ -116,7 +99,6 @@ func makeLock() *sync.Cond {
// New return *Cache
func New() *Table {
return &Table{
mapping: xsync.NewMapOf[string, *Entry](),
lockMap: xsync.NewMapOf[string, *sync.Cond](),
mapping: xsync.NewMapOf[string, *entry](),
}
}

View File

@ -255,12 +255,16 @@ type UDPPacketInAddr interface {
// PacketAdapter is a UDP Packet adapter for socks/redir/tun
type PacketAdapter interface {
UDPPacket
// Metadata returns destination metadata
Metadata() *Metadata
// Key is a SNAT key
Key() string
}
type packetAdapter struct {
UDPPacket
metadata *Metadata
key string
}
// Metadata returns destination metadata
@ -268,10 +272,16 @@ func (s *packetAdapter) Metadata() *Metadata {
return s.metadata
}
// Key is a SNAT key
func (s *packetAdapter) Key() string {
return s.key
}
func NewPacketAdapter(packet UDPPacket, metadata *Metadata) PacketAdapter {
return &packetAdapter{
packet,
metadata,
packet.LocalAddr().String(),
}
}
@ -284,17 +294,19 @@ type WriteBackProxy interface {
UpdateWriteBack(wb WriteBack)
}
type PacketSender interface {
// Send will send PacketAdapter nonblocking
// the implement must call UDPPacket.Drop() inside Send
Send(PacketAdapter)
Process(PacketConn, WriteBackProxy)
Close()
}
type NatTable interface {
Set(key string, e PacketConn, w WriteBackProxy)
Get(key string) (PacketConn, WriteBackProxy)
GetOrCreateLock(key string) (*sync.Cond, bool)
GetOrCreate(key string, maker func() PacketSender) (PacketSender, bool)
Delete(key string)
DeleteLock(key string)
GetForLocalConn(lAddr, rAddr string) *net.UDPConn
AddForLocalConn(lAddr, rAddr string, conn *net.UDPConn) bool

View File

@ -1,6 +1,7 @@
package tunnel
import (
"context"
"errors"
"net"
"net/netip"
@ -11,7 +12,78 @@ import (
"github.com/metacubex/mihomo/log"
)
type packetSender struct {
ctx context.Context
cancel context.CancelFunc
ch chan C.PacketAdapter
}
// newPacketSender return a chan based C.PacketSender
// It ensures that packets can be sent sequentially and without blocking
func newPacketSender() C.PacketSender {
ctx, cancel := context.WithCancel(context.Background())
ch := make(chan C.PacketAdapter, senderCapacity)
return &packetSender{
ctx: ctx,
cancel: cancel,
ch: ch,
}
}
func (s *packetSender) Process(pc C.PacketConn, proxy C.WriteBackProxy) {
for {
select {
case <-s.ctx.Done():
return // sender closed
case packet := <-s.ch:
if proxy != nil {
proxy.UpdateWriteBack(packet)
}
_ = handleUDPToRemote(packet, pc, packet.Metadata())
packet.Drop()
}
}
}
func (s *packetSender) dropAll() {
for {
select {
case data := <-s.ch:
data.Drop() // drop all data still in chan
default:
return // no data, exit goroutine
}
}
}
func (s *packetSender) Send(packet C.PacketAdapter) {
select {
case <-s.ctx.Done():
packet.Drop() // sender closed before Send()
return
default:
}
select {
case s.ch <- packet:
// put ok, so don't drop packet, will process by other side of chan
case <-s.ctx.Done():
packet.Drop() // sender closed when putting data to chan
default:
packet.Drop() // chan is full
}
}
func (s *packetSender) Close() {
s.cancel()
s.dropAll()
}
func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata) error {
if err := resolveUDP(metadata); err != nil {
return err
}
addr := metadata.UDPAddr()
if addr == nil {
return errors.New("udp addr invalid")
@ -26,8 +98,9 @@ func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata
return nil
}
func handleUDPToLocal(writeBack C.WriteBack, pc N.EnhancePacketConn, key string, oAddrPort netip.AddrPort, fAddr netip.Addr) {
func handleUDPToLocal(writeBack C.WriteBack, pc N.EnhancePacketConn, sender C.PacketSender, key string, oAddrPort netip.AddrPort, fAddr netip.Addr) {
defer func() {
sender.Close()
_ = pc.Close()
closeAllLocalCoon(key)
natTable.Delete(key)

View File

@ -28,11 +28,14 @@ import (
"github.com/metacubex/mihomo/tunnel/statistic"
)
const queueSize = 200
const (
queueCapacity = 64 // chan capacity tcpQueue and udpQueue
senderCapacity = 128 // chan capacity of PacketSender
)
var (
status = newAtomicStatus(Suspend)
tcpQueue = make(chan C.ConnContext, queueSize)
udpInit sync.Once
udpQueues []chan C.PacketAdapter
natTable = nat.New()
rules []C.Rule
@ -43,6 +46,12 @@ var (
ruleProviders map[string]provider.RuleProvider
configMux sync.RWMutex
// for compatibility, lazy init
tcpQueue chan C.ConnContext
tcpInOnce sync.Once
udpQueue chan C.PacketAdapter
udpInOnce sync.Once
// Outbound Rule
mode = Rule
@ -70,15 +79,33 @@ func (t tunnel) HandleTCPConn(conn net.Conn, metadata *C.Metadata) {
handleTCPConn(connCtx)
}
func (t tunnel) HandleUDPPacket(packet C.UDPPacket, metadata *C.Metadata) {
packetAdapter := C.NewPacketAdapter(packet, metadata)
func initUDP() {
numUDPWorkers := 4
if num := runtime.GOMAXPROCS(0); num > numUDPWorkers {
numUDPWorkers = num
}
hash := utils.MapHash(metadata.SourceAddress() + "-" + metadata.RemoteAddress())
udpQueues = make([]chan C.PacketAdapter, numUDPWorkers)
for i := 0; i < numUDPWorkers; i++ {
queue := make(chan C.PacketAdapter, queueCapacity)
udpQueues[i] = queue
go processUDP(queue)
}
}
func (t tunnel) HandleUDPPacket(packet C.UDPPacket, metadata *C.Metadata) {
udpInit.Do(initUDP)
packetAdapter := C.NewPacketAdapter(packet, metadata)
key := packetAdapter.Key()
hash := utils.MapHash(key)
queueNo := uint(hash) % uint(len(udpQueues))
select {
case udpQueues[queueNo] <- packetAdapter:
default:
packet.Drop()
}
}
@ -134,21 +161,32 @@ func IsSniffing() bool {
return sniffingEnable
}
func init() {
go process()
}
// TCPIn return fan-in queue
// Deprecated: using Tunnel instead
func TCPIn() chan<- C.ConnContext {
tcpInOnce.Do(func() {
tcpQueue = make(chan C.ConnContext, queueCapacity)
go func() {
for connCtx := range tcpQueue {
go handleTCPConn(connCtx)
}
}()
})
return tcpQueue
}
// UDPIn return fan-in udp queue
// Deprecated: using Tunnel instead
func UDPIn() chan<- C.PacketAdapter {
// compatibility: first queue is always available for external callers
return udpQueues[0]
udpInOnce.Do(func() {
udpQueue = make(chan C.PacketAdapter, queueCapacity)
go func() {
for packet := range udpQueue {
Tunnel.HandleUDPPacket(packet, packet.Metadata())
}
}()
})
return udpQueue
}
// NatTable return nat table
@ -249,32 +287,6 @@ func isHandle(t C.Type) bool {
return status == Running || (status == Inner && t == C.INNER)
}
// processUDP starts a loop to handle udp packet
func processUDP(queue chan C.PacketAdapter) {
for conn := range queue {
handleUDPConn(conn)
}
}
func process() {
numUDPWorkers := 4
if num := runtime.GOMAXPROCS(0); num > numUDPWorkers {
numUDPWorkers = num
}
udpQueues = make([]chan C.PacketAdapter, numUDPWorkers)
for i := 0; i < numUDPWorkers; i++ {
queue := make(chan C.PacketAdapter, queueSize)
udpQueues[i] = queue
go processUDP(queue)
}
queue := tcpQueue
for conn := range queue {
go handleTCPConn(conn)
}
}
func needLookupIP(metadata *C.Metadata) bool {
return resolver.MappingEnabled() && metadata.Host == "" && metadata.DstIP.IsValid()
}
@ -334,6 +346,25 @@ func resolveMetadata(metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err erro
return
}
func resolveUDP(metadata *C.Metadata) error {
// local resolve UDP dns
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(context.Background(), metadata.Host)
if err != nil {
return err
}
metadata.DstIP = ip
}
return nil
}
// processUDP starts a loop to handle udp packet
func processUDP(queue chan C.PacketAdapter) {
for conn := range queue {
handleUDPConn(conn)
}
}
func handleUDPConn(packet C.PacketAdapter) {
if !isHandle(packet.Metadata().Type) {
packet.Drop()
@ -363,85 +394,58 @@ func handleUDPConn(packet C.PacketAdapter) {
snifferDispatcher.UDPSniff(packet)
}
// local resolve UDP dns
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(context.Background(), metadata.Host)
if err != nil {
return
}
metadata.DstIP = ip
}
key := packet.LocalAddr().String()
handle := func() bool {
pc, proxy := natTable.Get(key)
if pc != nil {
if proxy != nil {
proxy.UpdateWriteBack(packet)
key := packet.Key()
sender, loaded := natTable.GetOrCreate(key, newPacketSender)
if !loaded {
dial := func() (C.PacketConn, C.WriteBackProxy, error) {
if err := resolveUDP(metadata); err != nil {
log.Warnln("[UDP] Resolve Ip error: %s", err)
return nil, nil, err
}
_ = handleUDPToRemote(packet, pc, metadata)
return true
}
return false
}
if handle() {
packet.Drop()
return
}
proxy, rule, err := resolveMetadata(metadata)
if err != nil {
log.Warnln("[UDP] Parse metadata failed: %s", err.Error())
return nil, nil, err
}
cond, loaded := natTable.GetOrCreateLock(key)
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultUDPTimeout)
defer cancel()
rawPc, err := retry(ctx, func(ctx context.Context) (C.PacketConn, error) {
return proxy.ListenPacketContext(ctx, metadata.Pure())
}, func(err error) {
logMetadataErr(metadata, rule, proxy, err)
})
if err != nil {
return nil, nil, err
}
logMetadata(metadata, rule, rawPc)
go func() {
defer packet.Drop()
pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule, 0, 0, true)
if loaded {
cond.L.Lock()
cond.Wait()
handle()
cond.L.Unlock()
return
if rawPc.Chains().Last() == "REJECT-DROP" {
_ = pc.Close()
return nil, nil, errors.New("rejected drop packet")
}
oAddrPort := metadata.AddrPort()
writeBackProxy := nat.NewWriteBackProxy(packet)
go handleUDPToLocal(writeBackProxy, pc, sender, key, oAddrPort, fAddr)
return pc, writeBackProxy, nil
}
defer func() {
natTable.DeleteLock(key)
cond.Broadcast()
go func() {
pc, proxy, err := dial()
if err != nil {
sender.Close()
natTable.Delete(key)
return
}
sender.Process(pc, proxy)
}()
proxy, rule, err := resolveMetadata(metadata)
if err != nil {
log.Warnln("[UDP] Parse metadata failed: %s", err.Error())
return
}
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultUDPTimeout)
defer cancel()
rawPc, err := retry(ctx, func(ctx context.Context) (C.PacketConn, error) {
return proxy.ListenPacketContext(ctx, metadata.Pure())
}, func(err error) {
logMetadataErr(metadata, rule, proxy, err)
})
if err != nil {
return
}
logMetadata(metadata, rule, rawPc)
pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule, 0, 0, true)
if rawPc.Chains().Last() == "REJECT-DROP" {
pc.Close()
return
}
oAddrPort := metadata.AddrPort()
writeBackProxy := nat.NewWriteBackProxy(packet)
natTable.Set(key, pc, writeBackProxy)
go handleUDPToLocal(writeBackProxy, pc, key, oAddrPort, fAddr)
handle()
}()
}
sender.Send(packet) // nonblocking
}
func handleTCPConn(connCtx C.ConnContext) {