From df0b7a4279f603737afff171f3128612e2c05025 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milan=20=C5=A0pinka?= Date: Sun, 26 Jan 2025 15:17:42 +0100 Subject: [PATCH] Implement rest of AES operations. --- src/primitive/blockcipher/aes.zig | 211 ++++++++++++++++++++++++++++-- 1 file changed, 203 insertions(+), 8 deletions(-) diff --git a/src/primitive/blockcipher/aes.zig b/src/primitive/blockcipher/aes.zig index bf0c618..421fd57 100644 --- a/src/primitive/blockcipher/aes.zig +++ b/src/primitive/blockcipher/aes.zig @@ -220,13 +220,50 @@ fn aes_sub_bytes(state: *[AES_BLOCK_SIZE]u8) void { } fn aes_shift_rows(state: *[AES_BLOCK_SIZE]u8) void { - _ = state; - // TODO + var tmp: u8 = undefined; + + // Note: Since we store the state matrix as an array of columns, + // we're technically shifting columns instead of rows. + + // Row 1 is shifted left by 1 position. + tmp = state[0 * 4 + 1]; + state[0 * 4 + 1] = state[1 * 4 + 1]; + state[1 * 4 + 1] = state[2 * 4 + 1]; + state[2 * 4 + 1] = state[3 * 4 + 1]; + state[3 * 4 + 1] = tmp; + + // Row 2 is shifted left by 2 positions. + tmp = state[0 * 4 + 2]; + state[0 * 4 + 2] = state[2 * 4 + 2]; + state[2 * 4 + 2] = tmp; + tmp = state[1 * 4 + 2]; + state[1 * 4 + 2] = state[3 * 4 + 2]; + state[3 * 4 + 2] = tmp; + + // Row 3 is shifted left by 3 positions. + tmp = state[0 * 4 + 3]; + state[0 * 4 + 3] = state[3 * 4 + 3]; + state[3 * 4 + 3] = state[2 * 4 + 3]; + state[2 * 4 + 3] = state[1 * 4 + 3]; + state[1 * 4 + 3] = tmp; +} + +fn aes_mix_one_column(column: *[4]u8) void { + const c0 = column[0]; + const c1 = column[1]; + const c2 = column[2]; + const c3 = column[3]; + + column[0] = xtime(c0) ^ (xtime(c1) ^ c1) ^ c2 ^ c3; + column[1] = c0 ^ xtime(c1) ^ (xtime(c2) ^ c2) ^ c3; + column[2] = c0 ^ c1 ^ xtime(c2) ^ (xtime(c3) ^ c3); + column[3] = (xtime(c0) ^ c0) ^ c1 ^ c2 ^ xtime(c3); } fn aes_mix_columns(state: *[AES_BLOCK_SIZE]u8) void { - _ = state; - // TODO + for (0..4) |i| { + aes_mix_one_column(@ptrCast(state[(4 * i)..(4 * i + 4)])); + } } fn aes_inv_sub_bytes(state: *[AES_BLOCK_SIZE]u8) void { @@ -236,13 +273,50 @@ fn aes_inv_sub_bytes(state: *[AES_BLOCK_SIZE]u8) void { } fn aes_inv_shift_rows(state: *[AES_BLOCK_SIZE]u8) void { - _ = state; - // TODO + var tmp: u8 = undefined; + + // Note: Since we store the state matrix as an array of columns, + // we're technically shifting columns instead of rows. + + // Row 1 is shifted right by 1 position. + tmp = state[3 * 4 + 1]; + state[3 * 4 + 1] = state[2 * 4 + 1]; + state[2 * 4 + 1] = state[1 * 4 + 1]; + state[1 * 4 + 1] = state[0 * 4 + 1]; + state[0 * 4 + 1] = tmp; + + // Row 2 is shifted right by 2 positions. + tmp = state[2 * 4 + 2]; + state[2 * 4 + 2] = state[0 * 4 + 2]; + state[0 * 4 + 2] = tmp; + tmp = state[3 * 4 + 2]; + state[3 * 4 + 2] = state[1 * 4 + 2]; + state[1 * 4 + 2] = tmp; + + // Row 3 is shifted right by 3 positions. + tmp = state[3 * 4 + 3]; + state[3 * 4 + 3] = state[0 * 4 + 3]; + state[0 * 4 + 3] = state[1 * 4 + 3]; + state[1 * 4 + 3] = state[2 * 4 + 3]; + state[2 * 4 + 3] = tmp; +} + +fn aes_inv_mix_one_column(column: *[4]u8) void { + const c0 = column[0]; + const c1 = column[1]; + const c2 = column[2]; + const c3 = column[3]; + + column[0] = gfmult(0x0e, c0) ^ gfmult(0x0b, c1) ^ gfmult(0x0d, c2) ^ gfmult(0x09, c3); + column[1] = gfmult(0x09, c0) ^ gfmult(0x0e, c1) ^ gfmult(0x0b, c2) ^ gfmult(0x0d, c3); + column[2] = gfmult(0x0d, c0) ^ gfmult(0x09, c1) ^ gfmult(0x0e, c2) ^ gfmult(0x0b, c3); + column[3] = gfmult(0x0b, c0) ^ gfmult(0x0d, c1) ^ gfmult(0x09, c2) ^ gfmult(0x0e, c3); } fn aes_inv_mix_columns(state: *[AES_BLOCK_SIZE]u8) void { - _ = state; - // TODO + for (0..4) |i| { + aes_inv_mix_one_column(@ptrCast(state[(4 * i)..(4 * i + 4)])); + } } fn aes_sub_word(word: u32) u32 { @@ -258,6 +332,37 @@ fn aes_rot_word(word: u32) u32 { return bytes_to_word(&.{ bytes[1], bytes[2], bytes[3], bytes[0] }); } +// ----------------------------------- GALOIS FIELD HELPERS ----------------------------------- // + +fn xtime(element: u8) u8 { + return if (element & 0x80 != 0) + // reduction modulo the AES irreducible polynomial + ((element << 1) ^ 0x1b) + else + (element << 1); +} + +fn gfmult(factor: comptime_int, element: u8) u8 { + const xe = xtime(element); + const x2e = xtime(xe); + const x3e = xtime(x2e); + + return if (factor == 0x09) + // 0x09 = 0x08 ^ 0x01 + x3e ^ element + else if (factor == 0x0b) + // 0x0b = 0x08 ^ 0x03 + x3e ^ xe ^ element + else if (factor == 0x0d) + // 0x0d = 0x08 ^ 0x05 + x3e ^ x2e ^ element + else if (factor == 0x0e) + // 0x0e = 0x08 ^ 0x06 + x3e ^ x2e ^ xe + else + unreachable; +} + // ----------------------------------- ENDIANNESS HELPERS ----------------------------------- // fn word_to_bytes(word: u32) [4]u8 { @@ -503,3 +608,93 @@ test "AES-256 ECB decryption" { try testing.expect(std.mem.eql(u8, &buffer, &plaintext[i])); } } + +test "AES S-box inverse" { + for (0..256) |x| { + try testing.expectEqual(x, AES_INV_SBOX[AES_SBOX[x]]); + } +} + +test "AES SubBytes" { + var state = [AES_BLOCK_SIZE]u8{ + 0x40, 0xBF, 0xAB, 0xF4, 0x06, 0xEE, 0x4D, 0x30, + 0x42, 0xCA, 0x6B, 0x99, 0x7A, 0x5C, 0x58, 0x16, + }; + const reference = [AES_BLOCK_SIZE]u8{ + 0x09, 0x08, 0x62, 0xBF, 0x6F, 0x28, 0xE3, 0x04, + 0x2C, 0x74, 0x7F, 0xEE, 0xDA, 0x4A, 0x6A, 0x47, + }; + + aes_sub_bytes(&state); + try testing.expect(std.mem.eql(u8, &state, &reference)); +} + +test "AES ShiftRows" { + var state = [AES_BLOCK_SIZE]u8{ + 0x09, 0x08, 0x62, 0xBF, 0x6F, 0x28, 0xE3, 0x04, + 0x2C, 0x74, 0x7F, 0xEE, 0xDA, 0x4A, 0x6A, 0x47, + }; + const reference = [AES_BLOCK_SIZE]u8{ + 0x09, 0x28, 0x7F, 0x47, 0x6F, 0x74, 0x6A, 0xBF, + 0x2C, 0x4A, 0x62, 0x04, 0xDA, 0x08, 0xE3, 0xEE, + }; + + aes_shift_rows(&state); + try testing.expect(std.mem.eql(u8, &state, &reference)); +} + +test "AES MixColumns" { + var state = [AES_BLOCK_SIZE]u8{ + 0x09, 0x28, 0x7F, 0x47, 0x6F, 0x74, 0x6A, 0xBF, + 0x2C, 0x4A, 0x62, 0x04, 0xDA, 0x08, 0xE3, 0xEE, + }; + const reference = [AES_BLOCK_SIZE]u8{ + 0x52, 0x9F, 0x16, 0xC2, 0x97, 0x86, 0x15, 0xCA, + 0xE0, 0x1A, 0xAE, 0x54, 0xBA, 0x1A, 0x26, 0x59, + }; + + aes_mix_columns(&state); + try testing.expect(std.mem.eql(u8, &state, &reference)); +} + +test "AES InvSubBytes" { + const reference = [AES_BLOCK_SIZE]u8{ + 0x40, 0xBF, 0xAB, 0xF4, 0x06, 0xEE, 0x4D, 0x30, + 0x42, 0xCA, 0x6B, 0x99, 0x7A, 0x5C, 0x58, 0x16, + }; + var state = [AES_BLOCK_SIZE]u8{ + 0x09, 0x08, 0x62, 0xBF, 0x6F, 0x28, 0xE3, 0x04, + 0x2C, 0x74, 0x7F, 0xEE, 0xDA, 0x4A, 0x6A, 0x47, + }; + + aes_inv_sub_bytes(&state); + try testing.expect(std.mem.eql(u8, &state, &reference)); +} + +test "AES InvShiftRows" { + const reference = [AES_BLOCK_SIZE]u8{ + 0x09, 0x08, 0x62, 0xBF, 0x6F, 0x28, 0xE3, 0x04, + 0x2C, 0x74, 0x7F, 0xEE, 0xDA, 0x4A, 0x6A, 0x47, + }; + var state = [AES_BLOCK_SIZE]u8{ + 0x09, 0x28, 0x7F, 0x47, 0x6F, 0x74, 0x6A, 0xBF, + 0x2C, 0x4A, 0x62, 0x04, 0xDA, 0x08, 0xE3, 0xEE, + }; + + aes_inv_shift_rows(&state); + try testing.expect(std.mem.eql(u8, &state, &reference)); +} + +test "AES InvMixColumns" { + const reference = [AES_BLOCK_SIZE]u8{ + 0x09, 0x28, 0x7F, 0x47, 0x6F, 0x74, 0x6A, 0xBF, + 0x2C, 0x4A, 0x62, 0x04, 0xDA, 0x08, 0xE3, 0xEE, + }; + var state = [AES_BLOCK_SIZE]u8{ + 0x52, 0x9F, 0x16, 0xC2, 0x97, 0x86, 0x15, 0xCA, + 0xE0, 0x1A, 0xAE, 0x54, 0xBA, 0x1A, 0x26, 0x59, + }; + + aes_inv_mix_columns(&state); + try testing.expect(std.mem.eql(u8, &state, &reference)); +}