import os
import json
from Crypto.Cipher import AES
from Crypto.Hash import SHA256

class BADAES:
    RCON = [1, 2, 4, 8, 16, 32, 64, 128, 27, 54]

    with open("badaes.bin", "rb") as fp:
        S = fp.read(2 ** 8)
        SK = fp.read(2 ** 8)
        M = []
        for _ in range(2 ** 8):
            M.append(fp.read(2 ** 8))
        XK = []
        for _ in range(2 ** 8):
            XK.append(fp.read(2 ** 8))
        X0 = []
        for _ in range(2 ** 8):
            X0.append(fp.read(2 ** 8))
        X1 = []
        for _ in range(2 ** 8):
            X1.append(fp.read(2 ** 8))
        X2 = []
        for _ in range(2 ** 8):
            X2.append(fp.read(2 ** 8))
        M1 = fp.read(2 ** 8)
        M2 = fp.read(2 ** 8)
        M3 = fp.read(2 ** 8)

    def __init__(self, k):
        self.rk = [ k[i:i + 4] for i in range(0, 16, 4) ]
        for i in range(4, 44):
            temp = self.rk[i - 1][:]
            if i % 4 == 0:
                temp = temp[1:] + temp[:1]
                temp = [ self.SK[x] for x in temp ]
                temp[0] = self.XK[temp[0]][self.RCON[i // 4 - 1]]
            rk = [ self.XK[self.rk[i - 4][j]][temp[j]] for j in range(4) ]
            self.rk.append(rk)
        self.rk = [ self.rk[i:i+4] for i in range(0, 44, 4) ]

    def ARK(self, s, r):
        for c in range(4):
            for r0 in range(4):
                s[r0][c] = self.X0[self.rk[r][c][r0]][s[r0][c]]
        return s

    def SB(self, s):
        for i in range(4):
            for j in range(4):
                s[i][j] = self.S[s[i][j]]
        return s

    def SR(self, s):
        for i in range(1, 4):
            s[i] = s[i][i:] + s[i][:i]
        return s

    def MC(self, s):
        for c in range(4):
            a0, a1, a2, a3 = [ s[r][c] for r in range(4) ]
            col = [
                self.X2[self.X1[self.M2[a0]][self.M3[a1]]][self.X1[self.M1[a2]][self.M1[a3]]],
                self.X2[self.X1[self.M1[a0]][self.M2[a1]]][self.X1[self.M3[a2]][self.M1[a3]]],
                self.X2[self.X1[self.M1[a0]][self.M1[a1]]][self.X1[self.M2[a2]][self.M3[a3]]],
                self.X2[self.X1[self.M3[a0]][self.M1[a1]]][self.X1[self.M1[a2]][self.M2[a3]]],
            ]
            for r in range(4):
                s[r][c] = col[r]
        return s

    def encrypt(self, p):
        s = [
            [ p[r + 4 * c] for c in range(4) ]
            for r in range(4)
        ]
        for r in range(10):
            s = self.ARK(s, r)
            s = self.SB(s)
            s = self.SR(s)
            s = self.MC(s)
        s = self.ARK(s, 10)
        return bytes([ s[r][c] for c in range(4) for r in range(4) ])

with open("key.bin", "rb") as fp:
    key = b"FCSC{{" + fp.read(8) + b"}}"

E = BADAES(key)
p = os.urandom(16)
c = E.encrypt(p)

flag = open("flag.txt", "rb").read()
E = AES.new(SHA256.new(key).digest(), AES.MODE_ECB)
enc = E.encrypt(flag.ljust(80, b"\x00"))

d = {
    "p": p.hex(),
    "c": c.hex(),
    "enc": enc.hex(),
}

with open("output.json", "w") as fp:
    json.dump(d, fp, indent = 4)
