import chainer
import chainer.functions as F
import chainer.links as L
from chainer import serializers
import numpy
import six
import kernel_funcs
from kernel_funcs import kernel_linear, kernel_gaussian, trace


NUM_EDGE_TYPE = 4
MAX_ATOM_TYPE = 11

class SubNet(chainer.Chain):
	def __init__(self, hidden_dim, out_dim):
		super(SubNet, self).__init__()
		with self.init_scope():
			self.edge_layer = L.Linear(hidden_dim, NUM_EDGE_TYPE * hidden_dim)
			self.out_layer = L.Linear(hidden_dim, out_dim)
		self.hidden_dim = hidden_dim
		self.out_dim = out_dim
		self.train = True
		self.dr = 0.5


	def __call__(self, x, h, adj):
		#print('++++ this is called from SubNet class')
		s0, s1, s2 = x.shape
		tmp = self.edge_layer(F.reshape(x, (s0 * s1, s2)))
		m = F.reshape(tmp, (s0, s1, s2, NUM_EDGE_TYPE))
		m = F.transpose(m, (0, 3, 1, 2))
		adj_t = F.reshape(adj, (s0 * NUM_EDGE_TYPE, s1, s1))
		m = F.reshape(m, (s0 * NUM_EDGE_TYPE, s1, s2))
		m = F.batch_matmul(adj_t, m)
		m = F.reshape(m, (s0, NUM_EDGE_TYPE, s1, s2))
		m = F.sum(m, axis = 1)
		m = m + x

		out_x = F.relu(m)
		#out_x = F.dropout(out_x, ratio=0.4)
		dh = self.out_layer(F.reshape(out_x, (s0 * s1, s2)))
		dh = F.softmax(dh)
		dh = F.reshape(dh, (s0, s1, self.out_dim))
		dh = F.sum(dh, axis = 1)
		out_h = dh
		return out_x, out_h

	def compute_activation(self, x, h, adj):
		s0 , s1, s2 = x.shape
		tmp = self.edge_layer(F.reshape(x, (s0 * s1, s2)))
		m = F.reshape(tmp, (s0, s1, s2, NUM_EDGE_TYPE))
		m = F.transpose(m, (0, 3, 1, 2))
		adj_t = F.reshape(adj, (s0 * NUM_EDGE_TYPE, s1, s1))
		m = F.reshape(m, (s0 * NUM_EDGE_TYPE, s1, s2))
		m = F.batch_matmul(adj_t, m)
		m = F.reshape(m, (s0, NUM_EDGE_TYPE, s1, s2))
		m = F.sum(m, axis = 1)
		m = m + x
		out_x = F.relu(m)
		dh = self.out_layer(F.reshape(out_x, (s0 * s1, s2)))
		dh = F.softmax(dh)
		activation = F.reshape(dh, (s0, s1, self.out_dim))
		out_h = F.sum(dh, axis = 1)
		return activation, out_x, out_h


class Net(chainer.Chain):
	def __init__(self, hidden_dim, out_dim, feat_dim, n_atom_type, radius):
		super(Net, self).__init__()
		num_degree_type = radius + 1
		with self.init_scope():
			self.embed = L.EmbedID(n_atom_type, hidden_dim)
			self.layers = chainer.ChainList(
				*[SubNet(hidden_dim, out_dim) for _ in range(radius)])
			self.feature_layer = L.Linear(out_dim, feat_dim)
		self.hidden_dim = hidden_dim
		self.radius = radius
		self.out_dim = out_dim

	def compute_activations(self, adjs, atom_arrays, fingerId, mol_ids):
		x = self.embed(atom_arrays)
		values = {}
		h = 0
		max_value = 0.0
		max_mol_id = -1
		max_atom_id = -1
		max_level_id = -1
		for (level, layer) in enumerate(self.layers):
			activation, x, h = layer.compute_activation(x, h, adjs)
			# we have (mol_ids, activation, level)
			s0, s1, s2 = activation.shape
			activation = F.reshape(activation, (s0 * s1, s2))
			for (order, mol_id) in enumerate(mol_ids):
				for atom_id in range(s1):
					act = activation.data[(order)*s1 + atom_id][fingerId]
					#return
					if act > max_value:
						max_value = act
						max_mol_id = mol_id
						max_atom_id = atom_id
						max_level_id = level
						print(level, max_mol_id,max_atom_id, max_value)


	def __call__(self, adjs, atom_arrays):
		#print('this is called from Net class')
		x = self.embed(atom_arrays)
		h = 0
		h_total = 0
		self.loss = 0
		for (level, layer) in enumerate(self.layers):
			x, h = layer(x, h, adjs)
			h_total += h
		phi = self.feature_layer(h_total)
		return phi


	def get_loss_func(self, phi, KX, ktype='linear', sigma = 1.0, gamma = 0.0):


		# verion 1

		phi_squared = phi ** 2
		phi_sum_squares = F.sum(phi_squared, axis = 1)
		phi_norm = F.sqrt(phi_sum_squares)
		tmp = F.reshape(phi_norm, (len(phi), 1))
		tmp = F.tile(tmp, (1, phi.shape[1]))
		phi = phi.__div__(tmp)
		mean_phi = F.mean(phi, axis = 0)
		phi = phi - F.tile(mean_phi, (phi.shape[0], 1))

		# version 2
		#phi_squared = phi ** 2
		#phi_sum_squares = F.sum(phi_squared, axis = 1)
		#phi_norm = F.sqrt(phi_sum_squares)
		#tmp = F.reshape(phi_norm, (len(phi), 1))
		#tmp = F.tile(tmp, (1, phi.shape[1]))
		#phi = phi.__div__(tmp)
		# calculate HSIC between Phi and Kernel


		if ktype == 'linear':
			K_phi = kernel_linear(phi, phi)
		if ktype == 'gaussian':
			K_phi = kernel_gaussian(phi, phi, sigma)
		sizeK = len(KX)
		


		square = F.square(K_phi)
		sum_square = F.sum(square)
		tmp = F.tile(F.sqrt(sum_square), K_phi.shape)
		K_phi = K_phi.__div__(tmp)
		self.hsic = - trace(F.matmul(KX, K_phi))
		self.reg = 0.0
		self.loss = self.hsic + gamma * self.reg
		return self.loss





















