diff --git a/src/base32.zig b/src/base32.zig index 99c0cfb..3a6fd12 100644 --- a/src/base32.zig +++ b/src/base32.zig @@ -3,11 +3,11 @@ const expect = std.testing.expect; const expectEqualStrings = std.testing.expectEqualStrings; const Allocator = std.mem.Allocator; -const Base32Error = error{InvalidLength}; +const Base32Error = error{ InvalidLength, InvalidCharacter, InvalidPadding }; const Base32 = struct { /// - pub fn decodeU8(allocator: Allocator, data: []const u8) ![]const u8 { + pub fn decodeU8_old(allocator: Allocator, data: []const u8) ![]const u8 { const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"; if (data.len % 8 != 0) { @@ -52,14 +52,114 @@ const Base32 = struct { return result; } + + pub fn decodeU8(allocator: Allocator, data: []const u8) ![]const u8 { + const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"; + const expectedLength = data.len / 8 * 5 + (data.len % 8) * 5 / 8; + const result = try allocator.alloc(u8, expectedLength); + + var dest_index: usize = 0; + var bytes: u40 = 0; + var bits_read: u4 = 0; + var padding_index: ?usize = null; + + for (data, 0..) |c, src_index| { + const bits = std.mem.indexOfScalar(u8, alphabet, c); + if (bits == null) { + if (c != '=') { + return Base32Error.InvalidCharacter; + } else { + padding_index = src_index; + break; + } + } + + const bitsAsU8: u8 = if (bits != null) @truncate(bits.?) else 0; + bytes = bytes << 5 | bitsAsU8; + + bits_read += 5; + + if (bits_read >= 8) { + bits_read -= 8; + result[dest_index] = @as(u8, @truncate(bytes >> bits_read)); + dest_index += 1; + } + } + + // The input has too many bytes. Each input byte encodes 5 bits. If there are 5 or more + // trailing bits, then the input could have been one byte shorter. + if (bits_read >= 5) { + return Base32Error.InvalidPadding; + } + + // The trailing bits must all be 0. + if ((bytes & (@as(u16, 1) << bits_read) - 1) != 0) { + return Base32Error.InvalidPadding; + } + + // ensure there are the correct number of '=' at the end + // (padding is optional) + if (padding_index) |p_index| { + var count: usize = 0; + const padding = data[p_index..]; + for (padding) |c| { + if (c != '=') { + return Base32Error.InvalidPadding; + } + count += 1; + } + if (count != 8 - (p_index % 8)) { + return Base32Error.InvalidPadding; + } + } + return result[0..dest_index]; + } }; -test "base32 decode base32('abcde') " { +test "base32 decode base32('abcde', 'MFRGGZDF') " { try testDecode("abcde", "MFRGGZDF"); } -test "base32 decode base32('aaaaa') " { +test "base32 decode base32('aaaaa', 'MFQWCYLB') " { try testDecode("aaaaa", "MFQWCYLB"); } +test "base32 decode base32('aaaaabbbbb', 'MFQWCYLBMJRGEYTC') " { + try testDecode("aaaaabbbbb", "MFQWCYLBMJRGEYTC"); +} +test "base32 decode base32('aaaa', 'MFQWCYI') " { + try testDecode("aaaa", "MFQWCYI"); +} +test "base32 decode base32('aaaa', 'MFQWCYI=') " { + try testDecode("aaaa", "MFQWCYI="); +} + +test "base32 decode base32('a', 'ME') " { + try testDecode("a", "ME"); +} +test "base32 decode base32('a', 'ME======') " { + try testDecode("a", "ME======"); +} +test "base32 decode base32('aa', 'MFQQ') " { + try testDecode("aa", "MFQQ"); +} +test "base32 decode base32('aa', 'MFQQ====') " { + try testDecode("aa", "MFQQ===="); +} + +test "invalid base32 decode base32('MF') - trailing bits" { + try testError(Base32Error.InvalidPadding, "MF"); +} +test "invalid base32 decode base32('ME==') - incorrect number or '='" { + try testError(Base32Error.InvalidPadding, "MF=="); +} +test "invalid base32 decode base32('ME=======') - incorrect number or '='" { + try testError(Base32Error.InvalidPadding, "MF========="); +} +test "invalid base32 decode base32('MFQWC=A=') - character after '='" { + try testError(Base32Error.InvalidPadding, "MFQWC=A="); +} +test "invalid base32 decode base32('Me') - invalid character" { + try testError(Base32Error.InvalidCharacter, "Me"); +} fn testDecode(expected_decoded: []const u8, encoded: []const u8) !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; @@ -69,16 +169,10 @@ fn testDecode(expected_decoded: []const u8, encoded: []const u8) !void { try expectEqualStrings(expected_decoded, decoded); } -test "base32 decode - invalid length" { +fn testError(expected_error: anyerror, encoded: []const u8) !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; const allocator = gpa.allocator(); - try std.testing.expect(Base32.decodeU8(allocator, "1") == error.InvalidLength); - try std.testing.expect(Base32.decodeU8(allocator, "12") == error.InvalidLength); - try std.testing.expect(Base32.decodeU8(allocator, "123") == error.InvalidLength); - try std.testing.expect(Base32.decodeU8(allocator, "1234") == error.InvalidLength); - try std.testing.expect(Base32.decodeU8(allocator, "12345") == error.InvalidLength); - try std.testing.expect(Base32.decodeU8(allocator, "123456") == error.InvalidLength); - try std.testing.expect(Base32.decodeU8(allocator, "1234567") == error.InvalidLength); - try std.testing.expect(Base32.decodeU8(allocator, "123456789") == error.InvalidLength); + const decoded = Base32.decodeU8(allocator, encoded); + try std.testing.expectError(expected_error, decoded); }