刚好在ByteCTF遇到了这道magic_lfsr,也是因此才开始对LFSR进行学习

以下是赛后参考结合了W&M的WriteUp和队里师傅x1ao的wp进行了复现

题面:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from Crypto.Cipher import AES
from Crypto.Util.number import *
from Crypto.Util.Padding import pad
from hashlib import sha512
from secret import flag

mask1 = 211151158277430590850506190902325379931
mask2 = 314024231732616562506949148198103849397
mask3 = 175840838278158851471916948124781906887
mask4 = 270726596087586267913580004170375666103


def lfsr(R, mask):
R_bin = [int(b) for b in bin(R)[2:].zfill(128)]
mask_bin = [int(b) for b in bin(mask)[2:].zfill(128)]
s = sum([R_bin[i] * mask_bin[i] for i in range(128)]) & 1
R_bin = [s] + R_bin[:-1]
return (int("".join(map(str, R_bin)), 2), s)


def ff(x0, x1, x2, x3):
return (int(sha512(long_to_bytes(x0 * x2 + x0 + x1**4 + x3**5 + x0 * x1 * x2 * x3 + (x1 * x3) ** 4)).hexdigest(), 16) & 1)


def round(R, R1_mask, R2_mask, R3_mask, R4_mask):
out = 0
R1_NEW, _ = lfsr(R, R1_mask)
R2_NEW, _ = lfsr(R, R2_mask)
R3_NEW, _ = lfsr(R, R3_mask)
R4_NEW, _ = lfsr(R, R4_mask)
for _ in range(270):
R1_NEW, x1 = lfsr(R1_NEW, R1_mask)
R2_NEW, x2 = lfsr(R2_NEW, R2_mask)
R3_NEW, x3 = lfsr(R3_NEW, R3_mask)
R4_NEW, x4 = lfsr(R4_NEW, R4_mask)
out = (out << 1) + ff(x1, x2, x3, x4)
return out


key = getRandomNBitInteger(128)
out = round(key, mask1, mask2, mask3, mask4)
cipher = AES.new(long_to_bytes(key), mode=AES.MODE_ECB)
print(f"out = {out}")
print(f"enc = {cipher.encrypt(pad(flag, 16))}")
# out = 1024311481407054984168503188572082106228007820628496173132204098971130766468510095
# enc = b'\r\x9du\xa15q\xd2\x81\x0b\xceq2\x8d)*\xe9\xf0;a\xad\xbe?\xa2\xb2\x1f\x88O\x8c\xa2[\xcb\x9a\xa9\x8d\xac\x0b\x85\xb4j@]\x0e}EF{\xb1\x91'

先进行分析:

提供了mask1~mask4四个掩码

1
2
3
4
5
6
def lfsr(R, mask):
R_bin = [int(b) for b in bin(R)[2:].zfill(128)]
mask_bin = [int(b) for b in bin(mask)[2:].zfill(128)]
s = sum([R_bin[i] * mask_bin[i] for i in range(128)]) & 1
R_bin = [s] + R_bin[:-1]
return (int("".join(map(str, R_bin)), 2), s)
  • 先把R转为二进制数并且填充到128位,mask同理
  • 然后将Rmask按位与,然后对所有位数求和取二进制形式的最低位得到s
  • R左移一位,把s加到R的最右边
  • 返回了Rs
1
2
def ff(x0, x1, x2, x3):
return (int(sha512(long_to_bytes(x0 * x2 + x0 + x1**4 + x3**5 + x0 * x1 * x2 * x3 + (x1 * x3) ** 4)).hexdigest(), 16) & 1)

此处是对数据的一种加密的处理,返回一个一位二进制的布尔值

1
2
3
4
5
6
7
8
9
10
11
12
13
def round(R, R1_mask, R2_mask, R3_mask, R4_mask):
out = 0
R1_NEW, _ = lfsr(R, R1_mask)
R2_NEW, _ = lfsr(R, R2_mask)
R3_NEW, _ = lfsr(R, R3_mask)
R4_NEW, _ = lfsr(R, R4_mask)
for _ in range(270):
R1_NEW, x1 = lfsr(R1_NEW, R1_mask)
R2_NEW, x2 = lfsr(R2_NEW, R2_mask)
R3_NEW, x3 = lfsr(R3_NEW, R3_mask)
R4_NEW, x4 = lfsr(R4_NEW, R4_mask)
out = (out << 1) + ff(x1, x2, x3, x4)
return out
  • 先是根据初始状态R和相应的掩码来初始化四个LFSR的状态。
  • 然后是循环更新状态,在这一步中,进行了270次迭代,每次迭代:
    • 更新每个LFSR的状态;
    • 获取每个LFSR的最后一位输出x1, x2, x3, x4
    • 将这些输出位传递给ff函数,得到一个布尔值(即0或1),该布尔值被追加到out的最低有效位(LSB)。
1
cipher = AES.new(long_to_bytes(key), mode=AES.MODE_ECB)

然后把key当作AES加密的密钥,在**cipher.encrypt(pad(flag, 16))**中对flag进行AES加密

思路:

已知mask1~mask4,以及AES加密后的结果,所以我们的目的也就是从lfsr中恢复出key

但是难点是ff函数难以逆向,所以考虑爆破出ff的真值表重新定义ff函数

题解:

出于学习的目的,我这里将两种思路都思考了一下

x1ao(未出flag,但值得学习):

先对ff函数进行爆破得到真值表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from Crypto.Util.number import *
from Crypto.Util.Padding import pad
from hashlib import sha512

def ff(x0, x1, x2, x3):
return (int(sha512(long_to_bytes(x0 * x2 + x0 + x1**4 + x3**5 + x0 * x1 * x2 * x3 + (x1 * x3) ** 4)).hexdigest(), 16) & 1)

n0 = 0
for x0 in range(2):
for x1 in range(2):
for x2 in range(2):
for x3 in range(2):
r = ff(x0,x1,x2,x3)
print((x0,x1,x2,x3),r)

得到的ff真值表如图: image-20240924172123002

然后我们重构ff函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#from z3 import *
def ff(x0, x1, x2, x3):
if x0==BitVecVal(0, 128) and x1==BitVecVal(0, 128) and x2==BitVecVal(0, 128) and x3==BitVecVal(0, 128):
return BitVecVal(0, 128)
elif x0==BitVecVal(0, 128) and x1==BitVecVal(0, 128) and x2==BitVecVal(0, 128) and x3==BitVecVal(1, 128):
return BitVecVal(1, 128)
elif x0==BitVecVal(0, 128) and x1==BitVecVal(0, 128) and x2==BitVecVal(1, 128) and x3==BitVecVal(0, 128):
return BitVecVal(0, 128)
elif x0==BitVecVal(0, 128) and x1==BitVecVal(0, 128) and x2==BitVecVal(1, 128) and x3==BitVecVal(1, 128):
return BitVecVal(1, 128)
elif x0==BitVecVal(0, 128) and x1==BitVecVal(1, 128) and x2==BitVecVal(0, 128) and x3==BitVecVal(0, 128):
return BitVecVal(1, 128)
elif x0==BitVecVal(0, 128) and x1==BitVecVal(1, 128) and x2==BitVecVal(0, 128) and x3==BitVecVal(1, 128):
return BitVecVal(0, 128)
elif x0==BitVecVal(0, 128) and x1==BitVecVal(1, 128) and x2==BitVecVal(1, 128) and x3==BitVecVal(0, 128):
return BitVecVal(1, 128)
elif x0==BitVecVal(0, 128) and x1==BitVecVal(1, 128) and x2==BitVecVal(1, 128) and x3==BitVecVal(1, 128):
return BitVecVal(0, 128)
elif x0==BitVecVal(1, 128) and x1==BitVecVal(0, 128) and x2==BitVecVal(0, 128) and x3==BitVecVal(0, 128):
return BitVecVal(1, 128)
elif x0==BitVecVal(1, 128) and x1==BitVecVal(0, 128) and x2==BitVecVal(0, 128) and x3==BitVecVal(1, 128):
return BitVecVal(0, 128)
elif x0==BitVecVal(1, 128) and x1==BitVecVal(0, 128) and x2==BitVecVal(1, 128) and x3==BitVecVal(0, 128):
return BitVecVal(0, 128)
elif x0==BitVecVal(1, 128) and x1==BitVecVal(0, 128) and x2==BitVecVal(1, 128) and x3==BitVecVal(1, 128):
return BitVecVal(0, 128)
elif x0==BitVecVal(1, 128) and x1==BitVecVal(1, 128) and x2==BitVecVal(0, 128) and x3==BitVecVal(0, 128):
return BitVecVal(0, 128)
elif x0==BitVecVal(1, 128) and x1==BitVecVal(1, 128) and x2==BitVecVal(0, 128) and x3==BitVecVal(1, 128):
return BitVecVal(1, 128)
elif x0==BitVecVal(1, 128) and x1==BitVecVal(1, 128) and x2==BitVecVal(1, 128) and x3==BitVecVal(0, 128):
return BitVecVal(0, 128)
else:
return BitVecVal(1, 128)

然后利用z3的SMT求解器添加约束条件后进行求解即可得到key

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
solver = Solver()	#创建一个SMT求解器实例solver

key = BitVec('key',128) #创建一个128位比特向量key

#将mask1~mask4都转换为128位的比特向量
mask1_bv = BitVecVal(mask1, 128)
mask2_bv = BitVecVal(mask2, 128)
mask3_bv = BitVecVal(mask3, 128)
mask4_bv = BitVecVal(mask4, 128)

out = 1024311481407054984168503188572082106228007820628496173132204098971130766468510095
out_128 = int(bin(out)[-128:], 2)

round_output = round(key,mask1_bv,mask2_bv,mask3_bv,mask4_bv)

solver.add(round_output == BitVecVal(out_128,128))
print(solver.check())

完整代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from z3 import *
from hashlib import sha512
from Crypto.Util.number import *

mask1 = 211151158277430590850506190902325379931
mask2 = 314024231732616562506949148198103849397
mask3 = 175840838278158851471916948124781906887
mask4 = 270726596087586267913580004170375666103

def lfsr(R, mask):
R_bin = R
mask_bin = mask
s = LShR(R_bin & mask_bin, 127)
R_new = (R_bin << 1) & ((1 << 128) - 1)
R_new = R_new | s
return R_new, s


def ff(x0, x1, x2, x3):
if x0==BitVecVal(0, 128) and x1==BitVecVal(0, 128) and x2==BitVecVal(0, 128) and x3==BitVecVal(0, 128):
return BitVecVal(0, 128)
elif x0==BitVecVal(0, 128) and x1==BitVecVal(0, 128) and x2==BitVecVal(0, 128) and x3==BitVecVal(1, 128):
return BitVecVal(1, 128)
elif x0==BitVecVal(0, 128) and x1==BitVecVal(0, 128) and x2==BitVecVal(1, 128) and x3==BitVecVal(0, 128):
return BitVecVal(0, 128)
elif x0==BitVecVal(0, 128) and x1==BitVecVal(0, 128) and x2==BitVecVal(1, 128) and x3==BitVecVal(1, 128):
return BitVecVal(1, 128)
elif x0==BitVecVal(0, 128) and x1==BitVecVal(1, 128) and x2==BitVecVal(0, 128) and x3==BitVecVal(0, 128):
return BitVecVal(1, 128)
elif x0==BitVecVal(0, 128) and x1==BitVecVal(1, 128) and x2==BitVecVal(0, 128) and x3==BitVecVal(1, 128):
return BitVecVal(0, 128)
elif x0==BitVecVal(0, 128) and x1==BitVecVal(1, 128) and x2==BitVecVal(1, 128) and x3==BitVecVal(0, 128):
return BitVecVal(1, 128)
elif x0==BitVecVal(0, 128) and x1==BitVecVal(1, 128) and x2==BitVecVal(1, 128) and x3==BitVecVal(1, 128):
return BitVecVal(0, 128)
elif x0==BitVecVal(1, 128) and x1==BitVecVal(0, 128) and x2==BitVecVal(0, 128) and x3==BitVecVal(0, 128):
return BitVecVal(1, 128)
elif x0==BitVecVal(1, 128) and x1==BitVecVal(0, 128) and x2==BitVecVal(0, 128) and x3==BitVecVal(1, 128):
return BitVecVal(0, 128)
elif x0==BitVecVal(1, 128) and x1==BitVecVal(0, 128) and x2==BitVecVal(1, 128) and x3==BitVecVal(0, 128):
return BitVecVal(0, 128)
elif x0==BitVecVal(1, 128) and x1==BitVecVal(0, 128) and x2==BitVecVal(1, 128) and x3==BitVecVal(1, 128):
return BitVecVal(0, 128)
elif x0==BitVecVal(1, 128) and x1==BitVecVal(1, 128) and x2==BitVecVal(0, 128) and x3==BitVecVal(0, 128):
return BitVecVal(0, 128)
elif x0==BitVecVal(1, 128) and x1==BitVecVal(1, 128) and x2==BitVecVal(0, 128) and x3==BitVecVal(1, 128):
return BitVecVal(1, 128)
elif x0==BitVecVal(1, 128) and x1==BitVecVal(1, 128) and x2==BitVecVal(1, 128) and x3==BitVecVal(0, 128):
return BitVecVal(0, 128)
else:
return BitVecVal(1, 128)



def round(R, R1_mask, R2_mask, R3_mask, R4_mask):
out = 0
R1_NEW, _ = lfsr(R, R1_mask)
R2_NEW, _ = lfsr(R, R2_mask)
R3_NEW, _ = lfsr(R, R3_mask)
R4_NEW, _ = lfsr(R, R4_mask)

for _ in range(270):
R1_NEW, x1 = lfsr(R1_NEW, R1_mask)
R2_NEW, x2 = lfsr(R2_NEW, R2_mask)
R3_NEW, x3 = lfsr(R3_NEW, R3_mask)
R4_NEW, x4 = lfsr(R4_NEW, R4_mask)
out = (out << 1) | ff(x1, x2, x3, x4)
return out


solver = Solver()

key = BitVec('key', 128)

mask1_bv = BitVecVal(mask1, 128)
mask2_bv = BitVecVal(mask2, 128)
mask3_bv = BitVecVal(mask3, 128)
mask4_bv = BitVecVal(mask4, 128)

out=1024311481407054984168503188572082106228007820628496173132204098971130766468510095
out_128=int(bin(out)[-128:],2)



round_output = round(key, mask1_bv, mask2_bv, mask3_bv, mask4_bv)

solver.add(round_output == out_128)
print(solver.check())
# unsat

W&M ‘s WriteUp:

根据真值表猜测了一个可能的逻辑关系:(x0+x1+x3)%2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from Crypto.Util.number import *
from Crypto.Util.Padding import pad
from hashlib import sha512

def ff(x0, x1, x2, x3):
return (int(sha512(long_to_bytes(x0 * x2 + x0 + x1**4 + x3**5 + x0 * x1 * x2 * x3 + (x1 * x3) ** 4)).hexdigest(), 16) & 1)

n0 = 0
for x0 in range(2):
for x1 in range(2):
for x2 in range(2):
for x3 in range(2):
r = ff(x0,x1,x2,x3)
print((x0,x1,x2,x3),r,(x0+x1+x3)%2)
n0 += int(r==(x0+x1+x3)%2) #计算符合度
print(n0)

发现仅有一种情况不符合: image-20240924201524642

因此选取所有的1值(总共127个)造矩阵,至于剩下的那个把所有的0值选取一遍爆破一下即可

这里先贴个wp,然后再进行分析:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from Crypto.Cipher import AES
from Crypto.Util.number import *
from Crypto.Util.Padding import pad
from hashlib import sha512
from copy import deepcopy

out = 1024311481407054984168503188572082106228007820628496173132204098971130766468510095
enc = b'\r\x9du\xa15q\xd2\x81\x0b\xceq2\x8d)*\xe9\xf0;a\xad\xbe?\xa2\xb2\x1f\x88O\x8c\xa2[\xcb\x9a\xa9\x8d\xac\x0b\x85\xb4j@]\x0e}EF{\xb1\x91'


mask1 = 211151158277430590850506190902325379931
mask2 = 314024231732616562506949148198103849397
mask3 = 175840838278158851471916948124781906887
mask4 = 270726596087586267913580004170375666103

m = matrix(GF(2),128,129)
M1,M2,M3 = matrix(GF(2),128,128),matrix(GF(2),128,128),matrix(GF(2),128,128)

for i in range(128):
M1[i,0] = int(bin(mask1)[2:][i])
M2[i,0] = int(bin(mask2)[2:][i])
M3[i,0] = int(bin(mask4)[2:][i])
if i != 127:
M1[i,i+1] = 1
M2[i,i+1] = 1
M3[i,i+1] = 1

out = bin(out)[2:].zfill(270)
print(out)
si = []
for i in range(270):
if out[i] == '1':
si.append(i)
print(si,len(si))

n = 0
MM1,MM2,MM3 = identity_matrix(GF(2),128),identity_matrix(GF(2),128),identity_matrix(GF(2),128)
MM1 *= M1
MM2 *= M2
MM3 *= M3
for i in range(270):
MM1 *= M1
MM2 *= M2
MM3 *= M3
if i in si:
m[:,n] = (MM1+MM2+MM3)[:,0]
n += 1

assert n == 127
output = matrix([1]*127+[0]*2)
MM1,MM2,MM3 = identity_matrix(GF(2),128),identity_matrix(GF(2),128),identity_matrix(GF(2),128)
MM1 *= M1
MM2 *= M2
MM3 *= M3
for i in range(270):
MM1 *= M1
MM2 *= M2
MM3 *= M3
if i not in si:
m[:,-2] = (MM1+MM2+MM3)[:,0]
MMM1,MMM2,MMM3 = deepcopy(MM1),deepcopy(MM2),deepcopy(MM3)
for j in range(i+1,270):
MMM1 *= M1
MMM2 *= M2
MMM3 *= M3
if j not in si:
m[:,-1] = (MMM1+MMM2+MMM3)[:,0]
try:
key = (m.solve_left(output))[0]
key = ''.join(str(_) for _ in key)
key = int(key,2)
cipher = AES.new(long_to_bytes(key), mode=AES.MODE_ECB)
print(cipher.decrypt(enc))
except:
continue

首先是:

1
2
3
4
5
6
7
8
9
10
11
m = matrix(GF(2),128,129)
M1,M2,M3 = matrix(GF(2),128,128),matrix(GF(2),128,128),matrix(GF(2),128,128)

for i in range(128):
M1[i,0] = int(bin(mask1)[2:][i])
M2[i,0] = int(bin(mask2)[2:][i])
M3[i,0] = int(bin(mask4)[2:][i])
if i != 127:
M1[i,i+1] = 1
M2[i,i+1] = 1
M3[i,i+1] = 1

创建了几个矩阵,用掩码填充了矩阵M1M2M3,同时设置了下三角矩阵的部分

1
2
3
4
5
6
out = bin(out)[2:].zfill(270)
print(out)
si = []
for i in range(270):
if out[i] == '1':
si.append(i)

把out转换为二进制字符串,并记录其中1的位置索引

1
2
3
4
5
n = 0
MM1,MM2,MM3 = identity_matrix(GF(2),128),identity_matrix(GF(2),128),identity_matrix(GF(2),128)
MM1 *= M1
MM2 *= M2
MM3 *= M3

n用于跟踪当前要填充的列索引

identity_matrix(GF(2),128)创建了一个有限域为2的128维单位矩阵,

将单位矩阵与掩码矩阵M1M2M3相乘,这样每个矩阵就变成了形如**[mask,1,0,…,0]**的形式

1
2
3
4
5
6
7
for i in range(270):
MM1 *= M1
MM2 *= M2
MM3 *= M3
if i in si:
m[:,n] = (MM1+MM2+MM3)[:,0]
n += 1
  • 迭代:然后进行迭代270次,模拟的是270次的LFSR的迭代
  • 更新矩阵MM1MM2MM3:每次迭代都将MM1, MM2, MM3 分别与掩码矩阵 M1, M2, M3 相乘,模拟一次 LFSR 的迭代。
  • 检查索引i是否在si:如果索引isi(即out字符串中为1的位置),则将(MM1+MM2+MM3)[:,0]填充矩阵m的第n
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
assert n == 127
output = matrix([1]*127+[0]*2)
MM1,MM2,MM3 = identity_matrix(GF(2),128),identity_matrix(GF(2),128),identity_matrix(GF(2),128)
MM1 *= M1
MM2 *= M2
MM3 *= M3
for i in range(270):
MM1 *= M1
MM2 *= M2
MM3 *= M3
if i not in si:
m[:,-2] = (MM1+MM2+MM3)[:,0]
MMM1,MMM2,MMM3 = deepcopy(MM1),deepcopy(MM2),deepcopy(MM3)
for j in range(i+1,270):
MMM1 *= M1
MMM2 *= M2
MMM3 *= M3
if j not in si:
m[:,-1] = (MMM1+MMM2+MMM3)[:,0]

由以上部分得到了n == 127

  • 定义了一个输出向量output,他是以一个129×1的矩阵,前127个1和2个0,目的是为了构造线性方程组,是的方程组的形式为
    $$
    m\ ·\ key\ =\ output
    $$

  • 重新初始化了三个单位矩阵,并将它们与掩码矩阵相乘,以模拟初始状态

  • 在每次迭代中更新矩阵MM1MM2,MM3,模拟LFSR的迭代

  • 检查索引i是否不在si中,如果不在,则填充矩阵m的倒数第二列

  • 使用深拷贝创建新的矩阵副本,并再次迭代更新矩阵,知道找到一个不在si中的索引j,然后填充矩阵m的最后一列

1
2
3
4
5
6
7
8
try:
key = (m.solve_left(output))[0]
key = ''.join(str(_) for _ in key)
key = int(key,2)
cipher = AES.new(long_to_bytes(key), mode=AES.MODE_ECB)
print(cipher.decrypt(enc))
except:
continue

尝试求解线性方程组:
$$
m\ ·\ key\ =\ output
$$
找到解key后,将其转换为字节,并使用AES解密enc