Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 51 additions & 21 deletions proxy/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
go_errors "errors"
"io"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -168,11 +169,15 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
}

ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout)
terminate := func() {
cancel()
conn.Close()
}
timer := signal.CancelAfterInactivity(ctx, terminate, h.timeout)
defer timer.SetTimeout(0)

request := func() error {
defer conn.Close()

defer timer.SetTimeout(0)
for {
b, err := reader.ReadMessage()
if err == io.EOF {
Expand All @@ -190,24 +195,33 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
if len(h.blockTypes) > 0 {
for _, blocktype := range h.blockTypes {
if blocktype == int32(qType) {
b.Release()
errors.LogInfo(ctx, "blocked type ", qType, " query for domain ", domain)
if h.nonIPQuery == "reject" {
go h.rejectNonIPQuery(id, qType, domain, writer)
err := h.rejectNonIPQuery(id, qType, domain, writer)
if err != nil {
return err
}
}
errors.LogInfo(ctx, "blocked type ", qType, " query for domain ", domain)
return nil
}
}
}
if isIPQuery {
go h.handleIPQuery(id, qType, domain, writer)
b.Release()
go h.handleIPQuery(id, qType, domain, writer, timer)
continue
}
if isIPQuery || h.nonIPQuery == "drop" {
if h.nonIPQuery == "drop" {
b.Release()
continue
}
if h.nonIPQuery == "reject" {
go h.rejectNonIPQuery(id, qType, domain, writer)
b.Release()
err := h.rejectNonIPQuery(id, qType, domain, writer)
if err != nil {
return err
}
continue
}
}
Expand All @@ -219,6 +233,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
}

response := func() error {
defer timer.SetTimeout(0)
for {
b, err := connReader.ReadMessage()
if err == io.EOF {
Expand All @@ -244,7 +259,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
return nil
}

func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) {
func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter, timer *signal.ActivityTimer) {
var ips []net.IP
var err error

Expand Down Expand Up @@ -319,16 +334,21 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string,
if err != nil {
errors.LogInfoInner(context.Background(), err, "pack message")
b.Release()
return
timer.SetTimeout(0)
}
b.Resize(0, int32(len(msgBytes)))

if err := writer.WriteMessage(b); err != nil {
errors.LogInfoInner(context.Background(), err, "write IP answer")
timer.SetTimeout(0)
}
}

func (h *Handler) rejectNonIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) {
func (h *Handler) rejectNonIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) error {
domainT := strings.TrimSuffix(domain, ".")
if domainT == "" {
return errors.New("empty domain name")
}
b := buf.New()
rawBytes := b.Extend(buf.Size)
builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{
Expand All @@ -349,20 +369,22 @@ func (h *Handler) rejectNonIPQuery(id uint16, qType dnsmessage.Type, domain stri
if err != nil {
errors.LogInfo(context.Background(), "unexpected domain ", domain, " when building reject message: ", err)
b.Release()
return
return err
}

msgBytes, err := builder.Finish()
if err != nil {
errors.LogInfoInner(context.Background(), err, "pack reject message")
b.Release()
return
return err
}
b.Resize(0, int32(len(msgBytes)))

if err := writer.WriteMessage(b); err != nil {
errors.LogInfoInner(context.Background(), err, "write reject answer")
return err
}
return nil
}

type outboundConn struct {
Expand All @@ -371,6 +393,7 @@ type outboundConn struct {

conn net.Conn
connReady chan struct{}
closed bool
}

func (c *outboundConn) dial() error {
Expand All @@ -385,12 +408,16 @@ func (c *outboundConn) dial() error {

func (c *outboundConn) Write(b []byte) (int, error) {
c.access.Lock()
if c.closed {
c.access.Unlock()
return 0, errors.New("outbound connection closed")
}

if c.conn == nil {
if err := c.dial(); err != nil {
c.access.Unlock()
errors.LogWarningInner(context.Background(), err, "failed to dial outbound connection")
return len(b), nil
return 0, err
}
}

Expand All @@ -400,24 +427,27 @@ func (c *outboundConn) Write(b []byte) (int, error) {
}

func (c *outboundConn) Read(b []byte) (int, error) {
var conn net.Conn
c.access.Lock()
conn = c.conn
c.access.Unlock()
if c.closed {
c.access.Unlock()
return 0, io.EOF
}

if conn == nil {
if c.conn == nil {
c.access.Unlock()
_, open := <-c.connReady
if !open {
return 0, io.EOF
}
conn = c.conn
return c.conn.Read(b)
}

return conn.Read(b)
c.access.Unlock()
return c.conn.Read(b)
}

func (c *outboundConn) Close() error {
c.access.Lock()
c.closed = true
close(c.connReady)
if c.conn != nil {
c.conn.Close()
Expand Down
Loading