fix: wrong usage of RLock

This commit is contained in:
wwqgtxx 2024-07-22 09:57:57 +08:00
parent fd5b537ab1
commit 4eb13a73bf
5 changed files with 42 additions and 20 deletions

View File

@ -205,7 +205,6 @@ func strategyStickySessions(url string) strategyFn {
proxy := proxies[nowIdx] proxy := proxies[nowIdx]
if proxy.AliveForTestUrl(url) { if proxy.AliveForTestUrl(url) {
if nowIdx != idx { if nowIdx != idx {
lruCache.Delete(key)
lruCache.Set(key, nowIdx) lruCache.Set(key, nowIdx)
} }
@ -215,7 +214,6 @@ func strategyStickySessions(url string) strategyFn {
} }
} }
lruCache.Delete(key)
lruCache.Set(key, 0) lruCache.Set(key, 0)
return proxies[0] return proxies[0]
} }

View File

@ -223,6 +223,10 @@ func (c *LruCache[K, V]) Delete(key K) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.delete(key)
}
func (c *LruCache[K, V]) delete(key K) {
if le, ok := c.cache[key]; ok { if le, ok := c.cache[key]; ok {
c.deleteElement(le) c.deleteElement(le)
} }
@ -255,6 +259,34 @@ func (c *LruCache[K, V]) Clear() error {
return nil return nil
} }
// Compute either sets the computed new value for the key or deletes
// the value for the key. When the delete result of the valueFn function
// is set to true, the value will be deleted, if it exists. When delete
// is set to false, the value is updated to the newValue.
// The ok result indicates whether value was computed and stored, thus, is
// present in the map. The actual result contains the new value in cases where
// the value was computed and stored.
func (c *LruCache[K, V]) Compute(
key K,
valueFn func(oldValue V, loaded bool) (newValue V, delete bool),
) (actual V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
if el := c.get(key); el != nil {
actual, ok = el.value, true
}
if newValue, del := valueFn(actual, ok); del {
if ok { // data not in cache, so needn't delete
c.delete(key)
}
return lo.Empty[V](), false
} else {
c.set(key, newValue)
return newValue, true
}
}
type entry[K comparable, V any] struct { type entry[K comparable, V any] struct {
key K key K
value V value V

View File

@ -59,8 +59,8 @@ func (q *Queue[T]) Copy() []T {
// Len returns the number of items in this queue. // Len returns the number of items in this queue.
func (q *Queue[T]) Len() int64 { func (q *Queue[T]) Len() int64 {
q.lock.Lock() q.lock.RLock()
defer q.lock.Unlock() defer q.lock.RUnlock()
return int64(len(q.items)) return int64(len(q.items))
} }

View File

@ -17,8 +17,8 @@ func NewCallback[T any]() *Callback[T] {
} }
func (c *Callback[T]) Register(item func(T)) io.Closer { func (c *Callback[T]) Register(item func(T)) io.Closer {
c.mutex.RLock() c.mutex.Lock()
defer c.mutex.RUnlock() defer c.mutex.Unlock()
element := c.list.PushBack(item) element := c.list.PushBack(item)
return &callbackCloser[T]{ return &callbackCloser[T]{
element: element, element: element,

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"sync"
"time" "time"
"github.com/metacubex/mihomo/common/lru" "github.com/metacubex/mihomo/common/lru"
@ -30,7 +29,6 @@ type SnifferDispatcher struct {
forceDomain *trie.DomainSet forceDomain *trie.DomainSet
skipSNI *trie.DomainSet skipSNI *trie.DomainSet
skipList *lru.LruCache[string, uint8] skipList *lru.LruCache[string, uint8]
rwMux sync.RWMutex
forceDnsMapping bool forceDnsMapping bool
parsePureIp bool parsePureIp bool
} }
@ -85,14 +83,11 @@ func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata
return false return false
} }
sd.rwMux.RLock()
dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort) dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort)
if count, ok := sd.skipList.Get(dst); ok && count > 5 { if count, ok := sd.skipList.Get(dst); ok && count > 5 {
log.Debugln("[Sniffer] Skip sniffing[%s] due to multiple failures", dst) log.Debugln("[Sniffer] Skip sniffing[%s] due to multiple failures", dst)
defer sd.rwMux.RUnlock()
return false return false
} }
sd.rwMux.RUnlock()
if host, err := sd.sniffDomain(conn, metadata); err != nil { if host, err := sd.sniffDomain(conn, metadata); err != nil {
sd.cacheSniffFailed(metadata) sd.cacheSniffFailed(metadata)
@ -104,9 +99,7 @@ func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata
return false return false
} }
sd.rwMux.RLock()
sd.skipList.Delete(dst) sd.skipList.Delete(dst)
sd.rwMux.RUnlock()
sd.replaceDomain(metadata, host, overrideDest) sd.replaceDomain(metadata, host, overrideDest)
return true return true
@ -176,14 +169,13 @@ func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metad
} }
func (sd *SnifferDispatcher) cacheSniffFailed(metadata *C.Metadata) { func (sd *SnifferDispatcher) cacheSniffFailed(metadata *C.Metadata) {
sd.rwMux.Lock()
dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort) dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort)
count, _ := sd.skipList.Get(dst) sd.skipList.Compute(dst, func(oldValue uint8, loaded bool) (newValue uint8, delete bool) {
if count <= 5 { if oldValue <= 5 {
count++ oldValue++
} }
sd.skipList.Set(dst, count) return oldValue, false
sd.rwMux.Unlock() })
} }
func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) { func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) {