We are given the encrypted flag and a signature oracle
#!/usr/local/bin/python3.8
from Crypto.PublicKey import RSA
def sign(sk, m):
mp, dq = m % sk.q, sk.d % (sk.q - 1)
mq, dp = m % sk.p, sk.d % (sk.p - 1)
s1 = pow(mq, dq, sk.q)
s2 = pow(mp, dp, sk.p)
h = (sk.u * (s1 - s2)) % sk.q
s = (s2 + h * sk.p) % sk.n
return s
if __name__ == "__main__":
with open("sk.pem", "r") as fp:
sk = RSA.import_key(fp.read())
with open("flag.txt", "rb") as fp:
flag = int.from_bytes(fp.read().strip(), "big")
c = pow(flag, sk.e, sk.n)
print(f"Encrypted flag: {c}")
while True:
print("What do you want me to sign?")
try:
m = int(input(">>> "))
print(sign(sk, m))
except:
break
The sign
function here is not secure, if m > p
and/or m > q
it returns faulty signatures.
If this weren’t the case we could simply sign(c)
to get the flag.
First we recover the modulus
n_approx = gcd(sign(2) ** 2 - sign(4), sign(3) ** 2 - sign(9)) # n_approx is N*k where k is a small value
*_, (n, _) = n_approx.factor(limit=0x100000)
Then we want to find a value m
so that m < p
and m < q
, (assuming q > p
), the reason for this is that pow(sign(m), e, N) - m
will then be divisible by p
, if instead m
is greater than or equal to q
then gcd(pow(sign(m), e, N) - m, N)
won’t return anything useful.
The easiest way I could think of to find such a m
was to use binary search. There are probably better ways, but this worked:
def binary_search_for_fault(N):
low, high = 0, N - 1
while low <= high:
m = (low + high) // 2
signature_fault = pow(sign(m), e, N) != m
if signature_fault:
diff = (pow(sign(m), e, N) - m) % N
p_cand = gcd(diff, N)
if p_cand != 1 and p_cand != N:
return p_cand
high = m - 1
else:
low = m + 1
raise ValueError("failed to find p")
Then once we have the factors of N
, decrypting c
is trivial:
p = binary_search_for_fault(n)
q = n / p
phi = (p - 1) * (q - 1)
d = pow(e, -1, int(phi))
print(long_to_bytes(pow(int(flag), int(d), int(n)))) # b'FCSC{78ef932f3e1f42a1b4cd25674082fb906bc70f7b1072415268f76c0df4cf7527}'
Full script:
from Crypto.Util.number import long_to_bytes
from pwn import remote
from sage.all import *
rem = remote("localhost", 4000)
flag = ZZ(rem.recvline().split(b":")[-1].strip().decode())
def sign(n: ZZ) -> ZZ:
rem.sendlineafter(b">>>", str(n).encode())
numb = ZZ(rem.recvline().decode())
return numb
# recover modulus
n_approx = gcd(sign(2) ** 2 - sign(4), sign(3) ** 2 - sign(9))
*_, (n, _) = n_approx.factor(limit=0x10001) # n_approx is N*k where k is a small value
print(f"{n=}")
e = 0x10001
assert pow(sign(2), e, n) == 2
def binary_search_for_fault(N):
low, high = 0, N - 1
while low <= high:
m = (low + high) // 2
signature_fault = pow(sign(m), e, N) != m
if signature_fault:
diff = (pow(sign(m), e, N) - m) % N
p_cand = gcd(diff, N)
if p_cand != 1 and p_cand != N:
return p_cand
high = m - 1
else:
low = m + 1
raise ValueError("failed to find p")
# recover p, from signature where m > p and m < q
p = binary_search_for_fault(n)
q = n / p
phi = (p - 1) * (q - 1)
d = pow(e, -1, int(phi))
print(long_to_bytes(pow(int(flag), int(d), int(n)))) # b'FCSC{78ef932f3e1f42a1b4cd25674082fb906bc70f7b1072415268f76c0df4cf7527}'