import random
from operator import xor
from sage.crypto.sboxes import Midori_Sb0 as Sb0
from sage.crypto.boolean_function import BooleanFunction

def xor3(a, b, c):
    return xor(a, xor(b, c))

def mixColumn(nibbles):
    return [
        xor3(nibbles[1], nibbles[2], nibbles[3]),
        xor3(nibbles[0], nibbles[2], nibbles[3]),
        xor3(nibbles[0], nibbles[1], nibbles[3]),
        xor3(nibbles[0], nibbles[1], nibbles[2])
    ]

def subCell(nibbles):
    for i in range(16):
        nibbles[i] = Sb0(nibbles[i])

def addKey(nibbles, key):
    for i in range(16):    
        nibbles[i] = xor(nibbles[i], key[i])

RC = [
    [0,0,0,1,0,1,0,1,1,0,1,1,0,0,1,1], [0,1,1,1,1,0,0,0,1,1,0,0,0,0,0,0],
    [1,0,1,0,0,1,0,0,0,0,1,1,0,1,0,1], [0,1,1,0,0,0,1,0,0,0,0,1,0,0,1,1],
    [0,0,0,1,0,0,0,0,0,1,0,0,1,1,1,1], [1,1,0,1,0,0,0,1,0,1,1,1,0,0,0,0],
    [0,0,0,0,0,0,1,0,0,1,1,0,0,1,1,0], [0,0,0,0,1,0,1,1,1,1,0,0,1,1,0,0],
    [1,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1], [0,1,0,0,0,0,0,0,1,0,1,1,1,0,0,0],
    [0,1,1,1,0,0,0,1,1,0,0,1,0,1,1,1], [0,0,1,0,0,0,1,0,1,0,0,0,1,1,1,0],
    [0,1,0,1,0,0,0,1,0,0,1,1,0,0,0,0], [1,1,1,1,1,0,0,0,1,1,0,0,1,0,1,0],
    [1,1,0,1,1,1,1,1,1,0,0,1,0,0,0,0]
]

def addRoundConstants(nibbles, r, b):
    for i in range(16):
        nibbles[i] = xor(nibbles[i], RC[r][i] << b)

ShuffleCell = [0, 10, 5, 15, 14, 4, 11, 1, 9, 3, 12, 6, 7, 13, 2, 8]
def shuffleCells(nibbles):
    result = [0] * 16
    for i in range(16):
        result[i] = nibbles[ShuffleCell[i]]
    return result

def midori64(nibbles, rounds, key, b = 0):
    whitening_key = [xor(key[0][i], key[1][i]) for i in range(16)]
    addKey(nibbles, whitening_key)
    for i in range(rounds - 1):
        subCell(nibbles)
        nibbles = shuffleCells(nibbles)
        for j in range(4):
            result = mixColumn(nibbles[4*j:4*j+4])
            for k in range(4):
                nibbles[4*j + k] = result[k]
        addRoundConstants(nibbles, i, b)
        addKey(nibbles, key[i % 2])
    subCell(nibbles)
    addKey(nibbles, whitening_key)
    return nibbles


R.<x0, x1, x2, x3> = BooleanPolynomialRing(4)
f = BooleanFunction(x0*x2 + x0 + x1 + x3)
g = BooleanFunction(x0 + x2)

key = [[0] * 16, [0] * 16]

# Test vector
assert midori64([0] * 16, 16, key) == \
    [3, 12, 9, 12, 12, 14, 13, 10, 2, 11, 11, 13, 4, 4, 9, 10]

nb_tests = 100
b        = 1   # Add RC to bit b

counts = [0, 0]
for i in range(nb_tests):
    input_value = [random.randint(0, 15) for i in range(16)]
    input_projection  = reduce(xor, map(g, input_value))
    output_value = midori64(input_value, 16, key, b)
    output_projection = reduce(xor, map(g, output_value))
    counts[xor(input_projection, output_projection)] += 1

print "Correlation: ", 2 * counts[1] / sum(counts) - 1