Solution de ribt pour RSA Destroyer

crypto RSA

29 novembre 2023

Introduction

Voici le code Python dont il est question :

# **This** destroyes the RSA cryptosystem.

from Crypto.Util.number import isPrime, bytes_to_long
from Crypto.Random.random import getrandbits

def fastPrime(bits, eps = 32):
	while True:
		a, e, u = getrandbits(eps), getrandbits(eps), getrandbits(4 * eps)
		p = a * (2 ** bits - e) + u
		if isPrime(p):
			return p

def generate(bits = 2048):
	p = fastPrime(bits // 2)
	q = fastPrime(bits // 2)
	return p * q, 2 ** 16 + 1

n, e = generate()

p = bytes_to_long(open("flag.txt", "rb").read())
c = pow(p, e, n)

print(f"e = {e}")
print(f"n = {n}")
print(f"c = {c}")

et output.txt contient les valeurs de e, n et c.

Petit rappel sur le chiffrement RSA : n est la clé publique (celle utilisée pour chiffrer un message) et ce nombre est le produit de 2 nombres premiers très grands, p et q et il faut ces deux nombres pour déchiffrer un message. Le n fournit ici fait 2048 bits (636 chiffres en décimal) donc à première vue il est bien trop énorme pour trouver ses facteurs. On vérifie quand même mais il n’est connu de factordb.

Le problème est sans doute dans la fonction fastPrime qui génère p et q (deux nombres premiers de 1024 bits). Chacun de ces deux nombres est créé en faisant a * (2**1024 - e) + u (** c’est puissance) avec a et e des nombres aléatoires de 32 bits et u de 128 bits.

Premiers tâtonnements

Posons :

p = a1 * (2**1024 - e1) + u1
q = a2 * (2**1024 - e2) + u2

On a donc :

n = p * q = (a1 * (2**1024 - e1) + u1) * (a2 * (2**1024 - e2) + u2)

Je me suis “““amusé””” à tout développer puis factoriser et ça donne :

n = 2**2048 * (a1*a2)
	+ 2**1024 * (a1*u2 -a1*a2*e2 - a1*a2*e1 + u1*a2)
	+ (a1*a2*e1*e2 - a1*e1*u2 - u1*a2*e2 + u1*u2)

En divisant n par 2**2048 on trouve 13765971169208528045. D’après factordb c’est 5 · 2203 · 1665479 · 750383357. Le seul moyen de faire deux facteurs inférieurs à 2**32 (car a1 et a2 sont des nombres de 32 bits) c’est que :

a1 = 2203*1665479 = 3669050237
a2 = 5*750383357 = 3751916785

On a déjà bien avancé ! Ensuite, on sait également que :

(n%(2**2048))//(2**1024) = a1*u2 -a1*a2*e2 - a1*a2*e1 + u1*a2

n%(2**1024) = a1*a2*e1*e2 - a1*e1*u2 - u1*a2*e2 + u1*u2

En remplaçant toutes les valeurs connues, on a :

3669050237*u2 -13765971169208528045*e2 - 13765971169208528045*e1 + u1*3751916785 = 1462483866390329830822836164002145062407975244184

13765971169208528045*e1*e2 - 3669050237*e1*u2 - u1*3751916785*e2 + u1*u2 = 37284463254120829734596659590852831388840149328402126048476097877596519338219

J’essaye de factoriser ces grands nombres, de tourner les équations dans tous les sens mais rien y fait… Dépité je vais me coucher en laissant ces deux équations à un solver z3.

Illumination

Le lendemain le solver n’a rien trouvé mais je suis plus apte à réfléchir. J’ai alors une idée brillante : afficher le nombre n fourni en hexadécimal :

0xbf0a8dd7d8f16cad000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001002c0b6fc6c3c2949b0a1e097f3c51eff2e89198000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000526e422445cbd24c429d60a4a3d75cfd20d09708a2945d9ad2d3b65a55f110eb

Pour être sûr de comprendre, je vais relancer le programme donné et comparer le n hexadécimal avec p et q :

from Crypto.Util.number import isPrime, bytes_to_long
from Crypto.Random.random import getrandbits

eps = 32
bits = 2048//2
while True:
    a1, e1, u1 = getrandbits(eps), getrandbits(eps), getrandbits(4 * eps)
    p = a1 * (2**bits - e1) + u1
    if isPrime(p):
        break

while True:
    a2, e2, u2 = getrandbits(eps), getrandbits(eps), getrandbits(4 * eps)
    q = a2 * (2**bits - e2) + u2
    if isPrime(q):
        break

n = p*q
e = 2 ** 16 + 1

m = bytes_to_long(b"FCSC{faux_flag}")
c = pow(p, e, n)
>>> hex(n)
'0x189c82ba8649106a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000003f3f37601b05ef93069190e776d1a8898a14a7f500000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000027f2612f94e592c8ae6159195be428796c2ebaf567a9ce005e59a00a97fb6585'
>>> hex(p)
'0x616d2c41000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000008d7518fa9ec90d6c7000e5b0cd80d923'
>>> hex(q)
'0x40ab5dea00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000484b018bed52736a192bd386fb0db537'

On retrouve bien 0x616d2c41 * 0x40ab5dea == 0x189c82ba8649106a (les a multipliés par 2**1024) et… une ampoule s’est allumée au dessus de ma tête ! Au lieu d’écrire p comme a1 * (2**1024 - e1) + u1, on peut simplement l’écrire comme a1 * 2**1024 + b1 (avec b1 = u1 - a1*e1 mais on s’en fout) comme ça on réduit le nombre d’inconnues ! On a donc :

n = p * q
  = (a1 * 2**1024 + b1) * (a2 * 2**1024 + b2)
  = a1*2**1024*a2*2**1024 + a1*2**1024*b2 + b1*a2*2**1024 + b1*b2
  = 2**2048*(a1*a2) + 2**1024*(a1*b2+b1*a2) + (b1*b2)

Cela se vérifie sur notre exemple :

>>> 0x616d2c41*0x484b018bed52736a192bd386fb0db537 + 0x8d7518fa9ec90d6c7000e5b0cd80d923*0x40ab5dea == 0x3f3f37601b05ef93069190e776d1a8898a14a7f5
True
>>> 0x8d7518fa9ec90d6c7000e5b0cd80d923*0x484b018bed52736a192bd386fb0db537 == 0x27f2612f94e592c8ae6159195be428796c2ebaf567a9ce005e59a00a97fb6585
True

Si l’on reprend le n donné (dont on connaît déjà a1 et a2) on a :

n = (13765971169208528045*2**2048) + ((3669050237*b2+3751916785*b1)*2**1024) + b1*b2

donc

13765971169208528045 = 0xbf0a8dd7d8f16cad
3669050237*b2+3751916785*b1 = 0x1002c0b6fc6c3c2949b0a1e097f3c51eff2e89198
b1*b2 = 0x526e422445cbd24c429d60a4a3d75cfd20d09708a2945d9ad2d3b65a55f110eb

On met tout ça dans un solver z3 :

from z3 import *

n = 444874973852804286630293120525019547982392964519934608680681255396764239795499482860997657663742247333836933457910503642061679607999128792657151145831533603267962151902191791568052924623477918783346790554917615006885807262798511378178431356140169891510484103567017335784087168191133679976921108092647227149255338118895695993606854195408940572577899625236666854544581041490770396755583819878794842828965377818593455075306655077757834318066860484956428681524881285058664687568640627516452658874124048546780999256640377399347893644988620246748059490751348919880389771785423781356133657866769589669296191804649195706447605778549172906037483
a1 = 3669050237
a2 = 3751916785

b1 = Int('b1')
b2 = Int('b2')

s = Solver()

s.add(3669050237*b2+3751916785*b1 == 0x1002c0b6fc6c3c2949b0a1e097f3c51eff2e89198)
s.add(b1*b2 == 0x526e422445cbd24c429d60a4a3d75cfd20d09708a2945d9ad2d3b65a55f110eb)

print (s.check())
print (s.model())

et après quelques secondes, on trouve :

sat
[b1 = 155855460081744155068217508253103646077,
 b2 = 239224620263965184662787181879747443847]

On a donc trouvé la clé privée :

p = (3669050237*2**1024)+155855460081744155068217508253103646077
q = (3751916785*2**1024)+239224620263965184662787181879747443847

Un petit tour sur notre moteur de recherche préféré pour trouver comment déchiffrer le message avec p et q et on a notre programme final :

e = 65537
n = 444874973852804286630293120525019547982392964519934608680681255396764239795499482860997657663742247333836933457910503642061679607999128792657151145831533603267962151902191791568052924623477918783346790554917615006885807262798511378178431356140169891510484103567017335784087168191133679976921108092647227149255338118895695993606854195408940572577899625236666854544581041490770396755583819878794842828965377818593455075306655077757834318066860484956428681524881285058664687568640627516452658874124048546780999256640377399347893644988620246748059490751348919880389771785423781356133657866769589669296191804649195706447605778549172906037483
c = 95237912740655706597869523108017194269174342313145809624317482236690453533195825723998662803480781411928531102859302761153780930600026069381338457909962825300269319811329312349030179047249481841770850760719178786027583177746485281874469568361239865139247368477628439074063199551773499058148848583822114902905937101832069433266700866684389484684637264625534353716652481372979896491011990121581654120224008271898183948045975282945190669287662303053695007661315593832681112603350797162485915921143973984584370685793424167878687293688079969123983391456553965822470300435648090790538426859154898556069348437896975230111242040448169800372469

p = (3669050237<<1024)+155855460081744155068217508253103646077
q = (3751916785<<1024)+239224620263965184662787181879747443847

def getModInverse(a, m):
    u1, u2, u3 = 1, 0, a
    v1, v2, v3 = 0, 1, m
    while v3 != 0:
        q = u3 // v3
        v1, v2, v3, u1, u2, u3 = (u1 - q * v1), (u2 - q * v2), (u3 - q * v3), v1, v2, v3
    return u1 % m

phi = (p - 1) * (q - 1)
d = getModInverse(e, phi)
flag = pow(c, d, n)
print(bytes.fromhex(hex(flag)[2:]))
b'FCSC{cd43566923980e47f6630e82c2d9a55b388f01043bc78b9ce3354ce02acf22e8}'