chore: share RelayDnsPacket function code

This commit is contained in:
wwqgtxx 2024-03-04 22:12:08 +08:00
parent fe4acebb8b
commit 8b9813079b
3 changed files with 102 additions and 108 deletions

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"time" "time"
"github.com/metacubex/mihomo/common/pool" "github.com/metacubex/mihomo/common/pool"
@ -12,8 +11,6 @@ import (
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
D "github.com/miekg/dns"
) )
type Dns struct { type Dns struct {
@ -79,12 +76,12 @@ func (d *dnsPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
} }
func (d *dnsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (d *dnsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
ctx, cancel := context.WithTimeout(d.ctx, time.Second*5) ctx, cancel := context.WithTimeout(d.ctx, resolver.DefaultDnsRelayTimeout)
defer cancel() defer cancel()
buf := pool.Get(2048) buf := pool.Get(resolver.SafeDnsPacketSize)
put := func() { _ = pool.Put(buf) } put := func() { _ = pool.Put(buf) }
buf, err = RelayDnsPacket(ctx, p, buf) buf, err = resolver.RelayDnsPacket(ctx, p, buf)
if err != nil { if err != nil {
put() put()
return 0, err return 0, err
@ -110,7 +107,11 @@ func (d *dnsPacketConn) Close() error {
} }
func (*dnsPacketConn) LocalAddr() net.Addr { func (*dnsPacketConn) LocalAddr() net.Addr {
return net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:53")) return &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 53,
Zone: "",
}
} }
func (*dnsPacketConn) SetDeadline(t time.Time) error { func (*dnsPacketConn) SetDeadline(t time.Time) error {
@ -139,22 +140,3 @@ func NewDnsWithOption(option DnsOption) *Dns {
}, },
} }
} }
// copied from listener/sing_mux/dns.go
func RelayDnsPacket(ctx context.Context, payload []byte, target []byte) ([]byte, error) {
msg := &D.Msg{}
if err := msg.Unpack(payload); err != nil {
return nil, err
}
r, err := resolver.ServeMsg(ctx, msg)
if err != nil {
m := new(D.Msg)
m.SetRcode(msg, D.RcodeServerFailure)
return m.PackBuffer(target)
}
r.SetRcode(msg, r.Rcode)
r.Compress = true
return r.PackBuffer(target)
}

View File

@ -0,0 +1,88 @@
package resolver
import (
"context"
"encoding/binary"
"io"
"net"
"time"
"github.com/metacubex/mihomo/common/pool"
D "github.com/miekg/dns"
)
const DefaultDnsReadTimeout = time.Second * 10
const DefaultDnsRelayTimeout = time.Second * 5
const SafeDnsPacketSize = 2 * 1024 // safe size which is 1232 from https://dnsflagday.net/2020/, so 2048 is enough
func RelayDnsConn(ctx context.Context, conn net.Conn) error {
buff := pool.Get(pool.UDPBufferSize)
defer func() {
_ = pool.Put(buff)
_ = conn.Close()
}()
for {
if conn.SetReadDeadline(time.Now().Add(DefaultDnsReadTimeout)) != nil {
break
}
length := uint16(0)
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
break
}
if int(length) > len(buff) {
break
}
n, err := io.ReadFull(conn, buff[:length])
if err != nil {
break
}
err = func() error {
ctx, cancel := context.WithTimeout(ctx, DefaultDnsRelayTimeout)
defer cancel()
inData := buff[:n]
msg, err := RelayDnsPacket(ctx, inData, buff)
if err != nil {
return err
}
err = binary.Write(conn, binary.BigEndian, uint16(len(msg)))
if err != nil {
return err
}
_, err = conn.Write(msg)
if err != nil {
return err
}
return nil
}()
if err != nil {
return err
}
}
return nil
}
func RelayDnsPacket(ctx context.Context, payload []byte, target []byte) ([]byte, error) {
msg := &D.Msg{}
if err := msg.Unpack(payload); err != nil {
return nil, err
}
r, err := ServeMsg(ctx, msg)
if err != nil {
m := new(D.Msg)
m.SetRcode(msg, D.RcodeServerFailure)
return m.PackBuffer(target)
}
r.SetRcode(msg, r.Rcode)
r.Compress = true
return r.PackBuffer(target)
}

View File

@ -2,29 +2,21 @@ package sing_tun
import ( import (
"context" "context"
"encoding/binary"
"io"
"net" "net"
"net/netip" "net/netip"
"sync" "sync"
"time" "time"
"github.com/metacubex/mihomo/common/pool"
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
"github.com/metacubex/mihomo/listener/sing" "github.com/metacubex/mihomo/listener/sing"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
D "github.com/miekg/dns"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/network"
) )
const DefaultDnsReadTimeout = time.Second * 10
const DefaultDnsRelayTimeout = time.Second * 5
type ListenerHandler struct { type ListenerHandler struct {
*sing.ListenerHandler *sing.ListenerHandler
DnsAdds []netip.AddrPort DnsAdds []netip.AddrPort
@ -45,61 +37,11 @@ func (h *ListenerHandler) ShouldHijackDns(targetAddr netip.AddrPort) bool {
func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
if h.ShouldHijackDns(metadata.Destination.AddrPort()) { if h.ShouldHijackDns(metadata.Destination.AddrPort()) {
log.Debugln("[DNS] hijack tcp:%s", metadata.Destination.String()) log.Debugln("[DNS] hijack tcp:%s", metadata.Destination.String())
buff := pool.Get(pool.UDPBufferSize) return resolver.RelayDnsConn(ctx, conn)
defer func() {
_ = pool.Put(buff)
_ = conn.Close()
}()
for {
if conn.SetReadDeadline(time.Now().Add(DefaultDnsReadTimeout)) != nil {
break
}
length := uint16(0)
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
break
}
if int(length) > len(buff) {
break
}
n, err := io.ReadFull(conn, buff[:length])
if err != nil {
break
}
err = func() error {
ctx, cancel := context.WithTimeout(ctx, DefaultDnsRelayTimeout)
defer cancel()
inData := buff[:n]
msg, err := RelayDnsPacket(ctx, inData, buff)
if err != nil {
return err
}
err = binary.Write(conn, binary.BigEndian, uint16(len(msg)))
if err != nil {
return err
}
_, err = conn.Write(msg)
if err != nil {
return err
}
return nil
}()
if err != nil {
return err
}
}
return nil
} }
return h.ListenerHandler.NewConnection(ctx, conn, metadata) return h.ListenerHandler.NewConnection(ctx, conn, metadata)
} }
const SafeDnsPacketSize = 2 * 1024 // safe size which is 1232 from https://dnsflagday.net/2020/, so 2048 is enough
func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.PacketConn, metadata M.Metadata) error { func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.PacketConn, metadata M.Metadata) error {
if h.ShouldHijackDns(metadata.Destination.AddrPort()) { if h.ShouldHijackDns(metadata.Destination.AddrPort()) {
log.Debugln("[DNS] hijack udp:%s from %s", metadata.Destination.String(), metadata.Source.String()) log.Debugln("[DNS] hijack udp:%s from %s", metadata.Destination.String(), metadata.Source.String())
@ -114,7 +56,7 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.
rwOptions := network.ReadWaitOptions{ rwOptions := network.ReadWaitOptions{
FrontHeadroom: network.CalculateFrontHeadroom(conn), FrontHeadroom: network.CalculateFrontHeadroom(conn),
RearHeadroom: network.CalculateRearHeadroom(conn), RearHeadroom: network.CalculateRearHeadroom(conn),
MTU: SafeDnsPacketSize, MTU: resolver.SafeDnsPacketSize,
} }
readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(conn) readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(conn)
if isReadWaiter { if isReadWaiter {
@ -126,7 +68,7 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.
dest M.Socksaddr dest M.Socksaddr
err error err error
) )
_ = conn.SetReadDeadline(time.Now().Add(DefaultDnsReadTimeout)) _ = conn.SetReadDeadline(time.Now().Add(resolver.DefaultDnsReadTimeout))
readBuff = nil // clear last loop status, avoid repeat release readBuff = nil // clear last loop status, avoid repeat release
if isReadWaiter { if isReadWaiter {
readBuff, dest, err = readWaiter.WaitReadPacket() readBuff, dest, err = readWaiter.WaitReadPacket()
@ -147,15 +89,15 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.
return err return err
} }
go func() { go func() {
ctx, cancel := context.WithTimeout(ctx, DefaultDnsRelayTimeout) ctx, cancel := context.WithTimeout(ctx, resolver.DefaultDnsRelayTimeout)
defer cancel() defer cancel()
inData := readBuff.Bytes() inData := readBuff.Bytes()
writeBuff := readBuff writeBuff := readBuff
writeBuff.Resize(writeBuff.Start(), 0) writeBuff.Resize(writeBuff.Start(), 0)
if len(writeBuff.FreeBytes()) < SafeDnsPacketSize { // only create a new buffer when space don't enough if len(writeBuff.FreeBytes()) < resolver.SafeDnsPacketSize { // only create a new buffer when space don't enough
writeBuff = rwOptions.NewPacketBuffer() writeBuff = rwOptions.NewPacketBuffer()
} }
msg, err := RelayDnsPacket(ctx, inData, writeBuff.FreeBytes()) msg, err := resolver.RelayDnsPacket(ctx, inData, writeBuff.FreeBytes())
if writeBuff != readBuff { if writeBuff != readBuff {
readBuff.Release() readBuff.Release()
} }
@ -182,21 +124,3 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.
} }
return h.ListenerHandler.NewPacketConnection(ctx, conn, metadata) return h.ListenerHandler.NewPacketConnection(ctx, conn, metadata)
} }
func RelayDnsPacket(ctx context.Context, payload []byte, target []byte) ([]byte, error) {
msg := &D.Msg{}
if err := msg.Unpack(payload); err != nil {
return nil, err
}
r, err := resolver.ServeMsg(ctx, msg)
if err != nil {
m := new(D.Msg)
m.SetRcode(msg, D.RcodeServerFailure)
return m.PackBuffer(target)
}
r.SetRcode(msg, r.Rcode)
r.Compress = true
return r.PackBuffer(target)
}