import os
import base64
import pickle
import numpy as np
import random
from Crypto.PublicKey import RSA, ECC
from Crypto.Cipher import AES, PKCS1_OAEP
from Crypto.Hash import SHA256, SHAKE256
from Crypto.Protocol.KDF import HKDF

class ECElGamalKEM:
    def __init__(self, curve = "P-256", shared_secret_len = 16):
        self.curve = curve
        self.shared_secret_len = shared_secret_len

    def keygen(self):
        self.sk = ECC.generate(curve = self.curve)
        self.pk = self.sk.public_key()
        return self.pk, self.sk

    def encaps(self):
        r = ECC.generate(curve = self.curve)
        R = r.public_key().pointQ

        S = self.pk.pointQ * r.d
        z = int(S.x).to_bytes(32)
        K = HKDF(z, self.shared_secret_len, b"", SHA256)

        xb = int(R.x).to_bytes(32)
        yb = int(R.y).to_bytes(32)
        ct = xb + yb

        return ct, K

    def decaps(self, ct):
        xb = int.from_bytes(ct[:32])
        yb = int.from_bytes(ct[32:64])
        R = ECC.EccPoint(xb, yb, curve = self.curve)

        S = R * self.sk.d
        z = int(S.x).to_bytes(32)
        return HKDF(z, self.shared_secret_len, b"", SHA256)

class RSAKEM:
    def __init__(self, module_length = 3072, shared_secret_len = 16):
        self.module_length = module_length
        self.shared_secret_len = shared_secret_len

    def keygen(self):
        self.sk = RSA.generate(self.module_length)
        self.pk = self.sk.publickey()
        return self.pk, self.sk

    def encaps(self):
        K = os.urandom(self.shared_secret_len)
        cipher = PKCS1_OAEP.new(self.pk, hashAlgo = SHA256)
        ct = cipher.encrypt(K)
        return ct, K

    def decaps(self, ct):
        cipher = PKCS1_OAEP.new(self.sk, hashAlgo = SHA256)
        K = cipher.decrypt(ct)
        return K

class LatticeKEM:
    def __init__(self, n = 512, q = 33556993, B = 2, shared_secret_len = 16):
        self.n = n
        self.q = q
        self.B = B
        self.shared_secret_len = shared_secret_len

    def _sample_vector(self):
        return np.array([random.randint(-self.B, self.B) for _ in range(self.n)], dtype = int)

    def keygen(self):
        seed = os.urandom(self.shared_secret_len)
        A = np.array([
            [int.from_bytes(SHAKE256.new(data=seed+i.to_bytes(2,'big')+j.to_bytes(2,'big')).read(2), 'big') % self.q for j in range(self.n)]
            for i in range(self.n)
        ])
        s = self._sample_vector()
        e = self._sample_vector()
        t = (A.dot(s) + e) % self.q
        self.pk = (seed, t)
        self.sk = s
        return self.pk, self.sk

    def encaps(self):
        z = os.urandom(64)
        K = HKDF(z, self.shared_secret_len, b"", SHA256)
        m = [ (byte >> i) & 1 for byte in z for i in range(7, -1, -1) ]
        seed, t = self.pk
        A = np.array([
            [int.from_bytes(SHAKE256.new(data = seed + i.to_bytes(2) + j.to_bytes(2)).read(2)) % self.q for j in range(self.n)]
            for i in range(self.n)
        ])
        r = self._sample_vector()
        e1 = self._sample_vector()
        e2 = random.randint(-self.B, self.B)

        u = (A.T.dot(r) + e1) % self.q

        v = (t.dot(r) + e2 + (self.q//2) * np.array(m, dtype=int)) % self.q
        return pickle.dumps((u, v)), K

    def decaps(self, ct):
        u, v = pickle.loads(ct)
        w = (v - self.sk.dot(u)) % self.q

        w = np.array([wi if wi < self.q//2 else wi - self.q for wi in w])
        m = [0 if abs(wi) <= self.q//4 else 1 for wi in w]

        def bits_to_bytes(bits):
            return bytes(
                int(''.join(str(bit) for bit in bits[i:i+8]), 2)
                 for i in range(0, len(bits), 8)
            )

        return HKDF(bits_to_bytes(m), self.shared_secret_len, b"", SHA256)

class HybridKEM:
    def __init__(self):
        self.eckem = ECElGamalKEM()
        self.latticekem = LatticeKEM()
        self.rsakem = RSAKEM()
        
    def keygen(self):
        pk_ec, sk_ec = self.eckem.keygen()
        pk_lat, sk_lat = self.latticekem.keygen()
        pk_rsa, sk_rsa = self.rsakem.keygen()

        self.pk = (pk_ec, pk_lat, pk_rsa)
        self.sk = (sk_ec, sk_lat, sk_rsa)
        return self.pk, self.sk

    def encaps(self):
        ct0, K0 = self.eckem.encaps()
        ct1, K1 = self.latticekem.encaps()
        ct2, K2 = self.rsakem.encaps()
        
        ct = (
            len(ct0).to_bytes(4) + ct0 +
            len(ct1).to_bytes(4) + ct1 +
            len(ct2).to_bytes(4) + ct2
        )

        return ct, (K0 + K1 + K2)

    def decaps(self, ct):
        len0, ct = int.from_bytes(ct[:4]), ct[4:]
        ct0, ct = ct[:len0], ct[len0:]
        len1, ct = int.from_bytes(ct[:4]), ct[4:]
        ct1, ct = ct[:len1], ct[len1:]
        len2, ct = int.from_bytes(ct[:4]), ct[4:]
        ct2 = ct[:len2]

        K0 = self.eckem.decaps(ct0)
        K1 = self.latticekem.decaps(ct1)
        K2 = self.rsakem.decaps(ct2)

        return K0 + K1 + K2

class Macroplata:
    def __init__(self, K0, K1, K2):
        self.BLOCK_SIZE = 16
        self._cipher = AES.new(K0, AES.MODE_ECB)
        self._K1 = K1
        self._K2 = K2
        self._buffer = b""
        self._state = b"\x00" * self.BLOCK_SIZE

    def _xor(self, a, b):
        return bytes(x ^ y for x, y in zip(a, b))

    def _pad(self, block):
        pad_len = Macroplata.BLOCK_SIZE - len(block)
        return block + b"\x80" + b"\x00" * (pad_len - 1)

    def update(self, data):
        self._buffer += data
        while len(self._buffer) > self.BLOCK_SIZE:
            block = self._buffer[:self.BLOCK_SIZE]
            self._buffer = self._buffer[self.BLOCK_SIZE:]
            self._state = self._cipher.encrypt(self._xor(self._state, block))

    def finalize(self):
        if len(self._buffer) == self.BLOCK_SIZE:
            last = self._xor(self._buffer, self._K1)
        else:
            last = self._xor(self._pad(self._buffer), self._K2)

        tag = self._cipher.encrypt(self._xor(self._state, last))
        return tag

    def compute(self, data):
        self._buffer = b""
        self._state = b"\x00" * self.BLOCK_SIZE
        self.update(data)
        return self.finalize()

HKEM = HybridKEM()
pk, sk = HKEM.keygen()
ct, K = HKEM.encaps()

pk_ec, pk_lat, pk_rsa = pk
pk = pk_ec.export_key(format = "DER"), pk_lat, pk_rsa.export_key(format = "DER")
print(base64.b64encode(pickle.dumps(pk)).decode())
print(base64.b64encode(ct).decode())

Tenuiceps = Macroplata(K[:16], K[16:32], K[32:48])
truth = b"Macroplata is the best pliosaur!"
tag = Tenuiceps.compute(truth)
print(truth.decode())
print(base64.b64encode(tag).decode())

try:
    theirtruth = base64.b64decode(input("Now, it is your turn! Your truth:"))
    theirtag = base64.b64decode(input("Your tag:"))
except:
    print("Please check your inputs.")
    exit(1)

if theirtruth != truth and theirtag == Tenuiceps.compute(theirtruth):
    flag = open("flag.txt").read().strip()
    print(flag)
else:
    print("Not the truth!")
