diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 4438a5761..b3772211f 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -195,7 +195,7 @@ func updateGeneral(general *config.General, force bool) { } if err := P.ReCreateSocks(general.SocksPort, tcpIn, udpIn); err != nil { - log.Errorln("Start SOCKS5 server error: %s", err.Error()) + log.Errorln("Start SOCKS server error: %s", err.Error()) } if err := P.ReCreateRedir(general.RedirPort, tcpIn, udpIn); err != nil { @@ -207,7 +207,7 @@ func updateGeneral(general *config.General, force bool) { } if err := P.ReCreateMixed(general.MixedPort, tcpIn, udpIn); err != nil { - log.Errorln("Start Mixed(http and socks5) server error: %s", err.Error()) + log.Errorln("Start Mixed(http and socks) server error: %s", err.Error()) } } diff --git a/listener/listener.go b/listener/listener.go index cf0390e7e..5d0e8d7bf 100644 --- a/listener/listener.go +++ b/listener/listener.go @@ -139,7 +139,7 @@ func ReCreateSocks(port int, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.P socksListener = tcpListener socksUDPListener = udpListener - log.Infoln("SOCKS5 proxy listening at: %s", socksListener.Address()) + log.Infoln("SOCKS proxy listening at: %s", socksListener.Address()) return nil } @@ -271,7 +271,7 @@ func ReCreateMixed(port int, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.P return err } - log.Infoln("Mixed(http+socks5) proxy listening at: %s", mixedListener.Address()) + log.Infoln("Mixed(http+socks) proxy listening at: %s", mixedListener.Address()) return nil } diff --git a/listener/mixed/mixed.go b/listener/mixed/mixed.go index 07ad03b62..7d89e3c8d 100644 --- a/listener/mixed/mixed.go +++ b/listener/mixed/mixed.go @@ -9,6 +9,7 @@ import ( C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/listener/http" "github.com/Dreamacro/clash/listener/socks" + "github.com/Dreamacro/clash/transport/socks4" "github.com/Dreamacro/clash/transport/socks5" ) @@ -58,10 +59,12 @@ func handleConn(conn net.Conn, in chan<- C.ConnContext, cache *cache.Cache) { return } - if head[0] == socks5.Version { - socks.HandleSocks(bufConn, in) - return + switch head[0] { + case socks4.Version: + socks.HandleSocks4(bufConn, in) + case socks5.Version: + socks.HandleSocks5(bufConn, in) + default: + http.HandleConn(bufConn, in, cache) } - - http.HandleConn(bufConn, in, cache) } diff --git a/listener/socks/tcp.go b/listener/socks/tcp.go index 8e12ac71d..60da0e264 100644 --- a/listener/socks/tcp.go +++ b/listener/socks/tcp.go @@ -6,8 +6,10 @@ import ( "net" "github.com/Dreamacro/clash/adapter/inbound" + N "github.com/Dreamacro/clash/common/net" C "github.com/Dreamacro/clash/constant" authStore "github.com/Dreamacro/clash/listener/auth" + "github.com/Dreamacro/clash/transport/socks4" "github.com/Dreamacro/clash/transport/socks5" ) @@ -33,7 +35,7 @@ func New(addr string, in chan<- C.ConnContext) (*Listener, error) { } continue } - go HandleSocks(c, in) + go handleSocks(c, in) } }() @@ -49,7 +51,37 @@ func (l *Listener) Address() string { return l.address } -func HandleSocks(conn net.Conn, in chan<- C.ConnContext) { +func handleSocks(conn net.Conn, in chan<- C.ConnContext) { + bufConn := N.NewBufferedConn(conn) + head, err := bufConn.Peek(1) + if err != nil { + conn.Close() + return + } + + switch head[0] { + case socks4.Version: + HandleSocks4(bufConn, in) + case socks5.Version: + HandleSocks5(bufConn, in) + default: + conn.Close() + } +} + +func HandleSocks4(conn net.Conn, in chan<- C.ConnContext) { + addr, _, err := socks4.ServerHandshake(conn, authStore.Authenticator()) + if err != nil { + conn.Close() + return + } + if c, ok := conn.(*net.TCPConn); ok { + c.SetKeepAlive(true) + } + in <- inbound.NewSocket(socks5.ParseAddr(addr), conn, C.SOCKS) +} + +func HandleSocks5(conn net.Conn, in chan<- C.ConnContext) { target, command, err := socks5.ServerHandshake(conn, authStore.Authenticator()) if err != nil { conn.Close() diff --git a/transport/socks4/socks4.go b/transport/socks4/socks4.go new file mode 100644 index 000000000..c6b2f2db5 --- /dev/null +++ b/transport/socks4/socks4.go @@ -0,0 +1,195 @@ +package socks4 + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "net" + "strconv" + + "github.com/Dreamacro/clash/component/auth" +) + +const Version = 0x04 + +type Command = uint8 + +const ( + CmdConnect Command = 0x01 + CmdBind Command = 0x02 +) + +type Code = uint8 + +const ( + RequestGranted Code = 90 + RequestRejected Code = 91 + RequestIdentdFailed Code = 92 + RequestIdentdMismatched Code = 93 +) + +var ( + errVersionMismatched = errors.New("version code mismatched") + errCommandNotSupported = errors.New("command not supported") + errIPv6NotSupported = errors.New("IPv6 not supported") + + ErrRequestRejected = errors.New("request rejected or failed") + ErrRequestIdentdFailed = errors.New("request rejected because SOCKS server cannot connect to identd on the client") + ErrRequestIdentdMismatched = errors.New("request rejected because the client program and identd report different user-ids") + ErrRequestUnknownCode = errors.New("request failed with unknown code") +) + +func ServerHandshake(rw io.ReadWriter, authenticator auth.Authenticator) (addr string, command Command, err error) { + var req [8]byte + if _, err = io.ReadFull(rw, req[:]); err != nil { + return + } + + if req[0] != Version { + err = errVersionMismatched + return + } + + if command = req[1]; command != CmdConnect { + err = errCommandNotSupported + return + } + + var ( + dstIP = req[4:8] // [4]byte + dstPort = req[2:4] // [2]byte + ) + + var ( + host string + port string + code uint8 + userID []byte + ) + if userID, err = readUntilNull(rw); err != nil { + return + } + + if isReservedIP(dstIP) { + var target []byte + if target, err = readUntilNull(rw); err != nil { + return + } + host = string(target) + } + + port = strconv.Itoa(int(binary.BigEndian.Uint16(dstPort))) + if host != "" { + addr = net.JoinHostPort(host, port) + } else { + addr = net.JoinHostPort(net.IP(dstIP).String(), port) + } + + // SOCKS4 only support USERID auth. + if authenticator == nil || authenticator.Verify(string(userID), "") { + code = RequestGranted + } else { + code = RequestIdentdMismatched + } + + var reply [8]byte + reply[0] = 0x00 // reply code + reply[1] = code // result code + copy(reply[4:8], dstIP) + copy(reply[2:4], dstPort) + + _, err = rw.Write(reply[:]) + return +} + +func ClientHandshake(rw io.ReadWriter, addr string, command Command, userID string) (err error) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return err + } + + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return err + } + + ip := net.ParseIP(host) + if ip == nil /* HOST */ { + ip = net.IPv4(0, 0, 0, 1).To4() + } else if ip.To4() == nil /* IPv6 */ { + return errIPv6NotSupported + } + + dstIP := ip.To4() + + req := &bytes.Buffer{} + req.WriteByte(Version) + req.WriteByte(command) + binary.Write(req, binary.BigEndian, uint16(port)) + req.Write(dstIP) + req.WriteString(userID) + req.WriteByte(0) /* NULL */ + + if isReservedIP(dstIP) /* SOCKS4A */ { + req.WriteString(host) + req.WriteByte(0) /* NULL */ + } + + if _, err = rw.Write(req.Bytes()); err != nil { + return err + } + + var resp [8]byte + if _, err = io.ReadFull(rw, resp[:]); err != nil { + return err + } + + if resp[0] != 0x00 { + return errVersionMismatched + } + + switch resp[1] { + case RequestGranted: + return nil + case RequestRejected: + return ErrRequestRejected + case RequestIdentdFailed: + return ErrRequestIdentdFailed + case RequestIdentdMismatched: + return ErrRequestIdentdMismatched + default: + return ErrRequestUnknownCode + } +} + +// For version 4A, if the client cannot resolve the destination host's +// domain name to find its IP address, it should set the first three bytes +// of DSTIP to NULL and the last byte to a non-zero value. (This corresponds +// to IP address 0.0.0.x, with x nonzero. As decreed by IANA -- The +// Internet Assigned Numbers Authority -- such an address is inadmissible +// as a destination IP address and thus should never occur if the client +// can resolve the domain name.) +func isReservedIP(ip net.IP) bool { + subnet := net.IPNet{ + IP: net.IPv4zero, + Mask: net.IPv4Mask(0xff, 0xff, 0xff, 0x00), + } + + return !ip.IsUnspecified() && subnet.Contains(ip) +} + +func readUntilNull(r io.Reader) ([]byte, error) { + var buf = &bytes.Buffer{} + var data [1]byte + + for { + if _, err := r.Read(data[:]); err != nil { + return nil, err + } + if data[0] == 0 { + return buf.Bytes(), nil + } + buf.WriteByte(data[0]) + } +}