Feature: add default-nameserver and outbound interface

This commit is contained in:
Dreamacro 2020-02-15 21:42:46 +08:00
parent f69f635e0b
commit d75cb069d9
28 changed files with 578 additions and 347 deletions

View File

@ -114,6 +114,7 @@ external-controller: 127.0.0.1:9090
# experimental feature # experimental feature
experimental: experimental:
ignore-resolve-fail: true # ignore dns resolve fail, default value is true ignore-resolve-fail: true # ignore dns resolve fail, default value is true
# interface-name: en0 # outbound interface name
# authentication of local SOCKS5/HTTP(S) server # authentication of local SOCKS5/HTTP(S) server
# authentication: # authentication:
@ -130,6 +131,9 @@ experimental:
# enable: true # set true to enable dns (default is false) # enable: true # set true to enable dns (default is false)
# ipv6: false # default is false # ipv6: false # default is false
# listen: 0.0.0.0:53 # listen: 0.0.0.0:53
# # default-nameserver: # resolve dns nameserver host, should fill pure IP
# # - 114.114.114.114
# # - 8.8.8.8
# enhanced-mode: redir-host # or fake-ip # enhanced-mode: redir-host # or fake-ip
# # fake-ip-range: 198.18.0.1/16 # if you don't know what it is, don't change it # # fake-ip-range: 198.18.0.1/16 # if you don't know what it is, don't change it
# fake-ip-filter: # fake ip white domain list # fake-ip-filter: # fake ip white domain list

View File

@ -18,7 +18,7 @@ func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn,
address = net.JoinHostPort(metadata.DstIP.String(), metadata.DstPort) address = net.JoinHostPort(metadata.DstIP.String(), metadata.DstPort)
} }
c, err := dialContext(ctx, "tcp", address) c, err := dialer.DialContext(ctx, "tcp", address)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -13,6 +13,7 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"github.com/Dreamacro/clash/component/dialer"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
) )
@ -35,7 +36,7 @@ type HttpOption struct {
} }
func (h *Http) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (h *Http) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialContext(ctx, "tcp", h.addr) c, err := dialer.DialContext(ctx, "tcp", h.addr)
if err == nil && h.tlsConfig != nil { if err == nil && h.tlsConfig != nil {
cc := tls.Client(c, h.tlsConfig) cc := tls.Client(c, h.tlsConfig)
err = cc.Handshake() err = cc.Handshake()

View File

@ -60,7 +60,7 @@ type v2rayObfsOption struct {
} }
func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialContext(ctx, "tcp", ss.server) c, err := dialer.DialContext(ctx, "tcp", ss.server)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ss.server, err) return nil, fmt.Errorf("%s connect error: %w", ss.server, err)
} }

View File

@ -7,6 +7,7 @@ import (
"strconv" "strconv"
"github.com/Dreamacro/clash/common/structure" "github.com/Dreamacro/clash/common/structure"
"github.com/Dreamacro/clash/component/dialer"
obfs "github.com/Dreamacro/clash/component/simple-obfs" obfs "github.com/Dreamacro/clash/component/simple-obfs"
"github.com/Dreamacro/clash/component/snell" "github.com/Dreamacro/clash/component/snell"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
@ -28,7 +29,7 @@ type SnellOption struct {
} }
func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialContext(ctx, "tcp", s.server) c, err := dialer.DialContext(ctx, "tcp", s.server)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", s.server, err) return nil, fmt.Errorf("%s connect error: %w", s.server, err)
} }

View File

@ -36,7 +36,7 @@ type Socks5Option struct {
} }
func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialContext(ctx, "tcp", ss.addr) c, err := dialer.DialContext(ctx, "tcp", ss.addr)
if err == nil && ss.tls { if err == nil && ss.tls {
cc := tls.Client(c, ss.tlsConfig) cc := tls.Client(c, ss.tlsConfig)
@ -64,7 +64,7 @@ func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn
func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) { func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, err error) {
ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout) ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout)
defer cancel() defer cancel()
c, err := dialContext(ctx, "tcp", ss.addr) c, err := dialer.DialContext(ctx, "tcp", ss.addr)
if err != nil { if err != nil {
err = fmt.Errorf("%s connect error: %w", ss.addr, err) err = fmt.Errorf("%s connect error: %w", ss.addr, err)
return return

View File

@ -2,7 +2,6 @@ package outbound
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"net" "net"
@ -11,10 +10,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/resolver"
"github.com/Dreamacro/clash/component/socks5" "github.com/Dreamacro/clash/component/socks5"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/dns"
) )
const ( const (
@ -88,92 +86,13 @@ func serializesSocksAddr(metadata *C.Metadata) []byte {
return bytes.Join(buf, nil) return bytes.Join(buf, nil)
} }
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
returned := make(chan struct{})
defer close(returned)
type dialResult struct {
net.Conn
error
resolved bool
ipv6 bool
done bool
}
results := make(chan dialResult)
var primary, fallback dialResult
startRacer := func(ctx context.Context, host string, ipv6 bool) {
dialer := dialer.Dialer()
result := dialResult{ipv6: ipv6, done: true}
defer func() {
select {
case results <- result:
case <-returned:
if result.Conn != nil {
result.Conn.Close()
}
}
}()
var ip net.IP
if ipv6 {
ip, result.error = dns.ResolveIPv6(host)
} else {
ip, result.error = dns.ResolveIPv4(host)
}
if result.error != nil {
return
}
result.resolved = true
if ipv6 {
result.Conn, result.error = dialer.DialContext(ctx, "tcp6", net.JoinHostPort(ip.String(), port))
} else {
result.Conn, result.error = dialer.DialContext(ctx, "tcp4", net.JoinHostPort(ip.String(), port))
}
}
go startRacer(ctx, host, false)
go startRacer(ctx, host, true)
for {
select {
case res := <-results:
if res.error == nil {
return res.Conn, nil
}
if !res.ipv6 {
primary = res
} else {
fallback = res
}
if primary.done && fallback.done {
if primary.resolved {
return nil, primary.error
} else if fallback.resolved {
return nil, fallback.error
} else {
return nil, primary.error
}
}
}
}
}
func resolveUDPAddr(network, address string) (*net.UDPAddr, error) { func resolveUDPAddr(network, address string) (*net.UDPAddr, error) {
host, port, err := net.SplitHostPort(address) host, port, err := net.SplitHostPort(address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ip, err := dns.ResolveIP(host) ip, err := resolver.ResolveIP(host)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -7,6 +7,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/Dreamacro/clash/component/dialer"
"github.com/Dreamacro/clash/component/vmess" "github.com/Dreamacro/clash/component/vmess"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
) )
@ -33,7 +34,7 @@ type VmessOption struct {
} }
func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialContext(ctx, "tcp", v.server) c, err := dialer.DialContext(ctx, "tcp", v.server)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error", v.server) return nil, fmt.Errorf("%s connect error", v.server)
} }
@ -45,7 +46,7 @@ func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn,
func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, error) { func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout) ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout)
defer cancel() defer cancel()
c, err := dialContext(ctx, "tcp", v.server) c, err := dialer.DialContext(ctx, "tcp", v.server)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error", v.server) return nil, fmt.Errorf("%s connect error", v.server)
} }

View File

@ -2,7 +2,10 @@ package dialer
import ( import (
"context" "context"
"errors"
"net" "net"
"github.com/Dreamacro/clash/component/resolver"
) )
func Dialer() *net.Dialer { func Dialer() *net.Dialer {
@ -28,11 +31,124 @@ func Dial(network, address string) (net.Conn, error) {
} }
func DialContext(ctx context.Context, network, address string) (net.Conn, error) { func DialContext(ctx context.Context, network, address string) (net.Conn, error) {
dailer := Dialer() switch network {
return dailer.DialContext(ctx, network, address) case "tcp4", "tcp6", "udp4", "udp6":
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
dialer := Dialer()
var ip net.IP
switch network {
case "tcp4", "udp4":
ip, err = resolver.ResolveIPv4(host)
default:
ip, err = resolver.ResolveIPv6(host)
}
if err != nil {
return nil, err
}
if DialHook != nil {
DialHook(dialer, network, ip)
}
return dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
case "tcp", "udp":
return dualStackDailContext(ctx, network, address)
default:
return nil, errors.New("network invalid")
}
} }
func ListenPacket(network, address string) (net.PacketConn, error) { func ListenPacket(network, address string) (net.PacketConn, error) {
lc := ListenConfig() lc := ListenConfig()
if ListenPacketHook != nil && address == "" {
ip := ListenPacketHook()
if ip != nil {
address = net.JoinHostPort(ip.String(), "0")
}
}
return lc.ListenPacket(context.Background(), network, address) return lc.ListenPacket(context.Background(), network, address)
} }
func dualStackDailContext(ctx context.Context, network, address string) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
returned := make(chan struct{})
defer close(returned)
type dialResult struct {
net.Conn
error
resolved bool
ipv6 bool
done bool
}
results := make(chan dialResult)
var primary, fallback dialResult
startRacer := func(ctx context.Context, network, host string, ipv6 bool) {
dialer := Dialer()
result := dialResult{ipv6: ipv6, done: true}
defer func() {
select {
case results <- result:
case <-returned:
if result.Conn != nil {
result.Conn.Close()
}
}
}()
var ip net.IP
if ipv6 {
ip, result.error = resolver.ResolveIPv6(host)
} else {
ip, result.error = resolver.ResolveIPv4(host)
}
if result.error != nil {
return
}
result.resolved = true
if DialHook != nil {
DialHook(dialer, network, ip)
}
result.Conn, result.error = dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port))
}
go startRacer(ctx, network+"4", host, false)
go startRacer(ctx, network+"6", host, true)
for {
select {
case res := <-results:
if res.error == nil {
return res.Conn, nil
}
if !res.ipv6 {
primary = res
} else {
fallback = res
}
if primary.done && fallback.done {
if primary.resolved {
return nil, primary.error
} else if fallback.resolved {
return nil, fallback.error
} else {
return nil, primary.error
}
}
}
}
}

View File

@ -1,11 +1,142 @@
package dialer package dialer
import "net" import (
"errors"
"net"
"time"
type DialerHookFunc = func(*net.Dialer) "github.com/Dreamacro/clash/common/singledo"
)
type DialerHookFunc = func(dialer *net.Dialer)
type DialHookFunc = func(dialer *net.Dialer, network string, ip net.IP)
type ListenConfigHookFunc = func(*net.ListenConfig) type ListenConfigHookFunc = func(*net.ListenConfig)
type ListenPacketHookFunc = func() net.IP
var ( var (
DialerHook DialerHookFunc = nil DialerHook DialerHookFunc
ListenConfigHook ListenConfigHookFunc = nil DialHook DialHookFunc
ListenConfigHook ListenConfigHookFunc
ListenPacketHook ListenPacketHookFunc
) )
var (
ErrAddrNotFound = errors.New("addr not found")
ErrNetworkNotSupport = errors.New("network not support")
)
func lookupTCPAddr(ip net.IP, addrs []net.Addr) (*net.TCPAddr, error) {
ipv4 := ip.To4() != nil
for _, elm := range addrs {
addr, ok := elm.(*net.IPNet)
if !ok {
continue
}
addrV4 := addr.IP.To4() != nil
if addrV4 && ipv4 {
return &net.TCPAddr{IP: addr.IP, Port: 0}, nil
} else if !addrV4 && !ipv4 {
return &net.TCPAddr{IP: addr.IP, Port: 0}, nil
}
}
return nil, ErrAddrNotFound
}
func lookupUDPAddr(ip net.IP, addrs []net.Addr) (*net.UDPAddr, error) {
ipv4 := ip.To4() != nil
for _, elm := range addrs {
addr, ok := elm.(*net.IPNet)
if !ok {
continue
}
addrV4 := addr.IP.To4() != nil
if addrV4 && ipv4 {
return &net.UDPAddr{IP: addr.IP, Port: 0}, nil
} else if !addrV4 && !ipv4 {
return &net.UDPAddr{IP: addr.IP, Port: 0}, nil
}
}
return nil, ErrAddrNotFound
}
func ListenPacketWithInterface(name string) ListenPacketHookFunc {
single := singledo.NewSingle(5 * time.Second)
return func() net.IP {
elm, err, _ := single.Do(func() (interface{}, error) {
iface, err := net.InterfaceByName(name)
if err != nil {
return nil, err
}
addrs, err := iface.Addrs()
if err != nil {
return nil, err
}
return addrs, nil
})
if err != nil {
return nil
}
addrs := elm.([]net.Addr)
for _, elm := range addrs {
addr, ok := elm.(*net.IPNet)
if !ok || addr.IP.To4() == nil {
continue
}
return addr.IP
}
return nil
}
}
func DialerWithInterface(name string) DialHookFunc {
single := singledo.NewSingle(5 * time.Second)
return func(dialer *net.Dialer, network string, ip net.IP) {
elm, err, _ := single.Do(func() (interface{}, error) {
iface, err := net.InterfaceByName(name)
if err != nil {
return nil, err
}
addrs, err := iface.Addrs()
if err != nil {
return nil, err
}
return addrs, nil
})
if err != nil {
return
}
addrs := elm.([]net.Addr)
switch network {
case "tcp", "tcp4", "tcp6":
if addr, err := lookupTCPAddr(ip, addrs); err == nil {
dialer.LocalAddr = addr
}
case "udp", "udp4", "udp6":
if addr, err := lookupUDPAddr(ip, addrs); err == nil {
dialer.LocalAddr = addr
}
}
}
}

View File

@ -1,16 +1,32 @@
package dns package resolver
import ( import (
"errors" "errors"
"net" "net"
"strings" "strings"
trie "github.com/Dreamacro/clash/component/domain-trie"
) )
var ( var (
errIPNotFound = errors.New("couldn't find ip") // DefaultResolver aim to resolve ip
errIPVersion = errors.New("ip version error") DefaultResolver Resolver
// DefaultHosts aim to resolve hosts
DefaultHosts = trie.New()
) )
var (
ErrIPNotFound = errors.New("couldn't find ip")
ErrIPVersion = errors.New("ip version error")
)
type Resolver interface {
ResolveIP(host string) (ip net.IP, err error)
ResolveIPv4(host string) (ip net.IP, err error)
ResolveIPv6(host string) (ip net.IP, err error)
}
// ResolveIPv4 with a host, return ipv4 // ResolveIPv4 with a host, return ipv4
func ResolveIPv4(host string) (net.IP, error) { func ResolveIPv4(host string) (net.IP, error) {
if node := DefaultHosts.Search(host); node != nil { if node := DefaultHosts.Search(host); node != nil {
@ -24,7 +40,7 @@ func ResolveIPv4(host string) (net.IP, error) {
if !strings.Contains(host, ":") { if !strings.Contains(host, ":") {
return ip, nil return ip, nil
} }
return nil, errIPVersion return nil, ErrIPVersion
} }
if DefaultResolver != nil { if DefaultResolver != nil {
@ -42,7 +58,7 @@ func ResolveIPv4(host string) (net.IP, error) {
} }
} }
return nil, errIPNotFound return nil, ErrIPNotFound
} }
// ResolveIPv6 with a host, return ipv6 // ResolveIPv6 with a host, return ipv6
@ -58,7 +74,7 @@ func ResolveIPv6(host string) (net.IP, error) {
if strings.Contains(host, ":") { if strings.Contains(host, ":") {
return ip, nil return ip, nil
} }
return nil, errIPVersion return nil, ErrIPVersion
} }
if DefaultResolver != nil { if DefaultResolver != nil {
@ -76,7 +92,7 @@ func ResolveIPv6(host string) (net.IP, error) {
} }
} }
return nil, errIPNotFound return nil, ErrIPNotFound
} }
// ResolveIP with a host, return ip // ResolveIP with a host, return ip
@ -86,11 +102,8 @@ func ResolveIP(host string) (net.IP, error) {
} }
if DefaultResolver != nil { if DefaultResolver != nil {
if DefaultResolver.ipv6 {
return DefaultResolver.ResolveIP(host) return DefaultResolver.ResolveIP(host)
} }
return DefaultResolver.ResolveIPv4(host)
}
ip := net.ParseIP(host) ip := net.ParseIP(host)
if ip != nil { if ip != nil {

View File

@ -1,6 +1,7 @@
package config package config
import ( import (
"errors"
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
@ -30,7 +31,7 @@ type General struct {
Authentication []string `json:"authentication"` Authentication []string `json:"authentication"`
AllowLan bool `json:"allow-lan"` AllowLan bool `json:"allow-lan"`
BindAddress string `json:"bind-address"` BindAddress string `json:"bind-address"`
Mode T.Mode `json:"mode"` Mode T.TunnelMode `json:"mode"`
LogLevel log.LogLevel `json:"log-level"` LogLevel log.LogLevel `json:"log-level"`
ExternalController string `json:"-"` ExternalController string `json:"-"`
ExternalUI string `json:"-"` ExternalUI string `json:"-"`
@ -46,6 +47,7 @@ type DNS struct {
FallbackFilter FallbackFilter `yaml:"fallback-filter"` FallbackFilter FallbackFilter `yaml:"fallback-filter"`
Listen string `yaml:"listen"` Listen string `yaml:"listen"`
EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"` EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"`
DefaultNameserver []dns.NameServer `yaml:"default-nameserver"`
FakeIPRange *fakeip.Pool FakeIPRange *fakeip.Pool
} }
@ -58,6 +60,7 @@ type FallbackFilter struct {
// Experimental config // Experimental config
type Experimental struct { type Experimental struct {
IgnoreResolveFail bool `yaml:"ignore-resolve-fail"` IgnoreResolveFail bool `yaml:"ignore-resolve-fail"`
Interface string `yaml:"interface-name"`
} }
// Config is clash config manager // Config is clash config manager
@ -82,6 +85,7 @@ type RawDNS struct {
EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"` EnhancedMode dns.EnhancedMode `yaml:"enhanced-mode"`
FakeIPRange string `yaml:"fake-ip-range"` FakeIPRange string `yaml:"fake-ip-range"`
FakeIPFilter []string `yaml:"fake-ip-filter"` FakeIPFilter []string `yaml:"fake-ip-filter"`
DefaultNameserver []string `yaml:"default-nameserver"`
} }
type RawFallbackFilter struct { type RawFallbackFilter struct {
@ -96,7 +100,7 @@ type RawConfig struct {
Authentication []string `yaml:"authentication"` Authentication []string `yaml:"authentication"`
AllowLan bool `yaml:"allow-lan"` AllowLan bool `yaml:"allow-lan"`
BindAddress string `yaml:"bind-address"` BindAddress string `yaml:"bind-address"`
Mode T.Mode `yaml:"mode"` Mode T.TunnelMode `yaml:"mode"`
LogLevel log.LogLevel `yaml:"log-level"` LogLevel log.LogLevel `yaml:"log-level"`
ExternalController string `yaml:"external-controller"` ExternalController string `yaml:"external-controller"`
ExternalUI string `yaml:"external-ui"` ExternalUI string `yaml:"external-ui"`
@ -143,6 +147,10 @@ func UnmarshalRawConfig(buf []byte) (*RawConfig, error) {
GeoIP: true, GeoIP: true,
IPCIDR: []string{}, IPCIDR: []string{},
}, },
DefaultNameserver: []string{
"114.114.114.114",
"8.8.8.8",
},
}, },
} }
@ -433,21 +441,21 @@ func parseHosts(cfg *RawConfig) (*trie.Trie, error) {
return tree, nil return tree, nil
} }
func hostWithDefaultPort(host string, defPort string) (string, error) { func hostWithDefaultPort(host string, defPort string) (string, string, error) {
if !strings.Contains(host, ":") { if !strings.Contains(host, ":") {
host += ":" host += ":"
} }
hostname, port, err := net.SplitHostPort(host) hostname, port, err := net.SplitHostPort(host)
if err != nil { if err != nil {
return "", err return "", "", err
} }
if port == "" { if port == "" {
port = defPort port = defPort
} }
return net.JoinHostPort(hostname, port), nil return net.JoinHostPort(hostname, port), hostname, nil
} }
func parseNameServer(servers []string) ([]dns.NameServer, error) { func parseNameServer(servers []string) ([]dns.NameServer, error) {
@ -463,20 +471,21 @@ func parseNameServer(servers []string) ([]dns.NameServer, error) {
return nil, fmt.Errorf("DNS NameServer[%d] format error: %s", idx, err.Error()) return nil, fmt.Errorf("DNS NameServer[%d] format error: %s", idx, err.Error())
} }
var host, dnsNetType string var addr, dnsNetType, host string
switch u.Scheme { switch u.Scheme {
case "udp": case "udp":
host, err = hostWithDefaultPort(u.Host, "53") addr, host, err = hostWithDefaultPort(u.Host, "53")
dnsNetType = "" // UDP dnsNetType = "" // UDP
case "tcp": case "tcp":
host, err = hostWithDefaultPort(u.Host, "53") addr, host, err = hostWithDefaultPort(u.Host, "53")
dnsNetType = "tcp" // TCP dnsNetType = "tcp" // TCP
case "tls": case "tls":
host, err = hostWithDefaultPort(u.Host, "853") addr, host, err = hostWithDefaultPort(u.Host, "853")
dnsNetType = "tcp-tls" // DNS over TLS dnsNetType = "tcp-tls" // DNS over TLS
case "https": case "https":
clearURL := url.URL{Scheme: "https", Host: u.Host, Path: u.Path} clearURL := url.URL{Scheme: "https", Host: u.Host, Path: u.Path}
host = clearURL.String() addr = clearURL.String()
_, host, err = hostWithDefaultPort(u.Host, "853")
dnsNetType = "https" // DNS over HTTPS dnsNetType = "https" // DNS over HTTPS
default: default:
return nil, fmt.Errorf("DNS NameServer[%d] unsupport scheme: %s", idx, u.Scheme) return nil, fmt.Errorf("DNS NameServer[%d] unsupport scheme: %s", idx, u.Scheme)
@ -490,7 +499,8 @@ func parseNameServer(servers []string) ([]dns.NameServer, error) {
nameservers, nameservers,
dns.NameServer{ dns.NameServer{
Net: dnsNetType, Net: dnsNetType,
Addr: host, Addr: addr,
Host: host,
}, },
) )
} }
@ -534,6 +544,19 @@ func parseDNS(cfg RawDNS) (*DNS, error) {
return nil, err return nil, err
} }
if len(cfg.DefaultNameserver) == 0 {
return nil, errors.New("default nameserver should have at least one nameserver")
}
if dnsCfg.DefaultNameserver, err = parseNameServer(cfg.DefaultNameserver); err != nil {
return nil, err
}
// check default nameserver is pure ip addr
for _, ns := range dnsCfg.DefaultNameserver {
if net.ParseIP(ns.Host) == nil {
return nil, errors.New("default nameserver should be pure IP")
}
}
if cfg.EnhancedMode == dns.FAKEIP { if cfg.EnhancedMode == dns.FAKEIP {
_, ipnet, err := net.ParseCIDR(cfg.FakeIPRange) _, ipnet, err := net.ParseCIDR(cfg.FakeIPRange)
if err != nil { if err != nil {

View File

@ -2,6 +2,7 @@ package dns
import ( import (
"context" "context"
"strings"
"github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/dialer"
@ -10,7 +11,9 @@ import (
type client struct { type client struct {
*D.Client *D.Client
Address string r *Resolver
addr string
host string
} }
func (c *client) Exchange(m *D.Msg) (msg *D.Msg, err error) { func (c *client) Exchange(m *D.Msg) (msg *D.Msg, err error) {
@ -18,7 +21,22 @@ func (c *client) Exchange(m *D.Msg) (msg *D.Msg, err error) {
} }
func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
c.Client.Dialer = dialer.Dialer() network := "udp"
if strings.HasPrefix(c.Client.Net, "tcp") {
network = "tcp"
}
ip, err := c.r.ResolveIPv4(c.host)
if err != nil {
return nil, err
}
d := dialer.Dialer()
if dialer.DialHook != nil {
dialer.DialHook(d, network, ip)
}
c.Client.Dialer = d
// miekg/dns ExchangeContext doesn't respond to context cancel. // miekg/dns ExchangeContext doesn't respond to context cancel.
// this is a workaround // this is a workaround
@ -28,7 +46,7 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err
} }
ch := make(chan result, 1) ch := make(chan result, 1)
go func() { go func() {
msg, _, err := c.Client.ExchangeContext(ctx, m, c.Address) msg, _, err := c.Client.Exchange(m, c.addr)
ch <- result{msg, err} ch <- result{msg, err}
}() }()

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/dialer"
@ -17,13 +18,9 @@ const (
dotMimeType = "application/dns-message" dotMimeType = "application/dns-message"
) )
var dohTransport = &http.Transport{
TLSClientConfig: &tls.Config{ClientSessionCache: globalSessionCache},
DialContext: dialer.DialContext,
}
type dohClient struct { type dohClient struct {
url string url string
transport *http.Transport
} }
func (dc *dohClient) Exchange(m *D.Msg) (msg *D.Msg, err error) { func (dc *dohClient) Exchange(m *D.Msg) (msg *D.Msg, err error) {
@ -58,7 +55,7 @@ func (dc *dohClient) newRequest(m *D.Msg) (*http.Request, error) {
} }
func (dc *dohClient) doRequest(req *http.Request) (msg *D.Msg, err error) { func (dc *dohClient) doRequest(req *http.Request) (msg *D.Msg, err error) {
client := &http.Client{Transport: dohTransport} client := &http.Client{Transport: dc.transport}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
@ -73,3 +70,25 @@ func (dc *dohClient) doRequest(req *http.Request) (msg *D.Msg, err error) {
err = msg.Unpack(buf) err = msg.Unpack(buf)
return msg, err return msg, err
} }
func newDoHClient(url string, r *Resolver) *dohClient {
return &dohClient{
url: url,
transport: &http.Transport{
TLSClientConfig: &tls.Config{ClientSessionCache: globalSessionCache},
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
ip, err := r.ResolveIPv4(host)
if err != nil {
return nil, err
}
return dialer.DialContext(ctx, "tcp4", net.JoinHostPort(ip.String(), port))
},
},
}
}

View File

@ -11,26 +11,18 @@ import (
"github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/common/picker" "github.com/Dreamacro/clash/common/picker"
trie "github.com/Dreamacro/clash/component/domain-trie"
"github.com/Dreamacro/clash/component/fakeip" "github.com/Dreamacro/clash/component/fakeip"
"github.com/Dreamacro/clash/component/resolver"
D "github.com/miekg/dns" D "github.com/miekg/dns"
"golang.org/x/sync/singleflight" "golang.org/x/sync/singleflight"
) )
var (
// DefaultResolver aim to resolve ip
DefaultResolver *Resolver
// DefaultHosts aim to resolve hosts
DefaultHosts = trie.New()
)
var ( var (
globalSessionCache = tls.NewLRUClientSessionCache(64) globalSessionCache = tls.NewLRUClientSessionCache(64)
) )
type resolver interface { type dnsClient interface {
Exchange(m *D.Msg) (msg *D.Msg, err error) Exchange(m *D.Msg) (msg *D.Msg, err error)
ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error)
} }
@ -45,8 +37,8 @@ type Resolver struct {
mapping bool mapping bool
fakeip bool fakeip bool
pool *fakeip.Pool pool *fakeip.Pool
main []resolver main []dnsClient
fallback []resolver fallback []dnsClient
fallbackFilters []fallbackFilter fallbackFilters []fallbackFilter
group singleflight.Group group singleflight.Group
cache *cache.Cache cache *cache.Cache
@ -74,7 +66,7 @@ func (r *Resolver) ResolveIP(host string) (ip net.IP, err error) {
ip, open := <-ch ip, open := <-ch
if !open { if !open {
return nil, errIPNotFound return nil, resolver.ErrIPNotFound
} }
return ip, nil return ip, nil
@ -174,7 +166,7 @@ func (r *Resolver) IsFakeIP(ip net.IP) bool {
return false return false
} }
func (r *Resolver) batchExchange(clients []resolver, m *D.Msg) (msg *D.Msg, err error) { func (r *Resolver) batchExchange(clients []dnsClient, m *D.Msg) (msg *D.Msg, err error) {
fast, ctx := picker.WithTimeout(context.Background(), time.Second*5) fast, ctx := picker.WithTimeout(context.Background(), time.Second*5)
for _, client := range clients { for _, client := range clients {
r := client r := client
@ -238,7 +230,7 @@ func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error)
ips := r.msgToIP(msg) ips := r.msgToIP(msg)
ipLength := len(ips) ipLength := len(ips)
if ipLength == 0 { if ipLength == 0 {
return nil, errIPNotFound return nil, resolver.ErrIPNotFound
} }
ip = ips[rand.Intn(ipLength)] ip = ips[rand.Intn(ipLength)]
@ -260,7 +252,7 @@ func (r *Resolver) msgToIP(msg *D.Msg) []net.IP {
return ips return ips
} }
func (r *Resolver) asyncExchange(client []resolver, msg *D.Msg) <-chan *result { func (r *Resolver) asyncExchange(client []dnsClient, msg *D.Msg) <-chan *result {
ch := make(chan *result) ch := make(chan *result)
go func() { go func() {
res, err := r.batchExchange(client, msg) res, err := r.batchExchange(client, msg)
@ -272,6 +264,7 @@ func (r *Resolver) asyncExchange(client []resolver, msg *D.Msg) <-chan *result {
type NameServer struct { type NameServer struct {
Net string Net string
Addr string Addr string
Host string
} }
type FallbackFilter struct { type FallbackFilter struct {
@ -281,6 +274,7 @@ type FallbackFilter struct {
type Config struct { type Config struct {
Main, Fallback []NameServer Main, Fallback []NameServer
Default []NameServer
IPv6 bool IPv6 bool
EnhancedMode EnhancedMode EnhancedMode EnhancedMode
FallbackFilter FallbackFilter FallbackFilter FallbackFilter
@ -288,9 +282,14 @@ type Config struct {
} }
func New(config Config) *Resolver { func New(config Config) *Resolver {
defaultResolver := &Resolver{
main: transform(config.Default, nil),
cache: cache.New(time.Second * 60),
}
r := &Resolver{ r := &Resolver{
ipv6: config.IPv6, ipv6: config.IPv6,
main: transform(config.Main), main: transform(config.Main, defaultResolver),
cache: cache.New(time.Second * 60), cache: cache.New(time.Second * 60),
mapping: config.EnhancedMode == MAPPING, mapping: config.EnhancedMode == MAPPING,
fakeip: config.EnhancedMode == FAKEIP, fakeip: config.EnhancedMode == FAKEIP,
@ -298,7 +297,7 @@ func New(config Config) *Resolver {
} }
if len(config.Fallback) != 0 { if len(config.Fallback) != 0 {
r.fallback = transform(config.Fallback) r.fallback = transform(config.Fallback, defaultResolver)
} }
fallbackFilters := []fallbackFilter{} fallbackFilters := []fallbackFilter{}

View File

@ -8,9 +8,9 @@ import (
"github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/cache"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
yaml "gopkg.in/yaml.v2"
D "github.com/miekg/dns" D "github.com/miekg/dns"
yaml "gopkg.in/yaml.v2"
) )
var ( var (
@ -117,11 +117,11 @@ func isIPRequest(q D.Question) bool {
return false return false
} }
func transform(servers []NameServer) []resolver { func transform(servers []NameServer, resolver *Resolver) []dnsClient {
ret := []resolver{} ret := []dnsClient{}
for _, s := range servers { for _, s := range servers {
if s.Net == "https" { if s.Net == "https" {
ret = append(ret, &dohClient{url: s.Addr}) ret = append(ret, newDoHClient(s.Addr, resolver))
continue continue
} }
@ -136,7 +136,8 @@ func transform(servers []NameServer) []resolver {
UDPSize: 4096, UDPSize: 4096,
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
}, },
Address: s.Addr, addr: s.Addr,
host: s.Host,
}) })
} }
return ret return ret

View File

@ -8,14 +8,16 @@ import (
"github.com/Dreamacro/clash/adapters/provider" "github.com/Dreamacro/clash/adapters/provider"
"github.com/Dreamacro/clash/component/auth" "github.com/Dreamacro/clash/component/auth"
"github.com/Dreamacro/clash/component/dialer"
trie "github.com/Dreamacro/clash/component/domain-trie" trie "github.com/Dreamacro/clash/component/domain-trie"
"github.com/Dreamacro/clash/component/resolver"
"github.com/Dreamacro/clash/config" "github.com/Dreamacro/clash/config"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/dns" "github.com/Dreamacro/clash/dns"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
P "github.com/Dreamacro/clash/proxy" P "github.com/Dreamacro/clash/proxy"
authStore "github.com/Dreamacro/clash/proxy/auth" authStore "github.com/Dreamacro/clash/proxy/auth"
T "github.com/Dreamacro/clash/tunnel" "github.com/Dreamacro/clash/tunnel"
) )
// forward compatibility before 1.0 // forward compatibility before 1.0
@ -83,7 +85,7 @@ func ApplyConfig(cfg *config.Config, force bool) {
updateRules(cfg.Rules) updateRules(cfg.Rules)
updateDNS(cfg.DNS) updateDNS(cfg.DNS)
updateHosts(cfg.Hosts) updateHosts(cfg.Hosts)
updateExperimental(cfg.Experimental) updateExperimental(cfg)
} }
func GetGeneral() *config.General { func GetGeneral() *config.General {
@ -100,20 +102,30 @@ func GetGeneral() *config.General {
Authentication: authenticator, Authentication: authenticator,
AllowLan: P.AllowLan(), AllowLan: P.AllowLan(),
BindAddress: P.BindAddress(), BindAddress: P.BindAddress(),
Mode: T.Instance().Mode(), Mode: tunnel.Mode(),
LogLevel: log.Level(), LogLevel: log.Level(),
} }
return general return general
} }
func updateExperimental(c *config.Experimental) { func updateExperimental(c *config.Config) {
T.Instance().UpdateExperimental(c.IgnoreResolveFail) cfg := c.Experimental
tunnel.UpdateExperimental(cfg.IgnoreResolveFail)
if cfg.Interface != "" && c.DNS.Enable {
dialer.DialHook = dialer.DialerWithInterface(cfg.Interface)
dialer.ListenPacketHook = dialer.ListenPacketWithInterface(cfg.Interface)
} else {
dialer.DialHook = nil
dialer.ListenPacketHook = nil
}
} }
func updateDNS(c *config.DNS) { func updateDNS(c *config.DNS) {
if c.Enable == false { if c.Enable == false {
dns.DefaultResolver = nil resolver.DefaultResolver = nil
tunnel.SetResolver(nil)
dns.ReCreateServer("", nil) dns.ReCreateServer("", nil)
return return
} }
@ -127,8 +139,10 @@ func updateDNS(c *config.DNS) {
GeoIP: c.FallbackFilter.GeoIP, GeoIP: c.FallbackFilter.GeoIP,
IPCIDR: c.FallbackFilter.IPCIDR, IPCIDR: c.FallbackFilter.IPCIDR,
}, },
Default: c.DefaultNameserver,
}) })
dns.DefaultResolver = r resolver.DefaultResolver = r
tunnel.SetResolver(r)
if err := dns.ReCreateServer(c.Listen, r); err != nil { if err := dns.ReCreateServer(c.Listen, r); err != nil {
log.Errorln("Start DNS server error: %s", err.Error()) log.Errorln("Start DNS server error: %s", err.Error())
return return
@ -140,11 +154,10 @@ func updateDNS(c *config.DNS) {
} }
func updateHosts(tree *trie.Trie) { func updateHosts(tree *trie.Trie) {
dns.DefaultHosts = tree resolver.DefaultHosts = tree
} }
func updateProxies(proxies map[string]C.Proxy, providers map[string]provider.ProxyProvider) { func updateProxies(proxies map[string]C.Proxy, providers map[string]provider.ProxyProvider) {
tunnel := T.Instance()
oldProviders := tunnel.Providers() oldProviders := tunnel.Providers()
// close providers goroutine // close providers goroutine
@ -156,12 +169,12 @@ func updateProxies(proxies map[string]C.Proxy, providers map[string]provider.Pro
} }
func updateRules(rules []C.Rule) { func updateRules(rules []C.Rule) {
T.Instance().UpdateRules(rules) tunnel.UpdateRules(rules)
} }
func updateGeneral(general *config.General) { func updateGeneral(general *config.General) {
log.SetLevel(general.LogLevel) log.SetLevel(general.LogLevel)
T.Instance().SetMode(general.Mode) tunnel.SetMode(general.Mode)
allowLan := general.AllowLan allowLan := general.AllowLan
P.SetAllowLan(allowLan) P.SetAllowLan(allowLan)

View File

@ -8,7 +8,7 @@ import (
"github.com/Dreamacro/clash/hub/executor" "github.com/Dreamacro/clash/hub/executor"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
P "github.com/Dreamacro/clash/proxy" P "github.com/Dreamacro/clash/proxy"
T "github.com/Dreamacro/clash/tunnel" "github.com/Dreamacro/clash/tunnel"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/go-chi/render" "github.com/go-chi/render"
@ -28,7 +28,7 @@ type configSchema struct {
RedirPort *int `json:"redir-port"` RedirPort *int `json:"redir-port"`
AllowLan *bool `json:"allow-lan"` AllowLan *bool `json:"allow-lan"`
BindAddress *string `json:"bind-address"` BindAddress *string `json:"bind-address"`
Mode *T.Mode `json:"mode"` Mode *tunnel.TunnelMode `json:"mode"`
LogLevel *log.LogLevel `json:"log-level"` LogLevel *log.LogLevel `json:"log-level"`
} }
@ -67,7 +67,7 @@ func patchConfigs(w http.ResponseWriter, r *http.Request) {
P.ReCreateRedir(pointerOrDefault(general.RedirPort, ports.RedirPort)) P.ReCreateRedir(pointerOrDefault(general.RedirPort, ports.RedirPort))
if general.Mode != nil { if general.Mode != nil {
T.Instance().SetMode(*general.Mode) tunnel.SetMode(*general.Mode)
} }
if general.LogLevel != nil { if general.LogLevel != nil {

View File

@ -5,7 +5,7 @@ import (
"net/http" "net/http"
"github.com/Dreamacro/clash/adapters/provider" "github.com/Dreamacro/clash/adapters/provider"
T "github.com/Dreamacro/clash/tunnel" "github.com/Dreamacro/clash/tunnel"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/go-chi/render" "github.com/go-chi/render"
@ -25,7 +25,7 @@ func proxyProviderRouter() http.Handler {
} }
func getProviders(w http.ResponseWriter, r *http.Request) { func getProviders(w http.ResponseWriter, r *http.Request) {
providers := T.Instance().Providers() providers := tunnel.Providers()
render.JSON(w, r, render.M{ render.JSON(w, r, render.M{
"providers": providers, "providers": providers,
}) })
@ -63,7 +63,7 @@ func parseProviderName(next http.Handler) http.Handler {
func findProviderByName(next http.Handler) http.Handler { func findProviderByName(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
name := r.Context().Value(CtxKeyProviderName).(string) name := r.Context().Value(CtxKeyProviderName).(string)
providers := T.Instance().Providers() providers := tunnel.Providers()
provider, exist := providers[name] provider, exist := providers[name]
if !exist { if !exist {
render.Status(r, http.StatusNotFound) render.Status(r, http.StatusNotFound)

View File

@ -10,7 +10,7 @@ import (
"github.com/Dreamacro/clash/adapters/outbound" "github.com/Dreamacro/clash/adapters/outbound"
"github.com/Dreamacro/clash/adapters/outboundgroup" "github.com/Dreamacro/clash/adapters/outboundgroup"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
T "github.com/Dreamacro/clash/tunnel" "github.com/Dreamacro/clash/tunnel"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/go-chi/render" "github.com/go-chi/render"
@ -40,7 +40,7 @@ func parseProxyName(next http.Handler) http.Handler {
func findProxyByName(next http.Handler) http.Handler { func findProxyByName(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
name := r.Context().Value(CtxKeyProxyName).(string) name := r.Context().Value(CtxKeyProxyName).(string)
proxies := T.Instance().Proxies() proxies := tunnel.Proxies()
proxy, exist := proxies[name] proxy, exist := proxies[name]
if !exist { if !exist {
render.Status(r, http.StatusNotFound) render.Status(r, http.StatusNotFound)
@ -54,7 +54,7 @@ func findProxyByName(next http.Handler) http.Handler {
} }
func getProxies(w http.ResponseWriter, r *http.Request) { func getProxies(w http.ResponseWriter, r *http.Request) {
proxies := T.Instance().Proxies() proxies := tunnel.Proxies()
render.JSON(w, r, render.M{ render.JSON(w, r, render.M{
"proxies": proxies, "proxies": proxies,
}) })

View File

@ -3,7 +3,7 @@ package route
import ( import (
"net/http" "net/http"
T "github.com/Dreamacro/clash/tunnel" "github.com/Dreamacro/clash/tunnel"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/go-chi/render" "github.com/go-chi/render"
@ -22,7 +22,7 @@ type Rule struct {
} }
func getRules(w http.ResponseWriter, r *http.Request) { func getRules(w http.ResponseWriter, r *http.Request) {
rawRules := T.Instance().Rules() rawRules := tunnel.Rules()
rules := []Rule{} rules := []Rule{}
for _, rule := range rawRules { for _, rule := range rawRules {

View File

@ -16,10 +16,6 @@ import (
"github.com/Dreamacro/clash/tunnel" "github.com/Dreamacro/clash/tunnel"
) )
var (
tun = tunnel.Instance()
)
type HttpListener struct { type HttpListener struct {
net.Listener net.Listener
address string address string
@ -100,9 +96,9 @@ func handleConn(conn net.Conn, cache *cache.Cache) {
if err != nil { if err != nil {
return return
} }
tun.Add(adapters.NewHTTPS(request, conn)) tunnel.Add(adapters.NewHTTPS(request, conn))
return return
} }
tun.Add(adapters.NewHTTP(request, conn)) tunnel.Add(adapters.NewHTTP(request, conn))
} }

View File

@ -9,10 +9,6 @@ import (
"github.com/Dreamacro/clash/tunnel" "github.com/Dreamacro/clash/tunnel"
) )
var (
tun = tunnel.Instance()
)
type RedirListener struct { type RedirListener struct {
net.Listener net.Listener
address string address string
@ -59,5 +55,5 @@ func handleRedir(conn net.Conn) {
return return
} }
conn.(*net.TCPConn).SetKeepAlive(true) conn.(*net.TCPConn).SetKeepAlive(true)
tun.Add(inbound.NewSocket(target, conn, C.REDIR, C.TCP)) tunnel.Add(inbound.NewSocket(target, conn, C.REDIR, C.TCP))
} }

View File

@ -13,10 +13,6 @@ import (
"github.com/Dreamacro/clash/tunnel" "github.com/Dreamacro/clash/tunnel"
) )
var (
tun = tunnel.Instance()
)
type SockListener struct { type SockListener struct {
net.Listener net.Listener
address string address string
@ -68,5 +64,5 @@ func handleSocks(conn net.Conn) {
io.Copy(ioutil.Discard, conn) io.Copy(ioutil.Discard, conn)
return return
} }
tun.Add(adapters.NewSocket(target, conn, C.SOCKS, C.TCP)) tunnel.Add(adapters.NewSocket(target, conn, C.SOCKS, C.TCP))
} }

View File

@ -7,6 +7,7 @@ import (
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/component/socks5" "github.com/Dreamacro/clash/component/socks5"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/tunnel"
) )
type SockUDPListener struct { type SockUDPListener struct {
@ -62,5 +63,5 @@ func handleSocksUDP(pc net.PacketConn, buf []byte, addr net.Addr) {
payload: payload, payload: payload,
bufRef: buf, bufRef: buf,
} }
tun.AddPacket(adapters.NewPacket(target, packet, C.SOCKS)) tunnel.AddPacket(adapters.NewPacket(target, packet, C.SOCKS))
} }

View File

@ -14,7 +14,7 @@ import (
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
) )
func (t *Tunnel) handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) { func handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) {
req := request.R req := request.R
host := req.Host host := req.Host
@ -81,17 +81,17 @@ func (t *Tunnel) handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) {
} }
} }
func (t *Tunnel) handleUDPToRemote(packet C.UDPPacket, pc net.PacketConn, addr net.Addr) { func handleUDPToRemote(packet C.UDPPacket, pc net.PacketConn, addr net.Addr) {
if _, err := pc.WriteTo(packet.Data(), addr); err != nil { if _, err := pc.WriteTo(packet.Data(), addr); err != nil {
return return
} }
DefaultManager.Upload() <- int64(len(packet.Data())) DefaultManager.Upload() <- int64(len(packet.Data()))
} }
func (t *Tunnel) handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string) { func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string) {
buf := pool.BufPool.Get().([]byte) buf := pool.BufPool.Get().([]byte)
defer pool.BufPool.Put(buf[:cap(buf)]) defer pool.BufPool.Put(buf[:cap(buf)])
defer t.natTable.Delete(key) defer natTable.Delete(key)
defer pc.Close() defer pc.Close()
for { for {
@ -109,7 +109,7 @@ func (t *Tunnel) handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key str
} }
} }
func (t *Tunnel) handleSocket(request *adapters.SocketAdapter, outbound net.Conn) { func handleSocket(request *adapters.SocketAdapter, outbound net.Conn) {
relay(request, outbound) relay(request, outbound)
} }

View File

@ -5,11 +5,11 @@ import (
"errors" "errors"
) )
type Mode int type TunnelMode int
var ( var (
// ModeMapping is a mapping for Mode enum // ModeMapping is a mapping for Mode enum
ModeMapping = map[string]Mode{ ModeMapping = map[string]TunnelMode{
Global.String(): Global, Global.String(): Global,
Rule.String(): Rule, Rule.String(): Rule,
Direct.String(): Direct, Direct.String(): Direct,
@ -17,13 +17,13 @@ var (
) )
const ( const (
Global Mode = iota Global TunnelMode = iota
Rule Rule
Direct Direct
) )
// UnmarshalJSON unserialize Mode // UnmarshalJSON unserialize Mode
func (m *Mode) UnmarshalJSON(data []byte) error { func (m *TunnelMode) UnmarshalJSON(data []byte) error {
var tp string var tp string
json.Unmarshal(data, &tp) json.Unmarshal(data, &tp)
mode, exist := ModeMapping[tp] mode, exist := ModeMapping[tp]
@ -35,7 +35,7 @@ func (m *Mode) UnmarshalJSON(data []byte) error {
} }
// UnmarshalYAML unserialize Mode with yaml // UnmarshalYAML unserialize Mode with yaml
func (m *Mode) UnmarshalYAML(unmarshal func(interface{}) error) error { func (m *TunnelMode) UnmarshalYAML(unmarshal func(interface{}) error) error {
var tp string var tp string
unmarshal(&tp) unmarshal(&tp)
mode, exist := ModeMapping[tp] mode, exist := ModeMapping[tp]
@ -47,11 +47,11 @@ func (m *Mode) UnmarshalYAML(unmarshal func(interface{}) error) error {
} }
// MarshalJSON serialize Mode // MarshalJSON serialize Mode
func (m Mode) MarshalJSON() ([]byte, error) { func (m TunnelMode) MarshalJSON() ([]byte, error) {
return json.Marshal(m.String()) return json.Marshal(m.String())
} }
func (m Mode) String() string { func (m TunnelMode) String() string {
switch m { switch m {
case Global: case Global:
return "Global" return "Global"

View File

@ -10,6 +10,7 @@ import (
"github.com/Dreamacro/clash/adapters/inbound" "github.com/Dreamacro/clash/adapters/inbound"
"github.com/Dreamacro/clash/adapters/provider" "github.com/Dreamacro/clash/adapters/provider"
"github.com/Dreamacro/clash/component/nat" "github.com/Dreamacro/clash/component/nat"
"github.com/Dreamacro/clash/component/resolver"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/dns" "github.com/Dreamacro/clash/dns"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
@ -18,136 +19,136 @@ import (
) )
var ( var (
tunnel *Tunnel tcpQueue = channels.NewInfiniteChannel()
once sync.Once udpQueue = channels.NewInfiniteChannel()
natTable = nat.New()
// default timeout for UDP session
udpTimeout = 60 * time.Second
)
// Tunnel handle relay inbound proxy and outbound proxy
type Tunnel struct {
tcpQueue *channels.InfiniteChannel
udpQueue *channels.InfiniteChannel
natTable *nat.Table
rules []C.Rule rules []C.Rule
proxies map[string]C.Proxy proxies = make(map[string]C.Proxy)
providers map[string]provider.ProxyProvider providers map[string]provider.ProxyProvider
configMux sync.RWMutex configMux sync.RWMutex
enhancedMode *dns.Resolver
// experimental features // experimental features
ignoreResolveFail bool ignoreResolveFail bool
// Outbound Rule // Outbound Rule
mode Mode mode = Rule
// default timeout for UDP session
udpTimeout = 60 * time.Second
)
func init() {
go process()
} }
// Add request to queue // Add request to queue
func (t *Tunnel) Add(req C.ServerAdapter) { func Add(req C.ServerAdapter) {
t.tcpQueue.In() <- req tcpQueue.In() <- req
} }
// AddPacket add udp Packet to queue // AddPacket add udp Packet to queue
func (t *Tunnel) AddPacket(packet *inbound.PacketAdapter) { func AddPacket(packet *inbound.PacketAdapter) {
t.udpQueue.In() <- packet udpQueue.In() <- packet
} }
// Rules return all rules // Rules return all rules
func (t *Tunnel) Rules() []C.Rule { func Rules() []C.Rule {
return t.rules return rules
} }
// UpdateRules handle update rules // UpdateRules handle update rules
func (t *Tunnel) UpdateRules(rules []C.Rule) { func UpdateRules(newRules []C.Rule) {
t.configMux.Lock() configMux.Lock()
t.rules = rules rules = newRules
t.configMux.Unlock() configMux.Unlock()
} }
// Proxies return all proxies // Proxies return all proxies
func (t *Tunnel) Proxies() map[string]C.Proxy { func Proxies() map[string]C.Proxy {
return t.proxies return proxies
} }
// Providers return all compatible providers // Providers return all compatible providers
func (t *Tunnel) Providers() map[string]provider.ProxyProvider { func Providers() map[string]provider.ProxyProvider {
return t.providers return providers
} }
// UpdateProxies handle update proxies // UpdateProxies handle update proxies
func (t *Tunnel) UpdateProxies(proxies map[string]C.Proxy, providers map[string]provider.ProxyProvider) { func UpdateProxies(newProxies map[string]C.Proxy, newProviders map[string]provider.ProxyProvider) {
t.configMux.Lock() configMux.Lock()
t.proxies = proxies proxies = newProxies
t.providers = providers providers = newProviders
t.configMux.Unlock() configMux.Unlock()
} }
// UpdateExperimental handle update experimental config // UpdateExperimental handle update experimental config
func (t *Tunnel) UpdateExperimental(ignoreResolveFail bool) { func UpdateExperimental(value bool) {
t.configMux.Lock() configMux.Lock()
t.ignoreResolveFail = ignoreResolveFail ignoreResolveFail = value
t.configMux.Unlock() configMux.Unlock()
} }
// Mode return current mode // Mode return current mode
func (t *Tunnel) Mode() Mode { func Mode() TunnelMode {
return t.mode return mode
} }
// SetMode change the mode of tunnel // SetMode change the mode of tunnel
func (t *Tunnel) SetMode(mode Mode) { func SetMode(m TunnelMode) {
t.mode = mode mode = m
}
// SetResolver set custom dns resolver for enhanced mode
func SetResolver(r *dns.Resolver) {
enhancedMode = r
} }
// processUDP starts a loop to handle udp packet // processUDP starts a loop to handle udp packet
func (t *Tunnel) processUDP() { func processUDP() {
queue := t.udpQueue.Out() queue := udpQueue.Out()
for elm := range queue { for elm := range queue {
conn := elm.(*inbound.PacketAdapter) conn := elm.(*inbound.PacketAdapter)
t.handleUDPConn(conn) handleUDPConn(conn)
} }
} }
func (t *Tunnel) process() { func process() {
numUDPWorkers := 4 numUDPWorkers := 4
if runtime.NumCPU() > numUDPWorkers { if runtime.NumCPU() > numUDPWorkers {
numUDPWorkers = runtime.NumCPU() numUDPWorkers = runtime.NumCPU()
} }
for i := 0; i < numUDPWorkers; i++ { for i := 0; i < numUDPWorkers; i++ {
go t.processUDP() go processUDP()
} }
queue := t.tcpQueue.Out() queue := tcpQueue.Out()
for elm := range queue { for elm := range queue {
conn := elm.(C.ServerAdapter) conn := elm.(C.ServerAdapter)
go t.handleTCPConn(conn) go handleTCPConn(conn)
} }
} }
func (t *Tunnel) resolveIP(host string) (net.IP, error) { func needLookupIP(metadata *C.Metadata) bool {
return dns.ResolveIP(host) return enhancedMode != nil && (enhancedMode.IsMapping() || enhancedMode.FakeIPEnabled()) && metadata.Host == "" && metadata.DstIP != nil
} }
func (t *Tunnel) needLookupIP(metadata *C.Metadata) bool { func preHandleMetadata(metadata *C.Metadata) error {
return dns.DefaultResolver != nil && (dns.DefaultResolver.IsMapping() || dns.DefaultResolver.FakeIPEnabled()) && metadata.Host == "" && metadata.DstIP != nil
}
func (t *Tunnel) preHandleMetadata(metadata *C.Metadata) error {
// handle IP string on host // handle IP string on host
if ip := net.ParseIP(metadata.Host); ip != nil { if ip := net.ParseIP(metadata.Host); ip != nil {
metadata.DstIP = ip metadata.DstIP = ip
} }
// preprocess enhanced-mode metadata // preprocess enhanced-mode metadata
if t.needLookupIP(metadata) { if needLookupIP(metadata) {
host, exist := dns.DefaultResolver.IPToHost(metadata.DstIP) host, exist := enhancedMode.IPToHost(metadata.DstIP)
if exist { if exist {
metadata.Host = host metadata.Host = host
metadata.AddrType = C.AtypDomainName metadata.AddrType = C.AtypDomainName
if dns.DefaultResolver.FakeIPEnabled() { if enhancedMode.FakeIPEnabled() {
metadata.DstIP = nil metadata.DstIP = nil
} }
} else if dns.DefaultResolver.IsFakeIP(metadata.DstIP) { } else if enhancedMode.IsFakeIP(metadata.DstIP) {
return fmt.Errorf("fake DNS record %s missing", metadata.DstIP) return fmt.Errorf("fake DNS record %s missing", metadata.DstIP)
} }
} }
@ -155,18 +156,18 @@ func (t *Tunnel) preHandleMetadata(metadata *C.Metadata) error {
return nil return nil
} }
func (t *Tunnel) resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) { func resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
var proxy C.Proxy var proxy C.Proxy
var rule C.Rule var rule C.Rule
switch t.mode { switch mode {
case Direct: case Direct:
proxy = t.proxies["DIRECT"] proxy = proxies["DIRECT"]
case Global: case Global:
proxy = t.proxies["GLOBAL"] proxy = proxies["GLOBAL"]
// Rule // Rule
default: default:
var err error var err error
proxy, rule, err = t.match(metadata) proxy, rule, err = match(metadata)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -174,23 +175,23 @@ func (t *Tunnel) resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error)
return proxy, rule, nil return proxy, rule, nil
} }
func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) { func handleUDPConn(packet *inbound.PacketAdapter) {
metadata := packet.Metadata() metadata := packet.Metadata()
if !metadata.Valid() { if !metadata.Valid() {
log.Warnln("[Metadata] not valid: %#v", metadata) log.Warnln("[Metadata] not valid: %#v", metadata)
return return
} }
if err := t.preHandleMetadata(metadata); err != nil { if err := preHandleMetadata(metadata); err != nil {
log.Debugln("[Metadata PreHandle] error: %s", err) log.Debugln("[Metadata PreHandle] error: %s", err)
return return
} }
key := packet.LocalAddr().String() key := packet.LocalAddr().String()
pc := t.natTable.Get(key) pc := natTable.Get(key)
if pc != nil { if pc != nil {
if !metadata.Resolved() { if !metadata.Resolved() {
ip, err := t.resolveIP(metadata.Host) ip, err := resolver.ResolveIP(metadata.Host)
if err != nil { if err != nil {
log.Warnln("[UDP] Resolve %s failed: %s, %#v", metadata.Host, err.Error(), metadata) log.Warnln("[UDP] Resolve %s failed: %s, %#v", metadata.Host, err.Error(), metadata)
return return
@ -198,20 +199,20 @@ func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) {
metadata.DstIP = ip metadata.DstIP = ip
} }
t.handleUDPToRemote(packet, pc, metadata.UDPAddr()) handleUDPToRemote(packet, pc, metadata.UDPAddr())
return return
} }
lockKey := key + "-lock" lockKey := key + "-lock"
wg, loaded := t.natTable.GetOrCreateLock(lockKey) wg, loaded := natTable.GetOrCreateLock(lockKey)
go func() { go func() {
if !loaded { if !loaded {
wg.Add(1) wg.Add(1)
proxy, rule, err := t.resolveMetadata(metadata) proxy, rule, err := resolveMetadata(metadata)
if err != nil { if err != nil {
log.Warnln("[UDP] Parse metadata failed: %s", err.Error()) log.Warnln("[UDP] Parse metadata failed: %s", err.Error())
t.natTable.Delete(lockKey) natTable.Delete(lockKey)
wg.Done() wg.Done()
return return
} }
@ -219,7 +220,7 @@ func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) {
rawPc, err := proxy.DialUDP(metadata) rawPc, err := proxy.DialUDP(metadata)
if err != nil { if err != nil {
log.Warnln("[UDP] dial %s error: %s", proxy.Name(), err.Error()) log.Warnln("[UDP] dial %s error: %s", proxy.Name(), err.Error())
t.natTable.Delete(lockKey) natTable.Delete(lockKey)
wg.Done() wg.Done()
return return
} }
@ -228,36 +229,36 @@ func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) {
switch true { switch true {
case rule != nil: case rule != nil:
log.Infoln("[UDP] %s --> %v match %s using %s", metadata.SourceAddress(), metadata.String(), rule.RuleType().String(), rawPc.Chains().String()) log.Infoln("[UDP] %s --> %v match %s using %s", metadata.SourceAddress(), metadata.String(), rule.RuleType().String(), rawPc.Chains().String())
case t.mode == Global: case mode == Global:
log.Infoln("[UDP] %s --> %v using GLOBAL", metadata.SourceAddress(), metadata.String()) log.Infoln("[UDP] %s --> %v using GLOBAL", metadata.SourceAddress(), metadata.String())
case t.mode == Direct: case mode == Direct:
log.Infoln("[UDP] %s --> %v using DIRECT", metadata.SourceAddress(), metadata.String()) log.Infoln("[UDP] %s --> %v using DIRECT", metadata.SourceAddress(), metadata.String())
default: default:
log.Infoln("[UDP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String()) log.Infoln("[UDP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String())
} }
t.natTable.Set(key, pc) natTable.Set(key, pc)
t.natTable.Delete(lockKey) natTable.Delete(lockKey)
wg.Done() wg.Done()
go t.handleUDPToLocal(packet.UDPPacket, pc, key) go handleUDPToLocal(packet.UDPPacket, pc, key)
} }
wg.Wait() wg.Wait()
pc := t.natTable.Get(key) pc := natTable.Get(key)
if pc != nil { if pc != nil {
if !metadata.Resolved() { if !metadata.Resolved() {
ip, err := dns.ResolveIP(metadata.Host) ip, err := resolver.ResolveIP(metadata.Host)
if err != nil { if err != nil {
return return
} }
metadata.DstIP = ip metadata.DstIP = ip
} }
t.handleUDPToRemote(packet, pc, metadata.UDPAddr()) handleUDPToRemote(packet, pc, metadata.UDPAddr())
} }
}() }()
} }
func (t *Tunnel) handleTCPConn(localConn C.ServerAdapter) { func handleTCPConn(localConn C.ServerAdapter) {
defer localConn.Close() defer localConn.Close()
metadata := localConn.Metadata() metadata := localConn.Metadata()
@ -266,12 +267,12 @@ func (t *Tunnel) handleTCPConn(localConn C.ServerAdapter) {
return return
} }
if err := t.preHandleMetadata(metadata); err != nil { if err := preHandleMetadata(metadata); err != nil {
log.Debugln("[Metadata PreHandle] error: %s", err) log.Debugln("[Metadata PreHandle] error: %s", err)
return return
} }
proxy, rule, err := t.resolveMetadata(metadata) proxy, rule, err := resolveMetadata(metadata)
if err != nil { if err != nil {
log.Warnln("Parse metadata failed: %v", err) log.Warnln("Parse metadata failed: %v", err)
return return
@ -288,9 +289,9 @@ func (t *Tunnel) handleTCPConn(localConn C.ServerAdapter) {
switch true { switch true {
case rule != nil: case rule != nil:
log.Infoln("[TCP] %s --> %v match %s using %s", metadata.SourceAddress(), metadata.String(), rule.RuleType().String(), remoteConn.Chains().String()) log.Infoln("[TCP] %s --> %v match %s using %s", metadata.SourceAddress(), metadata.String(), rule.RuleType().String(), remoteConn.Chains().String())
case t.mode == Global: case mode == Global:
log.Infoln("[TCP] %s --> %v using GLOBAL", metadata.SourceAddress(), metadata.String()) log.Infoln("[TCP] %s --> %v using GLOBAL", metadata.SourceAddress(), metadata.String())
case t.mode == Direct: case mode == Direct:
log.Infoln("[TCP] %s --> %v using DIRECT", metadata.SourceAddress(), metadata.String()) log.Infoln("[TCP] %s --> %v using DIRECT", metadata.SourceAddress(), metadata.String())
default: default:
log.Infoln("[TCP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String()) log.Infoln("[TCP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String())
@ -298,33 +299,33 @@ func (t *Tunnel) handleTCPConn(localConn C.ServerAdapter) {
switch adapter := localConn.(type) { switch adapter := localConn.(type) {
case *inbound.HTTPAdapter: case *inbound.HTTPAdapter:
t.handleHTTP(adapter, remoteConn) handleHTTP(adapter, remoteConn)
case *inbound.SocketAdapter: case *inbound.SocketAdapter:
t.handleSocket(adapter, remoteConn) handleSocket(adapter, remoteConn)
} }
} }
func (t *Tunnel) shouldResolveIP(rule C.Rule, metadata *C.Metadata) bool { func shouldResolveIP(rule C.Rule, metadata *C.Metadata) bool {
return !rule.NoResolveIP() && metadata.Host != "" && metadata.DstIP == nil return !rule.NoResolveIP() && metadata.Host != "" && metadata.DstIP == nil
} }
func (t *Tunnel) match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
t.configMux.RLock() configMux.RLock()
defer t.configMux.RUnlock() defer configMux.RUnlock()
var resolved bool var resolved bool
if node := dns.DefaultHosts.Search(metadata.Host); node != nil { if node := resolver.DefaultHosts.Search(metadata.Host); node != nil {
ip := node.Data.(net.IP) ip := node.Data.(net.IP)
metadata.DstIP = ip metadata.DstIP = ip
resolved = true resolved = true
} }
for _, rule := range t.rules { for _, rule := range rules {
if !resolved && t.shouldResolveIP(rule, metadata) { if !resolved && shouldResolveIP(rule, metadata) {
ip, err := t.resolveIP(metadata.Host) ip, err := resolver.ResolveIP(metadata.Host)
if err != nil { if err != nil {
if !t.ignoreResolveFail { if !ignoreResolveFail {
return nil, nil, fmt.Errorf("[DNS] resolve %s error: %s", metadata.Host, err.Error()) return nil, nil, fmt.Errorf("[DNS] resolve %s error: %s", metadata.Host, err.Error())
} }
log.Debugln("[DNS] resolve %s error: %s", metadata.Host, err.Error()) log.Debugln("[DNS] resolve %s error: %s", metadata.Host, err.Error())
@ -336,7 +337,7 @@ func (t *Tunnel) match(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
} }
if rule.Match(metadata) { if rule.Match(metadata) {
adapter, ok := t.proxies[rule.Adapter()] adapter, ok := proxies[rule.Adapter()]
if !ok { if !ok {
continue continue
} }
@ -348,24 +349,6 @@ func (t *Tunnel) match(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
return adapter, rule, nil return adapter, rule, nil
} }
} }
return t.proxies["DIRECT"], nil, nil
}
func newTunnel() *Tunnel { return proxies["DIRECT"], nil, nil
return &Tunnel{
tcpQueue: channels.NewInfiniteChannel(),
udpQueue: channels.NewInfiniteChannel(),
natTable: nat.New(),
proxies: make(map[string]C.Proxy),
mode: Rule,
}
}
// Instance return singleton instance of Tunnel
func Instance() *Tunnel {
once.Do(func() {
tunnel = newTunnel()
go tunnel.process()
})
return tunnel
} }