2024-10-10 11:08:23 +08:00
|
|
|
package reality
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
|
|
|
"crypto/cipher"
|
|
|
|
"crypto/sha256"
|
|
|
|
"encoding/binary"
|
|
|
|
"errors"
|
|
|
|
"io"
|
|
|
|
"math/rand"
|
|
|
|
"net"
|
|
|
|
"sync"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/mattn/go-colorable"
|
|
|
|
utls "github.com/refraction-networking/utls"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
|
|
"golang.org/x/crypto/hkdf"
|
|
|
|
)
|
|
|
|
|
|
|
|
var (
|
|
|
|
ErrVerifyFailed = errors.New("verify failed")
|
|
|
|
ErrDecryptFailed = errors.New("decrypt failed")
|
|
|
|
ErrProxyDie = errors.New("proxy die")
|
|
|
|
)
|
|
|
|
|
|
|
|
var Prefix = []byte("REALITY")
|
|
|
|
|
|
|
|
const DefaultExpireSecond = 30
|
|
|
|
|
|
|
|
var seqNumerOne = [8]byte{0, 0, 0, 0, 0, 0, 0, 1}
|
|
|
|
|
|
|
|
// generateNonce 根据SessionKey和ExpireSecond生成Nonce
|
|
|
|
func generateNonce(NonceSize int, SessionKey []byte, ExpireSecond uint32) ([]byte, error) {
|
|
|
|
info := make([]byte, 8)
|
2024-10-16 13:56:38 +08:00
|
|
|
binary.BigEndian.PutUint64(info, uint64(time.Now().Unix()/int64(ExpireSecond)))
|
2024-10-10 11:08:23 +08:00
|
|
|
nonce := make([]byte, NonceSize)
|
|
|
|
_, err := hkdf.New(sha256.New, SessionKey[:], Prefix, info).Read(nonce[:])
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return nonce, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
var versionTLS12 = uint16(utls.VersionTLS12)
|
|
|
|
|
|
|
|
const recordHeaderLen = 5
|
|
|
|
const (
|
|
|
|
recordTypeChangeCipherSpec = 20
|
|
|
|
recordTypeAlert = 21
|
|
|
|
recordTypeHandshake = 22
|
|
|
|
recordTypeApplicationData = 23
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
typeHelloRequest uint8 = 0
|
|
|
|
typeClientHello uint8 = 1
|
|
|
|
typeServerHello uint8 = 2
|
|
|
|
typeNewSessionTicket uint8 = 4
|
|
|
|
typeEndOfEarlyData uint8 = 5
|
|
|
|
typeEncryptedExtensions uint8 = 8
|
|
|
|
typeCertificate uint8 = 11
|
|
|
|
typeServerKeyExchange uint8 = 12
|
|
|
|
typeCertificateRequest uint8 = 13
|
|
|
|
typeServerHelloDone uint8 = 14
|
|
|
|
typeCertificateVerify uint8 = 15
|
|
|
|
typeClientKeyExchange uint8 = 16
|
|
|
|
typeFinished uint8 = 20
|
|
|
|
typeCertificateStatus uint8 = 22
|
|
|
|
typeKeyUpdate uint8 = 24
|
|
|
|
)
|
|
|
|
|
|
|
|
type tlsRecord struct {
|
|
|
|
recordType uint8
|
|
|
|
version uint16
|
|
|
|
recordData []byte
|
|
|
|
}
|
|
|
|
|
|
|
|
func newTLSRecord(recordType uint8, version uint16, recordData []byte) *tlsRecord {
|
|
|
|
return &tlsRecord{
|
|
|
|
recordType: recordType,
|
|
|
|
version: version,
|
|
|
|
recordData: recordData,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (r *tlsRecord) marshal() []byte {
|
|
|
|
data := make([]byte, recordHeaderLen+len(r.recordData))
|
|
|
|
data[0] = r.recordType
|
|
|
|
data[1] = byte(r.version >> 8)
|
|
|
|
data[2] = byte(r.version)
|
|
|
|
data[3] = byte(len(r.recordData) >> 8)
|
|
|
|
data[4] = byte(len(r.recordData))
|
|
|
|
copy(data[5:], r.recordData)
|
|
|
|
return data
|
|
|
|
}
|
|
|
|
|
|
|
|
func (r *tlsRecord) writeTo(w io.Writer) (int, error) {
|
|
|
|
n, err := bytes.NewReader(r.marshal()).WriteTo(w)
|
|
|
|
return int(n), err
|
|
|
|
}
|
|
|
|
|
|
|
|
func readTlsRecord(reader io.Reader) (*tlsRecord, error) {
|
|
|
|
hdr := make([]byte, recordHeaderLen)
|
|
|
|
if _, err := io.ReadFull(reader, hdr); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
recordType := hdr[0]
|
|
|
|
if recordType < recordTypeChangeCipherSpec || recordType > recordTypeApplicationData {
|
|
|
|
return nil, errors.New("tls: unknown record type")
|
|
|
|
}
|
|
|
|
version := uint16(hdr[1])<<8 | uint16(hdr[2])
|
|
|
|
if version < utls.VersionTLS10 || version > utls.VersionTLS13 {
|
|
|
|
return nil, errors.New("tls: unknown version")
|
|
|
|
}
|
|
|
|
recordLen := int(hdr[3])<<8 | int(hdr[4])
|
|
|
|
|
|
|
|
recordData := make([]byte, recordLen)
|
|
|
|
if _, err := io.ReadFull(reader, recordData); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return &tlsRecord{
|
|
|
|
recordType: recordType,
|
|
|
|
version: version,
|
|
|
|
recordData: recordData,
|
|
|
|
}, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
const maxSize = 1400
|
|
|
|
const minSize = 900
|
|
|
|
|
|
|
|
var r = rand.New(rand.NewSource(time.Now().UnixNano()))
|
|
|
|
|
|
|
|
// generateRandomData 生成随机900-1400数据
|
|
|
|
func generateRandomData(prefix []byte) []byte {
|
|
|
|
len := r.Intn(maxSize-minSize+1) + minSize
|
|
|
|
data := make([]byte, len)
|
|
|
|
r.Read(data)
|
|
|
|
copy(data, prefix)
|
|
|
|
return data
|
|
|
|
}
|
|
|
|
|
|
|
|
type OverlayData interface {
|
|
|
|
OverlayData() byte
|
|
|
|
}
|
|
|
|
|
|
|
|
var _ OverlayData = (*warpConn)(nil)
|
|
|
|
|
|
|
|
type warpConn struct {
|
|
|
|
net.Conn
|
|
|
|
aead cipher.AEAD
|
|
|
|
overlayData byte
|
|
|
|
seq []byte
|
|
|
|
lockRead *sync.Mutex
|
|
|
|
lockWrite *sync.Mutex
|
|
|
|
rawInput *bytes.Buffer
|
|
|
|
maxPayload int
|
|
|
|
}
|
|
|
|
|
|
|
|
func newWarpConn(conn net.Conn, aead cipher.AEAD, overlayData byte, seq [8]byte) *warpConn {
|
|
|
|
incSeq(seq[:])
|
|
|
|
w := &warpConn{
|
|
|
|
Conn: conn,
|
|
|
|
lockRead: &sync.Mutex{},
|
|
|
|
lockWrite: &sync.Mutex{},
|
|
|
|
rawInput: &bytes.Buffer{},
|
|
|
|
maxPayload: 0xFFFF - aead.Overhead() - recordHeaderLen,
|
|
|
|
aead: aead,
|
|
|
|
overlayData: overlayData,
|
|
|
|
seq: seq[:],
|
|
|
|
}
|
|
|
|
return w
|
|
|
|
}
|
|
|
|
|
|
|
|
func (w *warpConn) Write(b []byte) (int, error) {
|
|
|
|
w.lockWrite.Lock()
|
|
|
|
defer w.lockWrite.Unlock()
|
|
|
|
wrote := 0
|
|
|
|
for len(b) > 0 {
|
|
|
|
m := len(b)
|
|
|
|
if m > w.maxPayload {
|
|
|
|
m = w.maxPayload
|
|
|
|
}
|
|
|
|
data := w.aead.Seal(nil, w.seq[:], b[:m], nil)
|
|
|
|
data = append(w.seq[:], data...)
|
|
|
|
record := newTLSRecord(recordTypeApplicationData, versionTLS12, data)
|
|
|
|
incSeq(w.seq)
|
|
|
|
_, err := record.writeTo(w.Conn)
|
|
|
|
if err != nil {
|
|
|
|
return 0, err
|
|
|
|
}
|
|
|
|
wrote += m
|
|
|
|
b = b[m:]
|
|
|
|
}
|
|
|
|
return wrote, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (w *warpConn) Read(b []byte) (int, error) {
|
|
|
|
w.lockRead.Lock()
|
|
|
|
defer w.lockRead.Unlock()
|
|
|
|
if w.rawInput.Len() != 0 {
|
|
|
|
// 缓存中有数据,从缓存返回
|
|
|
|
return w.rawInput.Read(b)
|
|
|
|
}
|
|
|
|
|
|
|
|
record, err := readTlsRecord(w.Conn)
|
|
|
|
if err != nil {
|
|
|
|
return 0, err
|
|
|
|
}
|
|
|
|
if record.recordType != recordTypeApplicationData {
|
|
|
|
return 0, ErrVerifyFailed
|
|
|
|
}
|
|
|
|
if record.version != versionTLS12 {
|
|
|
|
return 0, ErrVerifyFailed
|
|
|
|
}
|
|
|
|
data := record.recordData
|
|
|
|
plaintext, err := w.aead.Open(nil, data[:8], data[8:], nil)
|
|
|
|
if err != nil {
|
|
|
|
return 0, err
|
|
|
|
}
|
|
|
|
n := copy(b, plaintext)
|
|
|
|
if n < len(plaintext) {
|
|
|
|
w.rawInput.Write(plaintext[n:])
|
|
|
|
}
|
|
|
|
return n, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (w *warpConn) OverlayData() byte {
|
|
|
|
return w.overlayData
|
|
|
|
}
|
|
|
|
|
|
|
|
func incSeq(seq []byte) {
|
|
|
|
for i := 7; i >= 0; i-- {
|
|
|
|
seq[i]++
|
|
|
|
if seq[i] != 0 {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func GetLogger(debug bool) logrus.FieldLogger {
|
|
|
|
level := logrus.InfoLevel
|
|
|
|
if debug {
|
|
|
|
level = logrus.DebugLevel
|
|
|
|
}
|
|
|
|
logger := logrus.New()
|
|
|
|
logger.SetLevel(level)
|
|
|
|
logger.SetOutput(colorable.NewColorableStderr())
|
|
|
|
logger.Formatter = &logrus.TextFormatter{
|
|
|
|
ForceColors: true,
|
|
|
|
DisableTimestamp: true,
|
|
|
|
}
|
|
|
|
return logger
|
|
|
|
}
|