chore: using http/httptrace to get local/remoteAddr for grpc client

This commit is contained in:
wwqgtxx 2025-04-03 19:47:49 +08:00
parent 7b37fcfc8d
commit 23ffe451f4
3 changed files with 35 additions and 36 deletions

View File

@ -13,6 +13,7 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"net/http/httptrace"
"net/url" "net/url"
"sync" "sync"
"time" "time"
@ -38,7 +39,7 @@ var defaultHeader = http.Header{
type DialFn = func(network, addr string) (net.Conn, error) type DialFn = func(network, addr string) (net.Conn, error)
type Conn struct { type Conn struct {
initFn func() (io.ReadCloser, error) initFn func() (io.ReadCloser, netAddr, error)
writer io.Writer writer io.Writer
flusher http.Flusher flusher http.Flusher
netAddr netAddr
@ -60,7 +61,7 @@ type Config struct {
} }
func (g *Conn) initReader() { func (g *Conn) initReader() {
reader, err := g.initFn() reader, addr, err := g.initFn()
if err != nil { if err != nil {
g.err = err g.err = err
if closer, ok := g.writer.(io.Closer); ok { if closer, ok := g.writer.(io.Closer); ok {
@ -68,6 +69,7 @@ func (g *Conn) initReader() {
} }
return return
} }
g.netAddr = addr
if !g.close.Load() { if !g.close.Load() {
g.reader = reader g.reader = reader
@ -209,15 +211,11 @@ func (g *Conn) SetDeadline(t time.Time) error {
} }
func NewHTTP2Client(dialFn DialFn, tlsConfig *tls.Config, Fingerprint string, realityConfig *tlsC.RealityConfig) *TransportWrap { func NewHTTP2Client(dialFn DialFn, tlsConfig *tls.Config, Fingerprint string, realityConfig *tlsC.RealityConfig) *TransportWrap {
wrap := TransportWrap{}
dialFunc := func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { dialFunc := func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
pconn, err := dialFn(network, addr) pconn, err := dialFn(network, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
wrap.remoteAddr = pconn.RemoteAddr()
wrap.localAddr = pconn.LocalAddr()
if tlsConfig == nil { if tlsConfig == nil {
return pconn, nil return pconn, nil
@ -269,15 +267,17 @@ func NewHTTP2Client(dialFn DialFn, tlsConfig *tls.Config, Fingerprint string, re
return conn, nil return conn, nil
} }
wrap.Transport = &http2.Transport{ transport := &http2.Transport{
DialTLSContext: dialFunc, DialTLSContext: dialFunc,
TLSClientConfig: tlsConfig, TLSClientConfig: tlsConfig,
AllowHTTP: false, AllowHTTP: false,
DisableCompression: true, DisableCompression: true,
PingTimeout: 0, PingTimeout: 0,
} }
wrap := &TransportWrap{
return &wrap Transport: transport,
}
return wrap
} }
func StreamGunWithTransport(transport *TransportWrap, cfg *Config) (net.Conn, error) { func StreamGunWithTransport(transport *TransportWrap, cfg *Config) (net.Conn, error) {
@ -304,15 +304,22 @@ func StreamGunWithTransport(transport *TransportWrap, cfg *Config) (net.Conn, er
} }
conn := &Conn{ conn := &Conn{
initFn: func() (io.ReadCloser, error) { initFn: func() (io.ReadCloser, netAddr, error) {
nAddr := netAddr{}
trace := &httptrace.ClientTrace{
GotConn: func(connInfo httptrace.GotConnInfo) {
nAddr.localAddr = connInfo.Conn.LocalAddr()
nAddr.remoteAddr = connInfo.Conn.RemoteAddr()
},
}
request = request.WithContext(httptrace.WithClientTrace(request.Context(), trace))
response, err := transport.RoundTrip(request) response, err := transport.RoundTrip(request)
if err != nil { if err != nil {
return nil, err return nil, nAddr, err
} }
return response.Body, nil return response.Body, nAddr, nil
}, },
writer: writer, writer: writer,
netAddr: transport.netAddr,
} }
go conn.Init() go conn.Init()

View File

@ -43,20 +43,21 @@ func NewServerHandler(options ServerOption) http.Handler {
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
conn := &Conn{ conn := &Conn{
initFn: func() (io.ReadCloser, error) { initFn: func() (io.ReadCloser, netAddr, error) {
return request.Body, nil nAddr := netAddr{}
},
writer: writer,
flusher: writer.(http.Flusher),
}
if request.RemoteAddr != "" { if request.RemoteAddr != "" {
metadata := C.Metadata{} metadata := C.Metadata{}
if err := metadata.SetRemoteAddress(request.RemoteAddr); err == nil { if err := metadata.SetRemoteAddress(request.RemoteAddr); err == nil {
conn.remoteAddr = net.TCPAddrFromAddrPort(metadata.AddrPort()) nAddr.remoteAddr = net.TCPAddrFromAddrPort(metadata.AddrPort())
} }
} }
if addr, ok := request.Context().Value(http.LocalAddrContextKey).(net.Addr); ok { if addr, ok := request.Context().Value(http.LocalAddrContextKey).(net.Addr); ok {
conn.localAddr = addr nAddr.localAddr = addr
}
return request.Body, nAddr, nil
},
writer: writer,
flusher: writer.(http.Flusher),
} }
wrapper := &h2ConnWrapper{ wrapper := &h2ConnWrapper{

View File

@ -7,15 +7,6 @@ import (
type TransportWrap struct { type TransportWrap struct {
*http2.Transport *http2.Transport
netAddr
}
func (tw *TransportWrap) RemoteAddr() net.Addr {
return tw.remoteAddr
}
func (tw *TransportWrap) LocalAddr() net.Addr {
return tw.localAddr
} }
type netAddr struct { type netAddr struct {
@ -23,10 +14,10 @@ type netAddr struct {
localAddr net.Addr localAddr net.Addr
} }
func (addr *netAddr) RemoteAddr() net.Addr { func (addr netAddr) RemoteAddr() net.Addr {
return addr.remoteAddr return addr.remoteAddr
} }
func (addr *netAddr) LocalAddr() net.Addr { func (addr netAddr) LocalAddr() net.Addr {
return addr.localAddr return addr.localAddr
} }