Clash.Meta/transport/anytls/session/client.go

161 lines
3.7 KiB
Go

package session
import (
"context"
"fmt"
"io"
"math"
"net"
"sync"
"sync/atomic"
"time"
"github.com/metacubex/mihomo/transport/anytls/padding"
"github.com/metacubex/mihomo/transport/anytls/skiplist"
"github.com/metacubex/mihomo/transport/anytls/util"
"github.com/sagernet/sing/common"
singAtomic "github.com/sagernet/sing/common/atomic"
)
type Client struct {
die context.Context
dieCancel context.CancelFunc
dialOut func(ctx context.Context) (net.Conn, error)
sessionCounter atomic.Uint64
idleSession *skiplist.SkipList[uint64, *Session]
idleSessionLock sync.Mutex
padding *singAtomic.TypedValue[*padding.PaddingFactory]
idleSessionTimeout time.Duration
}
func NewClient(ctx context.Context, dialOut func(ctx context.Context) (net.Conn, error), _padding *singAtomic.TypedValue[*padding.PaddingFactory], idleSessionCheckInterval, idleSessionTimeout time.Duration) *Client {
c := &Client{
dialOut: dialOut,
padding: _padding,
idleSessionTimeout: idleSessionTimeout,
}
if idleSessionCheckInterval <= time.Second*5 {
idleSessionCheckInterval = time.Second * 30
}
if c.idleSessionTimeout <= time.Second*5 {
c.idleSessionTimeout = time.Second * 30
}
c.die, c.dieCancel = context.WithCancel(ctx)
c.idleSession = skiplist.NewSkipList[uint64, *Session]()
util.StartRoutine(c.die, idleSessionCheckInterval, c.idleCleanup)
return c
}
func (c *Client) CreateStream(ctx context.Context) (net.Conn, error) {
select {
case <-c.die.Done():
return nil, io.ErrClosedPipe
default:
}
var session *Session
var stream *Stream
var err error
for i := 0; i < 3; i++ {
session, err = c.findSession(ctx)
if session == nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
stream, err = session.OpenStream()
if err != nil {
common.Close(session, stream)
continue
}
break
}
if session == nil || stream == nil {
return nil, fmt.Errorf("too many closed session: %w", err)
}
stream.dieHook = func() {
if session.IsClosed() {
if session.dieHook != nil {
session.dieHook()
}
} else {
c.idleSessionLock.Lock()
session.idleSince = time.Now()
c.idleSession.Insert(math.MaxUint64-session.seq, session)
c.idleSessionLock.Unlock()
}
}
return stream, nil
}
func (c *Client) findSession(ctx context.Context) (*Session, error) {
var idle *Session
c.idleSessionLock.Lock()
if !c.idleSession.IsEmpty() {
it := c.idleSession.Iterate()
idle = it.Value()
c.idleSession.Remove(it.Key())
}
c.idleSessionLock.Unlock()
if idle == nil {
s, err := c.createSession(ctx)
return s, err
}
return idle, nil
}
func (c *Client) createSession(ctx context.Context) (*Session, error) {
underlying, err := c.dialOut(ctx)
if err != nil {
return nil, err
}
session := NewClientSession(underlying, c.padding)
session.seq = c.sessionCounter.Add(1)
session.dieHook = func() {
//logrus.Debugln("session died", session)
c.idleSessionLock.Lock()
c.idleSession.Remove(math.MaxUint64 - session.seq)
c.idleSessionLock.Unlock()
}
session.Run(false)
return session, nil
}
func (c *Client) Close() error {
c.dieCancel()
go c.idleCleanupExpTime(time.Now())
return nil
}
func (c *Client) idleCleanup() {
c.idleCleanupExpTime(time.Now().Add(-c.idleSessionTimeout))
}
func (c *Client) idleCleanupExpTime(expTime time.Time) {
var sessionToRemove = make([]*Session, 0)
c.idleSessionLock.Lock()
it := c.idleSession.Iterate()
for it.IsNotEnd() {
session := it.Value()
if session.idleSince.Before(expTime) {
sessionToRemove = append(sessionToRemove, session)
c.idleSession.Remove(it.Key())
}
it.MoveToNext()
}
c.idleSessionLock.Unlock()
for _, session := range sessionToRemove {
session.Close()
}
}