diff --git a/dns/client.go b/dns/client.go index fc76c1241..3b4efed10 100644 --- a/dns/client.go +++ b/dns/client.go @@ -78,7 +78,9 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) options = append(options, dialer.WithInterface(c.iface)) } - conn, err := getDialHandler(c.r, c.proxyAdapter, c.proxyName, options...)(ctx, network, net.JoinHostPort(ip.String(), c.port)) + dialHandler := getDialHandler(c.r, c.proxyAdapter, c.proxyName, options...) + addr := net.JoinHostPort(ip.String(), c.port) + conn, err := dialHandler(ctx, network, addr) if err != nil { return nil, err } @@ -111,7 +113,16 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) if msg != nil && msg.Truncated && c.Client.Net == "" { tcpClient := *c.Client // copy a client tcpClient.Net = "tcp" + network = "tcp" log.Debugln("[DNS] Truncated reply from %s:%s for %s over UDP, retrying over TCP", c.host, c.port, m.Question[0].String()) + dConn.Conn, err = dialHandler(ctx, network, addr) + if err != nil { + ch <- result{msg, err} + return + } + defer func() { + _ = conn.Close() + }() msg, _, err = tcpClient.ExchangeWithConn(m, dConn) }