diff --git a/adapter/adapter.go b/adapter/adapter.go index 787a45b13..526866a5c 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -136,6 +136,8 @@ func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) { return http.ErrUseLastResponse }, } + defer client.CloseIdleConnections() + resp, err := client.Do(req) if err != nil { return diff --git a/adapter/provider/healthcheck.go b/adapter/provider/healthcheck.go index 98f934e42..636873d94 100644 --- a/adapter/provider/healthcheck.go +++ b/adapter/provider/healthcheck.go @@ -2,9 +2,9 @@ package provider import ( "context" - "sync" "time" + "github.com/Dreamacro/clash/common/batch" C "github.com/Dreamacro/clash/constant" "go.uber.org/atomic" @@ -60,19 +60,16 @@ func (hc *HealthCheck) touch() { func (hc *HealthCheck) check() { ctx, cancel := context.WithTimeout(context.Background(), defaultURLTestTimeout) - wg := &sync.WaitGroup{} + defer cancel() + b, ctx := batch.WithContext(ctx, batch.WithConcurrencyNum(10)) for _, proxy := range hc.proxies { - wg.Add(1) - - go func(p C.Proxy) { - p.URLTest(ctx, hc.url) - wg.Done() - }(proxy) + p := proxy + b.Go(p.Name(), func() (interface{}, error) { + return p.URLTest(ctx, hc.url) + }) } - - wg.Wait() - cancel() + b.Wait() } func (hc *HealthCheck) close() { diff --git a/common/batch/batch.go b/common/batch/batch.go new file mode 100644 index 000000000..7fbca4217 --- /dev/null +++ b/common/batch/batch.go @@ -0,0 +1,111 @@ +package batch + +import ( + "context" + "sync" +) + +type Option = func(b *Batch) + +type Result struct { + Value interface{} + Err error +} + +type Error struct { + Key string + Err error +} + +func WithConcurrencyNum(n int) Option { + return func(b *Batch) { + q := make(chan struct{}, n) + for i := 0; i < n; i++ { + q <- struct{}{} + } + b.queue = q + } +} + +// Batch similar to errgroup, but can control the maximum number of concurrent +type Batch struct { + result map[string]Result + queue chan struct{} + wg sync.WaitGroup + mux sync.Mutex + err *Error + once sync.Once + cancel func() +} + +func (b *Batch) Go(key string, fn func() (interface{}, error)) { + b.wg.Add(1) + go func() { + defer b.wg.Done() + if b.queue != nil { + <-b.queue + defer func() { + b.queue <- struct{}{} + }() + } + + value, err := fn() + if err != nil { + b.once.Do(func() { + b.err = &Error{key, err} + if b.cancel != nil { + b.cancel() + } + }) + } + + ret := Result{value, err} + b.mux.Lock() + defer b.mux.Unlock() + b.result[key] = ret + }() +} + +func (b *Batch) Wait() *Error { + b.wg.Wait() + if b.cancel != nil { + b.cancel() + } + return b.err +} + +func (b *Batch) WaitAndGetResult() (map[string]Result, *Error) { + err := b.Wait() + return b.Result(), err +} + +func (b *Batch) Result() map[string]Result { + b.mux.Lock() + defer b.mux.Unlock() + copy := map[string]Result{} + for k, v := range b.result { + copy[k] = v + } + return copy +} + +func New(opts ...Option) *Batch { + b := &Batch{ + result: map[string]Result{}, + } + + for _, o := range opts { + o(b) + } + + return b +} + +func WithContext(ctx context.Context, opts ...Option) (*Batch, context.Context) { + ctx, cancel := context.WithCancel(ctx) + + b := New(opts...) + b.cancel = cancel + + return b, ctx +} diff --git a/common/batch/batch_test.go b/common/batch/batch_test.go new file mode 100644 index 000000000..4fcdbe81c --- /dev/null +++ b/common/batch/batch_test.go @@ -0,0 +1,82 @@ +package batch + +import ( + "context" + "errors" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBatch(t *testing.T) { + b := New() + + now := time.Now() + b.Go("foo", func() (interface{}, error) { + time.Sleep(time.Millisecond * 100) + return "foo", nil + }) + b.Go("bar", func() (interface{}, error) { + time.Sleep(time.Millisecond * 150) + return "bar", nil + }) + result, err := b.WaitAndGetResult() + + assert.Nil(t, err) + + duration := time.Since(now) + assert.Less(t, duration, time.Millisecond*200) + assert.Equal(t, 2, len(result)) + + for k, v := range result { + assert.NoError(t, v.Err) + assert.Equal(t, k, v.Value.(string)) + } +} + +func TestBatchWithConcurrencyNum(t *testing.T) { + b := New( + WithConcurrencyNum(3), + ) + + now := time.Now() + for i := 0; i < 7; i++ { + idx := i + b.Go(strconv.Itoa(idx), func() (interface{}, error) { + time.Sleep(time.Millisecond * 100) + return strconv.Itoa(idx), nil + }) + } + result, _ := b.WaitAndGetResult() + duration := time.Since(now) + assert.Greater(t, duration, time.Millisecond*260) + assert.Equal(t, 7, len(result)) + + for k, v := range result { + assert.NoError(t, v.Err) + assert.Equal(t, k, v.Value.(string)) + } +} + +func TestBatchContext(t *testing.T) { + b, ctx := WithContext(context.Background()) + + b.Go("error", func() (interface{}, error) { + time.Sleep(time.Millisecond * 100) + return nil, errors.New("test error") + }) + + b.Go("ctx", func() (interface{}, error) { + <-ctx.Done() + return nil, ctx.Err() + }) + + result, err := b.WaitAndGetResult() + + assert.NotNil(t, err) + assert.Equal(t, "error", err.Key) + + assert.Equal(t, ctx.Err(), result["ctx"].Err) +}