diff --git a/README.md b/README.md index 6ec51b62a..7e834d34e 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,7 @@ Proxy: # The types of cipher are consistent with go-shadowsocks2 # support AEAD_AES_128_GCM AEAD_AES_192_GCM AEAD_AES_256_GCM AEAD_CHACHA20_POLY1305 AES-128-CTR AES-192-CTR AES-256-CTR AES-128-CFB AES-192-CFB AES-256-CFB CHACHA20-IETF XCHACHA20 # In addition to what go-shadowsocks2 supports, it also supports chacha20 rc4-md5 xchacha20-ietf-poly1305 -- { name: "ss1", type: ss, server: server, port: 443, cipher: AEAD_CHACHA20_POLY1305, password: "password" } +- { name: "ss1", type: ss, server: server, port: 443, cipher: AEAD_CHACHA20_POLY1305, password: "password", udp: true } # old obfs configuration remove after prerelease - name: "ss2" @@ -226,5 +226,5 @@ https://clash.gitbook.io/ - [x] Complementing the necessary rule operators - [x] Redir proxy -- [ ] UDP support +- [ ] UDP support (vmess, outbound socks5) - [ ] Connection manager diff --git a/adapters/inbound/http.go b/adapters/inbound/http.go index 8aa21e7c2..e97bcc3a1 100644 --- a/adapters/inbound/http.go +++ b/adapters/inbound/http.go @@ -10,26 +10,16 @@ import ( // HTTPAdapter is a adapter for HTTP connection type HTTPAdapter struct { + net.Conn metadata *C.Metadata - conn net.Conn R *http.Request } -// Close HTTP connection -func (h *HTTPAdapter) Close() { - h.conn.Close() -} - // Metadata return destination metadata func (h *HTTPAdapter) Metadata() *C.Metadata { return h.metadata } -// Conn return raw net.Conn of HTTP -func (h *HTTPAdapter) Conn() net.Conn { - return h.conn -} - // NewHTTP is HTTPAdapter generator func NewHTTP(request *http.Request, conn net.Conn) *HTTPAdapter { metadata := parseHTTPAddr(request) @@ -37,7 +27,7 @@ func NewHTTP(request *http.Request, conn net.Conn) *HTTPAdapter { return &HTTPAdapter{ metadata: metadata, R: request, - conn: conn, + Conn: conn, } } diff --git a/adapters/inbound/https.go b/adapters/inbound/https.go index e95126865..5207a5daf 100644 --- a/adapters/inbound/https.go +++ b/adapters/inbound/https.go @@ -11,6 +11,6 @@ func NewHTTPS(request *http.Request, conn net.Conn) *SocketAdapter { metadata.SourceIP = parseSourceIP(conn) return &SocketAdapter{ metadata: metadata, - conn: conn, + Conn: conn, } } diff --git a/adapters/inbound/socket.go b/adapters/inbound/socket.go index 66f13f65e..0d1be44bb 100644 --- a/adapters/inbound/socket.go +++ b/adapters/inbound/socket.go @@ -9,33 +9,24 @@ import ( // SocketAdapter is a adapter for socks and redir connection type SocketAdapter struct { - conn net.Conn + net.Conn metadata *C.Metadata } -// Close socks and redir connection -func (s *SocketAdapter) Close() { - s.conn.Close() -} - // Metadata return destination metadata func (s *SocketAdapter) Metadata() *C.Metadata { return s.metadata } -// Conn return raw net.Conn -func (s *SocketAdapter) Conn() net.Conn { - return s.conn -} - // NewSocket is SocketAdapter generator -func NewSocket(target socks.Addr, conn net.Conn, source C.SourceType) *SocketAdapter { +func NewSocket(target socks.Addr, conn net.Conn, source C.SourceType, netType C.NetWork) *SocketAdapter { metadata := parseSocksAddr(target) + metadata.NetWork = netType metadata.Source = source metadata.SourceIP = parseSourceIP(conn) return &SocketAdapter{ - conn: conn, + Conn: conn, metadata: metadata, } } diff --git a/adapters/inbound/util.go b/adapters/inbound/util.go index b059920aa..c29c06c85 100644 --- a/adapters/inbound/util.go +++ b/adapters/inbound/util.go @@ -11,7 +11,6 @@ import ( func parseSocksAddr(target socks.Addr) *C.Metadata { metadata := &C.Metadata{ - NetWork: C.TCP, AddrType: int(target[0]), } diff --git a/adapters/outbound/base.go b/adapters/outbound/base.go index c73a15424..588771211 100644 --- a/adapters/outbound/base.go +++ b/adapters/outbound/base.go @@ -2,6 +2,7 @@ package adapters import ( "encoding/json" + "errors" "net" "net/http" "time" @@ -13,6 +14,7 @@ import ( type Base struct { name string tp C.AdapterType + udp bool } func (b *Base) Name() string { @@ -23,6 +25,14 @@ func (b *Base) Type() C.AdapterType { return b.tp } +func (b *Base) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error) { + return nil, nil, errors.New("no support") +} + +func (b *Base) SupportUDP() bool { + return b.udp +} + func (b *Base) Destroy() {} func (b *Base) MarshalJSON() ([]byte, error) { diff --git a/adapters/outbound/direct.go b/adapters/outbound/direct.go index 245c28b33..1bd0a6f66 100644 --- a/adapters/outbound/direct.go +++ b/adapters/outbound/direct.go @@ -24,11 +24,22 @@ func (d *Direct) Dial(metadata *C.Metadata) (net.Conn, error) { return c, nil } +func (d *Direct) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error) { + pc, err := net.ListenPacket("udp", "") + if err != nil { + return nil, nil, err + } + + addr, _ := net.ResolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.Port)) + return pc, addr, nil +} + func NewDirect() *Direct { return &Direct{ Base: &Base{ name: "DIRECT", tp: C.Direct, + udp: true, }, } } diff --git a/adapters/outbound/fallback.go b/adapters/outbound/fallback.go index 78f573ae2..913383a44 100644 --- a/adapters/outbound/fallback.go +++ b/adapters/outbound/fallback.go @@ -35,6 +35,16 @@ func (f *Fallback) Dial(metadata *C.Metadata) (net.Conn, error) { return proxy.Dial(metadata) } +func (f *Fallback) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error) { + proxy := f.findAliveProxy() + return proxy.DialUDP(metadata) +} + +func (f *Fallback) SupportUDP() bool { + proxy := f.findAliveProxy() + return proxy.SupportUDP() +} + func (f *Fallback) MarshalJSON() ([]byte, error) { var all []string for _, proxy := range f.proxies { diff --git a/adapters/outbound/loadbalance.go b/adapters/outbound/loadbalance.go index df2145b27..761a184f8 100644 --- a/adapters/outbound/loadbalance.go +++ b/adapters/outbound/loadbalance.go @@ -67,6 +67,24 @@ func (lb *LoadBalance) Dial(metadata *C.Metadata) (net.Conn, error) { return lb.proxies[0].Dial(metadata) } +func (lb *LoadBalance) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error) { + key := uint64(murmur3.Sum32([]byte(getKey(metadata)))) + buckets := int32(len(lb.proxies)) + for i := 0; i < lb.maxRetry; i, key = i+1, key+1 { + idx := jumpHash(key, buckets) + proxy := lb.proxies[idx] + if proxy.Alive() { + return proxy.DialUDP(metadata) + } + } + + return lb.proxies[0].DialUDP(metadata) +} + +func (lb *LoadBalance) SupportUDP() bool { + return true +} + func (lb *LoadBalance) Destroy() { lb.done <- struct{}{} } diff --git a/adapters/outbound/selector.go b/adapters/outbound/selector.go index a0a126a59..39b4ac0d9 100644 --- a/adapters/outbound/selector.go +++ b/adapters/outbound/selector.go @@ -24,6 +24,14 @@ func (s *Selector) Dial(metadata *C.Metadata) (net.Conn, error) { return s.selected.Dial(metadata) } +func (s *Selector) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error) { + return s.selected.DialUDP(metadata) +} + +func (s *Selector) SupportUDP() bool { + return s.selected.SupportUDP() +} + func (s *Selector) MarshalJSON() ([]byte, error) { var all []string for k := range s.proxies { diff --git a/adapters/outbound/shadowsocks.go b/adapters/outbound/shadowsocks.go index 62045997b..f477d9d36 100644 --- a/adapters/outbound/shadowsocks.go +++ b/adapters/outbound/shadowsocks.go @@ -8,6 +8,7 @@ import ( "net" "strconv" + "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/structure" obfs "github.com/Dreamacro/clash/component/simple-obfs" v2rayObfs "github.com/Dreamacro/clash/component/v2ray-plugin" @@ -34,6 +35,7 @@ type ShadowSocksOption struct { Port int `proxy:"port"` Password string `proxy:"password"` Cipher string `proxy:"cipher"` + UDP bool `proxy:"udp,omitempty"` Plugin string `proxy:"plugin,omitempty"` PluginOpts map[string]interface{} `proxy:"plugin-opts,omitempty"` @@ -80,6 +82,19 @@ func (ss *ShadowSocks) Dial(metadata *C.Metadata) (net.Conn, error) { return c, err } +func (ss *ShadowSocks) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error) { + pc, err := net.ListenPacket("udp", "") + if err != nil { + return nil, nil, err + } + + addr, _ := net.ResolveUDPAddr("udp", ss.server) + remoteAddr, _ := net.ResolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.Port)) + + pc = ss.cipher.PacketConn(pc) + return &ssUDPConn{PacketConn: pc, rAddr: remoteAddr}, addr, nil +} + func (ss *ShadowSocks) MarshalJSON() ([]byte, error) { return json.Marshal(map[string]string{ "type": ss.Type().String(), @@ -144,6 +159,7 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) { Base: &Base{ name: option.Name, tp: C.Shadowsocks, + udp: option.UDP, }, server: server, cipher: ciph, @@ -173,3 +189,24 @@ func serializesSocksAddr(metadata *C.Metadata) []byte { } return bytes.Join(buf, nil) } + +type ssUDPConn struct { + net.PacketConn + rAddr net.Addr +} + +func (uc *ssUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { + buf := pool.BufPool.Get().([]byte) + defer pool.BufPool.Put(buf[:cap(buf)]) + rAddr := socks.ParseAddr(uc.rAddr.String()) + copy(buf[len(rAddr):], b) + copy(buf, rAddr) + return uc.PacketConn.WriteTo(buf[:len(rAddr)+len(b)], addr) +} + +func (uc *ssUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { + n, a, e := uc.PacketConn.ReadFrom(b) + addr := socks.SplitAddr(b[:n]) + copy(b, b[len(addr):]) + return n - len(addr), a, e +} diff --git a/adapters/outbound/socks5.go b/adapters/outbound/socks5.go index 2adb999a1..bd8680d5f 100644 --- a/adapters/outbound/socks5.go +++ b/adapters/outbound/socks5.go @@ -31,6 +31,7 @@ type Socks5Option struct { UserName string `proxy:"username,omitempty"` Password string `proxy:"password,omitempty"` TLS bool `proxy:"tls,omitempty"` + UDP bool `proxy:"udp,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` } @@ -126,6 +127,7 @@ func NewSocks5(option Socks5Option) *Socks5 { Base: &Base{ name: option.Name, tp: C.Socks5, + udp: option.UDP, }, addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), user: option.UserName, diff --git a/adapters/outbound/urltest.go b/adapters/outbound/urltest.go index 6ea6dd407..cd20c62fd 100644 --- a/adapters/outbound/urltest.go +++ b/adapters/outbound/urltest.go @@ -43,6 +43,14 @@ func (u *URLTest) Dial(metadata *C.Metadata) (net.Conn, error) { return a, err } +func (u *URLTest) DialUDP(metadata *C.Metadata) (net.PacketConn, net.Addr, error) { + return u.fast.DialUDP(metadata) +} + +func (u *URLTest) SupportUDP() bool { + return u.fast.SupportUDP() +} + func (u *URLTest) MarshalJSON() ([]byte, error) { var all []string for _, proxy := range u.proxies { diff --git a/common/pool/pool.go b/common/pool/pool.go new file mode 100644 index 000000000..aa651dc79 --- /dev/null +++ b/common/pool/pool.go @@ -0,0 +1,15 @@ +package pool + +import ( + "sync" +) + +const ( + // io.Copy default buffer size is 32 KiB + // but the maximum packet size of vmess/shadowsocks is about 16 KiB + // so define a buffer of 20 KiB to reduce the memory of each TCP relay + bufferSize = 20 * 1024 +) + +// BufPool provide buffer for relay +var BufPool = sync.Pool{New: func() interface{} { return make([]byte, bufferSize) }} diff --git a/component/simple-obfs/http.go b/component/simple-obfs/http.go index 4c7c60dd6..cd62cf766 100644 --- a/component/simple-obfs/http.go +++ b/component/simple-obfs/http.go @@ -8,6 +8,8 @@ import ( "math/rand" "net" "net/http" + + "github.com/Dreamacro/clash/common/pool" ) // HTTPObfs is shadowsocks http simple-obfs implementation @@ -32,15 +34,15 @@ func (ho *HTTPObfs) Read(b []byte) (int, error) { } if ho.firstResponse { - buf := bufPool.Get().([]byte) + buf := pool.BufPool.Get().([]byte) n, err := ho.Conn.Read(buf) if err != nil { - bufPool.Put(buf[:cap(buf)]) + pool.BufPool.Put(buf[:cap(buf)]) return 0, err } idx := bytes.Index(buf[:n], []byte("\r\n\r\n")) if idx == -1 { - bufPool.Put(buf[:cap(buf)]) + pool.BufPool.Put(buf[:cap(buf)]) return 0, io.EOF } ho.firstResponse = false @@ -50,7 +52,7 @@ func (ho *HTTPObfs) Read(b []byte) (int, error) { ho.buf = buf[:idx+4+length] ho.offset = idx + 4 + n } else { - bufPool.Put(buf[:cap(buf)]) + pool.BufPool.Put(buf[:cap(buf)]) } return n, nil } diff --git a/component/simple-obfs/tls.go b/component/simple-obfs/tls.go index b763eefb4..37db79a1e 100644 --- a/component/simple-obfs/tls.go +++ b/component/simple-obfs/tls.go @@ -6,8 +6,9 @@ import ( "io" "math/rand" "net" - "sync" "time" + + "github.com/Dreamacro/clash/common/pool" ) func init() { @@ -18,8 +19,6 @@ const ( chunkSize = 1 << 14 // 2 ** 14 == 16 * 1024 ) -var bufPool = sync.Pool{New: func() interface{} { return make([]byte, 2048) }} - // TLSObfs is shadowsocks tls simple-obfs implementation type TLSObfs struct { net.Conn @@ -30,12 +29,12 @@ type TLSObfs struct { } func (to *TLSObfs) read(b []byte, discardN int) (int, error) { - buf := bufPool.Get().([]byte) + buf := pool.BufPool.Get().([]byte) _, err := io.ReadFull(to.Conn, buf[:discardN]) if err != nil { return 0, err } - bufPool.Put(buf[:cap(buf)]) + pool.BufPool.Put(buf[:cap(buf)]) sizeBuf := make([]byte, 2) _, err = io.ReadFull(to.Conn, sizeBuf) @@ -103,7 +102,7 @@ func (to *TLSObfs) write(b []byte) (int, error) { return len(b), err } - size := bufPool.Get().([]byte) + size := pool.BufPool.Get().([]byte) binary.BigEndian.PutUint16(size[:2], uint16(len(b))) buf := &bytes.Buffer{} @@ -111,7 +110,7 @@ func (to *TLSObfs) write(b []byte) (int, error) { buf.Write(size[:2]) buf.Write(b) _, err := to.Conn.Write(buf.Bytes()) - bufPool.Put(size[:cap(size)]) + pool.BufPool.Put(size[:cap(size)]) return len(b), err } diff --git a/component/vmess/aead.go b/component/vmess/aead.go index 342f373e0..5c8da0deb 100644 --- a/component/vmess/aead.go +++ b/component/vmess/aead.go @@ -5,6 +5,8 @@ import ( "encoding/binary" "errors" "io" + + "github.com/Dreamacro/clash/common/pool" ) type aeadWriter struct { @@ -20,8 +22,8 @@ func newAEADWriter(w io.Writer, aead cipher.AEAD, iv []byte) *aeadWriter { } func (w *aeadWriter) Write(b []byte) (n int, err error) { - buf := bufPool.Get().([]byte) - defer bufPool.Put(buf[:cap(buf)]) + buf := pool.BufPool.Get().([]byte) + defer pool.BufPool.Put(buf[:cap(buf)]) length := len(b) for { if length == 0 { @@ -71,7 +73,7 @@ func (r *aeadReader) Read(b []byte) (int, error) { n := copy(b, r.buf[r.offset:]) r.offset += n if r.offset == len(r.buf) { - bufPool.Put(r.buf[:cap(r.buf)]) + pool.BufPool.Put(r.buf[:cap(r.buf)]) r.buf = nil } return n, nil @@ -87,10 +89,10 @@ func (r *aeadReader) Read(b []byte) (int, error) { return 0, errors.New("Buffer is larger than standard") } - buf := bufPool.Get().([]byte) + buf := pool.BufPool.Get().([]byte) _, err = io.ReadFull(r.Reader, buf[:size]) if err != nil { - bufPool.Put(buf[:cap(buf)]) + pool.BufPool.Put(buf[:cap(buf)]) return 0, err } @@ -105,7 +107,7 @@ func (r *aeadReader) Read(b []byte) (int, error) { realLen := size - r.Overhead() n := copy(b, buf[:realLen]) if len(b) >= realLen { - bufPool.Put(buf[:cap(buf)]) + pool.BufPool.Put(buf[:cap(buf)]) return n, nil } diff --git a/component/vmess/chunk.go b/component/vmess/chunk.go index b396fdd6e..7f30d0f0f 100644 --- a/component/vmess/chunk.go +++ b/component/vmess/chunk.go @@ -4,7 +4,8 @@ import ( "encoding/binary" "errors" "io" - "sync" + + "github.com/Dreamacro/clash/common/pool" ) const ( @@ -13,8 +14,6 @@ const ( maxSize = 17 * 1024 // 2 + chunkSize + aead.Overhead() ) -var bufPool = sync.Pool{New: func() interface{} { return make([]byte, maxSize) }} - type chunkReader struct { io.Reader buf []byte @@ -35,7 +34,7 @@ func (cr *chunkReader) Read(b []byte) (int, error) { n := copy(b, cr.buf[cr.offset:]) cr.offset += n if cr.offset == len(cr.buf) { - bufPool.Put(cr.buf[:cap(cr.buf)]) + pool.BufPool.Put(cr.buf[:cap(cr.buf)]) cr.buf = nil } return n, nil @@ -60,10 +59,10 @@ func (cr *chunkReader) Read(b []byte) (int, error) { return size, nil } - buf := bufPool.Get().([]byte) + buf := pool.BufPool.Get().([]byte) _, err = io.ReadFull(cr.Reader, buf[:size]) if err != nil { - bufPool.Put(buf[:cap(buf)]) + pool.BufPool.Put(buf[:cap(buf)]) return 0, err } n := copy(b, cr.buf[:]) @@ -77,8 +76,8 @@ type chunkWriter struct { } func (cw *chunkWriter) Write(b []byte) (n int, err error) { - buf := bufPool.Get().([]byte) - defer bufPool.Put(buf[:cap(buf)]) + buf := pool.BufPool.Get().([]byte) + defer pool.BufPool.Put(buf[:cap(buf)]) length := len(b) for { if length == 0 { diff --git a/constant/adapters.go b/constant/adapters.go index be755554a..e05cd58ed 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -20,14 +20,16 @@ const ( ) type ServerAdapter interface { + net.Conn Metadata() *Metadata - Close() } type ProxyAdapter interface { Name() string Type() AdapterType Dial(metadata *Metadata) (net.Conn, error) + DialUDP(metadata *Metadata) (net.PacketConn, net.Addr, error) + SupportUDP() bool Destroy() MarshalJSON() ([]byte, error) } diff --git a/go.mod b/go.mod index 31c879ec2..43fa2bb03 100644 --- a/go.mod +++ b/go.mod @@ -1,7 +1,7 @@ module github.com/Dreamacro/clash require ( - github.com/Dreamacro/go-shadowsocks2 v0.1.3-0.20190202135136-da4602d8f112 + github.com/Dreamacro/go-shadowsocks2 v0.1.3-0.20190406142755-9128a199439f github.com/eapache/queue v1.1.0 // indirect github.com/go-chi/chi v4.0.1+incompatible github.com/go-chi/cors v1.0.0 diff --git a/go.sum b/go.sum index 2a2f1273e..ac71fa0c5 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/Dreamacro/go-shadowsocks2 v0.1.3-0.20190202135136-da4602d8f112 h1:1axYxE0ZLJy40+ulq46XQt7MaJDJr4iGer1NQz7jmKw= -github.com/Dreamacro/go-shadowsocks2 v0.1.3-0.20190202135136-da4602d8f112/go.mod h1:giIuN+TuUudTxHc1jjTOyyQYiJ3VXp1pWOHdJbSCAPo= +github.com/Dreamacro/go-shadowsocks2 v0.1.3-0.20190406142755-9128a199439f h1:nlImrmI6I2AVjJ2AvE3w3f7fi8rhLQAhZO1Gs31+/nE= +github.com/Dreamacro/go-shadowsocks2 v0.1.3-0.20190406142755-9128a199439f/go.mod h1:giIuN+TuUudTxHc1jjTOyyQYiJ3VXp1pWOHdJbSCAPo= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/proxy/redir/tcp.go b/proxy/redir/tcp.go index ac88710a2..1c0fee657 100644 --- a/proxy/redir/tcp.go +++ b/proxy/redir/tcp.go @@ -59,5 +59,5 @@ func handleRedir(conn net.Conn) { return } conn.(*net.TCPConn).SetKeepAlive(true) - tun.Add(adapters.NewSocket(target, conn, C.REDIR)) + tun.Add(adapters.NewSocket(target, conn, C.REDIR, C.TCP)) } diff --git a/proxy/socks/tcp.go b/proxy/socks/tcp.go index 83220c091..3d884ff58 100644 --- a/proxy/socks/tcp.go +++ b/proxy/socks/tcp.go @@ -1,9 +1,11 @@ package socks import ( + "io" "net" + "strconv" - "github.com/Dreamacro/clash/adapters/inbound" + adapters "github.com/Dreamacro/clash/adapters/inbound" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/tunnel" @@ -15,6 +17,41 @@ var ( tun = tunnel.Instance() ) +// Error represents a SOCKS error +type Error byte + +func (err Error) Error() string { + return "SOCKS error: " + strconv.Itoa(int(err)) +} + +// SOCKS request commands as defined in RFC 1928 section 4. +const ( + CmdConnect = 1 + CmdBind = 2 + CmdUDPAssociate = 3 +) + +// SOCKS address types as defined in RFC 1928 section 5. +const ( + AtypIPv4 = 1 + AtypDomainName = 3 + AtypIPv6 = 4 +) + +const MaxAddrLen = 1 + 1 + 255 + 2 + +// SOCKS errors as defined in RFC 1928 section 6. +const ( + ErrGeneralFailure = Error(1) + ErrConnectionNotAllowed = Error(2) + ErrNetworkUnreachable = Error(3) + ErrHostUnreachable = Error(4) + ErrConnectionRefused = Error(5) + ErrTTLExpired = Error(6) + ErrCommandNotSupported = Error(7) + ErrAddressNotSupported = Error(8) +) + type SockListener struct { net.Listener address string @@ -55,11 +92,78 @@ func (l *SockListener) Address() string { } func handleSocks(conn net.Conn) { - target, err := socks.Handshake(conn) + target, command, err := handshake(conn) if err != nil { conn.Close() return } conn.(*net.TCPConn).SetKeepAlive(true) - tun.Add(adapters.NewSocket(target, conn, C.SOCKS)) + if command == CmdUDPAssociate { + tun.Add(adapters.NewSocket(target, conn, C.SOCKS, C.UDP)) + return + } + tun.Add(adapters.NewSocket(target, conn, C.SOCKS, C.TCP)) +} + +// handshake fast-tracks SOCKS initialization to get target address to connect. +func handshake(rw io.ReadWriter) (addr socks.Addr, command int, err error) { + // Read RFC 1928 for request and reply structure and sizes. + buf := make([]byte, MaxAddrLen) + // read VER, NMETHODS, METHODS + if _, err = io.ReadFull(rw, buf[:2]); err != nil { + return + } + nmethods := buf[1] + if _, err = io.ReadFull(rw, buf[:nmethods]); err != nil { + return + } + // write VER METHOD + if _, err = rw.Write([]byte{5, 0}); err != nil { + return + } + // read VER CMD RSV ATYP DST.ADDR DST.PORT + if _, err = io.ReadFull(rw, buf[:3]); err != nil { + return + } + if buf[1] != CmdConnect && buf[1] != CmdUDPAssociate { + err = ErrCommandNotSupported + return + } + + command = int(buf[1]) + addr, err = readAddr(rw, buf) + if err != nil { + return + } + // write VER REP RSV ATYP BND.ADDR BND.PORT + _, err = rw.Write([]byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0}) + return +} + +func readAddr(r io.Reader, b []byte) (socks.Addr, error) { + if len(b) < MaxAddrLen { + return nil, io.ErrShortBuffer + } + _, err := io.ReadFull(r, b[:1]) // read 1st byte for address type + if err != nil { + return nil, err + } + + switch b[0] { + case AtypDomainName: + _, err = io.ReadFull(r, b[1:2]) // read 2nd byte for domain length + if err != nil { + return nil, err + } + _, err = io.ReadFull(r, b[2:2+b[1]+2]) + return b[:1+1+b[1]+2], err + case AtypIPv4: + _, err = io.ReadFull(r, b[1:1+net.IPv4len+2]) + return b[:1+net.IPv4len+2], err + case AtypIPv6: + _, err = io.ReadFull(r, b[1:1+net.IPv6len+2]) + return b[:1+net.IPv6len+2], err + } + + return nil, ErrAddressNotSupported } diff --git a/tunnel/connection.go b/tunnel/connection.go index 25c3656e4..a8928a624 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -6,21 +6,12 @@ import ( "net" "net/http" "strings" - "sync" "time" adapters "github.com/Dreamacro/clash/adapters/inbound" + "github.com/Dreamacro/clash/common/pool" ) -const ( - // io.Copy default buffer size is 32 KiB - // but the maximum packet size of vmess/shadowsocks is about 16 KiB - // so define a buffer of 20 KiB to reduce the memory of each TCP relay - bufferSize = 20 * 1024 -) - -var bufPool = sync.Pool{New: func() interface{} { return make([]byte, bufferSize) }} - func (t *Tunnel) handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) { conn := newTrafficTrack(outbound, t.traffic) req := request.R @@ -50,7 +41,7 @@ func (t *Tunnel) handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) { } else { resp.Close = true } - err = resp.Write(request.Conn()) + err = resp.Write(request) if err != nil || resp.Close { break } @@ -59,7 +50,7 @@ func (t *Tunnel) handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) { break } - req, err = http.ReadRequest(bufio.NewReader(request.Conn())) + req, err = http.ReadRequest(bufio.NewReader(request)) if err != nil { break } @@ -72,9 +63,52 @@ func (t *Tunnel) handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) { } } -func (t *Tunnel) handleSOCKS(request *adapters.SocketAdapter, outbound net.Conn) { +func (t *Tunnel) handleSocket(request *adapters.SocketAdapter, outbound net.Conn) { conn := newTrafficTrack(outbound, t.traffic) - relay(request.Conn(), conn) + relay(request, conn) +} + +func (t *Tunnel) handleUDPOverTCP(conn net.Conn, pc net.PacketConn, addr net.Addr) error { + ch := make(chan error, 1) + + go func() { + buf := pool.BufPool.Get().([]byte) + defer pool.BufPool.Put(buf) + for { + n, err := conn.Read(buf) + if err != nil { + ch <- err + return + } + pc.SetReadDeadline(time.Now().Add(120 * time.Second)) + if _, err = pc.WriteTo(buf[:n], addr); err != nil { + ch <- err + return + } + t.traffic.Up() <- int64(n) + ch <- nil + } + }() + + buf := pool.BufPool.Get().([]byte) + defer pool.BufPool.Put(buf) + + for { + pc.SetReadDeadline(time.Now().Add(120 * time.Second)) + n, _, err := pc.ReadFrom(buf) + if err != nil { + break + } + + if _, err := conn.Write(buf[:n]); err != nil { + break + } + + t.traffic.Down() <- int64(n) + } + + <-ch + return nil } // relay copies between left and right bidirectionally. @@ -82,16 +116,16 @@ func relay(leftConn, rightConn net.Conn) { ch := make(chan error) go func() { - buf := bufPool.Get().([]byte) + buf := pool.BufPool.Get().([]byte) _, err := io.CopyBuffer(leftConn, rightConn, buf) - bufPool.Put(buf[:cap(buf)]) + pool.BufPool.Put(buf[:cap(buf)]) leftConn.SetReadDeadline(time.Now()) ch <- err }() - buf := bufPool.Get().([]byte) + buf := pool.BufPool.Get().([]byte) io.CopyBuffer(rightConn, leftConn, buf) - bufPool.Put(buf[:cap(buf)]) + pool.BufPool.Put(buf[:cap(buf)]) rightConn.SetReadDeadline(time.Now()) <-ch } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 6fec24eb5..775f2f1c0 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -19,7 +19,7 @@ var ( once sync.Once ) -// Tunnel handle proxy socket and HTTP/SOCKS socket +// Tunnel handle relay inbound proxy and outbound proxy type Tunnel struct { queue *channels.InfiniteChannel rules []C.Rule @@ -143,6 +143,12 @@ func (t *Tunnel) handleConn(localConn C.ServerAdapter) { } } + if metadata.NetWork == C.UDP { + pc, addr, _ := proxy.DialUDP(metadata) + t.handleUDPOverTCP(localConn, pc, addr) + return + } + remoConn, err := proxy.Dial(metadata) if err != nil { log.Warnln("Proxy[%s] connect [%s --> %s] error: %s", proxy.Name(), metadata.SourceIP.String(), metadata.String(), err.Error()) @@ -154,7 +160,7 @@ func (t *Tunnel) handleConn(localConn C.ServerAdapter) { case *InboundAdapter.HTTPAdapter: t.handleHTTP(adapter, remoConn) case *InboundAdapter.SocketAdapter: - t.handleSOCKS(adapter, remoConn) + t.handleSocket(adapter, remoConn) } } @@ -177,10 +183,17 @@ func (t *Tunnel) match(metadata *C.Metadata) (C.Proxy, error) { } if rule.IsMatch(metadata) { - if a, ok := t.proxies[rule.Adapter()]; ok { - log.Infoln("%s --> %v match %s using %s", metadata.SourceIP.String(), metadata.String(), rule.RuleType().String(), rule.Adapter()) - return a, nil + adapter, ok := t.proxies[rule.Adapter()] + if !ok { + continue } + + if metadata.NetWork == C.UDP && !adapter.SupportUDP() { + continue + } + + log.Infoln("%s --> %v match %s using %s", metadata.SourceIP.String(), metadata.String(), rule.RuleType().String(), rule.Adapter()) + return adapter, nil } } log.Infoln("%s --> %v doesn't match any rule using DIRECT", metadata.SourceIP.String(), metadata.String())