Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions infra/conf/transport_internet.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,9 @@ type TLSConfig struct {
MasterKeyLog string `json:"masterKeyLog"`
ServerNameToVerify string `json:"serverNameToVerify"`
VerifyPeerCertInNames []string `json:"verifyPeerCertInNames"`
ECHConfigList string `json:"echConfigList"`
ECHServerKeys string `json:"echServerKeys"`
ECHConfigList string `json:"echConfigList"`
ECHForceQuery bool `json:"echForceQuery"`
}

// Build implements Buildable.
Expand Down Expand Up @@ -485,15 +486,15 @@ func (c *TLSConfig) Build() (proto.Message, error) {
}
config.VerifyPeerCertInNames = c.VerifyPeerCertInNames

config.EchConfigList = c.ECHConfigList

if c.ECHServerKeys != "" {
EchPrivateKey, err := base64.StdEncoding.DecodeString(c.ECHServerKeys)
if err != nil {
return nil, errors.New("invalid ECH Config", c.ECHServerKeys)
}
config.EchServerKeys = EchPrivateKey
}
config.EchForceQuery = c.ECHForceQuery
config.EchConfigList = c.ECHConfigList

return config, nil
}
Expand Down
50 changes: 30 additions & 20 deletions transport/internet/tls/config.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions transport/internet/tls/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ message Config {
*/
repeated string verify_peer_cert_in_names = 17;

string ech_config_list = 18;
bytes ech_server_keys = 18;

bytes ech_server_keys = 19;
}
string ech_config_list = 19;

bool ech_force_query = 20;
}
83 changes: 52 additions & 31 deletions transport/internet/tls/ech.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,26 @@ func ApplyECH(c *Config, config *tls.Config) error {
nameToQuery := c.ServerName
var DNSServer string

// for server
if len(c.EchServerKeys) != 0 {
KeySets, err := ConvertToGoECHKeys(c.EchServerKeys)
if err != nil {
return errors.New("Failed to unmarshal ECHKeySetList: ", err)
}
config.EncryptedClientHelloKeys = KeySets
}

// for client
if len(c.EchConfigList) != 0 {
defer func() {
// if failed to get ECHConfig, use an invalid one to make connection fail
if err != nil {
if c.EchForceQuery {
ECHConfig = []byte{1, 1, 4, 5, 1, 4}
}
}
config.EncryptedClientHelloConfigList = ECHConfig
}()
// direct base64 config
if strings.Contains(c.EchConfigList, "://") {
// query config from dns
Expand All @@ -51,7 +69,7 @@ func ApplyECH(c *Config, config *tls.Config) error {
if nameToQuery == "" {
return errors.New("Using DNS for ECH Config needs serverName or use Server format example.com+https://1.1.1.1/dns-query")
}
ECHConfig, err = QueryRecord(nameToQuery, DNSServer)
ECHConfig, err = QueryRecord(nameToQuery, DNSServer, c.EchForceQuery)
if err != nil {
return err
}
Expand All @@ -61,17 +79,6 @@ func ApplyECH(c *Config, config *tls.Config) error {
return errors.New("Failed to unmarshal ECHConfigList: ", err)
}
}

config.EncryptedClientHelloConfigList = ECHConfig
}

// for server
if len(c.EchServerKeys) != 0 {
KeySets, err := ConvertToGoECHKeys(c.EchServerKeys)
if err != nil {
return errors.New("Failed to unmarshal ECHKeySetList: ", err)
}
config.EncryptedClientHelloKeys = KeySets
}

return nil
Expand All @@ -86,17 +93,19 @@ type ECHConfigCache struct {
type echConfigRecord struct {
config []byte
expire time.Time
err error
}

var (
// key value must be like this: "example.com|udp://1.1.1.1"
GlobalECHConfigCache = utils.NewTypedSyncMap[string, *ECHConfigCache]()
clientForECHDOH = utils.NewTypedSyncMap[string, *http.Client]()
)

// Update updates the ECH config for given domain and server.
// this method is concurrent safe, only one update request will be sent, others get the cache.
// if isLockedUpdate is true, it will not try to acquire the lock.
func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate bool) ([]byte, error) {
func (c *ECHConfigCache) Update(domain string, server string, forceQuery bool, isLockedUpdate bool) ([]byte, error) {
if !isLockedUpdate {
c.UpdateLock.Lock()
defer c.UpdateLock.Unlock()
Expand All @@ -105,13 +114,23 @@ func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate boo
configRecord := c.configRecord.Load()
if configRecord.expire.After(time.Now()) {
errors.LogDebug(context.Background(), "Cache hit for domain after double check: ", domain)
return configRecord.config, nil
return configRecord.config, configRecord.err
}
// Query ECH config from DNS server
errors.LogDebug(context.Background(), "Trying to query ECH config for domain: ", domain, " with ECH server: ", server)
echConfig, ttl, err := dnsQuery(server, domain)
if err != nil {
return nil, err
if forceQuery {
return nil, err
} else {
configRecord = &echConfigRecord{
config: nil,
expire: time.Now().Add(10 * time.Minute),
err: err,
}
c.configRecord.Store(configRecord)
return echConfig, err
}
}
configRecord = &echConfigRecord{
config: echConfig,
Expand All @@ -123,30 +142,31 @@ func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate boo

// QueryRecord returns the ECH config for given domain.
// If the record is not in cache or expired, it will query the DNS server and update the cache.
func QueryRecord(domain string, server string) ([]byte, error) {
echConfigCache, ok := GlobalECHConfigCache.Load(domain)
func QueryRecord(domain string, server string, forceQuery bool) ([]byte, error) {
GlobalECHConfigCacheKey := domain + "|" + server
echConfigCache, ok := GlobalECHConfigCache.Load(GlobalECHConfigCacheKey)
if !ok {
echConfigCache = &ECHConfigCache{}
echConfigCache.configRecord.Store(&echConfigRecord{})
echConfigCache, _ = GlobalECHConfigCache.LoadOrStore(domain, echConfigCache)
echConfigCache, _ = GlobalECHConfigCache.LoadOrStore(GlobalECHConfigCacheKey, echConfigCache)
}
configRecord := echConfigCache.configRecord.Load()
if configRecord.expire.After(time.Now()) {
errors.LogDebug(context.Background(), "Cache hit for domain: ", domain)
return configRecord.config, nil
return configRecord.config, configRecord.err
}

// If expire is zero value, it means we are in initial state, wait for the query to finish
// otherwise return old value immediately and update in a goroutine
// but if the cache is too old, wait for update
if configRecord.expire == (time.Time{}) || configRecord.expire.Add(time.Hour*6).Before(time.Now()) {
return echConfigCache.Update(domain, server, false)
return echConfigCache.Update(domain, server, false, forceQuery)
} else {
// If someone already acquired the lock, it means it is updating, do not start another update goroutine
if echConfigCache.UpdateLock.TryLock() {
go func() {
defer echConfigCache.UpdateLock.Unlock()
echConfigCache.Update(domain, server, true)
echConfigCache.Update(domain, server, true, forceQuery)
}()
}
return configRecord.config, nil
Expand All @@ -165,7 +185,7 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) {
m.Id = 0
msg, err := m.Pack()
if err != nil {
return []byte{}, 0, err
return nil, 0, err
}
var client *http.Client
if client, _ = clientForECHDOH.Load(server); client == nil {
Expand Down Expand Up @@ -194,20 +214,20 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) {
}
req, err := http.NewRequest("POST", server, bytes.NewReader(msg))
if err != nil {
return []byte{}, 0, err
return nil, 0, err
}
req.Header.Set("Content-Type", "application/dns-message")
resp, err := client.Do(req)
if err != nil {
return []byte{}, 0, err
return nil, 0, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return []byte{}, 0, err
return nil, 0, err
}
if resp.StatusCode != http.StatusOK {
return []byte{}, 0, errors.New("query failed with response code:", resp.StatusCode)
return nil, 0, errors.New("query failed with response code:", resp.StatusCode)
}
dnsResolve = respBody
} else if strings.HasPrefix(server, "udp://") { // for classic udp dns server
Expand All @@ -231,24 +251,25 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) {
}
}()
if err != nil {
return []byte{}, 0, err
return nil, 0, err
}
msg, err := m.Pack()
if err != nil {
return []byte{}, 0, err
return nil, 0, err
}
conn.Write(msg)
udpResponse := make([]byte, 512)
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_, err = conn.Read(udpResponse)
if err != nil {
return []byte{}, 0, err
return nil, 0, err
}
dnsResolve = udpResponse
}
respMsg := new(dns.Msg)
err := respMsg.Unpack(dnsResolve)
if err != nil {
return []byte{}, 0, errors.New("failed to unpack dns response for ECH: ", err)
return nil, 0, errors.New("failed to unpack dns response for ECH: ", err)
}
if len(respMsg.Answer) > 0 {
for _, answer := range respMsg.Answer {
Expand All @@ -262,7 +283,7 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) {
}
}
}
return []byte{}, 0, errors.New("no ech record found")
return nil, 0, errors.New("no ech record found")
}

// reference github.com/OmarTariq612/goech
Expand Down
Loading
Loading