chore: Add InUser for http/socks/mixed

This commit is contained in:
xishang0128 2024-04-25 11:48:53 +08:00
parent 2f8f139f7c
commit 8ff56b5bb8
No known key found for this signature in database
GPG Key ID: 44A1E10B5ADF68CB
4 changed files with 16 additions and 12 deletions

View File

@ -51,8 +51,8 @@ func HandleConn(c net.Conn, tunnel C.Tunnel, cache *lru.LruCache[string, bool],
var resp *http.Response var resp *http.Response
if !trusted { if !trusted {
resp = authenticate(request, cache) resp, user := authenticate(request, cache)
additions = append(additions, inbound.WithInUser(user))
trusted = resp == nil trusted = resp == nil
} }
@ -130,7 +130,7 @@ func HandleConn(c net.Conn, tunnel C.Tunnel, cache *lru.LruCache[string, bool],
_ = conn.Close() _ = conn.Close()
} }
func authenticate(request *http.Request, cache *lru.LruCache[string, bool]) *http.Response { func authenticate(request *http.Request, cache *lru.LruCache[string, bool]) (resp *http.Response, u string) {
authenticator := authStore.Authenticator() authenticator := authStore.Authenticator()
if inbound.SkipAuthRemoteAddress(request.RemoteAddr) { if inbound.SkipAuthRemoteAddress(request.RemoteAddr) {
authenticator = nil authenticator = nil
@ -140,23 +140,24 @@ func authenticate(request *http.Request, cache *lru.LruCache[string, bool]) *htt
if credential == "" { if credential == "" {
resp := responseWith(request, http.StatusProxyAuthRequired) resp := responseWith(request, http.StatusProxyAuthRequired)
resp.Header.Set("Proxy-Authenticate", "Basic") resp.Header.Set("Proxy-Authenticate", "Basic")
return resp return resp, ""
} }
authed, exist := cache.Get(credential) authed, exist := cache.Get(credential)
if !exist { if !exist {
user, pass, err := decodeBasicProxyAuthorization(credential) user, pass, err := decodeBasicProxyAuthorization(credential)
authed = err == nil && authenticator.Verify(user, pass) authed = err == nil && authenticator.Verify(user, pass)
u = user
cache.Set(credential, authed) cache.Set(credential, authed)
} }
if !authed { if !authed {
log.Infoln("Auth failed from %s", request.RemoteAddr) log.Infoln("Auth failed from %s", request.RemoteAddr)
return responseWith(request, http.StatusForbidden) return responseWith(request, http.StatusForbidden), u
} }
} }
return nil return nil, u
} }
func responseWith(request *http.Request, statusCode int) *http.Response { func responseWith(request *http.Request, statusCode int) *http.Response {

View File

@ -98,11 +98,12 @@ func HandleSocks4(conn net.Conn, tunnel C.Tunnel, additions ...inbound.Addition)
if inbound.SkipAuthRemoteAddr(conn.RemoteAddr()) { if inbound.SkipAuthRemoteAddr(conn.RemoteAddr()) {
authenticator = nil authenticator = nil
} }
addr, _, err := socks4.ServerHandshake(conn, authenticator) addr, _, user, err := socks4.ServerHandshake(conn, authenticator)
if err != nil { if err != nil {
conn.Close() conn.Close()
return return
} }
additions = append(additions, inbound.WithInUser(user))
tunnel.HandleTCPConn(inbound.NewSocket(socks5.ParseAddr(addr), conn, C.SOCKS4, additions...)) tunnel.HandleTCPConn(inbound.NewSocket(socks5.ParseAddr(addr), conn, C.SOCKS4, additions...))
} }
@ -111,7 +112,7 @@ func HandleSocks5(conn net.Conn, tunnel C.Tunnel, additions ...inbound.Addition)
if inbound.SkipAuthRemoteAddr(conn.RemoteAddr()) { if inbound.SkipAuthRemoteAddr(conn.RemoteAddr()) {
authenticator = nil authenticator = nil
} }
target, command, err := socks5.ServerHandshake(conn, authenticator) target, command, user, err := socks5.ServerHandshake(conn, authenticator)
if err != nil { if err != nil {
conn.Close() conn.Close()
return return
@ -121,5 +122,6 @@ func HandleSocks5(conn net.Conn, tunnel C.Tunnel, additions ...inbound.Addition)
io.Copy(io.Discard, conn) io.Copy(io.Discard, conn)
return return
} }
additions = append(additions, inbound.WithInUser(user))
tunnel.HandleTCPConn(inbound.NewSocket(target, conn, C.SOCKS5, additions...)) tunnel.HandleTCPConn(inbound.NewSocket(target, conn, C.SOCKS5, additions...))
} }

View File

@ -43,7 +43,7 @@ var (
var subnet = netip.PrefixFrom(netip.IPv4Unspecified(), 24) var subnet = netip.PrefixFrom(netip.IPv4Unspecified(), 24)
func ServerHandshake(rw io.ReadWriter, authenticator auth.Authenticator) (addr string, command Command, err error) { func ServerHandshake(rw io.ReadWriter, authenticator auth.Authenticator) (addr string, command Command, user string, err error) {
var req [8]byte var req [8]byte
if _, err = io.ReadFull(rw, req[:]); err != nil { if _, err = io.ReadFull(rw, req[:]); err != nil {
return return
@ -73,6 +73,7 @@ func ServerHandshake(rw io.ReadWriter, authenticator auth.Authenticator) (addr s
if userID, err = readUntilNull(rw); err != nil { if userID, err = readUntilNull(rw); err != nil {
return return
} }
user = string(userID)
if isReservedIP(dstIP) { if isReservedIP(dstIP) {
var target []byte var target []byte
@ -90,7 +91,7 @@ func ServerHandshake(rw io.ReadWriter, authenticator auth.Authenticator) (addr s
} }
// SOCKS4 only support USERID auth. // SOCKS4 only support USERID auth.
if authenticator == nil || authenticator.Verify(string(userID), "") { if authenticator == nil || authenticator.Verify(user, "") {
code = RequestGranted code = RequestGranted
} else { } else {
code = RequestIdentdMismatched code = RequestIdentdMismatched

View File

@ -106,7 +106,7 @@ type User struct {
} }
// ServerHandshake fast-tracks SOCKS initialization to get target address to connect on server side. // ServerHandshake fast-tracks SOCKS initialization to get target address to connect on server side.
func ServerHandshake(rw net.Conn, authenticator auth.Authenticator) (addr Addr, command Command, err error) { func ServerHandshake(rw net.Conn, authenticator auth.Authenticator) (addr Addr, command Command, user string, err error) {
// Read RFC 1928 for request and reply structure and sizes. // Read RFC 1928 for request and reply structure and sizes.
buf := make([]byte, MaxAddrLen) buf := make([]byte, MaxAddrLen)
// read VER, NMETHODS, METHODS // read VER, NMETHODS, METHODS
@ -141,7 +141,7 @@ func ServerHandshake(rw net.Conn, authenticator auth.Authenticator) (addr Addr,
if _, err = io.ReadFull(rw, authBuf[:userLen]); err != nil { if _, err = io.ReadFull(rw, authBuf[:userLen]); err != nil {
return return
} }
user := string(authBuf[:userLen]) user = string(authBuf[:userLen])
// Get password // Get password
if _, err = rw.Read(header[:1]); err != nil { if _, err = rw.Read(header[:1]); err != nil {