Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions app/dns/cache_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func (c *CacheController) migrate() {
return
}

errors.LogDebug(context.Background(), c.name, " starting background cache migration for ", len(dirtyips), " items.")
errors.LogDebug(context.Background(), c.name, " starting background cache migration for ", len(dirtyips), " items")

batch := make([]migrationEntry, 0, migrationBatchSize)
for domain, recD := range dirtyips {
Expand All @@ -214,7 +214,7 @@ func (c *CacheController) migrate() {
c.dirtyips = nil
c.Unlock()

errors.LogDebug(context.Background(), c.name, " cache migration completed.")
errors.LogDebug(context.Background(), c.name, " cache migration completed")
}

func (c *CacheController) flush(batch []migrationEntry) {
Expand Down
151 changes: 86 additions & 65 deletions app/dns/config.pb.go

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions app/dns/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ message NameServer {
bool finalQuery = 12;
repeated xray.app.router.GeoIP unexpected_geoip = 13;
bool actUnprior = 14;
uint32 policyID = 17;
}

enum DomainMatchingType {
Expand Down Expand Up @@ -89,4 +90,6 @@ message Config {

bool disableFallback = 10;
bool disableFallbackIfMatch = 11;

bool enableParallelQuery = 14;
}
256 changes: 215 additions & 41 deletions app/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type DNS struct {
sync.Mutex
disableFallback bool
disableFallbackIfMatch bool
enableParallelQuery bool
ipOption *dns.IPOption
hosts *StaticHosts
clients []*Client
Expand Down Expand Up @@ -157,6 +158,7 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
matcherInfos: matcherInfos,
disableFallback: config.DisableFallback,
disableFallbackIfMatch: config.DisableFallbackIfMatch,
enableParallelQuery: config.EnableParallelQuery,
checkSystem: checkSystem,
}, nil
}
Expand Down Expand Up @@ -235,45 +237,11 @@ func (s *DNS) LookupIP(domain string, option dns.IPOption) ([]net.IP, uint32, er
}

// Name servers lookup
var errs []error
for _, client := range s.sortClients(domain) {
if !option.FakeEnable && strings.EqualFold(client.Name(), "FakeDNS") {
errors.LogDebug(s.ctx, "skip DNS resolution for domain ", domain, " at server ", client.Name())
continue
}

ips, ttl, err := client.QueryIP(s.ctx, domain, option)

if len(ips) > 0 {
if ttl == 0 {
ttl = 1
}
return ips, ttl, nil
}

errors.LogInfoInner(s.ctx, err, "failed to lookup ip for domain ", domain, " at server ", client.Name())
if err == nil {
err = dns.ErrEmptyResponse
}
errs = append(errs, err)

if client.IsFinalQuery() {
break
}
}

if len(errs) > 0 {
allErrs := errors.Combine(errs...)
err0 := errs[0]
if errors.AllEqual(err0, allErrs) {
if go_errors.Is(err0, dns.ErrEmptyResponse) {
return nil, 0, dns.ErrEmptyResponse
}
return nil, 0, errors.New("returning nil for domain ", domain).Base(err0)
}
return nil, 0, errors.New("returning nil for domain ", domain).Base(allErrs)
if s.enableParallelQuery {
return s.parallelQuery(domain, option)
} else {
return s.serialQuery(domain, option)
}
return nil, 0, dns.ErrEmptyResponse
}

func (s *DNS) sortClients(domain string) []*Client {
Expand All @@ -300,6 +268,9 @@ func (s *DNS) sortClients(domain string) []*Client {
clients = append(clients, client)
clientNames = append(clientNames, client.Name())
hasMatch = true
if client.finalQuery {
return clients
}
}

if !(s.disableFallback || s.disableFallbackIfMatch && hasMatch) {
Expand All @@ -311,6 +282,9 @@ func (s *DNS) sortClients(domain string) []*Client {
clientUsed[idx] = true
clients = append(clients, client)
clientNames = append(clientNames, client.Name())
if client.finalQuery {
return clients
}
}
}

Expand All @@ -322,14 +296,214 @@ func (s *DNS) sortClients(domain string) []*Client {
}

if len(clients) == 0 {
clients = append(clients, s.clients[0])
clientNames = append(clientNames, s.clients[0].Name())
errors.LogDebug(s.ctx, "domain ", domain, " will use the first DNS: ", clientNames)
if len(s.clients) > 0 {
clients = append(clients, s.clients[0])
clientNames = append(clientNames, s.clients[0].Name())
errors.LogWarning(s.ctx, "domain ", domain, " will use the first DNS: ", clientNames)
} else {
errors.LogError(s.ctx, "no DNS clients available for domain ", domain, " and no default clients configured")
}
}

return clients
}

func mergeQueryErrors(domain string, errs []error) error {
if len(errs) == 0 {
return dns.ErrEmptyResponse
}

var noRNF error
for _, err := range errs {
if go_errors.Is(err, errRecordNotFound) {
continue // server no response, ignore
} else if noRNF == nil {
noRNF = err
} else if !go_errors.Is(err, noRNF) {
return errors.New("returning nil for domain ", domain).Base(errors.Combine(errs...))
}
}
if go_errors.Is(noRNF, dns.ErrEmptyResponse) {
return dns.ErrEmptyResponse
}
if noRNF == nil {
noRNF = errRecordNotFound
}
return errors.New("returning nil for domain ", domain).Base(noRNF)
}

func (s *DNS) serialQuery(domain string, option dns.IPOption) ([]net.IP, uint32, error) {
var errs []error
for _, client := range s.sortClients(domain) {
if !option.FakeEnable && strings.EqualFold(client.Name(), "FakeDNS") {
errors.LogDebug(s.ctx, "skip DNS resolution for domain ", domain, " at server ", client.Name())
continue
}

ips, ttl, err := client.QueryIP(s.ctx, domain, option)

if len(ips) > 0 {
return ips, ttl, nil
}

errors.LogInfoInner(s.ctx, err, "failed to lookup ip for domain ", domain, " at server ", client.Name(), " in serial query mode")
if err == nil {
err = dns.ErrEmptyResponse
}
errs = append(errs, err)
}
return nil, 0, mergeQueryErrors(domain, errs)
}

func (s *DNS) parallelQuery(domain string, option dns.IPOption) ([]net.IP, uint32, error) {
var errs []error
clients := s.sortClients(domain)

resultsChan := asyncQueryAll(domain, option, clients, s.ctx)

groups, groupOf := makeGroups( /*s.ctx,*/ clients)
results := make([]*queryResult, len(clients))
pending := make([]int, len(groups))
for gi, g := range groups {
pending[gi] = g.end - g.start + 1
}

nextGroup := 0
for range clients {
result := <-resultsChan
results[result.index] = &result

gi := groupOf[result.index]
pending[gi]--

for nextGroup < len(groups) {
g := groups[nextGroup]

// group race, minimum rtt -> return
for j := g.start; j <= g.end; j++ {
r := results[j]
if r != nil && r.err == nil && len(r.ips) > 0 {
return r.ips, r.ttl, nil
}
}

// current group is incomplete and no one success -> continue pending
if pending[nextGroup] > 0 {
break
}

// all failed -> log and continue next group
for j := g.start; j <= g.end; j++ {
r := results[j]
e := r.err
if e == nil {
e = dns.ErrEmptyResponse
}
errors.LogInfoInner(s.ctx, e, "failed to lookup ip for domain ", domain, " at server ", clients[j].Name(), " in parallel query mode")
errs = append(errs, e)
}
nextGroup++
}
}

return nil, 0, mergeQueryErrors(domain, errs)
}

type queryResult struct {
ips []net.IP
ttl uint32
err error
index int
}

func asyncQueryAll(domain string, option dns.IPOption, clients []*Client, ctx context.Context) chan queryResult {
if len(clients) == 0 {
ch := make(chan queryResult)
close(ch)
return ch
}

ch := make(chan queryResult, len(clients))
for i, client := range clients {
if !option.FakeEnable && strings.EqualFold(client.Name(), "FakeDNS") {
errors.LogDebug(ctx, "skip DNS resolution for domain ", domain, " at server ", client.Name())
ch <- queryResult{err: dns.ErrEmptyResponse, index: i}
continue
}

go func(i int, c *Client) {
qctx := ctx
if !c.server.IsDisableCache() {
nctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), c.timeoutMs*2)
qctx = nctx
defer cancel()
}
ips, ttl, err := c.QueryIP(qctx, domain, option)
ch <- queryResult{ips: ips, ttl: ttl, err: err, index: i}
}(i, client)
}
return ch
}

type group struct{ start, end int }

// merge only adjacent and rule-equivalent Client into a single group
func makeGroups( /*ctx context.Context,*/ clients []*Client) ([]group, []int) {
n := len(clients)
if n == 0 {
return nil, nil
}
groups := make([]group, 0, n)
groupOf := make([]int, n)

s, e := 0, 0
for i := 1; i < n; i++ {
if clients[i-1].policyID == clients[i].policyID {
e = i
} else {
for k := s; k <= e; k++ {
groupOf[k] = len(groups)
}
groups = append(groups, group{start: s, end: e})
s, e = i, i
}
}
for k := s; k <= e; k++ {
groupOf[k] = len(groups)
}
groups = append(groups, group{start: s, end: e})

// var b strings.Builder
// b.WriteString("dns grouping: total clients=")
// b.WriteString(strconv.Itoa(n))
// b.WriteString(", groups=")
// b.WriteString(strconv.Itoa(len(groups)))

// for gi, g := range groups {
// b.WriteString("\n [")
// b.WriteString(strconv.Itoa(g.start))
// b.WriteString("..")
// b.WriteString(strconv.Itoa(g.end))
// b.WriteString("] gid=")
// b.WriteString(strconv.Itoa(gi))
// b.WriteString(" pid=")
// b.WriteString(strconv.FormatUint(uint64(clients[g.start].policyID), 10))
// b.WriteString(" members: ")

// for i := g.start; i <= g.end; i++ {
// if i > g.start {
// b.WriteString(", ")
// }
// b.WriteString(strconv.Itoa(i))
// b.WriteByte(':')
// b.WriteString(clients[i].Name())
// }
// }
// errors.LogDebug(ctx, b.String())

return groups, groupOf
}

func init() {
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
return New(ctx, config.(*Config))
Expand Down
9 changes: 5 additions & 4 deletions app/dns/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ import (
type Server interface {
// Name of the Client.
Name() string

IsDisableCache() bool

// QueryIP sends IP queries to its configured server.
QueryIP(ctx context.Context, domain string, option dns.IPOption) ([]net.IP, uint32, error)
}
Expand All @@ -38,6 +41,7 @@ type Client struct {
finalQuery bool
ipOption *dns.IPOption
checkSystem bool
policyID uint32
}

// NewServer creates a name server object according to the network destination url.
Expand Down Expand Up @@ -199,6 +203,7 @@ func NewClient(
client.finalQuery = ns.FinalQuery
client.ipOption = &ipOption
client.checkSystem = checkSystem
client.policyID = ns.PolicyID
return nil
})
return client, err
Expand All @@ -209,10 +214,6 @@ func (c *Client) Name() string {
return c.server.Name()
}

func (c *Client) IsFinalQuery() bool {
return c.finalQuery
}

// QueryIP sends DNS query to the name server with the client's IP.
func (c *Client) QueryIP(ctx context.Context, domain string, option dns.IPOption) ([]net.IP, uint32, error) {
if c.checkSystem {
Expand Down
4 changes: 2 additions & 2 deletions app/dns/nameserver_cached.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ func queryIP(ctx context.Context, s CachedNameserver, domain string, option dns.
ips, ttl, err := merge(option, rec.A, rec.AAAA)
if !go_errors.Is(err, errRecordNotFound) {
if ttl > 0 {
// errors.LogDebugInner(ctx, err, cache.name, " cache HIT ", fqdn, " -> ", ips)
errors.LogDebugInner(ctx, err, cache.name, " cache HIT ", fqdn, " -> ", ips)
log.Record(&log.DNSLog{Server: cache.name, Domain: fqdn, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
return ips, uint32(ttl), err
}
if cache.serveStale && (cache.serveExpiredTTL == 0 || cache.serveExpiredTTL < ttl) {
// errors.LogDebugInner(ctx, err, cache.name, " cache OPTIMISTE ", fqdn, " -> ", ips)
errors.LogDebugInner(ctx, err, cache.name, " cache OPTIMISTE ", fqdn, " -> ", ips)
log.Record(&log.DNSLog{Server: cache.name, Domain: fqdn, Result: ips, Status: log.DNSCacheOptimiste, Elapsed: 0, Error: err})
go pull(ctx, s, fqdn, option)
return ips, 1, err
Expand Down
Loading