Writeup by xtea418 for Bugs Buggy Easy

crypto RSA

February 17, 2025

We are given the encrypted flag and a signature oracle


from Crypto.PublicKey import RSA

def sign(sk, m):
    mp, dq = m % sk.q, sk.d % (sk.q - 1)
    mq, dp = m % sk.p, sk.d % (sk.p - 1)
    s1 = pow(mq, dq, sk.q)
    s2 = pow(mp, dp, sk.p)
    h = (sk.u * (s1 - s2)) % sk.q
    s = (s2 + h * sk.p) % sk.n
    return s

if __name__ == "__main__":
    with open("sk.pem", "r") as fp:
        sk = RSA.import_key(fp.read())

    with open("flag.txt", "rb") as fp:
        flag = int.from_bytes(fp.read().strip(), "big")

    c = pow(flag, sk.e, sk.n)
    print(f"Encrypted flag: {c}")

    while True:
        print("What do you want me to sign?")
            m = int(input(">>> "))
            print(sign(sk, m))

The sign function here is not secure, if m > p and/or m > q it returns faulty signatures. If this weren’t the case we could simply sign(c) to get the flag.

First we recover the modulus

n_approx = gcd(sign(2) ** 2 - sign(4), sign(3) ** 2 - sign(9)) # n_approx is N*k where k is a small value
*_, (n, _) = n_approx.factor(limit=0x100000)

Then we want to find a value m so that m < p and m < q, (assuming q > p), the reason for this is that pow(sign(m), e, N) - m will then be divisible by p, if instead m is greater than or equal to q then gcd(pow(sign(m), e, N) - m, N) won’t return anything useful.

The easiest way I could think of to find such a m was to use binary search. There are probably better ways, but this worked:

def binary_search_for_fault(N):
    low, high = 0, N - 1
    while low <= high:
        m = (low + high) // 2
        signature_fault = pow(sign(m), e, N) != m

        if signature_fault:
            diff = (pow(sign(m), e, N) - m) % N
            p_cand = gcd(diff, N)
            if p_cand != 1 and p_cand != N:
                return p_cand
            high = m - 1
            low = m + 1

    raise ValueError("failed to find p")

Then once we have the factors of N, decrypting c is trivial:

p = binary_search_for_fault(n)
q = n / p

phi = (p - 1) * (q - 1)
d = pow(e, -1, int(phi))

print(long_to_bytes(pow(int(flag), int(d), int(n)))) # b'FCSC{78ef932f3e1f42a1b4cd25674082fb906bc70f7b1072415268f76c0df4cf7527}'

Full script:

from Crypto.Util.number import long_to_bytes
from pwn import remote
from sage.all import *

rem = remote("localhost", 4000)

flag = ZZ(rem.recvline().split(b":")[-1].strip().decode())

def sign(n: ZZ) -> ZZ:
    rem.sendlineafter(b">>>", str(n).encode())
    numb = ZZ(rem.recvline().decode())
    return numb

# recover modulus
n_approx = gcd(sign(2) ** 2 - sign(4), sign(3) ** 2 - sign(9))
*_, (n, _) = n_approx.factor(limit=0x10001)  # n_approx is N*k where k is a small value


e = 0x10001
assert pow(sign(2), e, n) == 2

def binary_search_for_fault(N):
    low, high = 0, N - 1
    while low <= high:
        m = (low + high) // 2
        signature_fault = pow(sign(m), e, N) != m

        if signature_fault:
            diff = (pow(sign(m), e, N) - m) % N
            p_cand = gcd(diff, N)
            if p_cand != 1 and p_cand != N:
                return p_cand
            high = m - 1
            low = m + 1

    raise ValueError("failed to find p")

# recover p, from signature where m > p and m < q
p = binary_search_for_fault(n)
q = n / p

phi = (p - 1) * (q - 1)
d = pow(e, -1, int(phi))

print(long_to_bytes(pow(int(flag), int(d), int(n)))) # b'FCSC{78ef932f3e1f42a1b4cd25674082fb906bc70f7b1072415268f76c0df4cf7527}'