import itertools
from sage.crypto.sboxes import Midori_Sb0 as Sb0
from operator import xor, __and__, __or__
CSb0 = Sb0.linear_approximation_matrix() / 8
# Construct Fourier basis matrix
m = matrix([[1, 1], [1, -1]])
basis = m.tensor_product(m, subdivide=False)\
.tensor_product(m, subdivide=False)\
.tensor_product(m, subdivide=False)
def isPositiveDefinite(v):
return all(x >= 0 for x in basis * v)
def vectorToInt(v):
return sum(int(v[i])*(2**i) for i in range(4))
def vectorSpaces(ks):
""" Generates all vector spaces of dimension k for all k in ks. """
for k in ks:
for space in VectorSpace(GF(2), 4).subspaces(k):
yield [vectorToInt(v) for v in space]
def affineSpaces(ks):
""" Generates all affine spaces of dimension k for all k in ks. """
for space in vectorSpaces(ks):
yield space
values = set(range(16)) - set(space)
while len(values):
c = values.pop()
affine_space = sorted([xor(x, c) for x in space])
values -= set(affine_space)
yield affine_space
def isMixColumnInvariant(v):
""" Implements Theorem 9 (the case len(supp) = 16 is not covered). """
supp = [i for i in range(16) if v[i] != 0]
if len(supp) not in [1, 2, 4, 8]:
return False
if len({abs(v[i]) for i in supp}) != 1:
return False
for (a, b) in itertools.combinations(supp, 2):
if xor(supp[0], xor(a, b)) not in supp:
return False
if len(supp) == 8 and len([i for i in range(16) if v[i] == -1]) % 2 != 0:
return False
return True
def getNbFreeBits(v):
""" Returns the number of key bits that do not affect v (excluding sign). """
supp = [i for i in range(16) if v[i] != 0]
and_supp = reduce(__and__, supp)
or_supp = reduce(__or__, supp)
return bin(__or__(and_supp, xor(0xf, or_supp))).count("1")
def formatNum(x):
s = str(x)
return " " * (2 - len(s)) + s
for space in affineSpaces([0, 1, 2, 3]):
for values in itertools.product([1, -1], repeat = len(space)-1):
values = [1] + list(values)
if len(space) == 3 and values.count(-1) % 2 != 0:
continue
v = matrix(16, 1)
for j, i in enumerate(space):
v[i, 0] = values[j]
w = CSb0 * v
if w == v or w == -v:
nb_free = getNbFreeBits(v) * 32
t = "{IP, ~}" if isPositiveDefinite(v) else "{IF, ~}"
print v.T.str(rep_mapping=formatNum), "[Free bits = %d]" % nb_free , t
elif isMixColumnInvariant(w):
nb_free = (getNbFreeBits(v) + getNbFreeBits(w)) * 16
t = "{IP," if isPositiveDefinite(v) else "{IF,"
t += "IP}" if isPositiveDefinite(w) else "IF}"
print v.T.str(rep_mapping=formatNum), "[Free bits = %d]" % nb_free, t