Writeup by the_one_and_only_az for Tight Schedule

crypto

April 16, 2024

Tight Schedule

Category Author Solves Points
Crypto Cryptanalyse & hg 4 493

Approaching the challenge

The goal is fairly obvious. Given a plaintext/ciphertext pair, we have to retrieve the key used to initialize a substitution-permutation network. The first reflex is to google the first few bytes of the sbox to see whether we’re dealing with anything custom (read: suspicious) or widely known (read: nothing-up-my-sleeve); turns out it’s AES’s. As such, we’ll work under the assumption that the sbox is basically unbreakable β€” the name of the chall being Tight Schedule, it’s expected the vulnerability will lie in the key schedule anyways.

Let’s look at the expandKey function.

    def expandKey(self, k):
        rk = [k]
        for _ in range(10):
            rk.append(self._round(rk[-1], self.RCON[len(rk)]))
        return rk

It takes in the cipher key, and derives a bunch of round keys. Nothing special. We’ll have to dig a tiny little bit deeper and analyze the _round function.

    def _round(self, x, cst = 0):
        a, b, c, d = x[-4:]
        t = bytes([self.S[b] ^ cst, self.S[c], self.S[d], self.S[a]])
        y  = xor(x[ 0: 4], t)
        y += xor(x[ 4: 8], y[-4:])
        y += xor(x[ 8:12], y[-4:])
        y += xor(x[12:16], y[-4:])
        return y

For the purpose of this function I’ll write x = [x0, x1, x2, x3]. Basically, it defines an “offset” t = S[x3], then computes a new state with the following equations:

y ^ [t] = [ x0, x0 ^ x1, x0 ^ x1 ^ x2, x0 ^ x1 ^ x2 ^ x3 ]

Ok. So what? Nothing too obvious and it’s already past midnight so let’s head to sleep and think about it with a clear head tomorrow.

time to sleep

Nevermind. It’s actually glaringly obvious (and somehow I did manage to sleep after that realisation): if we consider t as an external, completely independent variable, there is no diffusion. In other words, knowing the ith bit of x0,x1,x2,x3 and t immediately yields the ith bit of y0,y1,y2 and y3. This sounds absolutely disastrous, so let’s keep looking in that direction!

The main intuition we’ll follow for now is that by working bit-per-bit (or byte-per-byte) will reduce the search space enough that we can simply bruteforce the key. It’ll probably be something around 2^32 iterations for each chunk, or a variation of that. Ideally we’ll be able to maintain some kind of invariant throughout the encryption/decryption which will tell us whether our guess is correct.

Of course, the above is just vague intuition. Just talking won’t give us the invariants, we can’t have anything from nothing. By playing around with the defining equation for y, we see that it is possible to erase t to get some linear equations in x for z = round(y). This is of course highly suspicious and sounds like a pain to work through (given that there are 50 rounds to expand). It is high time for my beloved, symbolic analysis <3

We’ll model the cipher in a very simple way:

  • 4 variables for the key (x0, x1, x2, x3), which will represent the i-th bit of the actual x0x3.
  • 10 variables for the t that appear in the key expansion (t0t9).
  • 50 variables for the t that appear in the encryption (s0s49).
  • 4 variables for the plaintext (p0p3). Given these symbols the cipher becomes fully linear and is easy to reimplement (cf. solve.sage). Running the symbolic encryption gives equations for bytes of the ciphertext.
R = GF(2)[vars('p',4, 'x',4, 't',10, 's',50)]
ps = R.gens()[:4]
xs = R.gens()[-64:-60]
ts = R.gens()[-60:-50]
ss = R.gens()[-50:]

rk = symexpand(xs)
symc = symenc(ps, rk)
for ci in symc:
    print('-'*50)
    print(ci)
--------------------------------------------------
p00 + x00 + t01 + t03 + t05 + t07 + t09 + s00 + s01 + s02 + s03 + s04 + s05 + s06 + s07 + s08 + s09 + s10 + s11 + s12 + s13 + s14 + s15 + s16 + s17 + s18 + s19 + s20 + s21 + s22 + s23 + s24 + s25 + s26 + s27 + s28 + s29 + s30 + s31 + s32 + s33 + s34 + s35 + s36 + s37 + s38 + s39 + s40 + s41 + s42 + s43 + s44 + s45 + s46 + s47 + s48 + s49
--------------------------------------------------
p01 + x01 + t01 + t03 + t05 + t07 + t09 + s01 + s03 + s05 + s07 + s09 + s11 + s13 + s15 + s17 + s19 + s21 + s23 + s25 + s27 + s29 + s31 + s33 + s35 + s37 + s39 + s41 + s43 + s45 + s47 + s49
--------------------------------------------------
p00 + p02 + x00 + x02 + t01 + t05 + t09 + s00 + s01 + s04 + s05 + s08 + s09 + s12 + s13 + s16 + s17 + s20 + s21 + s24 + s25 + s28 + s29 + s32 + s33 + s36 + s37 + s40 + s41 + s44 + s45 + s48 + s49
--------------------------------------------------
p01 + p03 + x01 + x03 + t01 + t05 + t09 + s01 + s05 + s09 + s13 + s17 + s21 + s25 + s29 + s33 + s37 + s41 + s45 + s49

Very long, not very interesting. We already know there’s bound to be lots of cancellation going on; a little LLL to minimize equations won’t hurt! (recall that LLL combines rows linearly, so everything it outputs is computable from the bits of c).

LLL-everywhere

print('\n'+'*'*100+'\n')

M = matrix(ZZ, [
    [ci.monomial_coefficient(gi) for gi in R.gens()]
    for ci in symc
]).LLL().change_ring(GF(2))
print(M)

g = vector(R, R.gens())
V = GF(2)^R.ngens()

for v in M:
    w = sum(map(int, v))
    if w < 25:
        print(w, '|', g * V(v))

print('\n'+'*'*100+'\n')
****************************************************************************************************

[0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0]
[0 0 1 1 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0]
[0 1 0 1 0 1 0 1 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1]
[1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0]
16 | p03 + x03 + t03 + t07 + s03 + s07 + s11 + s15 + s19 + s23 + s27 + s31 + s35 + s39 + s43 + s47
16 | p02 + p03 + x02 + x03 + s02 + s06 + s10 + s14 + s18 + s22 + s26 + s30 + s34 + s38 + s42 + s46
20 | p01 + p03 + x01 + x03 + t01 + t05 + t09 + s01 + s05 + s09 + s13 + s17 + s21 + s25 + s29 + s33 + s37 + s41 + s45 + s49
21 | p00 + p01 + p02 + p03 + x00 + x01 + x02 + x03 + s00 + s04 + s08 + s12 + s16 + s20 + s24 + s28 + s32 + s36 + s40 + s44 + s48

****************************************************************************************************

This is already a lot more interesting. See how every equation only depends on every fourth variable? This is suspicious as hell, a definite sign we’re going in the right direction. It’s still not obvious at all how to exploit the pattern, however, so let’s dig deeper.

Derivation tables

Here, we shall note y = round(x); z = round(y); u = round(z); v = round(u). The t in y ^ [t] = [ x0, x0 ^ x1, x0 ^ x1 ^ x2, x0 ^ x1 ^ x2 ^ x3 ] shall be renamed t[y], so that we recall to which level it is tied.

Trying to keep the t from appearing for as long as possible (given the above patterns, it’s going to be 4 levels), we obtain the following equation:

v3 ^ t[v] ← u0 ^ u1 ^ u2 ^ u3
          ← z1 ^ z3
          ← y2 ^ y3
          ← x3

(tl;dr t will appear as soon as you xor an odd amount of variables). Knowing a bit gives you a relation FIVE levels down?? Very interesting, very suspicious.

We can in fact derive similar equations for the other bits, eg.

y0 ^ y[t] ← x0      and      z1 ^ z[t] ← y0 ^ y1
                                       ← x1

Less interesting, still as suspicious.

Suspicious

However, it is clear that handling these equations by hand is going to be absolutely awful. We need something much more compact and visual. Something in which we can easily display what is known and deduced. This is the point where I introduce a new notation, dubbed derivation table.

It is something very simple:

__| 0  1  2  3  t
x |          .        
y |       ^  ^  b
z |    ^     ^
u | ^  ^  ^  ^
v |          *  .

 .: given
 ^: deduced up to xor
 *: deduced
 b: deduced if full byte (warning: not in the same table)

Being given x3 and t[v], we know respectively y2 ^ y3, z1 ^ z3, u0 ^ u1 ^ u2 ^ u3 and v3. Here are the derivation tables for x0, x1 and x2:

__| 0 1 2 3 t     __| 0 1 2 3 t     __| 0 1 2 3 t
x | .             x |   .           x |     .        
y | ^       ^     y | ^ ^           y |   ^ ^    
                  z |   ^     ^     z | ^   ^    
                                    u | ^ ^ ^   ^
                                    v |     ^   ^

The point is that we can play around and combine them, for different patterns of given (read: bruteforced) information. For instance, assuming we bruteforce the bit (3) at every level, we see that after some time we get every bit at every level in the derivation table. If we bruteforce the full byte we get more information in another table

__| 0  1  2  3  t           __| 0  1  2  3  t
x |          .              … |       .  .
y |       *  .  b    (b)    … |    ^  ^  ^   
z |    *  *  .  b   ~~~~~>  x | ^  ^  ^  ^   
u | *  *  *  .  b           y | ^  ^  ^  *  @
v | *  *  *  .  b           z |       *  *  @

 @: external information

After playing around a bit, one may arrive at the following by repeatingly applying the table for bit (2) back and forth:

__| 0 1 2 3 t           __| 0 1 2 3 t           __| 0 1 2 3 t           __| 0 1 2 3 t
… |                     … |                     … |                     … |     *    
x |                     x |     *               x |     *               x |     *    
y |     .        ==>    y |     .        ==>    y |   * .        ==>    y |   * .        ==>    ...
z |   * .               z |   * .               z | * * .               z | * * .    
u | * * .               u | * * .               u | * * .   *           u | * * .   *
v | * * .   *           v | * * .   *           v | * * .   *           v | * * .   *

HYPE?

Bruteforcing only four bits worth of information gives you the entire table?! Even though t is supposed to depend on external information?? This is very suspicious, but not for the same reasons as before.

Too much information

Even so, it is difficult to see where this has gone wrong. After having rechecked this derivation 50 times I go to sleep. Implementing it in the morning is sure to tell where the equations fail, after all.

Zzz.

Woops. Turns out the table for bit (2) was wrong. The underlying derivation

v2 ^ t[v] ^ t[u] ← u0 ^ u1 ^ u2 ^ t[u]
                 ← z0 ^ z2
                 ← y1 ^ y2
                 ← x2

shows in fact that x2 does not give v2 ^ t[v] but v2 ^ t[v] ^ t[u]. The corrected table looks like this:

__| 0  1  2  3  t
x |       .      
y |    ^  ^
z | ^     ^
u | ^  ^  ^     ^
v |       ^     ↑

 ↑: deduced up to xor (with dependence on upward bit)

Hence the previous magical derivation fails when we try to go up a level.

The solve

After playing around a while more, it becomes clear that derivation tables are not going to work: even if we were to know x0,x1,x2,x3 and everything before that, it is impossible to actually deduce anything five levels down :c

It is also very frustrating that x3 actually gives more information, but we can’t use it because it ends up elsewhere. However, not all is lost!

Four tables

Following the conventions of _round, we call a = x3[0]d = x3[3]. Then, if we are very determined to use the external information, we might start with something like this:

         a           b           c           d     
   | 0 1 2 3 t | 0 1 2 3 t | 0 1 2 3 t | 0 1 2 3 t |
────────────────────────────────────────────────────
 x |       .   |           |           |           |
 y |         b |           |           |       . @ |
 z |           |           |       . @ |         b |
 u |           |       . @ |         b |           |
 v |         @ |         b |           |           |

And then… a miracle happens!

         a           b           c           d     
 r | 0 1 2 3 t | 0 1 2 3 t | 0 1 2 3 t | 0 1 2 3 t |
────────────────────────────────────────────────────
 0 |       .   |     ^ ^   |   ^   ^   | ^ ^ ^ ^   |
 1 |     ^ ^ b |   ^   ^   | ^ ^ ^ ^   |       . @ |
 2 |   ^   ^   | ^ ^ ^ ^   |       . @ |     ^ ^ b |
 3 | ^ ^ ^ ^   |       . @ |     ^ ^ b |   ^   ^   |
 4 |       * @ |     ^ ^ b |   ^   ^   | ^ ^ ^ ^   |
 5 |     ^ ^ b |   ^   ^   | ^ ^ ^ ^   |       * @ |
 6 |   ^   ^   | ^ ^ ^ ^   |       * @ |     ^ ^ b |
 7 | ^ ^ ^ ^   |       * @ |     ^ ^ b |   ^   ^   |
 8 |       * @ |     ^ ^ b |   ^   ^   | ^ ^ ^ ^   |
 9 |     ^ ^ b |   ^   ^   | ^ ^ ^ ^   |       * @ |
10 |   ^   ^   | ^ ^ ^ ^   |       * @ |     ^ ^ b |

At the mere price of 4 bytes, we get a self propagating pattern! (this may very well be the long sought after invariants :o)

But the good news do not stop there… Note that we don’t need to bruteforce things across 4 different rows. After all, given x3a, knowing x0d ^ x1d ^ x2d ^ x3d or y3d is pretty much the same. This means that the information we brute can all lie at a single level, assuming a given shape. There are four different shapes, tied to the phase β€” the offset from the first row, modulo 4.

But the good news do not stop there… Another miracle happens!

         a           b           c           d     
 r p | 0 1 2 3 t | 0 1 2 3 t | 0 1 2 3 t | 0 1 2 3 t |
──────────────────────────────────────────────────────
 0 0 |       .   |     ^ ^   |   ^   ^   | ^ ^ ^ ^   | ← k0 ; c
 1 1 |     ^ ^ b |   ^   ^   | ^ ^ ^ ^   |       . @ | ← k1
 2 2 |   ^   ^   | ^ ^ ^ ^   |       . @ |     ^ ^ b | ← k2
 3 3 | ^ ^ ^ ^   |       . @ |     ^ ^ b |   ^   ^   |
 4 0 |       * @ |     ^ ^ b |   ^   ^   | ^ ^ ^ ^   |
 5 1 |     ^ ^ b |   ^   ^   | ^ ^ ^ ^   |       * @ | ← round5(c+k0)
 6 2 |   ^   ^   | ^ ^ ^ ^   |       * @ |     ^ ^ b |
 7 3 | ^ ^ ^ ^   |       * @ |       ^ b |   ^   ^   |
 8 0 |       * @ |     ^ ^ b |     ^ ^   | ^ ^ ^ ^   |
 9 1 |     ^ ^ b |   ^   ^   |   ^   ^   |       * @ |
10 2 |   ^   ^   | ^ ^ ^ ^   | ^ ^ ^ * @ |     ^ ^ b | ← round5(c+k1)

Whenever we need to xor the current ciphertext c with a round key, it turns out they are both in the same phase (because 5 = 1 mod 4). Funny coincidence, isn’t it? Almost as if the challenge was designed for that.

With all of this, the solve becomes fairly straightforward:

  • For every phase:
    • brute the phased key (2^32 iterations);
    • encrypt the phased plaintext;
    • compare with the phased ciphertext. The python script is fairly simple (cf solve.py). However, our friend tqdm predicts ONE HUNDRED EIGHTY SIX hours for a single bruteforce.

Time to reimplement everything in c (cf solve.c) and let it run during dinner. A plate of spaghettis later, we have the phased keys.

Bruting phased key (phase 0)
a = 0 | a = 1 | a = 2 | a = 3 | a = 4 | a = 5 | a = 6 | a = 7 | a = 8 | a = 9 | a = 10 | a = 11 | a = 12 | a = 13 | a = 14 | a = 15 | a = 16 | a = 17 | a = 18 | a = 19 | a = 20 | a = 21 | a = 22 | a = 23 | a = 24 | a = 25 | a = 26 | a = 27 | a = 28 | a = 29 | a = 30 | a = 31 | a = 32 | a = 33 | a = 34 | a = 35 | a = 36 | a = 37 | a = 38 | a = 39 | a = 40 | a = 41 | a = 42 | a = 43 | a = 44 | a = 45 | a = 46 | a = 47 | a = 48 | a = 49 | a = 50 | a = 51 | a = 52 | a = 53 | a = 54 | a = 55 | a = 56 | a = 57 | a = 58 | a = 59 | a = 60 | a = 61 | a = 62 | a = 63 | a = 64 | a = 65 | a = 66 | a = 67 | a = 68 | a = 69 | a = 70 | a = 71 | a = 72 | a = 73 | a = 74 | a = 75 | a = 76 | DING DING DING!!!
partk = [76, 127, 191, 108]
Bruting phased key (phase 1)
a = 0 | a = 1 | a = 2 | a = 3 | a = 4 | a = 5 | a = 6 | a = 7 | a = 8 | a = 9 | a = 10 | a = 11 | a = 12 | a = 13 | a = 14 | a = 15 | a = 16 | a = 17 | a = 18 | a = 19 | a = 20 | a = 21 | a = 22 | a = 23 | a = 24 | a = 25 | a = 26 | a = 27 | a = 28 | a = 29 | a = 30 | a = 31 | a = 32 | a = 33 | a = 34 | a = 35 | a = 36 | a = 37 | a = 38 | a = 39 | a = 40 | a = 41 | a = 42 | a = 43 | a = 44 | a = 45 | a = 46 | a = 47 | a = 48 | a = 49 | a = 50 | a = 51 | a = 52 | a = 53 | a = 54 | a = 55 | a = 56 | a = 57 | a = 58 | a = 59 | a = 60 | a = 61 | a = 62 | a = 63 | a = 64 | a = 65 | a = 66 | a = 67 | a = 68 | a = 69 | a = 70 | a = 71 | DING DING DING!!!
partk = [71, 142, 75, 90]
Bruting phased key (phase 2)
a = 0 | a = 1 | a = 2 | a = 3 | a = 4 | a = 5 | a = 6 | a = 7 | a = 8 | a = 9 | a = 10 | a = 11 | a = 12 | a = 13 | a = 14 | a = 15 | a = 16 | a = 17 | a = 18 | a = 19 | a = 20 | a = 21 | a = 22 | a = 23 | a = 24 | a = 25 | a = 26 | a = 27 | a = 28 | a = 29 | a = 30 | a = 31 | a = 32 | a = 33 | a = 34 | a = 35 | a = 36 | a = 37 | a = 38 | a = 39 | a = 40 | a = 41 | a = 42 | a = 43 | a = 44 | a = 45 | a = 46 | a = 47 | a = 48 | a = 49 | a = 50 | a = 51 | a = 52 | a = 53 | a = 54 | a = 55 | a = 56 | a = 57 | a = 58 | a = 59 | a = 60 | a = 61 | a = 62 | a = 63 | a = 64 | a = 65 | a = 66 | a = 67 | a = 68 | a = 69 | a = 70 | a = 71 | a = 72 | a = 73 | a = 74 | a = 75 | a = 76 | a = 77 | a = 78 | a = 79 | a = 80 | a = 81 | a = 82 | a = 83 | a = 84 | a = 85 | a = 86 | a = 87 | a = 88 | a = 89 | a = 90 | a = 91 | a = 92 | a = 93 | a = 94 | a = 95 | a = 96 | DING DING DING!!!
partk = [96, 185, 153, 233]
Bruting phased key (phase 3)
a = 0 | a = 1 | a = 2 | a = 3 | a = 4 | a = 5 | a = 6 | a = 7 | DING DING DING!!!
partk = [7, 64, 187, 9]
Segmentation fault

The only thing left is to recombine the key from its phases and get the flag! (cf solve_final.py)

solve.*

Appendices

solve.sage

Backup from: https://github.com/AZ-0/Writeups/blob/main/2024/fcsc/crypto-tight-schedule/solve.sage.

def symxor(x, y):
    return [a+b for a,b in zip(x,y)]

def symround(x, t):
    y = [None]*4
    y[0] = x[0] + t
    y[1] = x[1] + y[0]
    y[2] = x[2] + y[1]
    y[3] = x[3] + y[2]
    return y

def symexpand(k):
    rk = [k]
    for i in range(10):
        rk.append(symround(rk[-1], ts[i]))
    return rk

def symenc(p, rk):
    c = p[:]
    for i, sk in enumerate(rk[:-1]):
        c = symxor(c, sk)
        for j in range(5):
            c = symround(c, ss[5*i+j])
    return symxor(c, rk[-1])

def vars(*args):
    names = args[::2]
    ns = args[1::2]
    return ','.join(
        ','.join(
            f"{name}{i:0>2}"
            for i in range(n)
        )
        for name, n in zip(names, ns)
    )


p  = bytes.fromhex("0dfa4c6052fb87ef0a8f03f705dd5101")
c  = bytes.fromhex("d4ed19e0694101b6b151e11c2db973bf")
iv = bytes.fromhex("cd31cb6e6ded184efbb9a398e31ffdbb")
flag_enc = bytes.fromhex("653ec0cdd7e3a98c33414be8ef07c583d87b876afbff1d960f8f43b5a338e9ff96d87da4406ebe39a439dab3a84697d40c24557cd1ea6f433053451d20ce1fbf191270f4b8cc7891f8779eb615d35c9f")

R = GF(2)[vars('p',4, 'x',4, 't',10, 's',50)]
ps = R.gens()[:4]
xs = R.gens()[-64:-60]
ts = R.gens()[-60:-50]
ss = R.gens()[-50:]

rk = symexpand(xs)
symc = symenc(ps, rk)
for ci in symc:
    print('-'*50)
    print(ci)


print('*'*100)

M = matrix(ZZ, [
    [ci.monomial_coefficient(gi) for gi in R.gens()]
    for ci in symc
]).LLL().change_ring(GF(2))
print(M)

g = vector(R, R.gens())
V = GF(2)^R.ngens()

for v in M.row_space():
    w = sum(map(int, v))
    if w < 25:
        print(w, '|', g * V(v))

print('\n'+'*'*100+'\n')

for i, ki in enumerate(rk):
    print('-'*50)

    M = matrix(QQ, [
        [ci.monomial_coefficient(gi) for gi in R.gens()]
        for ci in ki
    ]).augment(matrix.identity(QQ, 4)).dense_matrix()

    D = matrix.diagonal(QQ, [1000]*R.ngens() + [1]*4)
    M = ((M*D).LLL()/D).change_ring(GF(2))

    for v in M:
        w = sum(map(int, v[:-4]))
        eq = ' + '.join(f'k{4-i}' if v[-i] else ' 0' for i in range(1,5))
        print(w, '|', eq, '=', g * V(v[:-4]))

solve.py

Backup from: https://github.com/AZ-0/Writeups/blob/main/2024/fcsc/crypto-tight-schedule/solve.py.

# __| 0  1  2  3  t
# … |
# … |       …
# … |    …  !
# x | …  !  .
# y | !  *  .     …
# z | *  *  .     !
# u | *  *  .     *   # a^b^c^d; a^d; a^b; c
# v | ^  ^  ^     ^
#
#  !: deduced from inverse table 2), and then after.
#  dérivation suspecte (problème de théorie de l'information ?)
#  Γ  implem pour tester

from tight_schedule import os, TightSchedule as TS
S = TS.S
RCON = TS.RCON

#            a           b           c           d
#  k p | 0 1 2 3 t | 0 1 2 3 t | 0 1 2 3 t | 0 1 2 3 t |
# ──────────────────────────────────────────────────────
#  0 0 |       .   |     ^ ^   |   ^   ^   | ^ ^ ^ ^   | ← k0
#  1 1 |     ^ ^ b |   ^   ^   | ^ ^ ^ ^   |       . @ | ← k1
#  2 2 |   ^   ^   | ^ ^ ^ ^   |       . @ |     ^ ^ b | ← k2
#  3 3 | ^ ^ ^ ^   |       . @ |     ^ ^ b |           |
#  4 0 |       * @ |     ^ ^ b |   ^   ^   |           |
#  5 1 |     ^ ^ b |   ^   ^   | ^ ^ ^ ^   |       * @ | ← round 5 of p+k0
#  6 2 |           |           |       * @ |         b |
#  7 3 |           |       * @ |         b |           |
#  8 0 |       * @ |     ^ ^ b |           |           |
#  9 1 |     ^ ^ b |   ^   ^   |           |       * @ |
# 10 2 |   ^   ^   | ^ ^ ^ ^   |       * @ |     ^ ^ b | ← round 5 of p+k1

def xor(x, y):
    return [x^y for x,y in zip(x, y)]

def blocks(x, n=4):
    return [x[i:i+n] for i in range(0, len(x), n)]

def partial(k, phase):
    bk = blocks(k)
    a, b, c, d = 0, 1, 2, 3 # for readability

    if phase == 0:
        return [
            bk[3][a],
            bk[2][b] ^ bk[3][b],
            bk[1][c] ^ bk[3][c],
            bk[0][d] ^ bk[1][d] ^ bk[2][d] ^ bk[3][d],
        ]

    if phase == 1:
        return [
            bk[2][a] ^ bk[3][a],
            bk[1][b] ^ bk[3][b],
            bk[0][c] ^ bk[1][c] ^ bk[2][c] ^ bk[3][c],
            bk[3][d],
        ]

    if phase == 2:
        return [
            bk[1][a] ^ bk[3][a],
            bk[0][b] ^ bk[1][b] ^ bk[2][b] ^ bk[3][b],
            bk[3][c],
            bk[2][d] ^ bk[3][d],
        ]

    if phase == 3:
        return [
            bk[0][a] ^ bk[1][a] ^ bk[2][a] ^ bk[3][a],
            bk[3][b],
            bk[2][c] ^ bk[3][c],
            bk[1][d] ^ bk[3][d],
        ]

    raise ValueError('phase should be one of 0,1,2,3')

def partial_round(partx, phase, cst=0):
    a, b, c, d = partx
    if phase == 0:
        return [a, b, c, d ^ S[a]]
    if phase == 1:
        return [a, b, c ^ S[d], d]
    if phase == 2:
        return [a, b ^ S[c], c, d]
    if phase == 3:
        return [a ^ S[b] ^ cst, b, c, d]
    raise ValueError('phase should be one of 0,1,2,3')

def partial_expand(partk, phase):
    partrk = [partk]
    for _ in range(10):
        partrk.append(partial_round(partrk[-1], phase, RCON[len(partrk)]))
        phase = (phase + 1) % 4
    return partrk

def partial_encrypt(partp, partrk, phase):
    c = partp
    for sk in partrk[:-1]:
        c = xor(c, sk)
        for _ in range(5):
            c = partial_round(c, phase)
            phase = (phase + 1) % 4
    return xor(c, partrk[-1]), phase

k = os.urandom(16)
p = os.urandom(16)
P = TS(k)
c = P.encrypt(p)

partk  = partial(k, 0)
partrk = partial_expand(partk, 0)
partp  = partial(p, 0)
partc, phase = partial_encrypt(partp, partrk, 0)

for i, (sk, partsk) in enumerate(zip(P.rk, partrk)):
    assert partsk == partial(sk, i%4), f'failed #{i}'

assert partc == partial(c, phase)
print('Assertions passed!')

p = bytes.fromhex("0dfa4c6052fb87ef0a8f03f705dd5101")
c = bytes.fromhex("d4ed19e0694101b6b151e11c2db973bf")
iv = bytes.fromhex("cd31cb6e6ded184efbb9a398e31ffdbb")
flag_enc = bytes.fromhex("653ec0cdd7e3a98c33414be8ef07c583d87b876afbff1d960f8f43b5a338e9ff96d87da4406ebe39a439dab3a84697d40c24557cd1ea6f433053451d20ce1fbf191270f4b8cc7891f8779eb615d35c9f")

from tqdm import trange

for phase in range(4):
    end_phase = (phase + 50) % 4
    partp = partial(p, phase)
    realc = partial(c, end_phase)

    print(f'Bruting partial key relations (phase {phase})')
    for partk in trange(2**32):
        partk  = [partk & 0xFF, (partk >> 8) & 0xFF, (partk >> 16) & 0xFF, (partk >> 24) & 0xFF]
        partrk = partial_expand(partk, phase)
        partc, _ = partial_encrypt(partp, partrk, phase)
        if partc == realc:
            print('DING DING DING!!!')
            print('partk =', partk)
            with open('partk.txt', 'a') as file:
                file.write(f'phase{phase} = {partrk}')
            break
    else:
        print('THIS. IS. A BUG! >:c')

print('DONE!')

solve.c

Backup from: https://github.com/AZ-0/Writeups/blob/main/2024/fcsc/crypto-tight-schedule/solve.c.

#include <stdio.h>
#include <string.h>
#include <stdlib.h>

typedef unsigned char byte;

byte S[] = {
    0x63,0x7C,0x77,0x7B,0xF2,0x6B,0x6F,0xC5,0x30,0x01,0x67,0x2B,0xFE,0xD7,0xAB,0x76,
    0xCA,0x82,0xC9,0x7D,0xFA,0x59,0x47,0xF0,0xAD,0xD4,0xA2,0xAF,0x9C,0xA4,0x72,0xC0,
    0xB7,0xFD,0x93,0x26,0x36,0x3F,0xF7,0xCC,0x34,0xA5,0xE5,0xF1,0x71,0xD8,0x31,0x15,
    0x04,0xC7,0x23,0xC3,0x18,0x96,0x05,0x9A,0x07,0x12,0x80,0xE2,0xEB,0x27,0xB2,0x75,
    0x09,0x83,0x2C,0x1A,0x1B,0x6E,0x5A,0xA0,0x52,0x3B,0xD6,0xB3,0x29,0xE3,0x2F,0x84,
    0x53,0xD1,0x00,0xED,0x20,0xFC,0xB1,0x5B,0x6A,0xCB,0xBE,0x39,0x4A,0x4C,0x58,0xCF,
    0xD0,0xEF,0xAA,0xFB,0x43,0x4D,0x33,0x85,0x45,0xF9,0x02,0x7F,0x50,0x3C,0x9F,0xA8,
    0x51,0xA3,0x40,0x8F,0x92,0x9D,0x38,0xF5,0xBC,0xB6,0xDA,0x21,0x10,0xFF,0xF3,0xD2,
    0xCD,0x0C,0x13,0xEC,0x5F,0x97,0x44,0x17,0xC4,0xA7,0x7E,0x3D,0x64,0x5D,0x19,0x73,
    0x60,0x81,0x4F,0xDC,0x22,0x2A,0x90,0x88,0x46,0xEE,0xB8,0x14,0xDE,0x5E,0x0B,0xDB,
    0xE0,0x32,0x3A,0x0A,0x49,0x06,0x24,0x5C,0xC2,0xD3,0xAC,0x62,0x91,0x95,0xE4,0x79,
    0xE7,0xC8,0x37,0x6D,0x8D,0xD5,0x4E,0xA9,0x6C,0x56,0xF4,0xEA,0x65,0x7A,0xAE,0x08,
    0xBA,0x78,0x25,0x2E,0x1C,0xA6,0xB4,0xC6,0xE8,0xDD,0x74,0x1F,0x4B,0xBD,0x8B,0x8A,
    0x70,0x3E,0xB5,0x66,0x48,0x03,0xF6,0x0E,0x61,0x35,0x57,0xB9,0x86,0xC1,0x1D,0x9E,
    0xE1,0xF8,0x98,0x11,0x69,0xD9,0x8E,0x94,0x9B,0x1E,0x87,0xE9,0xCE,0x55,0x28,0xDF,
    0x8C,0xA1,0x89,0x0D,0xBF,0xE6,0x42,0x68,0x41,0x99,0x2D,0x0F,0xB0,0x54,0xBB,0x16
};
byte RCON[] = { 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36 };
byte p[] = { 0x0d,0xfa,0x4c,0x60,0x52,0xfb,0x87,0xef,0x0a,0x8f,0x03,0xf7,0x05,0xdd,0x51,0x01 };
byte c[] = { 0xd4,0xed,0x19,0xe0,0x69,0x41,0x01,0xb6,0xb1,0x51,0xe1,0x1c,0x2d,0xb9,0x73,0xbf };


void xor4(byte* x, byte* y) {
    for (int i = 0; i < 4; ++i)
        x[i] ^= y[i];
}


void partial(byte* partk, byte* bk, int phase) {
    int a = 0, b = 1, c = 2, d = 3;

    switch (phase)
    {
        case 0:
            partk[0] = bk[3*4 + a];
            partk[1] = bk[2*4 + b] ^ bk[3*4 + b];
            partk[2] = bk[1*4 + c] ^ bk[3*4 + c];
            partk[3] = bk[0*4 + d] ^ bk[1*4 + d] ^ bk[2*4 + d] ^ bk[3*4 + d];
            return;

        case 1:
            partk[0] = bk[2*4 + a] ^ bk[3*4 + a];
            partk[1] = bk[1*4 + b] ^ bk[3*4 + b];
            partk[2] = bk[0*4 + c] ^ bk[1*4 + c] ^ bk[2*4 + c] ^ bk[3*4 + c];
            partk[3] = bk[3*4 + d];
            return;

        case 2:
            partk[0] = bk[1*4 + a] ^ bk[3*4 + a];
            partk[1] = bk[0*4 + b] ^ bk[1*4 + b] ^ bk[2*4 + b] ^ bk[3*4 + b];
            partk[2] = bk[3*4 + c];
            partk[3] = bk[2*4 + d] ^ bk[3*4 + d];
            return;

        case 3:
            partk[0] = bk[0*4 + a] ^ bk[1*4 + a] ^ bk[2*4 + a] ^ bk[3*4 + a];
            partk[1] = bk[3*4 + b];
            partk[2] = bk[2*4 + c] ^ bk[3*4 + c];
            partk[3] = bk[1*4 + d] ^ bk[3*4 + d];
            return;

        default:
            printf("ERROR: In partial, phase was %d.\n", phase);
            exit(1);
    }
}

void partial_round(byte* partx, int phase, byte cst) {
    unsigned char a = partx[0], b = partx[1], c = partx[2], d = partx[3];
    switch (phase)
    {
        case 0:
            partx[3] ^= S[a];
            return;

        case 1:
            partx[2] ^= S[d];
            return;

        case 2:
            partx[1] ^= S[c];
            return;

        case 3:
            partx[0] ^= S[b] ^ cst;
            return;

        default:
            printf("ERROR: In partial_round, phase was %d.\n", phase);
            exit(1);
    }
}

void partial_expand(byte** partrk, byte* partk, int phase) {
    partrk[0] = partk;
    for (int i = 1; i < 11; ++i) {
        memcpy(partrk[i], partrk[i-1], 4*sizeof(byte));
        partial_round(partrk[i], phase, RCON[i]);
        phase = (phase + 1) & 0x3;
    }
}

void partial_encrypt(byte* c, byte** partrk, int phase) {
    for (int i = 0; i < 10; ++i) {
        xor4(c, partrk[i]);
        for (int _ = 0; _ < 5; ++_) {
            partial_round(c, phase, 0);
            phase = (phase + 1) & 3;
        }
    }
    xor4(c, partrk[10]);
}

int main() {
    byte partp[4];
    byte partc[4];
    byte realc[4];
    byte partk[4];
    byte *partrk[11];
    for (int i = 0; i < 11; ++i)
        partrk[i] = malloc(4*sizeof(byte));

    setbuf(stdout, NULL);

    for (int phase = 0; phase < 4; ++phase) {
        int end_phase = (phase + 50) & 3;

        partial(partp, p, phase);
        partial(realc, c, end_phase);

        printf("Bruting phased key (phase %d)\n", phase);
        int found = 0;

        for (int a = 0; a <= 0xFF; ++a) {
            printf("a = %d | ", a);
        for (int b = 0; b <= 0xFF; ++b) {
            printf("%.2x\b\b", b);
        for (int c = 0; c <= 0xFF; ++c) {
        for (int d = 0; d <= 0xFF; ++d) {
            partk[0] = a;
            partk[1] = b;
            partk[2] = c;
            partk[3] = d;

            partial_expand(partrk, partk, phase);
            memcpy(partc, partp, 4*sizeof(byte));
            partial_encrypt(partc, partrk, phase);

            if (memcmp(partc, realc, 4*sizeof(byte)) == 0) {
                printf("DING DING DING!!!\n");
                printf("partk = [%d, %d, %d, %d]\n", a, b, c, d);
                found = 1;
                goto out;
            }
        }}}}

        out:
        if (found == 0)
            printf("Didn't find phased key.\n");
    }

    for (int i = 0; i < 11; ++i)
        free(partrk[i]);
    printf("DONE!\n");
}

solve_final.py

Backup from: https://github.com/AZ-0/Writeups/blob/main/2024/fcsc/crypto-tight-schedule/solve_final.py.

# __| 0  1  2  3  t
# … |
# … |       …
# … |    …  !
# x | …  !  .
# y | !  *  .     …
# z | *  *  .     !
# u | *  *  .     *   # a^b^c^d; a^d; a^b; c
# v | ^  ^  ^     ^
#
#  !: deduced from inverse table 2), and then after.
#  dérivation suspecte (problème de théorie de l'information ?)
#  Γ  implem pour tester

from tight_schedule import os, TightSchedule as TS
S = TS.S
RCON = TS.RCON

#            a           b           c           d
#  k p | 0 1 2 3 t | 0 1 2 3 t | 0 1 2 3 t | 0 1 2 3 t |
# ──────────────────────────────────────────────────────
#  0 0 |       .   |     ^ ^   |   ^   ^   | ^ ^ ^ ^   | ← k0
#  1 1 |     ^ ^ b |   ^   ^   | ^ ^ ^ ^   |       . @ | ← k1
#  2 2 |   ^   ^   | ^ ^ ^ ^   |       . @ |     ^ ^ b | ← k2
#  3 3 | ^ ^ ^ ^   |       . @ |     ^ ^ b |           |
#  4 0 |       * @ |     ^ ^ b |   ^   ^   |           |
#  5 1 |     ^ ^ b |   ^   ^   | ^ ^ ^ ^   |       * @ | ← round 5 of p+k0
#  6 2 |           |           |       * @ |         b |
#  7 3 |           |       * @ |         b |           |
#  8 0 |       * @ |     ^ ^ b |           |           |
#  9 1 |     ^ ^ b |   ^   ^   |           |       * @ |
# 10 2 |   ^   ^   | ^ ^ ^ ^   |       * @ |     ^ ^ b | ← round 5 of p+k1

def xor(x, y):
    return [x^y for x,y in zip(x, y)]

def blocks(x, n=4):
    return [x[i:i+n] for i in range(0, len(x), n)]

def partial(k, phase):
    bk = blocks(k)
    a, b, c, d = 0, 1, 2, 3 # for readability

    if phase == 0:
        return [
            bk[3][a],
            bk[2][b] ^ bk[3][b],
            bk[1][c] ^ bk[3][c],
            bk[0][d] ^ bk[1][d] ^ bk[2][d] ^ bk[3][d],
        ]

    if phase == 1:
        return [
            bk[2][a] ^ bk[3][a],
            bk[1][b] ^ bk[3][b],
            bk[0][c] ^ bk[1][c] ^ bk[2][c] ^ bk[3][c],
            bk[3][d],
        ]

    if phase == 2:
        return [
            bk[1][a] ^ bk[3][a],
            bk[0][b] ^ bk[1][b] ^ bk[2][b] ^ bk[3][b],
            bk[3][c],
            bk[2][d] ^ bk[3][d],
        ]

    if phase == 3:
        return [
            bk[0][a] ^ bk[1][a] ^ bk[2][a] ^ bk[3][a],
            bk[3][b],
            bk[2][c] ^ bk[3][c],
            bk[1][d] ^ bk[3][d],
        ]

    raise ValueError('phase should be one of 0,1,2,3')

def partial_round(partx, phase, cst=0):
    a, b, c, d = partx
    if phase == 0:
        return [a, b, c, d ^ S[a]]
    if phase == 1:
        return [a, b, c ^ S[d], d]
    if phase == 2:
        return [a, b ^ S[c], c, d]
    if phase == 3:
        return [a ^ S[b] ^ cst, b, c, d]
    raise ValueError('phase should be one of 0,1,2,3')

def partial_expand(partk, phase):
    partrk = [partk]
    for _ in range(10):
        partrk.append(partial_round(partrk[-1], phase, RCON[len(partrk)]))
        phase = (phase + 1) % 4
    return partrk

def partial_encrypt(partp, partrk, phase):
    c = partp
    for sk in partrk[:-1]:
        c = xor(c, sk)
        for _ in range(5):
            c = partial_round(c, phase)
            phase = (phase + 1) % 4
    return xor(c, partrk[-1]), phase

k = os.urandom(16)
p = os.urandom(16)
P = TS(k)
c = P.encrypt(p)

partk  = partial(k, 0)
partrk = partial_expand(partk, 0)
partp  = partial(p, 0)
partc, phase = partial_encrypt(partp, partrk, 0)

for i, (sk, partsk) in enumerate(zip(P.rk, partrk)):
    assert partsk == partial(sk, i%4), f'failed #{i}'

assert partc == partial(c, phase)
print('Assertions passed!')

p = bytes.fromhex("0dfa4c6052fb87ef0a8f03f705dd5101")
c = bytes.fromhex("d4ed19e0694101b6b151e11c2db973bf")
iv = bytes.fromhex("cd31cb6e6ded184efbb9a398e31ffdbb")
flag_enc = bytes.fromhex("653ec0cdd7e3a98c33414be8ef07c583d87b876afbff1d960f8f43b5a338e9ff96d87da4406ebe39a439dab3a84697d40c24557cd1ea6f433053451d20ce1fbf191270f4b8cc7891f8779eb615d35c9f")

# Bruting phased key (phase 0)
# a = 0 | a = 1 | a = 2 | a = 3 | a = 4 | a = 5 | a = 6 | a = 7 | a = 8 | a = 9 | a = 10 | a = 11 | a = 12 | a = 13 | a = 14 | a = 15 | a = 16 | a = 17 | a = 18 | a = 19 | a = 20 | a = 21 | a = 22 | a = 23 | a = 24 | a = 25 | a = 26 | a = 27 | a = 28 | a = 29 | a = 30 | a = 31 | a = 32 | a = 33 | a = 34 | a = 35 | a = 36 | a = 37 | a = 38 | a = 39 | a = 40 | a = 41 | a = 42 | a = 43 | a = 44 | a = 45 | a = 46 | a = 47 | a = 48 | a = 49 | a = 50 | a = 51 | a = 52 | a = 53 | a = 54 | a = 55 | a = 56 | a = 57 | a = 58 | a = 59 | a = 60 | a = 61 | a = 62 | a = 63 | a = 64 | a = 65 | a = 66 | a = 67 | a = 68 | a = 69 | a = 70 | a = 71 | a = 72 | a = 73 | a = 74 | a = 75 | a = 76 | DING DING DING!!!
# partk = [76, 127, 191, 108]
# Bruting phased key (phase 1)
# a = 0 | a = 1 | a = 2 | a = 3 | a = 4 | a = 5 | a = 6 | a = 7 | a = 8 | a = 9 | a = 10 | a = 11 | a = 12 | a = 13 | a = 14 | a = 15 | a = 16 | a = 17 | a = 18 | a = 19 | a = 20 | a = 21 | a = 22 | a = 23 | a = 24 | a = 25 | a = 26 | a = 27 | a = 28 | a = 29 | a = 30 | a = 31 | a = 32 | a = 33 | a = 34 | a = 35 | a = 36 | a = 37 | a = 38 | a = 39 | a = 40 | a = 41 | a = 42 | a = 43 | a = 44 | a = 45 | a = 46 | a = 47 | a = 48 | a = 49 | a = 50 | a = 51 | a = 52 | a = 53 | a = 54 | a = 55 | a = 56 | a = 57 | a = 58 | a = 59 | a = 60 | a = 61 | a = 62 | a = 63 | a = 64 | a = 65 | a = 66 | a = 67 | a = 68 | a = 69 | a = 70 | a = 71 | DING DING DING!!!
# partk = [71, 142, 75, 90]
# Bruting phased key (phase 2)
# a = 0 | a = 1 | a = 2 | a = 3 | a = 4 | a = 5 | a = 6 | a = 7 | a = 8 | a = 9 | a = 10 | a = 11 | a = 12 | a = 13 | a = 14 | a = 15 | a = 16 | a = 17 | a = 18 | a = 19 | a = 20 | a = 21 | a = 22 | a = 23 | a = 24 | a = 25 | a = 26 | a = 27 | a = 28 | a = 29 | a = 30 | a = 31 | a = 32 | a = 33 | a = 34 | a = 35 | a = 36 | a = 37 | a = 38 | a = 39 | a = 40 | a = 41 | a = 42 | a = 43 | a = 44 | a = 45 | a = 46 | a = 47 | a = 48 | a = 49 | a = 50 | a = 51 | a = 52 | a = 53 | a = 54 | a = 55 | a = 56 | a = 57 | a = 58 | a = 59 | a = 60 | a = 61 | a = 62 | a = 63 | a = 64 | a = 65 | a = 66 | a = 67 | a = 68 | a = 69 | a = 70 | a = 71 | a = 72 | a = 73 | a = 74 | a = 75 | a = 76 | a = 77 | a = 78 | a = 79 | a = 80 | a = 81 | a = 82 | a = 83 | a = 84 | a = 85 | a = 86 | a = 87 | a = 88 | a = 89 | a = 90 | a = 91 | a = 92 | a = 93 | a = 94 | a = 95 | a = 96 | DING DING DING!!!
# partk = [96, 185, 153, 233]
# Bruting phased key (phase 3)
# a = 0 | a = 1 | a = 2 | a = 3 | a = 4 | a = 5 | a = 6 | a = 7 | DING DING DING!!!
# partk = [7, 64, 187, 9]
# Segmentation fault

partks = [
    [76, 127, 191, 108],
    [71, 142, 75, 90],
    [96, 185, 153, 233],
    [7, 64, 187, 9]
]

for phase in range(4):
    end_phase = (phase + 50) % 4
    partp = partial(p, phase)
    realc = partial(c, end_phase)

    partk = partks[phase]
    partrk = partial_expand(partk, phase)
    partc, _ = partial_encrypt(partp, partrk, phase)
    if partc == realc:
        print('DING DING DING!!!')
    else:
        print('NOOOOOOOOOO')

#            a           b           c           d
#  k p | 0 1 2 3 t | 0 1 2 3 t | 0 1 2 3 t | 0 1 2 3 t |
# ──────────────────────────────────────────────────────
#  0 0 |       .   |     ^ ^   |   ^   ^   | ^ ^ ^ ^   | ← k0
#  1 1 |     ^ ^ b |   ^   ^   | ^ ^ ^ ^   |       . @ | ← k1
#  2 2 |   ^   ^   | ^ ^ ^ ^   |       . @ |     ^ ^ b | ← k2
#  3 3 | ^ ^ ^ ^   |       . @ |     ^ ^ b |   ^   ^   |
#  4 0 |       * @ |     ^ ^ b |   ^   ^   | ^ ^ ^ ^   |
#  5 1 |     ^ ^ b |   ^   ^   | ^ ^ ^ ^   |       * @ | ← round 5 of p+k0
#  6 2 |           |           |       * @ |         b |
#  7 3 |           |       * @ |         b |           |
#  8 0 |       * @ |     ^ ^ b |           |           |
#  9 1 |     ^ ^ b |   ^   ^   |           |       * @ |
# 10 2 |   ^   ^   | ^ ^ ^ ^   |       * @ |     ^ ^ b | ← round 5 of p+k1

pk0, pk1, pk2, pk3 = partks

a, b, c, d = 0, 1, 2, 3
k = [None]*16

k[3*4+a] = pk0[a]
k[3*4+b] = pk3[b]
k[3*4+c] = pk2[c]
k[3*4+d] = pk1[d]

k[2*4+a] = pk0[a] ^ pk1[a]
k[2*4+b] = pk3[b] ^ pk0[b]
k[2*4+c] = pk2[c] ^ pk3[c]
k[2*4+d] = pk1[d] ^ pk2[d]

k[1*4+a] = pk0[a] ^ pk2[a]
k[1*4+b] = pk3[b] ^ pk1[b]
k[1*4+c] = pk2[c] ^ pk0[c]
k[1*4+d] = pk1[d] ^ pk3[d]

k[0*4+a] = pk0[a] ^ pk1[a] ^ pk2[a] ^ pk3[a]
k[0*4+b] = pk0[b] ^ pk1[b] ^ pk2[b] ^ pk3[b]
k[0*4+c] = pk0[c] ^ pk1[c] ^ pk2[c] ^ pk3[c]
k[0*4+d] = pk0[d] ^ pk1[d] ^ pk2[d] ^ pk3[d]

k = bytes(k)

p = bytes.fromhex("0dfa4c6052fb87ef0a8f03f705dd5101")
c = bytes.fromhex("d4ed19e0694101b6b151e11c2db973bf")
iv = bytes.fromhex("cd31cb6e6ded184efbb9a398e31ffdbb")
flag_enc = bytes.fromhex("653ec0cdd7e3a98c33414be8ef07c583d87b876afbff1d960f8f43b5a338e9ff96d87da4406ebe39a439dab3a84697d40c24557cd1ea6f433053451d20ce1fbf191270f4b8cc7891f8779eb615d35c9f")

assert TS(k).encrypt(p) == c

from Crypto.Cipher import AES
E = AES.new(k, AES.MODE_CBC, iv = iv)
flag = E.decrypt(flag_enc)
print(flag)