import itertools, operator
import autograd.numpy as np
import numpy.random as npr

from pymanopt.manifolds import Sphere, Product
from pymanopt import Problem
from pymanopt.solvers import ConjugateGradient
from autograd.test_util import check_grads

lambdas = [1, 1, 1, -1]
As = [
  np.array([[1, 0], [0, 1]]),
  np.array([[0, 1], [1, 0]]),
  np.array([[1, 0], [0, -1]]),
  np.array([[0, 1], [-1, 0]])
]

mc_terms = []
for js in itertools.product(range(4), repeat = 4):
  mc_terms.append((
    reduce(operator.mul, (lambdas[i] for i in js)),
    reduce(np.kron, (As[i] for i in js))
  ))

CS = np.array([
  [1, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00],
  [0, 0.25, 0.50, 0.25,-0.25, 0.00, 0.25, 0.00,-0.25, 0.00, 0.25, 0.00, 0.50,-0.25, 0.00,-0.25],
  [0, 0.50, 0.00, 0.00, 0.50, 0.00, 0.00, 0.00,-0.50, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00, 0.00],
  [0, 0.25, 0.00, 0.25,-0.25, 0.00, 0.25, 0.50, 0.25,-0.50,-0.25, 0.00, 0.00, 0.25, 0.00, 0.25],
  [0,-0.25, 0.50,-0.25, 0.25, 0.00,-0.25, 0.00,-0.25,-0.50,-0.25, 0.00, 0.00,-0.25, 0.00, 0.25],
  [0, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00,-0.50,-0.50, 0.00, 0.00, 0.50,-0.50],
  [0, 0.25, 0.00, 0.25,-0.25, 0.00, 0.25,-0.50,-0.25, 0.00,-0.25, 0.00,-0.50,-0.25, 0.00, 0.25],
  [0, 0.00, 0.00, 0.50, 0.00, 0.00,-0.50, 0.00, 0.00, 0.00, 0.00,-0.50, 0.00, 0.00,-0.50, 0.00],
  [0,-0.25,-0.50, 0.25,-0.25, 0.00,-0.25, 0.00,-0.50,-0.25, 0.00, 0.25, 0.25, 0.00, 0.25, 0.00],
  [0, 0.00, 0.00,-0.50,-0.50, 0.00, 0.00, 0.00,-0.25, 0.25,-0.25,-0.25, 0.25, 0.25,-0.25, 0.25],
  [0, 0.25, 0.00,-0.25,-0.25,-0.50,-0.25, 0.00, 0.00,-0.25, 0.50,-0.25,-0.25, 0.00, 0.25, 0.00],
  [0, 0.00, 0.00, 0.00, 0.00,-0.50, 0.00,-0.50, 0.25,-0.25,-0.25, 0.25, 0.25, 0.25,-0.25,-0.25],
  [0, 0.50, 0.00, 0.00, 0.00, 0.00,-0.50, 0.00, 0.25, 0.25,-0.25, 0.25, 0.25,-0.25, 0.25, 0.25],
  [0,-0.25, 0.50, 0.25,-0.25, 0.00,-0.25, 0.00, 0.00, 0.25, 0.00, 0.25,-0.25, 0.50, 0.25, 0.00],
  [0, 0.00, 0.00, 0.00, 0.00, 0.50, 0.00,-0.50, 0.25,-0.25, 0.25,-0.25, 0.25, 0.25, 0.25, 0.25],
  [0,-0.25, 0.00, 0.25, 0.25,-0.50, 0.25, 0.00, 0.00, 0.25, 0.00,-0.25, 0.25, 0.00, 0.25, 0.50]
])
CS_r = CS[:, 1:]

def toArray(x):
  return np.squeeze(np.asarray(x))

def roundCorrelation(u, v):
  """ Computes <C^S u, C^M C^S v>. """

  #z = map(lambda x : np.dot(CS_r, x), u)
  #w = map(lambda x : np.dot(CS_r, x), v)

  z = [None] * 4
  w = [None] * 4
  for i in range(4):
    if v[i] is None:
      w[i] = CS[:, 0]
    else:
      w[i] = np.dot(CS_r, v[i])

    if u[i] is None:
      z[i] = CS[:, 0]
    else:
      z[i] = np.dot(CS_r, u[i])

  result = 0
  for s, A in mc_terms:
    #A_r = A[1:,1:]
    result = result + s * reduce(operator.mul, (np.dot(z[i], np.dot(A, w[i])) for i in range(4)))
  return result / 16


# Expected output: 1
print roundCorrelation([
    np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, -.5, -.5, 0, 0, .5, -.5]),
    np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, -.5, -.5, 0, 0, .5, -.5]),
    np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, -.5, -.5, 0, 0, .5, -.5]),
    np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, -.5, -.5, 0, 0, .5, -.5])], [
    np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
    np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
    np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
    np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])]
)

# Expected output: 0
print roundCorrelation([
    np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, -.5, -.5, 0, 0, .5, -.5]),
    np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, -.5, -.5, 0, 0, .5, -.5]),
    np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, -.5, -.5, 0, 0, .5, -.5]),
    np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, -.5, -.5, 0, 0, .5, -.5])], [
    np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, -.5, -.5, 0, 0, .5, -.5]),
    np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, -.5, -.5, 0, 0, .5, -.5]),
    np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, -.5, -.5, 0, 0, .5, -.5]),
    np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, -.5, -.5, 0, 0, .5, -.5])]
)

v = [np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -.5, -.5, 0, 0, .5, -.5]),
    np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -.5, -.5, 0, 0, .5, -.5]),
    np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -.5, -.5, 0, 0, .5, -.5]),
    np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -.5, -.5, 0, 0, .5, -.5])]

# Number of rounds
r = 2
# Activity patterns for all masks
activity_pattern = [
    1, 1, 0, 0,
    1, 1, 0, 0
]
active_positions = [i for (i, x) in enumerate(activity_pattern) if x == 1]

def cost(vs):
  c = 1
  for i in range(r - 1):
    j = 4 * i

    vs_in  = [
        vs[active_positions.index(k)] if k in active_positions else None for k in range(j, j+4)
    ]
    vs_out = [
        vs[active_positions.index(k)] if k in active_positions else None for k in range(j+4, j+8)
    ]
    c = c * roundCorrelation(vs_in, vs_out)
  return -np.log(np.abs(c))

# Numerical gradient check
# check_grads(cost, modes=['rev'], order=2)([npr.rand(15) for i in range(8)])

manifold = Product(
    tuple([Sphere(15) for i in range(activity_pattern.count(1))])
)

problem = Problem(manifold=manifold, cost=cost, verbosity=2)
solver = ConjugateGradient(logverbosity=2)

nb_runs = 10
eps = 1e-3

current_best = None
optima = []
optimizers = []
for i in range(nb_runs):
  opt, optlog = solver.solve(problem)
  opt_cost = cost(opt)
  if all(opt_cost < x - eps for x in optima):
    del optima[:]
    del optimizers[:]
    current_best = opt_cost
  if opt_cost <= current_best + eps:
    optima.append(opt_cost)
    optimizers.append(opt)
  
print "Optima: ", optima
for opt in optimizers:
  print opt
  print