diff --git a/common/cache/cache.go b/common/cache/cache.go index e587d77b7..b87392b43 100644 --- a/common/cache/cache.go +++ b/common/cache/cache.go @@ -7,50 +7,50 @@ import ( ) // Cache store element with a expired time -type Cache struct { - *cache +type Cache[K comparable, V any] struct { + *cache[K, V] } -type cache struct { +type cache[K comparable, V any] struct { mapping sync.Map - janitor *janitor + janitor *janitor[K, V] } -type element struct { +type element[V any] struct { Expired time.Time - Payload any + Payload V } // Put element in Cache with its ttl -func (c *cache) Put(key any, payload any, ttl time.Duration) { - c.mapping.Store(key, &element{ +func (c *cache[K, V]) Put(key K, payload V, ttl time.Duration) { + c.mapping.Store(key, &element[V]{ Payload: payload, Expired: time.Now().Add(ttl), }) } // Get element in Cache, and drop when it expired -func (c *cache) Get(key any) any { +func (c *cache[K, V]) Get(key K) V { item, exist := c.mapping.Load(key) if !exist { - return nil + return getZero[V]() } - elm := item.(*element) + elm := item.(*element[V]) // expired if time.Since(elm.Expired) > 0 { c.mapping.Delete(key) - return nil + return getZero[V]() } return elm.Payload } // GetWithExpire element in Cache with Expire Time -func (c *cache) GetWithExpire(key any) (payload any, expired time.Time) { +func (c *cache[K, V]) GetWithExpire(key K) (payload V, expired time.Time) { item, exist := c.mapping.Load(key) if !exist { return } - elm := item.(*element) + elm := item.(*element[V]) // expired if time.Since(elm.Expired) > 0 { c.mapping.Delete(key) @@ -59,10 +59,10 @@ func (c *cache) GetWithExpire(key any) (payload any, expired time.Time) { return elm.Payload, elm.Expired } -func (c *cache) cleanup() { +func (c *cache[K, V]) cleanup() { c.mapping.Range(func(k, v any) bool { key := k.(string) - elm := v.(*element) + elm := v.(*element[V]) if time.Since(elm.Expired) > 0 { c.mapping.Delete(key) } @@ -70,12 +70,12 @@ func (c *cache) cleanup() { }) } -type janitor struct { +type janitor[K comparable, V any] struct { interval time.Duration stop chan struct{} } -func (j *janitor) process(c *cache) { +func (j *janitor[K, V]) process(c *cache[K, V]) { ticker := time.NewTicker(j.interval) for { select { @@ -88,19 +88,19 @@ func (j *janitor) process(c *cache) { } } -func stopJanitor(c *Cache) { +func stopJanitor[K comparable, V any](c *Cache[K, V]) { c.janitor.stop <- struct{}{} } // New return *Cache -func New(interval time.Duration) *Cache { - j := &janitor{ +func New[K comparable, V any](interval time.Duration) *Cache[K, V] { + j := &janitor[K, V]{ interval: interval, stop: make(chan struct{}), } - c := &cache{janitor: j} + c := &cache[K, V]{janitor: j} go j.process(c) - C := &Cache{c} - runtime.SetFinalizer(C, stopJanitor) + C := &Cache[K, V]{c} + runtime.SetFinalizer(C, stopJanitor[K, V]) return C } diff --git a/common/cache/cache_test.go b/common/cache/cache_test.go index cf4a39148..0945d905e 100644 --- a/common/cache/cache_test.go +++ b/common/cache/cache_test.go @@ -11,48 +11,50 @@ import ( func TestCache_Basic(t *testing.T) { interval := 200 * time.Millisecond ttl := 20 * time.Millisecond - c := New(interval) + c := New[string, int](interval) c.Put("int", 1, ttl) - c.Put("string", "a", ttl) + + d := New[string, string](interval) + d.Put("string", "a", ttl) i := c.Get("int") - assert.Equal(t, i.(int), 1, "should recv 1") + assert.Equal(t, i, 1, "should recv 1") - s := c.Get("string") - assert.Equal(t, s.(string), "a", "should recv 'a'") + s := d.Get("string") + assert.Equal(t, s, "a", "should recv 'a'") } func TestCache_TTL(t *testing.T) { interval := 200 * time.Millisecond ttl := 20 * time.Millisecond now := time.Now() - c := New(interval) + c := New[string, int](interval) c.Put("int", 1, ttl) c.Put("int2", 2, ttl) i := c.Get("int") _, expired := c.GetWithExpire("int2") - assert.Equal(t, i.(int), 1, "should recv 1") + assert.Equal(t, i, 1, "should recv 1") assert.True(t, now.Before(expired)) time.Sleep(ttl * 2) i = c.Get("int") j, _ := c.GetWithExpire("int2") - assert.Nil(t, i, "should recv nil") - assert.Nil(t, j, "should recv nil") + assert.True(t, i == 0, "should recv 0") + assert.True(t, j == 0, "should recv 0") } func TestCache_AutoCleanup(t *testing.T) { interval := 10 * time.Millisecond ttl := 15 * time.Millisecond - c := New(interval) + c := New[string, int](interval) c.Put("int", 1, ttl) time.Sleep(ttl * 2) i := c.Get("int") j, _ := c.GetWithExpire("int") - assert.Nil(t, i, "should recv nil") - assert.Nil(t, j, "should recv nil") + assert.True(t, i == 0, "should recv 0") + assert.True(t, j == 0, "should recv 0") } func TestCache_AutoGC(t *testing.T) { @@ -60,7 +62,7 @@ func TestCache_AutoGC(t *testing.T) { go func() { interval := 10 * time.Millisecond ttl := 15 * time.Millisecond - c := New(interval) + c := New[string, int](interval) c.Put("int", 1, ttl) sign <- struct{}{} }() diff --git a/listener/http/proxy.go b/listener/http/proxy.go index 18f1e5d47..e8a805a93 100644 --- a/listener/http/proxy.go +++ b/listener/http/proxy.go @@ -15,7 +15,7 @@ import ( "github.com/Dreamacro/clash/log" ) -func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache) { +func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) { client := newClient(c.RemoteAddr(), in) defer client.CloseIdleConnections() @@ -98,7 +98,7 @@ func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.Cache) { conn.Close() } -func authenticate(request *http.Request, cache *cache.Cache) *http.Response { +func authenticate(request *http.Request, cache *cache.Cache[string, bool]) *http.Response { authenticator := authStore.Authenticator() if authenticator != nil { credential := parseBasicProxyAuthorization(request) @@ -108,13 +108,13 @@ func authenticate(request *http.Request, cache *cache.Cache) *http.Response { return resp } - var authed any - if authed = cache.Get(credential); authed == nil { + var authed bool + if authed = cache.Get(credential); !authed { user, pass, err := decodeBasicProxyAuthorization(credential) authed = err == nil && authenticator.Verify(user, pass) cache.Put(credential, authed, time.Minute) } - if !authed.(bool) { + if !authed { log.Infoln("Auth failed from %s", request.RemoteAddr) return responseWith(request, http.StatusForbidden) diff --git a/listener/http/server.go b/listener/http/server.go index bfdd9f1b6..6b9661439 100644 --- a/listener/http/server.go +++ b/listener/http/server.go @@ -40,9 +40,9 @@ func NewWithAuthenticate(addr string, in chan<- C.ConnContext, authenticate bool return nil, err } - var c *cache.Cache + var c *cache.Cache[string, bool] if authenticate { - c = cache.New(time.Second * 30) + c = cache.New[string, bool](time.Second * 30) } hl := &Listener{ diff --git a/listener/mixed/mixed.go b/listener/mixed/mixed.go index 57fd055e0..14a81bc38 100644 --- a/listener/mixed/mixed.go +++ b/listener/mixed/mixed.go @@ -16,7 +16,7 @@ import ( type Listener struct { listener net.Listener addr string - cache *cache.Cache + cache *cache.Cache[string, bool] closed bool } @@ -45,7 +45,7 @@ func New(addr string, in chan<- C.ConnContext) (*Listener, error) { ml := &Listener{ listener: l, addr: addr, - cache: cache.New(30 * time.Second), + cache: cache.New[string, bool](30 * time.Second), } go func() { for { @@ -63,7 +63,7 @@ func New(addr string, in chan<- C.ConnContext) (*Listener, error) { return ml, nil } -func handleConn(conn net.Conn, in chan<- C.ConnContext, cache *cache.Cache) { +func handleConn(conn net.Conn, in chan<- C.ConnContext, cache *cache.Cache[string, bool]) { conn.(*net.TCPConn).SetKeepAlive(true) bufConn := N.NewBufferedConn(conn)