reality/server.go

409 lines
10 KiB
Go
Raw Normal View History

2024-10-10 11:08:23 +08:00
package reality
import (
"bufio"
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/ecdh"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/cryptobyte"
)
type ServerConfig struct {
SNIAddr string `json:"sni_addr"`
ServerAddr string `json:"server_addr"`
SkipVerify bool `json:"skip_verify"`
2024-10-10 11:08:23 +08:00
PrivateKeyECDH string `json:"private_key_ecdh"`
PrivateKeySign string `json:"private_key_sign"`
ExpireSecond uint32 `json:"expire_second"`
Debug bool `json:"debug"`
ClientFingerPrint string `json:"finger_print,omitempty"`
privateKeyECDH *ecdh.PrivateKey
privateKeySign ed25519.PrivateKey
sniHost string
sniPort string
}
func NewServerConfig(sniAddr string, serverAddr string) (*ServerConfig, error) {
privateKeyECDH, err := ecdh.X25519().GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
_, privateKeySign, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
sniHost, sniPort, err := net.SplitHostPort(sniAddr)
if err != nil {
return nil, err
}
return &ServerConfig{
SNIAddr: sniAddr,
ServerAddr: serverAddr,
PrivateKeyECDH: base64.StdEncoding.EncodeToString(privateKeyECDH.Bytes()),
PrivateKeySign: base64.StdEncoding.EncodeToString(privateKeySign),
ExpireSecond: DefaultExpireSecond,
privateKeyECDH: privateKeyECDH,
privateKeySign: privateKeySign,
sniHost: sniHost,
sniPort: sniPort,
}, nil
}
func (c *ServerConfig) Validate() error {
if c.SNIAddr == "" {
return errors.New("SNI is required")
}
var err error
c.sniHost, c.sniPort, err = net.SplitHostPort(c.SNIAddr)
if err != nil {
return err
}
if c.ServerAddr == "" {
return errors.New("server address is required")
}
data, err := base64.StdEncoding.DecodeString(c.PrivateKeyECDH)
if err != nil {
return err
}
c.privateKeyECDH, err = ecdh.X25519().NewPrivateKey(data)
if err != nil {
return err
}
data, err = base64.StdEncoding.DecodeString(c.PrivateKeySign)
if err != nil {
return err
}
if len(data) != ed25519.PrivateKeySize {
return errors.New("private key sign length error")
}
c.privateKeySign = ed25519.PrivateKey(data)
if c.ExpireSecond == 0 {
c.ExpireSecond = DefaultExpireSecond
}
if c.ClientFingerPrint == "" {
c.ClientFingerPrint = "chrome"
}
return nil
}
func (c *ServerConfig) SNIHost() string {
return c.sniHost
}
func (c *ServerConfig) SNIPort() string {
return c.sniPort
}
2024-10-17 18:55:28 +08:00
func (s *ServerConfig) ToClientConfig(overlayData byte) *ClientConfig {
2024-10-10 11:08:23 +08:00
return &ClientConfig{
SNI: s.sniHost,
ServerAddr: s.ServerAddr,
SkipVerify: s.SkipVerify,
2024-10-10 11:08:23 +08:00
PublicKeyECDH: base64.StdEncoding.EncodeToString(s.privateKeyECDH.PublicKey().Bytes()),
PublicKeyVerify: base64.StdEncoding.EncodeToString(s.privateKeySign.Public().(ed25519.PublicKey)),
ExpireSecond: s.ExpireSecond,
Debug: s.Debug,
FingerPrint: s.ClientFingerPrint,
2024-10-17 18:55:28 +08:00
OverlayData: overlayData,
2024-10-10 11:08:23 +08:00
}
}
type Listener struct {
net.Listener
config *ServerConfig
chanConn chan net.Conn
chanErr chan error
logger logrus.FieldLogger
}
func Listen(laddr string, config *ServerConfig) (net.Listener, error) {
inner, err := net.Listen("tcp", laddr)
if err != nil {
return nil, err
}
l := &Listener{
Listener: inner,
config: config,
chanConn: make(chan net.Conn),
chanErr: make(chan error),
logger: GetLogger(config.Debug),
}
go func() {
for {
conn, err := l.Listener.Accept()
if err != nil {
l.chanErr <- err
close(l.chanConn)
return
}
go func() {
c, err := l.handshake(conn)
if err != nil {
if l.config.Debug {
l.logger.Warnln("handshake", conn.RemoteAddr(), err)
}
} else {
l.chanConn <- c
}
}()
}
}()
return l, nil
}
func (l *Listener) Accept() (net.Conn, error) {
if c, ok := <-l.chanConn; ok {
return c, nil
}
return nil, <-l.chanErr
}
// handshake 尝试处理私有握手,失败则进行客户端和代理目标转发,成功返回加密包装后的客户端连接
func (l *Listener) handshake(clientConn net.Conn) (net.Conn, error) {
logger := l.logger
targetConn, err := net.Dial("tcp", l.config.SNIAddr)
if err != nil {
return nil, errors.Join(ErrProxyDie, err)
}
// bufio.Reader是为了在读数据时不是一个一个record读取而是模仿一次性读取尽可能多的record
// io.TeeReader是为了在读数据时同时互相转发
clientReader := bufio.NewReader(io.TeeReader(clientConn, targetConn))
targetReader := bufio.NewReader(io.TeeReader(targetConn, clientConn))
var aead cipher.AEAD
var plaintext []byte
readClientHello := func() error {
recordClientHello, err := readTlsRecord(clientReader)
if err != nil {
return err
}
var random, sessionId []byte
s := cryptobyte.String(recordClientHello.recordData)
if !s.Skip(6) || // skip type(1) length(3) version(2)
!s.ReadBytes(&random, 32) ||
!s.ReadUint8LengthPrefixed((*cryptobyte.String)(&sessionId)) ||
len(sessionId) != 32 {
return fmt.Errorf("invalid client hello: %x", hex.EncodeToString(recordClientHello.recordData))
}
logger.Debugf("random(public for ecdh): %x", random)
logger.Debugf("sessionId(ciphertext): %x", sessionId)
pub, err := ecdh.X25519().NewPublicKey(random)
if err != nil {
return err
}
sessionKey, err := l.config.privateKeyECDH.ECDH(pub)
if err != nil {
return err
}
logger.Debugf("sessionKey: %x", sessionKey)
block, err := aes.NewCipher(sessionKey)
if err != nil {
return err
}
aead, err = cipher.NewGCMWithNonceSize(block, 8)
if err != nil {
return err
}
nonce, err := generateNonce(aead.NonceSize(), sessionKey, l.config.ExpireSecond)
if err != nil {
return err
}
logger.Debugf("nonce: %x", nonce)
plaintext, err = aead.Open(nil, nonce, sessionId, nil)
if err != nil {
return err
}
logger.Debugf("plaintext: %x", plaintext)
if !bytes.HasPrefix(plaintext, Prefix) {
return fmt.Errorf("invalid prefix: %x", plaintext[:len(Prefix)])
}
logger.Debug("handshake ok")
return nil
}
if err = readClientHello(); err != nil {
go dup(clientConn, targetConn)
return nil, errors.Join(ErrVerifyFailed, err)
}
if _, err = serverOrder1.wait(targetReader, logger); err != nil {
go dup(clientConn, targetConn)
return nil, err
}
if _, err = clientOrder.wait(clientReader, logger); err != nil {
go dup(clientConn, targetConn)
return nil, err
}
records, err := serverOrder2.wait(targetReader, logger)
if err != nil {
go dup(clientConn, targetConn)
return nil, err
}
// 客户端和代理目标的tls握手已经完成可以关闭目标的连接
targetConn.Close()
// 获取模拟目标的seq如果有的话
seq := [8]byte{}
copy(seq[:], seqNumerOne[:])
if len(records) > 0 {
record := records[len(records)-1]
recordData := record.recordData
if len(recordData) > len(seq) {
copy(seq[:], recordData[:len(seq)])
}
}
logger.Debugf("seqNumer: %x", seq)
incSeq(seq[:])
// 读取客户端发送的附加内容
record, err := readTlsRecord(clientConn)
if err != nil {
return nil, err
}
overlayData := record.recordData[len(record.recordData)-1]
logger.Debugf("overlayData: %x", overlayData)
// 发送服务端签名
sign := ed25519.Sign(ed25519.PrivateKey(l.config.privateKeySign), plaintext)
logger.Debugf("sign: %x", sign)
record = newTLSRecord(
recordTypeApplicationData, versionTLS12,
generateRandomData(append(seq[:], sign...)), // record数据前缀模仿seq
)
if _, err = record.writeTo(clientConn); err != nil {
clientConn.Close()
return nil, err
}
return newWarpConn(clientConn, aead, overlayData, seq), nil
}
// dup 转发两个连接
func dup(clientConn net.Conn, proxyConn net.Conn) {
defer clientConn.Close()
defer proxyConn.Close()
go io.Copy(proxyConn, clientConn)
io.Copy(clientConn, proxyConn)
}
type recordOrders []struct {
recordType byte
handshakeType byte
optional bool
}
var serverOrder1 = recordOrders{
{
recordType: recordTypeHandshake,
handshakeType: typeServerHello,
},
{
recordType: recordTypeHandshake,
handshakeType: typeCertificate,
},
{
recordType: recordTypeHandshake,
handshakeType: typeServerKeyExchange,
},
{
recordType: recordTypeHandshake,
handshakeType: typeServerHelloDone,
},
}
var clientOrder = recordOrders{
{
recordType: recordTypeHandshake,
handshakeType: typeCertificate,
optional: true,
},
{
recordType: recordTypeHandshake,
handshakeType: typeClientKeyExchange,
},
{
recordType: recordTypeHandshake,
handshakeType: typeCertificateVerify,
optional: true,
},
{
recordType: recordTypeChangeCipherSpec,
},
{
recordType: recordTypeHandshake, // Encrypted Handshake Message(Finished)
},
}
var serverOrder2 = recordOrders{
{
recordType: recordTypeHandshake,
handshakeType: typeNewSessionTicket,
optional: true,
},
{
recordType: recordTypeChangeCipherSpec,
},
{
recordType: recordTypeHandshake, // Encrypted Handshake Message(Finished)
},
}
func (orders recordOrders) wait(reader io.Reader, logger logrus.FieldLogger) ([]*tlsRecord, error) {
records := make([]*tlsRecord, 0, len(orders))
orderPos := 0
for {
record, err := readTlsRecord(reader)
if err != nil {
return nil, err
}
records = append(records, record)
for pos := orderPos; pos < len(orders); pos++ {
o := orders[pos]
if o.handshakeType != 0 {
// 需要判断握手类型
if len(record.recordData) != 0 &&
record.recordData[0] == o.handshakeType {
orderPos = pos + 1
break
}
} else {
orderPos = pos + 1
break
}
if o.optional {
// 如果当前类型是可选的,继续向下查找
logger.Debugf("try optional record: %+v", record)
orderPos = pos + 1
continue
} else {
return nil, fmt.Errorf(
"invalid record, want %+v, got %d %x,",
o, record.recordType, record.recordData,
)
}
}
if orderPos == len(orders) {
return records, nil
}
}
}