Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions app/dns/dnscommon.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ type record struct {

// IPRecord is a cacheable item for a resolved domain
type IPRecord struct {
ReqID uint16
IP []net.Address
Expire time.Time
RCode dnsmessage.RCode
ReqID uint16
IP []net.Address
Expire time.Time
RCode dnsmessage.RCode
RawHeader *dnsmessage.Header
}

func (r *IPRecord) getIPs() ([]net.Address, error) {
Expand Down Expand Up @@ -179,9 +180,10 @@ func parseResponse(payload []byte) (*IPRecord, error) {

now := time.Now()
ipRecord := &IPRecord{
ReqID: h.ID,
RCode: h.RCode,
Expire: now.Add(time.Second * 600),
ReqID: h.ID,
RCode: h.RCode,
Expire: now.Add(time.Second * 600),
RawHeader: &h,
}

L:
Expand Down
8 changes: 5 additions & 3 deletions app/dns/dnscommon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func Test_parseResponse(t *testing.T) {
}{
{
"empty",
&IPRecord{0, []net.Address(nil), time.Time{}, dnsmessage.RCodeSuccess},
&IPRecord{0, []net.Address(nil), time.Time{}, dnsmessage.RCodeSuccess, nil},
false,
},
{
Expand All @@ -66,12 +66,13 @@ func Test_parseResponse(t *testing.T) {
[]net.Address{net.ParseAddress("8.8.8.8"), net.ParseAddress("8.8.4.4")},
time.Time{},
dnsmessage.RCodeSuccess,
nil,
},
false,
},
{
"aaaa record",
&IPRecord{2, []net.Address{net.ParseAddress("2001::123:8888"), net.ParseAddress("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess},
&IPRecord{2, []net.Address{net.ParseAddress("2001::123:8888"), net.ParseAddress("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess, nil},
false,
},
}
Expand All @@ -84,8 +85,9 @@ func Test_parseResponse(t *testing.T) {
}

if got != nil {
// reset the time
// reset the time and RawHeader
got.Expire = time.Time{}
got.RawHeader = nil
}
if cmp.Diff(got, tt.want) != "" {
t.Error(cmp.Diff(got, tt.want))
Expand Down
40 changes: 35 additions & 5 deletions app/dns/nameserver_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,19 @@ type ClassicNameServer struct {
name string
address *net.Destination
ips map[string]*record
requests map[uint16]*dnsRequest
requests map[uint16]*udpDnsRequest
pub *pubsub.Service
udpServer *udp.Dispatcher
cleanup *task.Periodic
reqID uint32
queryStrategy QueryStrategy
}

type udpDnsRequest struct {
dnsRequest
ctx context.Context
}

// NewClassicNameServer creates udp server object for remote resolving.
func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy) *ClassicNameServer {
// default to 53 if unspecific
Expand All @@ -45,7 +50,7 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
s := &ClassicNameServer{
address: &address,
ips: make(map[string]*record),
requests: make(map[uint16]*dnsRequest),
requests: make(map[uint16]*udpDnsRequest),
pub: pubsub.NewService(),
name: strings.ToUpper(address.String()),
queryStrategy: queryStrategy,
Expand Down Expand Up @@ -101,7 +106,7 @@ func (s *ClassicNameServer) Cleanup() error {
}

if len(s.requests) == 0 {
s.requests = make(map[uint16]*dnsRequest)
s.requests = make(map[uint16]*udpDnsRequest)
}

return nil
Expand All @@ -128,6 +133,27 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
return
}

// if truncated, retry with EDNS0 option(udp payload size: 1350)
if ipRec.RawHeader.Truncated {
// if already has EDNS0 option, no need to retry
if ok && len(req.msg.Additionals) == 0 {
// copy necessary meta data from original request
// and add EDNS0 option
opt := new(dnsmessage.Resource)
common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
opt.Body = &dnsmessage.OPTResource{}
newMsg := *req.msg
newReq := *req
newMsg.Additionals = append(newMsg.Additionals, *opt)
newMsg.ID = s.newReqID()
newReq.msg = &newMsg
s.addPendingRequest(&newReq)
b, _ := dns.PackMessage(newReq.msg)
s.udpServer.Dispatch(toDnsContext(newReq.ctx, s.address.String()), *s.address, b)
return
}
}

var rec record
switch req.reqType {
case dnsmessage.TypeA:
Expand Down Expand Up @@ -179,7 +205,7 @@ func (s *ClassicNameServer) newReqID() uint16 {
return uint16(atomic.AddUint32(&s.reqID, 1))
}

func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) {
func (s *ClassicNameServer) addPendingRequest(req *udpDnsRequest) {
s.Lock()
defer s.Unlock()

Expand All @@ -194,7 +220,11 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, client
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP))

for _, req := range reqs {
s.addPendingRequest(req)
udpReq := &udpDnsRequest{
dnsRequest: *req,
ctx: ctx,
}
s.addPendingRequest(udpReq)
b, _ := dns.PackMessage(req.msg)
s.udpServer.Dispatch(toDnsContext(ctx, s.address.String()), *s.address, b)
}
Expand Down