# Python regular imports
import gmpy2
from Crypto import Random
from Crypto.PublicKey import RSA
from Crypto.Hash import SHA256
from Crypto.Signature import pss
from Crypto.Signature.pss import _EMSA_PSS_ENCODE, MGF1
from Crypto.Util.number import bytes_to_long, long_to_bytes

# Public files (challenge specific)
from machine import Machine
from machine_faulted import FaultedMachine
from assembly import assembly

def sign(machine, msg, key):
    machine.reset()
    machine.R5 = msg
    machine.R6 = key.q
    machine.R7 = key.p
    machine.R8 = key.u
    machine.R9 = int(gmpy2.mod(key.d, key.q - 1))
    machine.RA = int(gmpy2.mod(key.d, key.p - 1))
    machine.runCode()
    if machine.error:
        return 0
    return machine.R0

def PSS_encoding(h, bitsize):
    randFunc = Random.get_random_bytes
    mgf = lambda x, y: MGF1(x, y, SHA256)
    sLen = h.digest_size
    EM = bytes_to_long(_EMSA_PSS_ENCODE(h, bitsize - 1, randFunc, mgf, sLen))
    return EM

def verify(h, s, key):
    verifier = pss.new(key.publickey())
    try:
        verifier.verify(h, s)
        return True
    except:
        return False

if __name__ == "__main__":
    try:
        # Key generation
        e = 2 ** 16 + 1
        bitsize = 1024
        print("Generating RSA Key...")
        key = RSA.generate(bitsize, e = e)
        print("Here is the public modulus:")
        print(f"n = {int(key.n)}")

        # Get user input
        print("Enter your list of faulted instructions (for instance 0 1 10000):")
        L = input(">>> ")
        faults = { int(x) for x in L.split() }

        print("Enter your message to be signed:")
        message = str(input(">>> "))
        print(f"{message = }")

        # Initialize
        code = open("rsa-sign.asm").read().splitlines()
        code = assembly(code)
        machine = FaultedMachine(code, faults)

        h = SHA256.new(message.encode("utf-8"))
        EM = PSS_encoding(h, bitsize)
        s = sign(machine, EM, key)
        print("Here is the signature:")
        print(f"s = {int(s)}")

        print(f"Is the signature correct? {verify(h, long_to_bytes(s, bitsize >> 3), key)}")

        # Check key recovery
        print("Give a prime factor of the public modulus")
        potential_factor = int(input(">>> "))
        if potential_factor == key.p or potential_factor == key.q:
            flag = open("flag.txt").read().strip()
            print("Congrats! Here is the flag:")
            print(flag)
        else:
            print("Nope!")
    except:
        print("Please check your inputs.")
