diff --git a/.gitignore b/.gitignore index 2b0c6e4..a80978c 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,8 @@ # Output of the go coverage tool, specifically when used with LiteIDE *.out +*.pprof + # Dependency directories (remove the comment below to include it) # vendor/ diff --git a/README.md b/README.md index a906690..9446276 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,8 @@ $ basic_hmac_auth -h Usage of /usr/local/bin/basic_hmac_auth: -buffer-size int initial buffer size for stream parsing + -cpu-profile string + write CPU profile to file -secret string hex-encoded HMAC secret value -secret-file string diff --git a/cmd/main.go b/cmd/main.go index f41553d..b57486f 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -9,6 +9,7 @@ import ( "io" "log" "os" + "runtime/pprof" "github.com/SenseUnit/basic_hmac_auth/handler" ) @@ -24,6 +25,7 @@ var ( hexSecret = flag.String("secret", "", "hex-encoded HMAC secret value") hexSecretFile = flag.String("secret-file", "", "file containing single line with hex-encoded secret") showVersion = flag.Bool("version", false, "show program version and exit") + cpuProfile = flag.String("cpu-profile", "", "write CPU profile to file") ) func run() int { @@ -65,6 +67,16 @@ func run() int { return 3 } + if *cpuProfile != "" { + f, err := os.Create(*cpuProfile) + if err != nil { + log.Fatal(err) + } + defer f.Close() + pprof.StartCPUProfile(f) + defer pprof.StopCPUProfile() + } + err = (&handler.BasicHMACAuthHandler{ Secret: secret, BufferSize: *bufferSize, diff --git a/handler/handler.go b/handler/handler.go index 2901e05..d24ab91 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -28,20 +28,36 @@ func (a *BasicHMACAuthHandler) Run(input io.Reader, output io.Writer) error { rd := bufio.NewReaderSize(input, bufSize) scanner := proto.NewElasticLineScanner(rd, '\n') + verifier := hmac.NewVerifier(a.Secret) + + emitter := proto.NewResponseEmitter(output) + for scanner.Scan() { - parts := bytes.SplitN(scanner.Bytes(), []byte{' '}, 4) - if len(parts) < 3 { - err := fmt.Errorf("bad request line sent to auth helper: %q", string(scanner.Bytes())) - return err + line := scanner.Bytes() + + before, after, found := bytes.Cut(line, []byte{' '}) + if !found { + return fmt.Errorf("bad request line sent to auth helper: %q", line) } - channelID := parts[0] - username := proto.RFC1738Unescape(parts[1]) - password := proto.RFC1738Unescape(parts[2]) + channelID := before + + before, after, found = bytes.Cut(after, []byte{' '}) + if !found { + return fmt.Errorf("bad request line sent to auth helper: %q", line) + } + username := proto.RFC1738Unescape(before) + + before, _, _ = bytes.Cut(after, []byte{' '}) + password := proto.RFC1738Unescape(before) - if hmac.VerifyHMACLoginAndPassword(a.Secret, username, password) { - fmt.Fprintf(output, "%s OK\n", channelID) + if verifier.VerifyLoginAndPassword(username, password) { + if err := emitter.EmitOK(channelID); err != nil { + return fmt.Errorf("response write failed: %w", err) + } } else { - fmt.Fprintf(output, "%s ERR\n", channelID) + if err := emitter.EmitERR(channelID); err != nil { + return fmt.Errorf("response write failed: %w", err) + } } } diff --git a/hmac/hmac.go b/hmac/hmac.go index b8dfbbe..3e96a8b 100644 --- a/hmac/hmac.go +++ b/hmac/hmac.go @@ -1,46 +1,79 @@ package hmac import ( - "bytes" "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/binary" + "hash" "time" ) const ( HMACSignaturePrefix = "dumbproxy grant token v1" - HMACSignatureSize = 32 + HMACExpireSize = 8 + passwordBufferSize = HMACExpireSize + 64 // for worst case if 512-bit hash is used for some reason ) var hmacSignaturePrefix = []byte(HMACSignaturePrefix) -type HMACToken struct { - Expire int64 - Signature [HMACSignatureSize]byte +func NewHasher(secret []byte) hash.Hash { + return hmac.New(sha256.New, secret) } -func VerifyHMACLoginAndPassword(secret, login, password []byte) bool { - rd := base64.NewDecoder(base64.RawURLEncoding, bytes.NewReader(password)) +type Verifier struct { + mac hash.Hash + buf []byte +} + +func NewVerifier(secret []byte) *Verifier { + return &Verifier{ + mac: hmac.New(sha256.New, secret), + } +} - var token HMACToken - if err := binary.Read(rd, binary.BigEndian, &token); err != nil { +func (v *Verifier) ensureBufferSize(size int) { + if len(v.buf) < size { + v.buf = make([]byte, size) + } +} + +func (v *Verifier) VerifyLoginAndPassword(login, password []byte) bool { + v.ensureBufferSize(base64.RawURLEncoding.DecodedLen(len(password))) + buf := v.buf + n, err := base64.RawURLEncoding.Decode(buf, password) + if err != nil { return false } + buf = buf[:n] - if time.Unix(token.Expire, 0).Before(time.Now()) { + var expire int64 + if len(buf) < HMACExpireSize { return false } + expire = int64(binary.BigEndian.Uint64(buf[:HMACExpireSize])) + buf = buf[HMACExpireSize:] - expectedMAC := CalculateHMACSignature(secret, login, token.Expire) - return hmac.Equal(token.Signature[:], expectedMAC) + if time.Unix(expire, 0).Before(time.Now()) { + return false + } + + if len(buf) < v.mac.Size() { + return false + } + + expectedMAC := v.calculateHMACSignature(login, expire) + return hmac.Equal(buf[:v.mac.Size()], expectedMAC) } -func CalculateHMACSignature(secret, username []byte, expire int64) []byte { - mac := hmac.New(sha256.New, secret) - mac.Write(hmacSignaturePrefix) - mac.Write(username) - binary.Write(mac, binary.BigEndian, expire) - return mac.Sum(nil) +func (v *Verifier) calculateHMACSignature(username []byte, expire int64) []byte { + var buf [HMACExpireSize]byte + binary.BigEndian.PutUint64(buf[:], uint64(expire)) + + v.mac.Reset() + v.mac.Write(hmacSignaturePrefix) + v.mac.Write(username) + v.mac.Write(buf[:]) + + return v.mac.Sum(nil) } diff --git a/proto/emit.go b/proto/emit.go new file mode 100644 index 0000000..2c2df08 --- /dev/null +++ b/proto/emit.go @@ -0,0 +1,46 @@ +package proto + +import ( + "bytes" + "io" +) + +const ( + OK = "OK" + ERR = "ERR" +) + +type ResponseEmitter struct { + writer io.Writer + buffer bytes.Buffer +} + +func NewResponseEmitter(writer io.Writer) *ResponseEmitter { + return &ResponseEmitter{ + writer: writer, + } +} + +func (e *ResponseEmitter) EmitOK(channelID []byte) error { + e.beginResponse(channelID) + e.buffer.WriteString(OK) + return e.finishResponse() +} + +func (e *ResponseEmitter) EmitERR(channelID []byte) error { + e.beginResponse(channelID) + e.buffer.WriteString(ERR) + return e.finishResponse() +} + +func (e *ResponseEmitter) beginResponse(channelID []byte) { + e.buffer.Reset() + e.buffer.Write(channelID) + e.buffer.WriteByte(' ') +} + +func (e *ResponseEmitter) finishResponse() error { + e.buffer.WriteByte('\n') + _, err := e.buffer.WriteTo(e.writer) + return err +} diff --git a/proto/scanner.go b/proto/scanner.go index 1545c78..e97e994 100644 --- a/proto/scanner.go +++ b/proto/scanner.go @@ -1,20 +1,23 @@ package proto -import "io" +import ( + "bufio" + "io" +) -type BytesReader interface { - ReadBytes(byte) ([]byte, error) +type ReadSlicer interface { + ReadSlice(byte) ([]byte, error) } type ElasticLineScanner struct { line []byte - reader BytesReader + reader ReadSlicer lastErr error done bool delim byte } -func NewElasticLineScanner(reader BytesReader, delim byte) *ElasticLineScanner { +func NewElasticLineScanner(reader ReadSlicer, delim byte) *ElasticLineScanner { return &ElasticLineScanner{ reader: reader, delim: delim, @@ -37,19 +40,28 @@ func (els *ElasticLineScanner) Scan() bool { return false } - data, err := els.reader.ReadBytes(els.delim) + els.line = els.line[:0] + var ( + data []byte + err error + ) + for data, err = els.reader.ReadSlice(els.delim); ; data, err = els.reader.ReadSlice(els.delim) { + els.line = append(els.line, data...) + if err != bufio.ErrBufferFull { + break + } + } if err != nil { els.done = true els.lastErr = err - if len(data) == 0 { + if len(els.line) == 0 { return false } } else { // strip delimiter if needed - if len(data) > 0 && data[len(data)-1] == els.delim { - data = data[:len(data)-1] + if len(els.line) > 0 && els.line[len(els.line)-1] == els.delim { + els.line = els.line[:len(els.line)-1] } } - els.line = data return true }