diff --git a/src/base64.cc b/src/base64.cc index f54ddaed..fe8fbbae 100644 --- a/src/base64.cc +++ b/src/base64.cc @@ -16,6 +16,8 @@ #include "base64.h" +#include + namespace base64 { namespace { @@ -27,21 +29,23 @@ constexpr const unsigned char base64url_chars[] = std::string Encode(const std::string &source) { std::string result; #if 0 - const std::size_t remainder = source.size() % 3; + const std::size_t remainder = (source.size() - 1) % 3; for (std::size_t i = 0; i < source.size(); i += 3) { + const bool one_more = i < source.size() - 3 || remainder > 0; + const bool two_more = i < source.size() - 3 || remainder > 1; const unsigned char c0 = source[i]; - const unsigned char c1 = remainder > 0 ? source[i + 1] : 0u; - const unsigned char c2 = remainder > 1 ? source[i + 2] : 0u; + const unsigned char c1 = one_more ? source[i + 1] : 0u; + const unsigned char c2 = two_more ? source[i + 2] : 0u; result.push_back(base64url_chars[0x3f & c0 >> 2]); - result.push_back(base64url_chars[0x3f & c0 << 4 | c1 >> 4]); - if (remainder > 0) { - result.push_back(base64url_chars[0x3f & c1 << 2 | c2 >> 6]); + result.push_back(base64url_chars[(0x3f & c0 << 4) | c1 >> 4]); + if (one_more) { + result.push_back(base64url_chars[(0x3f & c1 << 2) | c2 >> 6]); } - if (remainder > 1) { + if (two_more) { result.push_back(base64url_chars[0x3f & c2]); } } -#endif +#else unsigned int code_buffer = 0; int code_buffer_size = -6; for (unsigned char c : source) { @@ -57,30 +61,29 @@ std::string Encode(const std::string &source) { code_buffer_size += 8; result.push_back(base64url_chars[(code_buffer >> code_buffer_size) & 0x3f]); } +#endif // No padding needed. return result; } -#if 0 std::string Decode(const std::string &source) { std::string result; - std::vector T(256,-1); - for (int i=0; i<64; i++) T[base64url_chars[i]] = i; + std::vector T(256, -1); + for (int i = 0; i < 64; i++) T[base64url_chars[i]] = i; - int val=0, valb=-8; - for (uchar c : source) { + int val = 0, shift = -8; + for (char c : source) { if (T[c] == -1) break; - val = (val<<6) + T[c]; - valb += 6; - if (valb>=0) { - result.push_back(char((val>>valb)&0xFF)); - valb-=8; + val = (val << 6) + T[c]; + shift += 6; + if (shift >= 0) { + result.push_back(char((val >> shift) & 0xFF)); + shift -= 8; } } return result; } -#endif } diff --git a/test/Makefile b/test/Makefile index efb58a36..5b33fc2a 100644 --- a/test/Makefile +++ b/test/Makefile @@ -21,7 +21,8 @@ TEST_DIR=. TEST_SOURCES=$(wildcard $(TEST_DIR)/*_unittest.cc) TEST_OBJS=$(TEST_SOURCES:$(TEST_DIR)/%.cc=%.o) TESTS=\ - format_unittest + format_unittest \ + base64_unittest GTEST_LIB=gtest_lib.a @@ -63,5 +64,7 @@ $(GTEST_LIB): gtest-all.o gtest_main.o format_unittest: $(GTEST_LIB) format_unittest.o $(SRC_DIR)/format.o $(CXX) $(LDFLAGS) $^ $(LDLIBS) -o $@ +base64_unittest: $(GTEST_LIB) base64_unittest.o $(SRC_DIR)/base64.o + $(CXX) $(LDFLAGS) $^ $(LDLIBS) -o $@ .PHONY: all test clean purge diff --git a/test/base64_unittest.cc b/test/base64_unittest.cc new file mode 100644 index 00000000..57f331d2 --- /dev/null +++ b/test/base64_unittest.cc @@ -0,0 +1,61 @@ +#include "../src/base64.h" +#include "gtest/gtest.h" + +namespace { + +TEST(EncodeTest, EmptyEncode) { + EXPECT_EQ("", base64::Encode("")); +} + +TEST(EncodeTest, SimpleEncode) { + EXPECT_EQ("dGVz", base64::Encode("tes")); +} + +// Base64 encoders typically pad messages to ensure output length % 4 == 0. To +// achieve this, encoders will pad messages with either one or two "=". Our +// implementation does not do this. The following two tests ensure that +// base64::Encode does not append one or two "=". +TEST(EncodeTest, OnePhantom) { + EXPECT_EQ("dGVzdDA", base64::Encode("test0")); +} + +TEST(EncodeTest, TwoPhantom) { + EXPECT_EQ("dGVzdA", base64::Encode("test")); +} + +TEST(DecodeTest, EmptyDecode) { + EXPECT_EQ("", base64::Decode("")); +} + +TEST(DecodeTest, SimpleDecode) { + EXPECT_EQ("tes", base64::Decode("dGVz")); +} + +TEST(DecodeTest, OnePadding) { + EXPECT_EQ("test0", base64::Decode("dGVzdDA=")); +} + +TEST(DecodeTest, OnePhantom) { + EXPECT_EQ("test0", base64::Decode("dGVzdDA")); +} + +TEST(DecodeTest, TwoPadding) { + EXPECT_EQ("test", base64::Decode("dGVzdA==")); +} + +TEST(DecodeTest, TwoPhantom) { + EXPECT_EQ("test", base64::Decode("dGVzdA")); +} + +TEST(RoundTripTest, FullString) { + EXPECT_EQ("tes", base64::Decode(base64::Encode("tes"))); +} + +TEST(RoundTripTest, OnePhantom) { + EXPECT_EQ("test0", base64::Decode(base64::Encode("test0"))); +} + +TEST(RoundTripTest, TwoPhantoms) { + EXPECT_EQ("test", base64::Decode(base64::Encode("test"))); +} +} // namespace