剪枝相关

比赛时遇到了挺多剪枝相关的题,但是只会套板子,趁着有时间学习总结一下

定义

DFS之剪枝与优化指的是在执行深度优先搜索(DFS, Depth-First Search)时,采取的一系列策略来减少搜索空间,避免无效计算,从而加速找到问题的解。剪枝是指在搜索过程中,当遇到某些条件不符合解的要求或者可以预判后续搜索不会产生有效解时,直接放弃这条搜索路径,这一过程称为剪枝。优化则是指通过调整搜索策略、顺序等,提高搜索效率。

题目(收集ing……)

首尾剪枝

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from Crypto.Util.number import *
from secret import flag

m = bytes_to_long(flag)
p = getPrime(256)
q = getPrime(256)
n = p * q
e = 65537
_q = int(bin(q)[2:][::-1] , 2)
c = pow(m,e,n)

print(p ^ _q)
print(n)
print(c)

'''
47761879279815109356923025519387920397647575481870870315845640832106405230526
10310021142875344535823132048350287610122830618624222175188882916320750885684668357543070611134424902255744858233485983896082731376191044874283981089774677
999963120986258459742830847940927620860107164857685447047839375819380831715400110131705491405902374029088041611909274341590559275004502111124764419485191
'''

题目给出p 与 q 的反方向二进制的异或值,根据异或操作的特性,可知如果当前需搜索的最高位为”1”,则对应两种可能:p该位为1,q对应低位为0;p该位为0,q对应低位为1。对应的剪枝条件为

1.将p、q未搜索到的位全填0,乘积应小于n
2.将p、q未搜索到的位全填1,乘积应大于n
3.p、q 低 k 位乘积再取低 k 位,应与 n 的低 k 位相同

首先定义搜索函数

1
2
3
4
5
6
7
8
9
10
11
12
def find(ph,qh,pl,ql):
l = len(ph)
tmp0 = ph + (256-2*l)*"0" + pl
tmp1 = ph + (256-2*l)*"1" + pl
tmq0 = qh + (256-2*l)*"0" + ql
tmq1 = qh + (256-2*l)*"1" + ql
if(int(tmp0,2)*int(tmq0,2) > n):
return
if(int(tmp1,2)*int(tmq1,2) < n):
return
if(int(pl,2)*int(ql,2) % (2**(l-1)) != n % (2**(l-1))):
return

根据条件进行剪枝操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
if(l == 128):
pp0 = int(tmp0,2)
if(n % pp0 == 0):
pf = pp0
qf = n//pp0
phi = (pf-1)*(qf-1)
d = inverse(e,phi)
m1 = pow(c,d,n)
print(long_to_bytes(m1))
exit()

else:
if(pxorq[l] == "1" and pxorq[255-l] == "1"):
find(ph+"1",qh+"0","1"+pl,"0"+ql)
find(ph+"0",qh+"0","1"+pl,"1"+ql)
find(ph+"1",qh+"1","0"+pl,"0"+ql)
find(ph+"0",qh+"1","0"+pl,"1"+ql)
elif(pxorq[l] == "1" and pxorq[255-l] == "0"):
find(ph+"1",qh+"0","0"+pl,"0"+ql)
find(ph+"0",qh+"0","0"+pl,"1"+ql)
find(ph+"1",qh+"1","1"+pl,"0"+ql)
find(ph+"0",qh+"1","1"+pl,"1"+ql)
elif(pxorq[l] == "0" and pxorq[255-l] == "1"):
find(ph+"0",qh+"0","1"+pl,"0"+ql)
find(ph+"0",qh+"1","0"+pl,"0"+ql)
find(ph+"1",qh+"0","1"+pl,"1"+ql)
find(ph+"1",qh+"1","0"+pl,"1"+ql)
elif(pxorq[l] == "0" and pxorq[255-l] == "0"):
find(ph+"0",qh+"0","0"+pl,"0"+ql)
find(ph+"1",qh+"0","0"+pl,"1"+ql)
find(ph+"0",qh+"1","1"+pl,"0"+ql)
find(ph+"1",qh+"1","1"+pl,"1"+ql)

2(名称待定)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from Crypto.Util.number import *
def myGetPrime():
while True:
x = getRandomNBitInteger(1024) & ((1 << 1024) - 1)//3
if isPrime(x):
return x
p = myGetPrime()
q = myGetPrime()
n = p * q
e = 65537
message = open('flag.txt', 'rb')
m = bytes_to_long(message.read())
c = pow(m, e, n)
open("superstitious-2.txt", "w").write(f"n = {n}\ne = {e}\nc = {c}")

首先关注一下((1 << 1024) - 1)//3这个数,发现是10101010……01,适合剪枝,根据逻辑与操作的特点:全一为一,有零为零。且p*q的低位等于n的低位.

首先知道p和q末尾必是01,再逐步从后向前进行剪枝,又因为只有奇数位有1,每次可以操作两位,用01和00搭配可能性即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from Cryptodome.Util.number import inverse, long_to_bytes

n = 550201148354755741271315125069984668413716061796183554308291706476140978529375848655819753667593579308959498512392008673328929157581219035186964125404507736120739215348759388064536447663960474781494820693212364523703341226714116205457869455356277737202439784607342540447463472816215050993875701429638490180199815506308698408730404219351173549572700738532419937183041379726568197333982735249868511771330859806268212026233242635600099895587053175025078998220267857284923478523586874031245098448804533507730432495577952519158565255345194711612376226297640371430160273971165373431548882970946865209008499974693758670929
e = 65537
c = 12785320910832143088122342957660384847883123024416376075086619647021969680401296902000223390419402987207599720081750892719692986089224687862496368722454869160470101334513312534671470957897816352186267364039566768347665078311312979099890672319750445450996125821736515659224070277556345919426352317110605563901547710417861311613471239486750428623317970117574821881877688142593093266784366282508041153548993479036139219677970329934829870592931817113498603787339747542136956697591131562660228145606363369396262955676629503331736406313979079546532031753085902491581634604928829965989997727970438591537519511620204387132


def findp(p, q):
if len(p) == 1024:
pp = int(p, 2)
if n % pp == 0:
print(pp)
print(n // pp)
else:
l = len(p)
pp = int(p, 2)
qq = int(q, 2)
if pp * qq % (2**l) == n % (2**l):
findp("01" + p, "01" + q)
findp("01" + p, "00" + q)
findp("00" + p, "01" + q)
findp("00" + p, "00" + q)

findp('01','01')

p = 11466867937506443031079406557463511000236825156042986330491372554263065048494616429572254582549332374593524344514321333368747919034845244563606383834070804967345648840205613712911286600828703809116499141392947298788689558078395325755136448592591616295144118450804581480471547613492025968699740517273286296657
q = n // p
d = inverse(e, (p - 1) * (q - 1))
print(long_to_bytes(pow(c, d, n)))

3,已知p^q

1
2
3
4
5
6
7
8
9
10
from Crypto.Util.number import *
p = getPrime(128)
q = getPrime(128)
n = p*q
xor = p^q
print(f"n = {n}")
print(f"xor = {xor}")

#n = 81273634095521392491945168120330007101677085824054016784875224305683560308213
#xor = 55012774068906519160740720236510369652

​ 搜索条件:

  • 从低位向高位搜索

  • 若xor当前位为1,则可能为两种情况:p为1,q为0 或者 p为0,q为1;反之xor当前位为0,则p为1,q为1 或者 p为0,q为0.

    剪枝条件:

    • 将p和q剩下位全部填充为1,需要满足 p*q > n
    • 将p和q剩下位全部填充为0,需要满足 p*q < n

其实算是第一道题的简化版,当p*q=n或者n mod p=0时结束

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
n = 81273634095521392491945168120330007101677085824054016784875224305683560308213
xor = 55012774068906519160740720236510369652
pbits = 128
ph = ''
qh = ''
xor = str(bin(xor)[2:]).zfill(pbits)

def find(ph,qh):
l0 = len(ph)
l1 = len(qh)
tmp0 = ph + '0' * (pbits-l0)
tmp1 = ph + '1' * (pbits-l0)
tmq0 = qh + '0' * (pbits-l1)
tmq1 = qh + '1' * (pbits-l1)
if int(tmp0,2) * int(tmq0,2) > n:#剪枝条件1
return
if int(tmp1,2) * int(tmq1,2) < n:#剪枝条件2
return

if l0 == pbits:#结束条件
if int(ph,2) * int(qh,2) == n:
print(f'p = {int(ph,2)}')
print(f'q = {int(qh,2)}')
return

else:
if xor[l1] == '1':
find(ph+'0',qh+'1')
find(ph + '1',qh+'0')
else:
find(ph+'1',qh+'1')
find(ph + '0',qh+'0')

find(ph,qh)


#运行结果
'''
p = 270451921611135557038833183249275131423
q = 300510470073047693263940829088190906731
p = 300510470073047693263940829088190906731
q = 270451921611135557038833183249275131423
'''

4,p ^(q >> 16)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from Crypto.Util.number import *
from secret import secret, flag
def encrypt(m):
return pow(m, e, n)
assert flag == b"dasctf{" + secret + b"}"
e = 11
p = getPrime(512)
q = getPrime(512)
n = p * q
P = getPrime(512)
Q = getPrime(512)
N = P * Q
gift = P ^ (Q >> 16)

print(N, gift, pow(n, e, N))
print(encrypt(bytes_to_long(secret)),
encrypt(bytes_to_long(flag)))

N = 75000029602085996700582008490482326525611947919932949726582734167668021800854674616074297109962078048435714672088452939300776268788888016125632084529419230038436738761550906906671010312930801751000022200360857089338231002088730471277277319253053479367509575754258003761447489654232217266317081318035524086377
gift = 8006730615575401350470175601463518481685396114003290299131469001242636369747855817476589805833427855228149768949773065563676033514362512835553274555294034
pow(n,e,N) = 14183763184495367653522884147951054630177015952745593358354098952173965560488104213517563098676028516541915855754066719475487503348914181674929072472238449853082118064823835322313680705889432313419976738694317594843046001448855575986413338142129464525633835911168202553914150009081557835620953018542067857943
pow(secret,e,n) = 69307306970629523181683439240748426263979206546157895088924929426911355406769672385984829784804673821643976780928024209092360092670457978154309402591145689825571209515868435608753923870043647892816574684663993415796465074027369407799009929334083395577490711236614662941070610575313972839165233651342137645009
pow(flag,e,n) = 46997465834324781573963709865566777091686340553483507705539161842460528999282057880362259416654012854237739527277448599755805614622531827257136959664035098209206110290879482726083191005164961200125296999449598766201435057091624225218351537278712880859703730566080874333989361396420522357001928540408351500991

这里只看一下通过剪枝分解n的操作,其实与上一道题差别不大,但是剪枝条件有所变化。

1,(pp ^ (qq >> 16)) % (2 ** l) == gift % (2 ** l)

2,pp * qq % (2 ** l) == N % (2 ** l)

第二点感觉是都有的,第一点根据题目信息改动即可

tips:因为gift是p异或q右移16位的结果,所以p的最后一位1相当于异或了q的第十七位。这也就是为什么只搜p而不是同时搜p,q,传入的也不是q的末位1而是q的末17位,在调用函数的时候才会有爆破了q后17位的操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
N=75000029602085996700582008490482326525611947919932949726582734167668021800854674616074297109962078048435714672088452939300776268788888016125632084529419230038436738761550906906671010312930801751000022200360857089338231002088730471277277319253053479367509575754258003761447489654232217266317081318035524086377
gift=8006730615575401350470175601463518481685396114003290299131469001242636369747855817476589805833427855228149768949773065563676033514362512835553274555294034
c1=14183763184495367653522884147951054630177015952745593358354098952173965560488104213517563098676028516541915855754066719475487503348914181674929072472238449853082118064823835322313680705889432313419976738694317594843046001448855575986413338142129464525633835911168202553914150009081557835620953018542067857943

def findp(p,q):
if len(p)==512:
p1=int(p,2)
if N % p1 ==0:
print(p1,N//p1)
else:
bit=len(p)
p1=int(p,2)
q1=int(q,2)
if (p1^(q1>>16))%(2**bit)==gift%(2**bit) and p1*q1%(2**bit)==N%(2**bit):#当目前深搜出来的位数符合实际,继续搜索。
findp('1'+p,'1'+q)
findp('0'+p,'1'+q)
findp('0'+p,'0'+q)
findp('1'+p,'0'+q)


for i in range(2**17):
findp('1',bin(i)[2:])#其中i可以看作q的低位

小结

时间关系先暂时收录这四种题型,其实归根结底都是一种问题,即通过p*q=n和题目给出的条件进行剪枝分解n,

关于剪枝,其实感觉没有特定的做法,而是一种思想,就是通过各种方法减少搜索规模,从而提高效率。目前遇到的剪枝都是在RSA中,或许其他的密码体系也会有着这种思想存在?

还感觉比较重要的一点是搜索的顺序,剪枝是一种方法,但是有些时候我们可以通过该变搜索的顺序来进一步提高效率,包括上述提到的首尾剪枝,低位向高位剪枝……继续深挖下去发现涉及到更深一步的算法待后续研究。任重而道远捏{}