import itertools
from sage.crypto.sboxes import Midori_Sb0 as Sb0
from operator import xor, __and__, __or__

CSb0 = Sb0.linear_approximation_matrix() / 8

# Construct Fourier basis matrix
m = matrix([[1, 1], [1, -1]])
basis = m.tensor_product(m, subdivide=False)\
         .tensor_product(m, subdivide=False)\
         .tensor_product(m, subdivide=False)

def isPositiveDefinite(v):
    return all(x >= 0 for x in basis * v)

def vectorToInt(v):
    return sum(int(v[i])*(2**i) for i in range(4))

def vectorSpaces(ks):
    """ Generates all vector spaces of dimension k for all k in ks. """
    for k in ks: 
        for space in VectorSpace(GF(2), 4).subspaces(k):
            yield [vectorToInt(v) for v in space]

def affineSpaces(ks):
    """ Generates all affine spaces of dimension k for all k in ks. """
    for space in vectorSpaces(ks):
        yield space
        values = set(range(16)) - set(space)
        while len(values):
            c = values.pop()
            affine_space = sorted([xor(x, c) for x in space])
            values -= set(affine_space)
            yield affine_space

def isMixColumnInvariant(v):
    """ Implements Theorem 9 (the case len(supp) = 16 is not covered). """
    supp = [i for i in range(16) if v[i] != 0]
    if len(supp) not in [1, 2, 4, 8]:
        return False
    if len({abs(v[i]) for i in supp}) != 1:
        return False
    for (a, b) in itertools.combinations(supp, 2):
        if xor(supp[0], xor(a, b)) not in supp:
            return False
    if len(supp) == 8 and len([i for i in range(16) if v[i] == -1]) % 2 != 0:
        return False

    return True

def getNbFreeBits(v):
    """ Returns the number of key bits that do not affect v (excluding sign). """
    supp = [i for i in range(16) if v[i] != 0]
    and_supp = reduce(__and__, supp)
    or_supp  = reduce(__or__, supp)
    return bin(__or__(and_supp, xor(0xf, or_supp))).count("1")

def formatNum(x):
    s = str(x)
    return " " * (2 - len(s)) + s

for space in affineSpaces([0, 1, 2, 3]):
    for values in itertools.product([1, -1], repeat = len(space)-1):
        values = [1] + list(values)
        if len(space) == 3 and values.count(-1) % 2 != 0:
            continue
        v = matrix(16, 1)
        for j, i in enumerate(space):
            v[i, 0] = values[j]

        w = CSb0 * v
        if w == v or w == -v:
            nb_free = getNbFreeBits(v) * 32
            t = "{IP, ~}" if isPositiveDefinite(v) else "{IF, ~}"
            print v.T.str(rep_mapping=formatNum), "[Free bits = %d]" % nb_free , t
        elif isMixColumnInvariant(w):
            nb_free = (getNbFreeBits(v) + getNbFreeBits(w)) * 16
            t  = "{IP," if isPositiveDefinite(v) else "{IF,"
            t += "IP}"  if isPositiveDefinite(w) else "IF}"
            print v.T.str(rep_mapping=formatNum), "[Free bits = %d]" % nb_free, t