Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 62 additions & 15 deletions agent/utils/ssl/manual_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,39 @@ type Resolve struct {
}

func (c *ManualClient) GetDNSResolve(ctx context.Context, websiteSSL *model.WebsiteSSL) (map[string]Resolve, error) {
order, err := c.client.AuthorizeOrder(ctx, acme.DomainIDs(getWebsiteSSLDomains(websiteSSL)...))
var order *acme.Order
var err error

// Check if we have an existing valid order for this SSL
existingOrder, exists := Orders[websiteSSL.ID]
if exists && existingOrder != nil {
// Verify the order is still valid (not expired and still pending)
// If Expires is zero, order is still valid (ACME doesn't always set expiry immediately)
isNotExpired := existingOrder.Expires.IsZero() || existingOrder.Expires.After(time.Now())
if isNotExpired && (existingOrder.Status == acme.StatusPending || existingOrder.Status == acme.StatusReady) {
// Try to reuse the existing order
records, err := c.extractDNSChallenges(ctx, existingOrder)
if err == nil && len(records) > 0 {
return records, nil
}
// If extraction failed, fall through to create a new order
}
// Existing order is expired or invalid, remove it
delete(Orders, websiteSSL.ID)
}

// Create a new order
order, err = c.client.AuthorizeOrder(ctx, acme.DomainIDs(getWebsiteSSLDomains(websiteSSL)...))
if err != nil {
return nil, err
}
Orders[websiteSSL.ID] = order

return c.extractDNSChallenges(ctx, order)
}

// extractDNSChallenges extracts DNS-01 challenge values from an ACME order
func (c *ManualClient) extractDNSChallenges(ctx context.Context, order *acme.Order) (map[string]Resolve, error) {
records := make(map[string]Resolve)

for _, authzURL := range order.AuthzURLs {
Expand All @@ -103,23 +130,30 @@ func (c *ManualClient) GetDNSResolve(ctx context.Context, websiteSSL *model.Webs
return nil, err
}

records[domain] = Resolve{
// Use different map key for wildcard vs non-wildcard to avoid overwriting
// Both use the same DNS record name (_acme-challenge.domain) but different values
mapKey := domain
if authz.Wildcard {
mapKey = "*." + domain
}

records[mapKey] = Resolve{
Key: fmt.Sprintf("_acme-challenge.%s", domain),
Value: txtValue,
}
}
return records, nil
}

func queryDNSRecords(domain string) (map[string]string, error) {
func queryDNSRecords(domain string) (map[string][]string, error) {
recordName := fmt.Sprintf("_acme-challenge.%s", domain)
txts, err := net.LookupTXT(recordName)
if err != nil {
return nil, err
}
records := make(map[string]string)
records := make(map[string][]string)
if len(txts) > 0 {
records[recordName] = txts[0]
records[recordName] = txts
}
return records, nil
}
Expand Down Expand Up @@ -159,7 +193,7 @@ func (c *ManualClient) handleAuthorization(ctx context.Context, authzURL string,

for {
c.logger.Printf("[INFO] [%s] acme: Checking DNS record propagation.", domain)
var currentRecords map[string]string
var currentRecords map[string][]string
var queryErr error
if len(nameservers) == 0 {
currentRecords, queryErr = queryDNSRecords(domain)
Expand All @@ -177,16 +211,26 @@ func (c *ManualClient) handleAuthorization(ctx context.Context, authzURL string,
return fmt.Errorf("failed to query DNS records: %v", queryErr)
}
recordName := fmt.Sprintf("_acme-challenge.%s", domain)
providedRecord, exists := currentRecords[recordName]
if exists && providedRecord == expectedRecord {
providedRecords, exists := currentRecords[recordName]
// Check if expected record is in any of the TXT values
found := false
if exists {
for _, record := range providedRecords {
if record == expectedRecord {
found = true
break
}
}
}
if found {
break
}
if time.Now().After(deadline) {
if !exists {
if !exists || len(providedRecords) == 0 {
return fmt.Errorf("TXT record not provided for domain %s after retrying", domain)
}
c.logger.Printf("[INFO] [%s] TXT record mismatch for %s: expected %s, got %s\"", domain, domain, expectedRecord, providedRecord)
return fmt.Errorf("TXT record mismatch for %s: expected %s, got %s", domain, expectedRecord, providedRecord)
c.logger.Printf("[INFO] [%s] TXT record mismatch for %s: expected %s, got %v", domain, domain, expectedRecord, providedRecords)
return fmt.Errorf("TXT record mismatch for %s: expected %s, got %v", domain, expectedRecord, providedRecords)
}
time.Sleep(pollingInterval)
}
Expand Down Expand Up @@ -339,7 +383,7 @@ func handleNameserver(nameserver string) string {
return fmt.Sprintf("%s:53", nameserver)
}

func queryDNSRecordsWithResolver(ctx context.Context, logger *log.Logger, domain string, dnsServer string) (map[string]string, error) {
func queryDNSRecordsWithResolver(ctx context.Context, logger *log.Logger, domain string, dnsServer string) (map[string][]string, error) {
recordName := fmt.Sprintf("_acme-challenge.%s", domain)
c := new(dns.Client)
c.Timeout = 10 * time.Second
Expand Down Expand Up @@ -367,16 +411,19 @@ func queryDNSRecordsWithResolver(ctx context.Context, logger *log.Logger, domain
return nil, fmt.Errorf("DNS query failed with code: %s", dns.RcodeToString[r.Rcode])
}

records := make(map[string]string)
records := make(map[string][]string)
var txtValues []string

for _, answer := range r.Answer {
if txt, ok := answer.(*dns.TXT); ok {
if len(txt.Txt) > 0 {
records[recordName] = txt.Txt[0]
break
txtValues = append(txtValues, txt.Txt[0])
}
}
}
if len(txtValues) > 0 {
records[recordName] = txtValues
}

return records, nil
}
Expand Down