diff --git a/lib/std/base64.zig b/lib/std/base64.zig index e88b72343984..99929fe6af1e 100644 --- a/lib/std/base64.zig +++ b/lib/std/base64.zig @@ -108,7 +108,7 @@ pub const Base64Encoder = struct { } } - // dest must be compatible with std.io.Writer's writeAll interface + /// `dest` must be compatible with `std.io.Writer`'s `writeAll` interface. pub fn encodeWriter(encoder: *const Base64Encoder, dest: anytype, source: []const u8) !void { var chunker = window(u8, source, 3, 3); while (chunker.next()) |chunk| { @@ -118,19 +118,19 @@ pub const Base64Encoder = struct { } } - // destWriter must be compatible with std.io.Writer's writeAll interface - // sourceReader must be compatible with std.io.Reader's read interface - pub fn encodeFromReaderToWriter(encoder: *const Base64Encoder, destWriter: anytype, sourceReader: anytype) !void { + /// `dest_writer` must be compatible with `std.io.Writer`'s `writeAll` interface. + /// `source_reader` must be compatible with `std.io.Reader`'s `read` interface. + pub fn encodeFromReaderToWriter(encoder: *const Base64Encoder, dest_writer: anytype, source_reader: anytype) !void { while (true) { - var tempSource: [3]u8 = undefined; - const bytesRead = try sourceReader.read(&tempSource); + var temp_source: [3]u8 = undefined; + const bytesRead = try source_reader.read(&temp_source); if (bytesRead == 0) { break; } var temp: [5]u8 = undefined; - const s = encoder.encode(&temp, tempSource[0..bytesRead]); - try destWriter.writeAll(s); + const s = encoder.encode(&temp, temp_source[0..bytesRead]); + try dest_writer.writeAll(s); } } @@ -310,6 +310,33 @@ pub const Base64Decoder = struct { if (padding_chars != padding_len) return error.InvalidPadding; } } + + /// `dest_writer` must be compatible with `std.io.Writer`'s `writeAll` interface. + pub fn decodeWriter(decoder: *const Base64Decoder, dest_writer: anytype, source: []const u8) !void { + var temp = [_]u8{0} ** 4; + var chunker = window(u8, source, 4, 4); + while (chunker.next()) |chunk| { + const size = try decoder.calcSizeForSlice(chunk); + try decoder.decode(&temp, chunk); + try dest_writer.writeAll(temp[0..size]); + } + } + + /// `dest_writer` must be compatible with `std.io.Writer`'s `writeAll` interface. + /// `source_reader` must be compatible with `std.io.Reader`'s `read` interface. + pub fn decodeFromReaderToWriter(decoder: *const Base64Decoder, dest_writer: anytype, source_reader: anytype) !void { + var temp = [_]u8{0} ** 3; + var temp_source = [_]u8{0} ** 4; + while (true) { + const bytesRead = try source_reader.read(&temp_source); + if (bytesRead == 0) { + break; + } + const size = try decoder.calcSizeForSlice(temp_source[0..bytesRead]); + try decoder.decode(&temp, temp_source[0..bytesRead]); + try dest_writer.writeAll(temp[0..size]); + } + } }; pub const Base64DecoderWithIgnore = struct { @@ -341,54 +368,82 @@ pub const Base64DecoderWithIgnore = struct { return result; } + fn WindowWithIgnore(comptime ReaderType: type) type { + return struct { + const Self = @This(); + const Err = ReaderType.NoEofError; + + reader: ReaderType, + decoder: *const Base64DecoderWithIgnore, + + pub fn init(reader: ReaderType, decoder: *const Base64DecoderWithIgnore) Self { + return .{ .reader = reader, .decoder = decoder }; + } + + pub fn next(self: *Self, buffer: []u8) Err![]u8 { + var size: usize = 0; + while (true) { + const byte = self.reader.readByte() catch |err| switch (err) { + Self.Err.EndOfStream => { + break; + }, + else => return err, + }; + + if (self.decoder.char_is_ignored[byte]) { + continue; + } + buffer[size] = byte; + size += 1; + if (size == 4) { + break; + } + } + if (size == 0) { + return Self.Err.EndOfStream; + } + return buffer[0..size]; + } + }; + } + /// Invalid characters that are not ignored result in error.InvalidCharacter. /// Invalid padding results in error.InvalidPadding. /// Decoding more data than can fit in dest results in error.NoSpaceLeft. See also ::calcSizeUpperBound. /// Returns the number of bytes written to dest. pub fn decode(decoder_with_ignore: *const Base64DecoderWithIgnore, dest: []u8, source: []const u8) Error!usize { - const decoder = &decoder_with_ignore.decoder; - var acc: u12 = 0; - var acc_len: u4 = 0; - var dest_idx: usize = 0; - var leftover_idx: ?usize = null; - for (source, 0..) |c, src_idx| { - if (decoder_with_ignore.char_is_ignored[c]) continue; - const d = decoder.char_to_index[c]; - if (d == Base64Decoder.invalid_char) { - if (decoder.pad_char == null or c != decoder.pad_char.?) return error.InvalidCharacter; - leftover_idx = src_idx; - break; - } - acc = (acc << 6) + d; - acc_len += 6; - if (acc_len >= 8) { - if (dest_idx == dest.len) return error.NoSpaceLeft; - acc_len -= 8; - dest[dest_idx] = @as(u8, @truncate(acc >> acc_len)); - dest_idx += 1; - } - } - if (acc_len > 4 or (acc & (@as(u12, 1) << acc_len) - 1) != 0) { - return error.InvalidPadding; - } - const padding_len = acc_len / 2; - if (leftover_idx == null) { - if (decoder.pad_char != null and padding_len != 0) return error.InvalidPadding; - return dest_idx; - } - const leftover = source[leftover_idx.?..]; - if (decoder.pad_char) |pad_char| { - var padding_chars: usize = 0; - for (leftover) |c| { - if (decoder_with_ignore.char_is_ignored[c]) continue; - if (c != pad_char) { - return if (c == Base64Decoder.invalid_char) error.InvalidCharacter else error.InvalidPadding; - } - padding_chars += 1; - } - if (padding_chars != padding_len) return error.InvalidPadding; + var sourceStream = std.io.fixedBufferStream(source); + const source_reader = sourceStream.reader(); + var dest_stream = std.io.fixedBufferStream(dest); + const DestStreamType = @TypeOf(dest_stream); + const dest_writer = dest_stream.writer(); + decoder_with_ignore.decodeFromReaderToWriter(dest_writer, source_reader) catch |err| switch (err) { + DestStreamType.WriteError.NoSpaceLeft => return error.NoSpaceLeft, + WindowWithIgnore(@TypeOf(source_reader)).Err.EndOfStream => unreachable, + error.InvalidCharacter, error.InvalidPadding => |e| return e, + }; + return dest_stream.pos; + } + + /// `dest_writer` must be compatible with `std.io.Writer`'s `writeAll` interface. + pub fn decodeWriter(decoder_with_ignore: *const Base64DecoderWithIgnore, dest_writer: anytype, source: []const u8) !void { + var stream = std.io.fixedBufferStream(source); + const reader = stream.reader(); + return decoder_with_ignore.decodeFromReaderToWriter(dest_writer, reader); + } + + /// `dest_writer` must be compatible with `std.io.Writer`'s `writeAll` interface. + /// `source_reader` must be compatible with `std.io.Reader`'s `readByte` interface. + pub fn decodeFromReaderToWriter(decoder_with_ignore: *const Base64DecoderWithIgnore, dest_writer: anytype, source_reader: anytype) !void { + var buffer = [_]u8{0} ** 4; + const WindowType = WindowWithIgnore(@TypeOf(source_reader)); + var chunker = WindowType.init(source_reader, decoder_with_ignore); + while (chunker.next(&buffer)) |chunk| { + try decoder_with_ignore.decoder.decodeWriter(dest_writer, chunk); + } else |err| switch (err) { + WindowType.Err.EndOfStream => return, + else => return err, } - return dest_idx; } }; @@ -532,20 +587,55 @@ fn testAllApis(codecs: Codecs, expected_decoded: []const u8, expected_encoded: [ // Base64Decoder { - var buffer: [0x100]u8 = undefined; - const decoded = buffer[0..try codecs.Decoder.calcSizeForSlice(expected_encoded)]; - try codecs.Decoder.decode(decoded, expected_encoded); - try testing.expectEqualSlices(u8, expected_decoded, decoded); + { + var buffer: [0x100]u8 = undefined; + const decoded = buffer[0..try codecs.Decoder.calcSizeForSlice(expected_encoded)]; + try codecs.Decoder.decode(decoded, expected_encoded); + try testing.expectEqualSlices(u8, expected_decoded, decoded); + } + + //stream version + { + var list = try std.BoundedArray(u8, 0x100).init(0); + try codecs.Decoder.decodeWriter(list.writer(), expected_encoded); + try testing.expectEqualSlices(u8, expected_decoded, list.slice()); + } + + // from reader to writer version + { + var list = try std.BoundedArray(u8, 0x100).init(0); + var stream = std.io.fixedBufferStream(expected_encoded); + try codecs.Decoder.decodeFromReaderToWriter(list.writer(), stream.reader()); + try testing.expectEqualSlices(u8, expected_decoded, list.slice()); + } } // Base64DecoderWithIgnore { const decoder_ignore_nothing = codecs.decoderWithIgnore(""); - var buffer: [0x100]u8 = undefined; - const decoded = buffer[0..try decoder_ignore_nothing.calcSizeUpperBound(expected_encoded.len)]; - const written = try decoder_ignore_nothing.decode(decoded, expected_encoded); - try testing.expect(written <= decoded.len); - try testing.expectEqualSlices(u8, expected_decoded, decoded[0..written]); + + { + var buffer: [0x100]u8 = undefined; + const decoded = buffer[0..try decoder_ignore_nothing.calcSizeUpperBound(expected_encoded.len)]; + const written = try decoder_ignore_nothing.decode(decoded, expected_encoded); + try testing.expect(written <= decoded.len); + try testing.expectEqualSlices(u8, expected_decoded, decoded[0..written]); + } + + //stream version + { + var list = try std.BoundedArray(u8, 0x100).init(0); + try decoder_ignore_nothing.decodeWriter(list.writer(), expected_encoded); + try testing.expectEqualSlices(u8, expected_decoded, list.slice()); + } + + // from reader to writer + { + var list = try std.BoundedArray(u8, 0x100).init(0); + var stream = std.io.fixedBufferStream(expected_encoded); + try decoder_ignore_nothing.decodeFromReaderToWriter(list.writer(), stream.reader()); + try testing.expectEqualSlices(u8, expected_decoded, list.slice()); + } } } @@ -555,6 +645,11 @@ fn testDecodeIgnoreSpace(codecs: Codecs, expected_decoded: []const u8, encoded: const decoded = buffer[0..try decoder_ignore_space.calcSizeUpperBound(encoded.len)]; const written = try decoder_ignore_space.decode(decoded, encoded); try testing.expectEqualSlices(u8, expected_decoded, decoded[0..written]); + + //stream version + var list = try std.BoundedArray(u8, 0x100).init(0); + try decoder_ignore_space.decodeWriter(list.writer(), encoded); + try testing.expectEqualSlices(u8, expected_decoded, list.slice()); } fn testError(codecs: Codecs, encoded: []const u8, expected_err: anyerror) !void {