import os
from base64 import b64encode
from zlib import compress
from pickle import dumps
import time

# https://github.com/gmh5225/slh-dsa-py/blob/main/slh_dsa.py
from slh_dsa import SLH_DSA

HashingGoesBrrr = SLH_DSA(
    hashname = "SHAKE",
    paramid = "FCSC",
    n    = 16,
    h    = 10,
    d    =  5,
    hp   =  2,
    a    =  4,
    k    = 10,
    lg_w =  4,
    m    = 12,
    rbg  = os.urandom,
)
pk, sk = HashingGoesBrrr.keygen()
print(pk.hex())

try:
    N = int(input("How many signatures do you need? "))
    if 1 <= N <= 2 ** HashingGoesBrrr.h:
        for i in range(N):
            msg = i.to_bytes(2) + b"FCSC"
            sig = HashingGoesBrrr.slh_sign(msg, sk)
            print(b64encode(compress(dumps(sig))).decode())
    else:
        print("Invalid number.")

    start = time.perf_counter()
    your_signature = bytes.fromhex(input("Your signature: "))
    end = time.perf_counter()

    if end - start > 60:
        print("Hash faster next time. Brrr brrrrrrr.")
    elif HashingGoesBrrr.slh_verify(b"hashinggoesbrrr", your_signature, pk):
        flag = open("flag.txt", "rb").read().strip()
        print(flag)
    else:
        print("Better luck next time!")
except:
    print("Brrrease check your inputs.")
