refactor: Implement extended IO

This commit is contained in:
H1JK 2023-01-16 09:42:03 +08:00
parent 8fa66c13a9
commit d1565bb46f
7 changed files with 219 additions and 39 deletions

View File

@ -4,12 +4,15 @@ import (
"context"
"encoding/json"
"errors"
"github.com/gofrs/uuid"
"net"
"strings"
"github.com/Dreamacro/clash/component/dialer"
C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/network"
)
type Base struct {
@ -166,7 +169,7 @@ func NewBase(opt BaseOption) *Base {
}
type conn struct {
net.Conn
network.ExtendedConn
chain C.Chain
actualRemoteDestination string
}
@ -185,8 +188,15 @@ func (c *conn) AppendToChains(a C.ProxyAdapter) {
c.chain = append(c.chain, a.Name())
}
func (c *conn) Upstream() any {
if wrapper, ok := c.ExtendedConn.(*bufio.ExtendedConnWrapper); ok {
return wrapper.Conn
}
return c.ExtendedConn
}
func NewConn(c net.Conn, a C.ProxyAdapter) C.Conn {
return &conn{c, []string{a.Name()}, parseRemoteDestination(a.Addr())}
return &conn{bufio.NewExtendedConn(c), []string{a.Name()}, parseRemoteDestination(a.Addr())}
}
type packetConn struct {

View File

@ -14,6 +14,8 @@ import (
"github.com/Dreamacro/clash/transport/gun"
"github.com/Dreamacro/clash/transport/trojan"
"github.com/Dreamacro/clash/transport/vless"
"github.com/sagernet/sing/common/bufio"
)
type Trojan struct {
@ -95,7 +97,7 @@ func (t *Trojan) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error)
return c, err
}
err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata))
return c, err
return bufio.NewExtendedConn(c), err
}
// DialContext implements C.ProxyAdapter

View File

@ -3,18 +3,24 @@ package net
import (
"bufio"
"net"
"github.com/sagernet/sing/common/buf"
sing_bufio "github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/network"
)
var _ network.ExtendedConn = (*BufferedConn)(nil)
type BufferedConn struct {
r *bufio.Reader
net.Conn
network.ExtendedConn
}
func NewBufferedConn(c net.Conn) *BufferedConn {
if bc, ok := c.(*BufferedConn); ok {
return bc
}
return &BufferedConn{bufio.NewReader(c), c}
return &BufferedConn{bufio.NewReader(c), sing_bufio.NewExtendedConn(c)}
}
// Reader returns the internal bufio.Reader.
@ -42,3 +48,18 @@ func (c *BufferedConn) UnreadByte() error {
func (c *BufferedConn) Buffered() int {
return c.r.Buffered()
}
func (c *BufferedConn) ReadBuffer(buffer *buf.Buffer) (err error) {
if c.r.Buffered() > 0 {
_, err = buffer.ReadOnceFrom(c.r)
return
}
return c.ExtendedConn.ReadBuffer(buffer)
}
func (c *BufferedConn) Upstream() any {
if wrapper, ok := c.ExtendedConn.(*sing_bufio.ExtendedConnWrapper); ok {
return wrapper.Conn
}
return c.ExtendedConn
}

View File

@ -1,7 +1,6 @@
package vless
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
@ -9,12 +8,16 @@ import (
"net"
"github.com/gofrs/uuid"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/network"
xtls "github.com/xtls/go"
"google.golang.org/protobuf/proto"
)
type Conn struct {
net.Conn
network.ExtendedConn
dst *DstAddr
id *uuid.UUID
addons *Addons
@ -23,57 +26,82 @@ type Conn struct {
func (vc *Conn) Read(b []byte) (int, error) {
if vc.received {
return vc.Conn.Read(b)
return vc.ExtendedConn.Read(b)
}
if err := vc.recvResponse(); err != nil {
return 0, err
}
vc.received = true
return vc.Conn.Read(b)
return vc.ExtendedConn.Read(b)
}
func (vc *Conn) sendRequest() error {
buf := &bytes.Buffer{}
func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error {
if vc.received {
return vc.ExtendedConn.ReadBuffer(buffer)
}
buf.WriteByte(Version) // protocol version
buf.Write(vc.id.Bytes()) // 16 bytes of uuid
if err := vc.recvResponse(); err != nil {
return err
}
vc.received = true
return vc.ExtendedConn.ReadBuffer(buffer)
}
func (vc *Conn) sendRequest() (err error) {
requestLen := 1 // protocol version
requestLen += 16 // UUID
requestLen += 1 // addons length
var addonsBytes []byte
if vc.addons != nil {
bytes, err := proto.Marshal(vc.addons)
addonsBytes, err = proto.Marshal(vc.addons)
if err != nil {
return err
}
buf.WriteByte(byte(len(bytes)))
buf.Write(bytes)
} else {
buf.WriteByte(0) // addon data length. 0 means no addon data
}
requestLen += len(addonsBytes)
requestLen += 1 // command
if !vc.dst.Mux {
requestLen += 2 // port
requestLen += 1 // addr type
requestLen += len(vc.dst.Addr)
}
_buffer := buf.StackNewSize(requestLen)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
common.Must(
buffer.WriteByte(Version), // protocol version
common.Error(buffer.Write(vc.id.Bytes())), // 16 bytes of uuid
buffer.WriteByte(byte(len(addonsBytes))),
common.Error(buffer.Write(addonsBytes)),
)
if vc.dst.Mux {
buf.WriteByte(CommandMux)
common.Must(buffer.WriteByte(CommandMux))
} else {
if vc.dst.UDP {
buf.WriteByte(CommandUDP)
common.Must(buffer.WriteByte(CommandUDP))
} else {
buf.WriteByte(CommandTCP)
common.Must(buffer.WriteByte(CommandTCP))
}
// Port AddrType Addr
binary.Write(buf, binary.BigEndian, vc.dst.Port)
buf.WriteByte(vc.dst.AddrType)
buf.Write(vc.dst.Addr)
binary.BigEndian.PutUint16(buffer.Extend(2), vc.dst.Port)
common.Must(
buffer.WriteByte(vc.dst.AddrType),
common.Error(buffer.Write(vc.dst.Addr)),
)
}
_, err := vc.Conn.Write(buf.Bytes())
return err
_, err = vc.ExtendedConn.Write(buffer.Bytes())
return
}
func (vc *Conn) recvResponse() error {
var err error
buf := make([]byte, 1)
_, err = io.ReadFull(vc.Conn, buf)
var buf [1]byte
_, err = io.ReadFull(vc.ExtendedConn, buf[:])
if err != nil {
return err
}
@ -82,25 +110,32 @@ func (vc *Conn) recvResponse() error {
return errors.New("unexpected response version")
}
_, err = io.ReadFull(vc.Conn, buf)
_, err = io.ReadFull(vc.ExtendedConn, buf[:])
if err != nil {
return err
}
length := int64(buf[0])
if length != 0 { // addon data length > 0
io.CopyN(io.Discard, vc.Conn, length) // just discard
io.CopyN(io.Discard, vc.ExtendedConn, length) // just discard
}
return nil
}
func (vc *Conn) Upstream() any {
if wrapper, ok := vc.ExtendedConn.(*bufio.ExtendedConnWrapper); ok {
return wrapper.Conn
}
return vc.ExtendedConn
}
// newConn return a Conn instance
func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) {
c := &Conn{
Conn: conn,
id: client.uuid,
dst: dst,
ExtendedConn: bufio.NewExtendedConn(conn),
id: client.uuid,
dst: dst,
}
if !dst.UDP && client.Addons != nil {

View File

@ -5,9 +5,11 @@ import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/url"
@ -15,15 +17,24 @@ import (
"strings"
"sync"
"time"
_ "unsafe"
"github.com/gorilla/websocket"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/network"
)
//go:linkname maskBytes github.com/gorilla/websocket.maskBytes
func maskBytes(key [4]byte, pos int, b []byte) int
type websocketConn struct {
conn *websocket.Conn
reader io.Reader
remoteAddr net.Addr
rawWriter network.ExtendedWriter
// https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency
rMux sync.Mutex
wMux sync.Mutex
@ -31,6 +42,7 @@ type websocketConn struct {
type websocketWithEarlyDataConn struct {
net.Conn
wsWriter network.ExtendedWriter
underlay net.Conn
closed bool
dialed chan bool
@ -79,6 +91,54 @@ func (wsc *websocketConn) Write(b []byte) (int, error) {
return len(b), nil
}
func (wsc *websocketConn) WriteBuffer(buffer *buf.Buffer) error {
var payloadBitLength int
dataLen := buffer.Len()
data := buffer.Bytes()
if dataLen < 126 {
payloadBitLength = 1
} else if dataLen < 65536 {
payloadBitLength = 3
} else {
payloadBitLength = 9
}
var headerLen int
headerLen += 1 // FIN / RSV / OPCODE
headerLen += payloadBitLength
headerLen += 4 // MASK KEY
header := buffer.ExtendHeader(headerLen)
header[0] = websocket.BinaryMessage | 1<<7
header[1] = 1 << 7
if dataLen < 126 {
header[1] |= byte(dataLen)
} else if dataLen < 65536 {
header[1] |= 126
binary.BigEndian.PutUint16(header[2:], uint16(dataLen))
} else {
header[1] |= 127
binary.BigEndian.PutUint64(header[2:], uint64(dataLen))
}
maskKey := rand.Uint32()
binary.BigEndian.PutUint32(header[1+payloadBitLength:], maskKey)
maskBytes(*(*[4]byte)(header[1+payloadBitLength:]), 0, data)
wsc.wMux.Lock()
defer wsc.wMux.Unlock()
return wsc.rawWriter.WriteBuffer(buffer)
}
func (wsc *websocketConn) FrontHeadroom() int {
return 14
}
func (wsc *websocketConn) Upstream() any {
return wsc.conn.UnderlyingConn()
}
func (wsc *websocketConn) Close() error {
var errors []string
if err := wsc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil {
@ -149,6 +209,7 @@ func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
}
wsedc.dialed <- true
wsedc.wsWriter = bufio.NewExtendedWriter(wsedc.Conn)
if earlyDataBuf.Len() != 0 {
_, err = wsedc.Conn.Write(earlyDataBuf.Bytes())
}
@ -170,6 +231,20 @@ func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) {
return wsedc.Conn.Write(b)
}
func (wsedc *websocketWithEarlyDataConn) WriteBuffer(buffer *buf.Buffer) error {
if wsedc.closed {
return io.ErrClosedPipe
}
if wsedc.Conn == nil {
if err := wsedc.Dial(buffer.Bytes()); err != nil {
return err
}
return nil
}
return wsedc.wsWriter.WriteBuffer(buffer)
}
func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) {
if wsedc.closed {
return 0, io.ErrClosedPipe
@ -228,6 +303,10 @@ func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error {
return wsedc.Conn.SetWriteDeadline(t)
}
func (wsedc *websocketWithEarlyDataConn) Upstream() any {
return wsedc.Conn
}
func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
ctx, cancel := context.WithCancel(context.Background())
conn = &websocketWithEarlyDataConn{
@ -294,6 +373,7 @@ func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buf
return &websocketConn{
conn: wsConn,
rawWriter: bufio.NewExtendedWriter(wsConn.UnderlyingConn()),
remoteAddr: conn.RemoteAddr(),
}, nil
}

View File

@ -1,14 +1,16 @@
package tunnel
import (
"context"
"errors"
"net"
"net/netip"
"time"
N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/pool"
C "github.com/Dreamacro/clash/constant"
"github.com/sagernet/sing/common/bufio"
)
func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata) error {
@ -60,5 +62,5 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, oAddr,
}
func handleSocket(ctx C.ConnContext, outbound net.Conn) {
N.Relay(ctx.Conn(), outbound)
bufio.CopyConn(context.TODO(), ctx.Conn(), outbound)
}

View File

@ -7,6 +7,9 @@ import (
C "github.com/Dreamacro/clash/constant"
"github.com/gofrs/uuid"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/network"
"go.uber.org/atomic"
)
@ -29,7 +32,9 @@ type trackerInfo struct {
type tcpTracker struct {
C.Conn `json:"-"`
*trackerInfo
manager *Manager
manager *Manager
extendedReader network.ExtendedReader
extendedWriter network.ExtendedWriter
}
func (tt *tcpTracker) ID() string {
@ -44,6 +49,14 @@ func (tt *tcpTracker) Read(b []byte) (int, error) {
return n, err
}
func (tt *tcpTracker) ReadBuffer(buffer *buf.Buffer) (err error) {
err = tt.extendedReader.ReadBuffer(buffer)
download := int64(buffer.Len())
tt.manager.PushDownloaded(download)
tt.DownloadTotal.Add(download)
return
}
func (tt *tcpTracker) Write(b []byte) (int, error) {
n, err := tt.Conn.Write(b)
upload := int64(n)
@ -52,11 +65,26 @@ func (tt *tcpTracker) Write(b []byte) (int, error) {
return n, err
}
func (tt *tcpTracker) WriteBuffer(buffer *buf.Buffer) (err error) {
err = tt.extendedWriter.WriteBuffer(buffer)
var upload int64
if err != nil {
upload = int64(buffer.Len())
}
tt.manager.PushUploaded(upload)
tt.UploadTotal.Add(upload)
return
}
func (tt *tcpTracker) Close() error {
tt.manager.Leave(tt)
return tt.Conn.Close()
}
func (tt *tcpTracker) Upstream() any {
return tt.Conn
}
func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.Rule) *tcpTracker {
uuid, _ := uuid.NewV4()
if conn != nil {
@ -79,6 +107,8 @@ func NewTCPTracker(conn C.Conn, manager *Manager, metadata *C.Metadata, rule C.R
UploadTotal: atomic.NewInt64(0),
DownloadTotal: atomic.NewInt64(0),
},
extendedReader: bufio.NewExtendedReader(conn),
extendedWriter: bufio.NewExtendedWriter(conn),
}
if rule != nil {