From cc6d496143ecaeaf0049825e481954114e9006c3 Mon Sep 17 00:00:00 2001 From: beyondkmp Date: Sun, 4 Nov 2018 21:36:20 +0800 Subject: [PATCH] Chore: optimize code structure in vmess websocket (#28) * Chore: move conn process of ws to websocket.go * Chore: some routine adjustment --- component/vmess/vmess.go | 81 +++++++++++++----------------------- component/vmess/websocket.go | 52 +++++++++++++++++++++-- 2 files changed, 76 insertions(+), 57 deletions(-) diff --git a/component/vmess/vmess.go b/component/vmess/vmess.go index f9cd27ec7..3356d8bc3 100644 --- a/component/vmess/vmess.go +++ b/component/vmess/vmess.go @@ -5,13 +5,10 @@ import ( "fmt" "math/rand" "net" - "net/url" "runtime" "sync" - "time" "github.com/gofrs/uuid" - "github.com/gorilla/websocket" ) // Version of vmess @@ -67,14 +64,13 @@ type DstAddr struct { // Client is vmess connection generator type Client struct { - user []*ID - uuid *uuid.UUID - security Security - tls bool - host string - websocket bool - websocketPath string - tlsConfig *tls.Config + user []*ID + uuid *uuid.UUID + security Security + tls bool + host string + wsConfig *websocketConfig + tlsConfig *tls.Config } // Config of vmess @@ -92,43 +88,13 @@ type Config struct { // New return a Conn with net.Conn and DstAddr func (c *Client) New(conn net.Conn, dst *DstAddr) (net.Conn, error) { + var err error r := rand.Intn(len(c.user)) - if c.websocket { - dialer := &websocket.Dialer{ - NetDial: func(network, addr string) (net.Conn, error) { - return conn, nil - }, - ReadBufferSize: 4 * 1024, - WriteBufferSize: 4 * 1024, - HandshakeTimeout: time.Second * 8, - } - scheme := "ws" - if c.tls { - scheme = "wss" - dialer.TLSClientConfig = c.tlsConfig - } - - host, port, err := net.SplitHostPort(c.host) - if (scheme == "ws" && port != "80") || (scheme == "wss" && port != "443") { - host = c.host - } - - uri := url.URL{ - Scheme: scheme, - Host: host, - Path: c.websocketPath, - } - - wsConn, resp, err := dialer.Dial(uri.String(), nil) + if c.wsConfig != nil { + conn, err = newWebsocketConn(conn, c.wsConfig) if err != nil { - var reason string - if resp != nil { - reason = resp.Status - } - return nil, fmt.Errorf("Dial %s error: %s", host, reason) + return nil, err } - - conn = newWebsocketConn(wsConn, conn.RemoteAddr()) } else if c.tls { conn = tls.Client(conn, c.tlsConfig) } @@ -174,15 +140,24 @@ func NewClient(config Config) (*Client, error) { } } + var wsConfig *websocketConfig + if config.NetWork == "ws" { + wsConfig = &websocketConfig{ + host: config.Host, + path: config.WebSocketPath, + tls: config.TLS, + tlsConfig: tlsConfig, + } + } + return &Client{ - user: newAlterIDs(newID(&uid), config.AlterID), - uuid: &uid, - security: security, - tls: config.TLS, - host: config.Host, - websocket: config.NetWork == "ws", - websocketPath: config.WebSocketPath, - tlsConfig: tlsConfig, + user: newAlterIDs(newID(&uid), config.AlterID), + uuid: &uid, + security: security, + tls: config.TLS, + host: config.Host, + wsConfig: wsConfig, + tlsConfig: tlsConfig, }, nil } diff --git a/component/vmess/websocket.go b/component/vmess/websocket.go index 21e458140..7bef5e559 100644 --- a/component/vmess/websocket.go +++ b/component/vmess/websocket.go @@ -1,9 +1,11 @@ package vmess import ( + "crypto/tls" "fmt" "io" "net" + "net/url" "strings" "time" @@ -16,6 +18,13 @@ type websocketConn struct { remoteAddr net.Addr } +type websocketConfig struct { + host string + path string + tls bool + tlsConfig *tls.Config +} + // Read implements net.Conn.Read() func (wsc *websocketConn) Read(b []byte) (int, error) { for { @@ -91,9 +100,44 @@ func (wsc *websocketConn) SetWriteDeadline(t time.Time) error { return wsc.conn.SetWriteDeadline(t) } -func newWebsocketConn(conn *websocket.Conn, remoteAddr net.Addr) net.Conn { - return &websocketConn{ - conn: conn, - remoteAddr: remoteAddr, +func newWebsocketConn(conn net.Conn, c *websocketConfig) (net.Conn, error) { + dialer := &websocket.Dialer{ + NetDial: func(network, addr string) (net.Conn, error) { + return conn, nil + }, + ReadBufferSize: 4 * 1024, + WriteBufferSize: 4 * 1024, + HandshakeTimeout: time.Second * 8, } + + scheme := "ws" + if c.tls { + scheme = "wss" + dialer.TLSClientConfig = c.tlsConfig + } + + host, port, err := net.SplitHostPort(c.host) + if (scheme == "ws" && port != "80") || (scheme == "wss" && port != "443") { + host = c.host + } + + uri := url.URL{ + Scheme: scheme, + Host: host, + Path: c.path, + } + + wsConn, resp, err := dialer.Dial(uri.String(), nil) + if err != nil { + var reason string + if resp != nil { + reason = resp.Status + } + return nil, fmt.Errorf("Dial %s error: %s", host, reason) + } + + return &websocketConn{ + conn: wsConn, + remoteAddr: conn.RemoteAddr(), + }, nil }