diff --git a/source/common/common/base64.cc b/source/common/common/base64.cc index ac5ee3d6c0798..b0e7ebcaad062 100644 --- a/source/common/common/base64.cc +++ b/source/common/common/base64.cc @@ -68,6 +68,40 @@ std::string Base64::decode(const std::string& input) { return result; } +void Base64::encodeBase(const uint8_t cur_char, uint64_t pos, uint8_t& next_c, std::string& ret) { + switch (pos % 3) { + case 0: + ret.push_back(CHAR_TABLE[cur_char >> 2]); + next_c = (cur_char & 0x03) << 4; + break; + case 1: + ret.push_back(CHAR_TABLE[next_c | (cur_char >> 4)]); + next_c = (cur_char & 0x0f) << 2; + break; + case 2: + ret.push_back(CHAR_TABLE[next_c | (cur_char >> 6)]); + ret.push_back(CHAR_TABLE[cur_char & 0x3f]); + next_c = 0; + break; + } +} + +void Base64::encodeLast(uint64_t pos, uint8_t last_char, std::string& ret) { + switch (pos % 3) { + case 1: + ret.push_back(CHAR_TABLE[last_char]); + ret.push_back('='); + ret.push_back('='); + break; + case 2: + ret.push_back(CHAR_TABLE[last_char]); + ret.push_back('='); + break; + default: + break; + } +} + std::string Base64::encode(const Buffer::Instance& buffer, uint64_t length) { uint64_t output_length = (std::min(length, buffer.length()) + 2) / 3 * 4; std::string ret; @@ -83,22 +117,7 @@ std::string Base64::encode(const Buffer::Instance& buffer, uint64_t length) { const uint8_t* slice_mem = static_cast(slice.mem_); for (uint64_t i = 0; i < slice.len_ && j < length; ++i, ++j) { - const uint8_t c = slice_mem[i]; - switch (j % 3) { - case 0: - ret.push_back(CHAR_TABLE[c >> 2]); - next_c = (c & 0x03) << 4; - break; - case 1: - ret.push_back(CHAR_TABLE[next_c | (c >> 4)]); - next_c = (c & 0x0f) << 2; - break; - case 2: - ret.push_back(CHAR_TABLE[next_c | (c >> 6)]); - ret.push_back(CHAR_TABLE[c & 0x3f]); - next_c = 0; - break; - } + encodeBase(slice_mem[i], j, next_c, ret); } if (j == length) { @@ -106,19 +125,24 @@ std::string Base64::encode(const Buffer::Instance& buffer, uint64_t length) { } } - switch (j % 3) { - case 1: - ret.push_back(CHAR_TABLE[next_c]); - ret.push_back('='); - ret.push_back('='); - break; - case 2: - ret.push_back(CHAR_TABLE[next_c]); - ret.push_back('='); - break; - default: - break; + encodeLast(j, next_c, ret); + + return ret; +} + +std::string Base64::encode(const char* input, uint64_t length) { + uint64_t output_length = (length + 2) / 3 * 4; + std::string ret; + ret.reserve(output_length); + + uint64_t pos = 0; + uint8_t next_c = 0; + + for (uint64_t i = 0; i < length; ++i) { + encodeBase(input[i], pos++, next_c, ret); } + encodeLast(pos, next_c, ret); + return ret; } diff --git a/source/common/common/base64.h b/source/common/common/base64.h index 69047f1c8cdc8..19b2fe30ca657 100644 --- a/source/common/common/base64.h +++ b/source/common/common/base64.h @@ -11,6 +11,13 @@ class Base64 { */ static std::string encode(const Buffer::Instance& buffer, uint64_t length); + /** + * Base64 encode an input char buffer with a given length. + * @param input char array to encode. + * @param length of the input array. + */ + static std::string encode(const char* input, uint64_t length); + /** * Base64 decode an input string. * @param input supplies the input to decode. @@ -19,4 +26,16 @@ class Base64 { * bytes. */ static std::string decode(const std::string& input); + +private: + /** + * Helper method for encoding. This is used to encode all of the characters from the input string. + */ + static void encodeBase(const uint8_t cur_char, uint64_t pos, uint8_t& next_c, std::string& ret); + + /** + * Encode last characters. It appends '=' chars to the ret if input + * string length is not divisible by 3. + */ + static void encodeLast(uint64_t pos, uint8_t last_char, std::string& ret); }; diff --git a/test/common/common/base64_test.cc b/test/common/common/base64_test.cc index 1cebf7e0a5524..9b6d5e76a2730 100644 --- a/test/common/common/base64_test.cc +++ b/test/common/common/base64_test.cc @@ -22,6 +22,13 @@ TEST(Base64Test, SingleSliceBufferEncode) { EXPECT_EQ("Zm8=", Base64::encode(buffer, 2)); } +TEST(Base64Test, EncodeString) { + EXPECT_EQ("", Base64::encode("", 0)); + EXPECT_EQ("AAA=", Base64::encode("\0\0", 2)); + EXPECT_EQ("Zm9v", Base64::encode("foo", 3)); + EXPECT_EQ("Zm8=", Base64::encode("fo", 2)); +} + TEST(Base64Test, Decode) { EXPECT_EQ("", Base64::decode("")); EXPECT_EQ("foo", Base64::decode("Zm9v")); @@ -42,12 +49,23 @@ TEST(Base64Test, Decode) { EXPECT_FALSE(memcmp(test_string, Base64::decode(Base64::encode(buffer, 36)).data(), 36)); } + { + const char* test_string = "\0\0\0\0als;jkopqitu[\0opbjlcxnb35g]b[\xaa\b\n"; + EXPECT_FALSE(memcmp(test_string, Base64::decode(Base64::encode(test_string, 36)).data(), 36)); + } + { std::string test_string = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; std::string decoded = Base64::decode(test_string); Buffer::OwnedImpl buffer(decoded); EXPECT_EQ(test_string, Base64::encode(buffer, decoded.length())); } + + { + const char* test_string = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + std::string decoded = Base64::decode(test_string); + EXPECT_EQ(test_string, Base64::encode(decoded.c_str(), decoded.length())); + } } TEST(Base64Test, MultiSlicesBufferEncode) {