From bd431fbf49e01819d625c4b6758ea59539c99661 Mon Sep 17 00:00:00 2001 From: H1JK Date: Sat, 6 May 2023 15:49:10 +0800 Subject: [PATCH] fix: Update unsafe pointer add usage --- component/geodata/strmatcher/mph_matcher.go | 26 +++++++++------------ component/tls/reality.go | 2 +- transport/vless/conn.go | 12 +++++----- 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/component/geodata/strmatcher/mph_matcher.go b/component/geodata/strmatcher/mph_matcher.go index 3c10cb492..8d8b05089 100644 --- a/component/geodata/strmatcher/mph_matcher.go +++ b/component/geodata/strmatcher/mph_matcher.go @@ -234,26 +234,26 @@ tail: case s == 0: case s < 4: h ^= uint64(*(*byte)(p)) - h ^= uint64(*(*byte)(add(p, s>>1))) << 8 - h ^= uint64(*(*byte)(add(p, s-1))) << 16 + h ^= uint64(*(*byte)(unsafe.Add(p, s>>1))) << 8 + h ^= uint64(*(*byte)(unsafe.Add(p, s-1))) << 16 h = rotl31(h*m1) * m2 case s <= 8: h ^= uint64(readUnaligned32(p)) - h ^= uint64(readUnaligned32(add(p, s-4))) << 32 + h ^= uint64(readUnaligned32(unsafe.Add(p, s-4))) << 32 h = rotl31(h*m1) * m2 case s <= 16: h ^= readUnaligned64(p) h = rotl31(h*m1) * m2 - h ^= readUnaligned64(add(p, s-8)) + h ^= readUnaligned64(unsafe.Add(p, s-8)) h = rotl31(h*m1) * m2 case s <= 32: h ^= readUnaligned64(p) h = rotl31(h*m1) * m2 - h ^= readUnaligned64(add(p, 8)) + h ^= readUnaligned64(unsafe.Add(p, 8)) h = rotl31(h*m1) * m2 - h ^= readUnaligned64(add(p, s-16)) + h ^= readUnaligned64(unsafe.Add(p, s-16)) h = rotl31(h*m1) * m2 - h ^= readUnaligned64(add(p, s-8)) + h ^= readUnaligned64(unsafe.Add(p, s-8)) h = rotl31(h*m1) * m2 default: v1 := h @@ -263,16 +263,16 @@ tail: for s >= 32 { v1 ^= readUnaligned64(p) v1 = rotl31(v1*m1) * m2 - p = add(p, 8) + p = unsafe.Add(p, 8) v2 ^= readUnaligned64(p) v2 = rotl31(v2*m2) * m3 - p = add(p, 8) + p = unsafe.Add(p, 8) v3 ^= readUnaligned64(p) v3 = rotl31(v3*m3) * m4 - p = add(p, 8) + p = unsafe.Add(p, 8) v4 ^= readUnaligned64(p) v4 = rotl31(v4*m4) * m1 - p = add(p, 8) + p = unsafe.Add(p, 8) s -= 32 } h = v1 ^ v2 ^ v3 ^ v4 @@ -285,10 +285,6 @@ tail: return uintptr(h) } -func add(p unsafe.Pointer, x uintptr) unsafe.Pointer { - return unsafe.Pointer(uintptr(p) + x) -} - func readUnaligned32(p unsafe.Pointer) uint32 { q := (*[4]byte)(p) return uint32(q[0]) | uint32(q[1])<<8 | uint32(q[2])<<16 | uint32(q[3])<<24 diff --git a/component/tls/reality.go b/component/tls/reality.go index dd4f3af8a..b8a7fa3a5 100644 --- a/component/tls/reality.go +++ b/component/tls/reality.go @@ -141,7 +141,7 @@ var pOffset = utils.MustOK(reflect.TypeOf((*utls.UConn)(nil)).Elem().FieldByName func (c *realityVerifier) VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { //p, _ := reflect.TypeOf(c.Conn).Elem().FieldByName("peerCertificates") - certs := *(*[]*x509.Certificate)(unsafe.Pointer(uintptr(unsafe.Pointer(c.Conn)) + pOffset)) + certs := *(*[]*x509.Certificate)(unsafe.Add(unsafe.Pointer(c.Conn), pOffset)) if pub, ok := certs[0].PublicKey.(ed25519.PublicKey); ok { h := hmac.New(sha512.New, c.authKey) h.Write(pub) diff --git a/transport/vless/conn.go b/transport/vless/conn.go index 6c3714e01..9289afcfe 100644 --- a/transport/vless/conn.go +++ b/transport/vless/conn.go @@ -474,34 +474,34 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) { c.writeFilterApplicationData = true c.addons = client.Addons var t reflect.Type - var p uintptr + var p unsafe.Pointer switch underlying := conn.(type) { case *gotls.Conn: //log.Debugln("type tls") c.Conn = underlying.NetConn() c.tlsConn = underlying t = reflect.TypeOf(underlying).Elem() - p = uintptr(unsafe.Pointer(underlying)) + p = unsafe.Pointer(underlying) case *utls.UConn: //log.Debugln("type *utls.UConn") c.Conn = underlying.NetConn() c.tlsConn = underlying t = reflect.TypeOf(underlying.Conn).Elem() - p = uintptr(unsafe.Pointer(underlying.Conn)) + p = unsafe.Pointer(underlying.Conn) case *tlsC.UConn: //log.Debugln("type *tlsC.UConn") c.Conn = underlying.NetConn() c.tlsConn = underlying.UConn t = reflect.TypeOf(underlying.Conn).Elem() //log.Debugln("t:%v", t) - p = uintptr(unsafe.Pointer(underlying.Conn)) + p = unsafe.Pointer(underlying.Conn) default: return nil, fmt.Errorf(`failed to use %s, maybe "security" is not "tls" or "utls"`, client.Addons.Flow) } i, _ := t.FieldByName("input") r, _ := t.FieldByName("rawInput") - c.input = (*bytes.Reader)(unsafe.Pointer(p + i.Offset)) - c.rawInput = (*bytes.Buffer)(unsafe.Pointer(p + r.Offset)) + c.input = (*bytes.Reader)(unsafe.Add(p, i.Offset)) + c.rawInput = (*bytes.Buffer)(unsafe.Add(p, r.Offset)) //if _, ok := c.Conn.(*net.TCPConn); !ok { // log.Debugln("XTLS underlying conn is not *net.TCPConn, got %T", c.Conn) //}