reality/utils.go

255 lines
5.9 KiB
Go
Raw Permalink Normal View History

2024-10-10 11:08:23 +08:00
package reality
import (
"bytes"
"crypto/cipher"
"crypto/sha256"
"encoding/binary"
"errors"
"io"
"math/rand"
"net"
"sync"
"time"
"github.com/mattn/go-colorable"
utls "github.com/refraction-networking/utls"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/hkdf"
)
var (
ErrVerifyFailed = errors.New("verify failed")
ErrDecryptFailed = errors.New("decrypt failed")
ErrProxyDie = errors.New("proxy die")
)
var Prefix = []byte("REALITY")
const DefaultExpireSecond = 30
var seqNumerOne = [8]byte{0, 0, 0, 0, 0, 0, 0, 1}
// generateNonce 根据SessionKey和ExpireSecond生成Nonce
func generateNonce(NonceSize int, SessionKey []byte, ExpireSecond uint32) ([]byte, error) {
info := make([]byte, 8)
2024-10-16 13:56:38 +08:00
binary.BigEndian.PutUint64(info, uint64(time.Now().Unix()/int64(ExpireSecond)))
2024-10-10 11:08:23 +08:00
nonce := make([]byte, NonceSize)
_, err := hkdf.New(sha256.New, SessionKey[:], Prefix, info).Read(nonce[:])
if err != nil {
return nil, err
}
return nonce, nil
}
var versionTLS12 = uint16(utls.VersionTLS12)
const recordHeaderLen = 5
const (
recordTypeChangeCipherSpec = 20
recordTypeAlert = 21
recordTypeHandshake = 22
recordTypeApplicationData = 23
)
const (
typeHelloRequest uint8 = 0
typeClientHello uint8 = 1
typeServerHello uint8 = 2
typeNewSessionTicket uint8 = 4
typeEndOfEarlyData uint8 = 5
typeEncryptedExtensions uint8 = 8
typeCertificate uint8 = 11
typeServerKeyExchange uint8 = 12
typeCertificateRequest uint8 = 13
typeServerHelloDone uint8 = 14
typeCertificateVerify uint8 = 15
typeClientKeyExchange uint8 = 16
typeFinished uint8 = 20
typeCertificateStatus uint8 = 22
typeKeyUpdate uint8 = 24
)
type tlsRecord struct {
recordType uint8
version uint16
recordData []byte
}
func newTLSRecord(recordType uint8, version uint16, recordData []byte) *tlsRecord {
return &tlsRecord{
recordType: recordType,
version: version,
recordData: recordData,
}
}
func (r *tlsRecord) marshal() []byte {
data := make([]byte, recordHeaderLen+len(r.recordData))
data[0] = r.recordType
data[1] = byte(r.version >> 8)
data[2] = byte(r.version)
data[3] = byte(len(r.recordData) >> 8)
data[4] = byte(len(r.recordData))
copy(data[5:], r.recordData)
return data
}
func (r *tlsRecord) writeTo(w io.Writer) (int, error) {
n, err := bytes.NewReader(r.marshal()).WriteTo(w)
return int(n), err
}
func readTlsRecord(reader io.Reader) (*tlsRecord, error) {
hdr := make([]byte, recordHeaderLen)
if _, err := io.ReadFull(reader, hdr); err != nil {
return nil, err
}
recordType := hdr[0]
if recordType < recordTypeChangeCipherSpec || recordType > recordTypeApplicationData {
return nil, errors.New("tls: unknown record type")
}
version := uint16(hdr[1])<<8 | uint16(hdr[2])
if version < utls.VersionTLS10 || version > utls.VersionTLS13 {
return nil, errors.New("tls: unknown version")
}
recordLen := int(hdr[3])<<8 | int(hdr[4])
recordData := make([]byte, recordLen)
if _, err := io.ReadFull(reader, recordData); err != nil {
return nil, err
}
return &tlsRecord{
recordType: recordType,
version: version,
recordData: recordData,
}, nil
}
const maxSize = 1400
const minSize = 900
var r = rand.New(rand.NewSource(time.Now().UnixNano()))
// generateRandomData 生成随机900-1400数据
func generateRandomData(prefix []byte) []byte {
len := r.Intn(maxSize-minSize+1) + minSize
data := make([]byte, len)
r.Read(data)
copy(data, prefix)
return data
}
type OverlayData interface {
OverlayData() byte
}
var _ OverlayData = (*warpConn)(nil)
type warpConn struct {
net.Conn
aead cipher.AEAD
overlayData byte
seq []byte
lockRead *sync.Mutex
lockWrite *sync.Mutex
rawInput *bytes.Buffer
maxPayload int
}
func newWarpConn(conn net.Conn, aead cipher.AEAD, overlayData byte, seq [8]byte) *warpConn {
incSeq(seq[:])
w := &warpConn{
Conn: conn,
lockRead: &sync.Mutex{},
lockWrite: &sync.Mutex{},
rawInput: &bytes.Buffer{},
maxPayload: 0xFFFF - aead.Overhead() - recordHeaderLen,
aead: aead,
overlayData: overlayData,
seq: seq[:],
}
return w
}
func (w *warpConn) Write(b []byte) (int, error) {
w.lockWrite.Lock()
defer w.lockWrite.Unlock()
wrote := 0
for len(b) > 0 {
m := len(b)
if m > w.maxPayload {
m = w.maxPayload
}
data := w.aead.Seal(nil, w.seq[:], b[:m], nil)
data = append(w.seq[:], data...)
record := newTLSRecord(recordTypeApplicationData, versionTLS12, data)
incSeq(w.seq)
_, err := record.writeTo(w.Conn)
if err != nil {
return 0, err
}
wrote += m
b = b[m:]
}
return wrote, nil
}
func (w *warpConn) Read(b []byte) (int, error) {
w.lockRead.Lock()
defer w.lockRead.Unlock()
if w.rawInput.Len() != 0 {
// 缓存中有数据,从缓存返回
return w.rawInput.Read(b)
}
record, err := readTlsRecord(w.Conn)
if err != nil {
return 0, err
}
if record.recordType != recordTypeApplicationData {
return 0, ErrVerifyFailed
}
if record.version != versionTLS12 {
return 0, ErrVerifyFailed
}
data := record.recordData
plaintext, err := w.aead.Open(nil, data[:8], data[8:], nil)
if err != nil {
return 0, err
}
n := copy(b, plaintext)
if n < len(plaintext) {
w.rawInput.Write(plaintext[n:])
}
return n, nil
}
func (w *warpConn) OverlayData() byte {
return w.overlayData
}
func incSeq(seq []byte) {
for i := 7; i >= 0; i-- {
seq[i]++
if seq[i] != 0 {
return
}
}
}
func GetLogger(debug bool) logrus.FieldLogger {
level := logrus.InfoLevel
if debug {
level = logrus.DebugLevel
}
logger := logrus.New()
logger.SetLevel(level)
logger.SetOutput(colorable.NewColorableStderr())
logger.Formatter = &logrus.TextFormatter{
ForceColors: true,
DisableTimestamp: true,
}
return logger
}