fix: firstWriteCallBackConn can pass N.ExtendedConn too

This commit is contained in:
wwqgtxx 2023-04-01 20:56:49 +08:00
parent 4af4935e7e
commit cd95cf4849
4 changed files with 78 additions and 35 deletions

View File

@ -37,16 +37,13 @@ func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata, opts .
} }
if N.NeedHandshake(c) { if N.NeedHandshake(c) {
c = &callback.FirstWriteCallBackConn{ c = callback.NewFirstWriteCallBackConn(c, func(err error) {
Conn: c, if err == nil {
Callback: func(err error) { f.onDialSuccess()
if err == nil { } else {
f.onDialSuccess() f.onDialFailed(proxy.Type(), err)
} else { }
f.onDialFailed(proxy.Type(), err) })
}
},
}
} }
return c, err return c, err

View File

@ -95,16 +95,13 @@ func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata, op
} }
if N.NeedHandshake(c) { if N.NeedHandshake(c) {
c = &callback.FirstWriteCallBackConn{ c = callback.NewFirstWriteCallBackConn(c, func(err error) {
Conn: c, if err == nil {
Callback: func(err error) { lb.onDialSuccess()
if err == nil { } else {
lb.onDialSuccess() lb.onDialFailed(proxy.Type(), err)
} else { }
lb.onDialFailed(proxy.Type(), err) })
}
},
}
} }
return return

View File

@ -45,16 +45,13 @@ func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata, opts ..
} }
if N.NeedHandshake(c) { if N.NeedHandshake(c) {
c = &callback.FirstWriteCallBackConn{ c = callback.NewFirstWriteCallBackConn(c, func(err error) {
Conn: c, if err == nil {
Callback: func(err error) { u.onDialSuccess()
if err == nil { } else {
u.onDialSuccess() u.onDialFailed(proxy.Type(), err)
} else { }
u.onDialFailed(proxy.Type(), err) })
}
},
}
} }
return c, err return c, err

View File

@ -1,25 +1,77 @@
package callback package callback
import ( import (
"github.com/Dreamacro/clash/common/buf"
N "github.com/Dreamacro/clash/common/net"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
) )
type FirstWriteCallBackConn struct { type firstWriteCallBackConn struct {
C.Conn C.Conn
Callback func(error) callback func(error)
written bool written bool
} }
func (c *FirstWriteCallBackConn) Write(b []byte) (n int, err error) { func (c *firstWriteCallBackConn) Write(b []byte) (n int, err error) {
defer func() { defer func() {
if !c.written { if !c.written {
c.written = true c.written = true
c.Callback(err) c.callback(err)
} }
}() }()
return c.Conn.Write(b) return c.Conn.Write(b)
} }
func (c *FirstWriteCallBackConn) Upstream() any { func (c *firstWriteCallBackConn) Upstream() any {
return c.Conn return c.Conn
} }
type extendedConn interface {
C.Conn
N.ExtendedConn
}
type firstWriteCallBackExtendedConn struct {
extendedConn
callback func(error)
written bool
}
func (c *firstWriteCallBackExtendedConn) Write(b []byte) (n int, err error) {
defer func() {
if !c.written {
c.written = true
c.callback(err)
}
}()
return c.extendedConn.Write(b)
}
func (c *firstWriteCallBackExtendedConn) WriteBuffer(buffer *buf.Buffer) (err error) {
defer func() {
if !c.written {
c.written = true
c.callback(err)
}
}()
return c.extendedConn.WriteBuffer(buffer)
}
func (c *firstWriteCallBackExtendedConn) Upstream() any {
return c.extendedConn
}
func NewFirstWriteCallBackConn(c C.Conn, callback func(error)) C.Conn {
if c, ok := c.(extendedConn); ok {
return &firstWriteCallBackExtendedConn{
extendedConn: c,
callback: callback,
written: false,
}
}
return &firstWriteCallBackConn{
Conn: c,
callback: callback,
written: false,
}
}