From ca5399a16ed7f54ea3abc281e88c026253c89472 Mon Sep 17 00:00:00 2001 From: Dreamacro <305009791@qq.com> Date: Sat, 23 Feb 2019 20:31:59 +0800 Subject: [PATCH] Fix: dns cache behavior --- dns/client.go | 23 ++++++++--------------- dns/util.go | 28 ++++++++++++++++++++++++---- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/dns/client.go b/dns/client.go index cfae7b9f8..234f14f57 100644 --- a/dns/client.go +++ b/dns/client.go @@ -12,7 +12,6 @@ import ( "github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/picker" C "github.com/Dreamacro/clash/constant" - "github.com/Dreamacro/clash/log" D "github.com/miekg/dns" geoip2 "github.com/oschwald/geoip2-golang" @@ -55,23 +54,17 @@ func (r *Resolver) Exchange(m *D.Msg) (msg *D.Msg, err error) { cache, expireTime := r.cache.GetWithExpire(q.String()) if cache != nil { msg = cache.(*D.Msg).Copy() - if len(msg.Answer) > 0 { - ttl := uint32(expireTime.Sub(time.Now()).Seconds()) - for _, answer := range msg.Answer { - answer.Header().Ttl = ttl - } - } + setMsgTTL(msg, uint32(expireTime.Sub(time.Now()).Seconds())) return } defer func() { - if msg != nil { - putMsgToCache(r.cache, q.String(), msg) - if r.mapping { - ips, err := r.msgToIP(msg) - if err != nil { - log.Debugln("[DNS] msg to ip error: %s", err.Error()) - return - } + if msg == nil { + return + } + + putMsgToCache(r.cache, q.String(), msg) + if r.mapping { + if ips, err := r.msgToIP(msg); err == nil { for _, ip := range ips { putMsgToCache(r.cache, ip.String(), msg) } diff --git a/dns/util.go b/dns/util.go index 4102a584e..2eee70c47 100644 --- a/dns/util.go +++ b/dns/util.go @@ -79,11 +79,31 @@ func (e EnhancedMode) String() string { } func putMsgToCache(c *cache.Cache, key string, msg *D.Msg) { - if len(msg.Answer) == 0 { - log.Debugln("[DNS] answer length is zero: %#v", msg) + var ttl time.Duration + if len(msg.Answer) != 0 { + ttl = time.Duration(msg.Answer[0].Header().Ttl) * time.Second + } else if len(msg.Ns) != 0 { + ttl = time.Duration(msg.Ns[0].Header().Ttl) * time.Second + } else if len(msg.Extra) != 0 { + ttl = time.Duration(msg.Extra[0].Header().Ttl) * time.Second + } else { + log.Debugln("[DNS] response msg error: %#v", msg) return } - ttl := time.Duration(msg.Answer[0].Header().Ttl) * time.Second - c.Put(key, msg, ttl) + c.Put(key, msg.Copy(), ttl) +} + +func setMsgTTL(msg *D.Msg, ttl uint32) { + for _, answer := range msg.Answer { + answer.Header().Ttl = ttl + } + + for _, ns := range msg.Ns { + ns.Header().Ttl = ttl + } + + for _, extra := range msg.Extra { + extra.Header().Ttl = ttl + } }