# Python regular imports
import gmpy2
from Crypto.PublicKey import RSA
from random import randrange

# Public files (challenge specific)
from machine import Machine
from machine_faulted import FaultedMachine
from assembly import assembly

def sign(machine, msg, n, d):
    machine.reset()
    machine.R5 = msg
    machine.R6 = n
    machine.R7 = d
    machine.runCode()
    if machine.error:
        return 0
    return machine.R0

def challenge(choice):
    assert choice in [1, 2]
    e = 2 ** 16 + 1
    bitsize = 1024
    print("Generating RSA Key...")
    k = RSA.generate(bitsize)
    print("Here is the public modulus:")
    print(f"n = {int(k.n)}")

    # Get user input
    print("Enter your list of faulted instructions (for instance 0 1 10000):")
    L = input(">>> ")
    faults = { int(x) for x in L.split() }
    if len(faults) > 1:
        print("With only one fault please")
        exit()

    # Initialize
    if choice == 1:
        code = open("RSA_sign.asm").read().splitlines()
    else:
        code = open("RSA_sign_masked.asm").read().splitlines()

    code = assembly(code)
    machine = FaultedMachine(code, faults)

    msg = randrange(k.n)
    s = sign(machine, msg, k.n, k.d)
    print("Here is the message:")
    print(f"msg = {int(msg)}")
    print("Here is the signature:")
    print(f"s = {int(s)}")

    # Check key recovery
    print("What was the private exponent?")
    potential_d = int(input(">>> "))
    if 1 == gmpy2.powmod(msg,(e * potential_d - 1), k.n):
        if choice == 1:
            flag = open("flag.txt").read().strip()
        else:
            flag = open("flag_masked.txt").read().strip()
        print("[+] Congrats! Here is the flag:")
        print(flag)
        exit()
    else:
        print("Nope!")

if __name__ == "__main__":
    try:
        while True:
            print("Which flag do you want to grab?")
            print("  0. Quit.")
            print("  1. Hard flag 1/2 - Private exponent is not protected.")
            print("  2. Hardest flag 2/2 - Private exponent is split thanks to a euclidean division.")
            choice = int(input(">>> "))

            if choice == 0: exit()
            if choice == 1: challenge(1)
            if choice == 2: challenge(2)
    except:
        print("Please check your inputs.")
