Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions app/dns/dnscommon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,31 @@ func Test_parseResponse(t *testing.T) {

ans := new(dns.Msg)
ans.Id = 0
p = append(p, common.Must2(ans.Pack()).([]byte))
p = append(p, common.Must2(ans.Pack()))

p = append(p, []byte{})

ans = new(dns.Msg)
ans.Id = 1
ans.Answer = append(ans.Answer,
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN A 8.8.8.8")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN A 8.8.4.4")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")),
common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")),
common.Must2(dns.NewRR("google.com. IN A 8.8.8.8")),
common.Must2(dns.NewRR("google.com. IN A 8.8.4.4")),
)
p = append(p, common.Must2(ans.Pack()).([]byte))
p = append(p, common.Must2(ans.Pack()))

ans = new(dns.Msg)
ans.Id = 2
ans.Answer = append(ans.Answer,
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN CNAME test.google.com")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8888")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8844")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")),
common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")),
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")),
common.Must2(dns.NewRR("google.com. IN CNAME test.google.com")),
common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8888")),
common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8844")),
)
p = append(p, common.Must2(ans.Pack()).([]byte))
p = append(p, common.Must2(ans.Pack()))

tests := []struct {
name string
Expand Down
4 changes: 3 additions & 1 deletion common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ func Must(err error) {
}

// Must2 panics if the second parameter is not nil, otherwise returns the first parameter.
func Must2(v interface{}, err error) interface{} {
// This is useful when function returned "sth, err" and avoid many "if err != nil"
// Internal usage only, if user input can cause err, it must be handled
func Must2[T any](v T, err error) T {
Must(err)
return v
}
Expand Down
6 changes: 2 additions & 4 deletions common/crypto/aes.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ func NewAesCTRStream(key []byte, iv []byte) cipher.Stream {

// NewAesGcm creates a AEAD cipher based on AES-GCM.
func NewAesGcm(key []byte) cipher.AEAD {
block, err := aes.NewCipher(key)
common.Must(err)
aead, err := cipher.NewGCM(block)
common.Must(err)
block := common.Must2(aes.NewCipher(key))
aead := common.Must2(cipher.NewGCM(block))
return aead
}
14 changes: 3 additions & 11 deletions common/crypto/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package crypto_test

import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"io"
"testing"
Expand All @@ -18,11 +16,8 @@ import (
func TestAuthenticationReaderWriter(t *testing.T) {
key := make([]byte, 16)
rand.Read(key)
block, err := aes.NewCipher(key)
common.Must(err)

aead, err := cipher.NewGCM(block)
common.Must(err)
aead := NewAesGcm(key)

const payloadSize = 1024 * 80
rawPayload := make([]byte, payloadSize)
Expand Down Expand Up @@ -71,7 +66,7 @@ func TestAuthenticationReaderWriter(t *testing.T) {
t.Error(r)
}

_, err = reader.ReadMultiBuffer()
_, err := reader.ReadMultiBuffer()
if err != io.EOF {
t.Error("error: ", err)
}
Expand All @@ -80,11 +75,8 @@ func TestAuthenticationReaderWriter(t *testing.T) {
func TestAuthenticationReaderWriterPacket(t *testing.T) {
key := make([]byte, 16)
common.Must2(rand.Read(key))
block, err := aes.NewCipher(key)
common.Must(err)

aead, err := cipher.NewGCM(block)
common.Must(err)
aead := NewAesGcm(key)

cache := buf.New()
iv := make([]byte, 12)
Expand Down
4 changes: 2 additions & 2 deletions proxy/dokodemo/dokodemo.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
}
}
if dest.Port == 0 {
dest.Port = net.Port(common.Must2(strconv.Atoi(port)).(int))
dest.Port = net.Port(common.Must2(strconv.Atoi(port)))
}
if d.portMap != nil && d.portMap[port] != "" {
h, p, _ := net.SplitHostPort(d.portMap[port])
if len(h) > 0 {
dest.Address = net.ParseAddress(h)
}
if len(p) > 0 {
dest.Port = net.Port(common.Must2(strconv.Atoi(p)).(int))
dest.Port = net.Port(common.Must2(strconv.Atoi(p)))
}
}
}
Expand Down
7 changes: 1 addition & 6 deletions proxy/shadowsocks/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package shadowsocks

import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"crypto/sha1"
Expand Down Expand Up @@ -58,11 +57,7 @@ func (a *MemoryAccount) CheckIV(iv []byte) error {
}

func createAesGcm(key []byte) cipher.AEAD {
block, err := aes.NewCipher(key)
common.Must(err)
gcm, err := cipher.NewGCM(block)
common.Must(err)
return gcm
return crypto.NewAesGcm(key)
}

func createChaCha20Poly1305(key []byte) cipher.AEAD {
Expand Down
43 changes: 5 additions & 38 deletions proxy/vmess/aead/encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@ package aead

import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"io"
"time"

"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/crypto"
)

func SealVMessAEADHeader(key [16]byte, data []byte) []byte {
Expand All @@ -34,15 +33,7 @@ func SealVMessAEADHeader(key [16]byte, data []byte) []byte {

payloadHeaderLengthAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadLengthAEADIV, string(generatedAuthID[:]), string(connectionNonce))[:12]

payloadHeaderLengthAEADAESBlock, err := aes.NewCipher(payloadHeaderLengthAEADKey)
if err != nil {
panic(err.Error())
}

payloadHeaderAEAD, err := cipher.NewGCM(payloadHeaderLengthAEADAESBlock)
if err != nil {
panic(err.Error())
}
payloadHeaderAEAD := crypto.NewAesGcm(payloadHeaderLengthAEADKey)

payloadHeaderLengthAEADEncrypted = payloadHeaderAEAD.Seal(nil, payloadHeaderLengthAEADNonce, aeadPayloadLengthSerializedByte, generatedAuthID[:])
}
Expand All @@ -54,15 +45,7 @@ func SealVMessAEADHeader(key [16]byte, data []byte) []byte {

payloadHeaderAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadAEADIV, string(generatedAuthID[:]), string(connectionNonce))[:12]

payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderAEADKey)
if err != nil {
panic(err.Error())
}

payloadHeaderAEAD, err := cipher.NewGCM(payloadHeaderAEADAESBlock)
if err != nil {
panic(err.Error())
}
payloadHeaderAEAD := crypto.NewAesGcm(payloadHeaderAEADKey)

payloadHeaderAEADEncrypted = payloadHeaderAEAD.Seal(nil, payloadHeaderAEADNonce, data, generatedAuthID[:])
}
Expand Down Expand Up @@ -104,15 +87,7 @@ func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte,

payloadHeaderLengthAEADNonce := KDF(key[:], KDFSaltConstVMessHeaderPayloadLengthAEADIV, string(authid[:]), string(nonce[:]))[:12]

payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderLengthAEADKey)
if err != nil {
panic(err.Error())
}

payloadHeaderLengthAEAD, err := cipher.NewGCM(payloadHeaderAEADAESBlock)
if err != nil {
panic(err.Error())
}
payloadHeaderLengthAEAD := crypto.NewAesGcm(payloadHeaderLengthAEADKey)

decryptedAEADHeaderLengthPayload, erropenAEAD := payloadHeaderLengthAEAD.Open(nil, payloadHeaderLengthAEADNonce, payloadHeaderLengthAEADEncrypted[:], authid[:])

Expand Down Expand Up @@ -145,15 +120,7 @@ func OpenVMessAEADHeader(key [16]byte, authid [16]byte, data io.Reader) ([]byte,
return nil, false, bytesRead, err
}

payloadHeaderAEADAESBlock, err := aes.NewCipher(payloadHeaderAEADKey)
if err != nil {
panic(err.Error())
}

payloadHeaderAEAD, err := cipher.NewGCM(payloadHeaderAEADAESBlock)
if err != nil {
panic(err.Error())
}
payloadHeaderAEAD := crypto.NewAesGcm(payloadHeaderAEADKey)

decryptedAEADHeaderPayload, erropenAEAD := payloadHeaderAEAD.Open(nil, payloadHeaderAEADNonce, payloadHeaderAEADEncrypted, authid[:])

Expand Down
8 changes: 2 additions & 6 deletions proxy/vmess/encoding/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package encoding
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
Expand Down Expand Up @@ -182,8 +180,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderLenKey)
aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderLenIV)[:12]

aeadResponseHeaderLengthEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderLengthEncryptionKey)).(cipher.Block)
aeadResponseHeaderLengthEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderLengthEncryptionKeyAESBlock)).(cipher.AEAD)
aeadResponseHeaderLengthEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderLengthEncryptionKey)

var aeadEncryptedResponseHeaderLength [18]byte
var decryptedResponseHeaderLength int
Expand All @@ -205,8 +202,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(c.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadKey)
aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(c.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadIV)[:12]

aeadResponseHeaderPayloadEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderPayloadEncryptionKey)).(cipher.Block)
aeadResponseHeaderPayloadEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderPayloadEncryptionKeyAESBlock)).(cipher.AEAD)
aeadResponseHeaderPayloadEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderPayloadEncryptionKey)

encryptedResponseHeaderBuffer := make([]byte, decryptedResponseHeaderLength+16)

Expand Down
8 changes: 2 additions & 6 deletions proxy/vmess/encoding/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package encoding

import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/sha256"
"encoding/binary"
"hash/fnv"
Expand Down Expand Up @@ -350,8 +348,7 @@ func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr
aeadResponseHeaderLengthEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderLenKey)
aeadResponseHeaderLengthEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderLenIV)[:12]

aeadResponseHeaderLengthEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderLengthEncryptionKey)).(cipher.Block)
aeadResponseHeaderLengthEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderLengthEncryptionKeyAESBlock)).(cipher.AEAD)
aeadResponseHeaderLengthEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderLengthEncryptionKey)

aeadResponseHeaderLengthEncryptionBuffer := bytes.NewBuffer(nil)

Expand All @@ -365,8 +362,7 @@ func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr
aeadResponseHeaderPayloadEncryptionKey := vmessaead.KDF16(s.responseBodyKey[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadKey)
aeadResponseHeaderPayloadEncryptionIV := vmessaead.KDF(s.responseBodyIV[:], vmessaead.KDFSaltConstAEADRespHeaderPayloadIV)[:12]

aeadResponseHeaderPayloadEncryptionKeyAESBlock := common.Must2(aes.NewCipher(aeadResponseHeaderPayloadEncryptionKey)).(cipher.Block)
aeadResponseHeaderPayloadEncryptionAEAD := common.Must2(cipher.NewGCM(aeadResponseHeaderPayloadEncryptionKeyAESBlock)).(cipher.AEAD)
aeadResponseHeaderPayloadEncryptionAEAD := crypto.NewAesGcm(aeadResponseHeaderPayloadEncryptionKey)

aeadEncryptedHeaderPayload := aeadResponseHeaderPayloadEncryptionAEAD.Seal(nil, aeadResponseHeaderPayloadEncryptionIV, aeadEncryptedHeaderBuffer.Bytes(), nil)
common.Must2(io.Copy(writer, bytes.NewReader(aeadEncryptedHeaderPayload)))
Expand Down
6 changes: 2 additions & 4 deletions transport/internet/kcp/cryptreal.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package kcp

import (
"crypto/aes"
"crypto/cipher"
"crypto/sha256"

"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/crypto"
)

func NewAEADAESGCMBasedOnSeed(seed string) cipher.AEAD {
hashedSeed := sha256.Sum256([]byte(seed))
aesBlock := common.Must2(aes.NewCipher(hashedSeed[:16])).(cipher.Block)
return common.Must2(cipher.NewGCM(aesBlock)).(cipher.AEAD)
return crypto.NewAesGcm(hashedSeed[:])
}
5 changes: 1 addition & 4 deletions transport/internet/reality/reality.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package reality
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/ecdh"
"crypto/ed25519"
"crypto/hmac"
Expand Down Expand Up @@ -169,8 +167,7 @@ func UClient(c net.Conn, config *Config, ctx context.Context, dest net.Destinati
if _, err := hkdf.New(sha256.New, uConn.AuthKey, hello.Random[:20], []byte("REALITY")).Read(uConn.AuthKey); err != nil {
return nil, err
}
block, _ := aes.NewCipher(uConn.AuthKey)
aead, _ := cipher.NewGCM(block)
aead := crypto.NewAesGcm(uConn.AuthKey)
if config.Show {
fmt.Printf("REALITY localAddr: %v\tuConn.AuthKey[:16]: %v\tAEAD: %T\n", localAddr, uConn.AuthKey[:16], aead)
}
Expand Down
2 changes: 1 addition & 1 deletion transport/internet/splithttp/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
if transportConfiguration.DownloadSettings != nil {
globalDialerAccess.Lock()
if streamSettings.DownloadSettings == nil {
streamSettings.DownloadSettings = common.Must2(internet.ToMemoryStreamConfig(transportConfiguration.DownloadSettings)).(*internet.MemoryStreamConfig)
streamSettings.DownloadSettings = common.Must2(internet.ToMemoryStreamConfig(transportConfiguration.DownloadSettings))
if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.Penetrate {
streamSettings.DownloadSettings.SocketSettings = streamSettings.SocketSettings
}
Expand Down
Loading