import chainer
import chainer.functions as F
import chainer.links as L
import chainer.iterators as I
import chainer.optimizers as O
import chainer.computational_graph as c
from chainer import serializers
#rdkit
import rdkit
from rdkit import Chem
from rdkit.Chem import rdmolops
from rdkit.Chem import MolFromSmiles
from rdkit.Chem import inchi
from rdkit.Chem.inchi import MolFromInchi
from rdkit.Chem import Draw

# numpy
import numpy
import math
import sys
import six
import os
from os import listdir
from os.path import isfile, join



import load_data_inchi
from load_data_inchi import make_dataset, divide_dataset, get_atom_list, get_edge_matrix, MAX_NUMBER_ATOM
import model
from model import Net

def convert_folder(net, atom2id, infolder, outfolder, batchsize = 300):
	filenames = [file for file in listdir(infolder) if isfile(join(infolder, file))]
	#print(filenames)

	count = 0
	invFilenames = []
	for filename in filenames:
		print(filename)
		infile = join(infolder, filename)
		outfile = join(outfolder, filename)
		invalid = convert_file(net, atom2id, infile, outfile, batchsize)
		if invalid == 1:
			count += 1
			invFilenames.append(filename)
			print(count)
	print('number of invalid molecules:', count)
	f = open('invalid_filenames.txt', 'w')
	for filename in invFilenames:
		f.write(filename)
		f.write('\n')
	f.close()


def convert_file(net, atom2id, ifilename, ofilename, batchsize= 300):
	
	ifile = open(ifilename).read()

	ofile = open(ofilename, 'w')
	lines = ifile.split('\n')[:-1]
	nlines = len(lines)
	init = True
	PHI = None
	count = 0
	for i in range(0, nlines, batchsize):
		atom_arrays = []
		adjs = []
		maxid = min(i + batchsize, nlines)
		batch = []
		
		for id in range(i, maxid):
			items = lines[id].split(' ')
			#print(items)
			inchi = items[-1]
			mol = MolFromInchi(inchi)
			if mol is not None:
				adj = get_edge_matrix(mol)
				atom_list = get_atom_list(mol)
				atom_ids = [atom2id[a] for a in atom_list]
				n_atom = len(atom_list)
				atom_array = numpy.zeros(MAX_NUMBER_ATOM, dtype=numpy.int32)
				atom_array[:n_atom] = numpy.array(atom_ids)
				adjs.append(adj)
				atom_arrays.append(atom_array)
			else:
				size = MAX_NUMBER_ATOM
				adj = numpy.zeros((4, size, size), dtype=numpy.float32)
				atom_array = numpy.zeros(MAX_NUMBER_ATOM, dtype=numpy.int32)
				adjs.append(adj)
				atom_arrays.append(atom_array)
				count += 1
				#print('count:', count)

		adjs = numpy.asarray(adjs)
		atom_arrays = numpy.asarray(atom_arrays)
		phi = net(adjs, atom_arrays).data
		if init == True:
			PHI = phi
			init = False
		else:
			PHI = numpy.concatenate((PHI, phi), axis = 0)
	numpy.savetxt(ofilename, PHI, delimiter=' ')
	if count != 0:
		return 1
	else:
		return 0
		

def convert_file1(net, atom2id, ifilename, ofilename, batchsize= 300):
	#print(ifilename)
	ifile = open(ifilename).read()

	ofile = open(ofilename, 'w')
	lines = ifile.split('\n')[:-1]
	nlines = len(lines)
	init = True
	PHI = None

	with chainer.using_config('train', False):
		for i in range(0, nlines, batchsize):
			atom_arrays = []
			adjs = []
			maxid = min(i + batchsize, nlines)
			batch = []
			for id in range(i, maxid):
				items = lines[id].split(' ')[:-1]
			#print(items)
				inchi = items[-1]
				mol = MolFromInchi(inchi)
				if mol is not None:
					adj = get_edge_matrix(mol)
					atom_list = get_atom_list(mol)
					atom_ids = [atom2id[a] for a in atom_list]
					n_atom = len(atom_list)
					atom_array = numpy.zeros(MAX_NUMBER_ATOM, dtype=numpy.int32)
					atom_array[:n_atom] = numpy.array(atom_ids)
					adjs.append(adj)
					atom_arrays.append(atom_array)
				else:
					size = MAX_NUMBER_ATOM
					adj = numpy.zeros((4, size, size), dtype=numpy.float32)
					atom_array = numpy.zeros(MAX_NUMBER_ATOM, dtype=numpy.int32)
					adjs.append(adj)
					atom_arrays.append(atom_array)

			adjs = numpy.asarray(adjs)
			atom_arrays = numpy.asarray(atom_arrays)
			phi = net(adjs, atom_arrays).data
			if init == True:
				PHI = phi
				init = False
			else:
				PHI = numpy.concatenate((PHI, phi), axis = 0)
	numpy.savetxt(ofilename, PHI, delimiter=' ')






def train_model(model_name, train_descriptors, valid_descriptors, train_kernel, valid_kernel, hidden_dim, out_dim, feat_dim, max_degree, n_atom_type, radius, n_epoch, batchsize, gamma = 0.0):
	# initialize model
	print('initializing model...')
	model = Net(hidden_dim, out_dim, feat_dim, n_atom_type, radius)
	# setup an optimizer
	optimizer = chainer.optimizers.Adam()
	optimizer.setup(model)
	print('starting training...')
	
	for e in range(1, n_epoch + 1):
		indexes = numpy.random.permutation(len(train_descriptors))

		hsic = 0
		reg = 0
		tloss = 0
		for i in range(0, len(train_descriptors), batchsize):
			adjs = []
			atom_arrays = []
			maxid = min(i + batchsize, len(train_descriptors))
			KX = train_kernel[indexes[i:maxid], :]
			KX = KX[:, indexes[i:maxid]]
			# version 1
			sizeK = KX.shape[0]
			H = numpy.eye(sizeK, dtype=numpy.float32) - 1/sizeK * numpy.ones(sizeK, dtype=numpy.float32)
			KX = numpy.dot(H, numpy.dot(KX, H))
			KX = KX / numpy.linalg.norm(KX, 'fro')
			#sizeK = KX.shape[0]
			#tmp = numpy.ones(sizeK, dtype=numpy.float32)
			#m = KX.shape[0]
			#KX_ = KX - numpy.diag(numpy.diag(KX))
			
			#print(xxx)
			#term1 = KX_
			#term2 = 1/((m - 1) * (m - 2)) * numpy.matmul(tmp, numpy.matmul(KX_, tmp))
			#term3 = 2/(m-2) * numpy.matmul(tmp, KX_)

			#S = 1/(m * (m-3)) * (term1 + term2 - term3)



		
			for id in indexes[i:maxid]:
				adjs.append(train_descriptors[id][0])
				atom_arrays.append(train_descriptors[id][1])
			adjs = numpy.asarray(adjs)
			atom_arrays = numpy.asarray(atom_arrays)
			phi = model(adjs, atom_arrays)
			

			loss = model.get_loss_func(phi, KX)
			model.cleargrads()
			loss.backward()
			if e > 1:
				optimizer.update()
			hsic += float(model.hsic.data) * len(adjs)
			tloss += float(model.loss.data) * len(adjs)
		print('epoch:', e, ',  HSIC:', hsic/len(train_descriptors),  ',LOSS:', tloss/len(train_descriptors))

		# for validation
		with chainer.using_config('train', False):
			hsic = 0
			reg = 0
			tloss = 0
			batchsize_valid = 100
			for i in range(0, len(valid_descriptors), batchsize_valid):

				adjs = []
				atom_arrays = []
				maxid = min(i + batchsize_valid, len(valid_descriptors))
				KX = valid_kernel[i:maxid, :]
				KX = KX[:, i:maxid]
				sizeK = KX.shape[0]
				H = numpy.eye(sizeK, dtype=numpy.float32) - 1/sizeK * numpy.ones(sizeK, dtype=numpy.float32)
				KX = numpy.dot(H, numpy.dot(KX, H))
			#D = numpy.diag(KX)
			#D = numpy.reshape(D, (sizeK,1))
			#DD = numpy.dot(D, D.transpose())
			#KX = numpy.divide(KX, numpy.sqrt(DD))
				KX = KX / numpy.linalg.norm(KX, 'fro')
				#sizeK = KX.shape[0]
				#tmp = numpy.ones(sizeK, dtype=numpy.float32)
				#m = KX.shape[0]
				#KX_ = KX - numpy.diag(numpy.diag(KX))

				#term1 = KX_
				#term2 = 1/((m - 1) * (m - 2)) * numpy.matmul(tmp, numpy.matmul(KX_, tmp))
				#term3 = 2/(m-2) * numpy.matmul(tmp, KX_)
				#S = 1/(m * (m-3)) * (term1 + term2 - term3)
				for id in range(i, maxid):
					adjs.append(valid_descriptors[id][0])
					atom_arrays.append(valid_descriptors[id][1])
				adjs = numpy.asarray(adjs)
				atom_arrays = numpy.asarray(atom_arrays)
				phi = model(adjs, atom_arrays)
				loss = model.get_loss_func(phi, KX)
				hsic += float(model.hsic.data) * len(adjs)
				#reg += float(model.reg.data) * len(adjs)
				tloss += float(model.loss.data) * len(adjs)
			print('+++ Validation:  HSIC:', hsic/len(valid_descriptors), ',LOSS:', tloss/len(valid_descriptors))
		serializers.save_npz(model_name, model)
		
	return model

def list_file_in_folder(fold_name, outfile):
	filenames = [file for file in listdir(fold_name) if isfile(join(fold_name, file))]
	f = open(outfile,'w')
	for filename in filenames:
		f.write(filename)
		f.write('\n')
	f.close()

def main():
	# list of inchi for all molecules
	ifilename = 'wholeinchi.txt'
	ofilename = 'wholefp243.txt'
	model_name = 'my_model243'
	kfilename = 'wholekernel24.txt'

	atom2id, index_set, descriptors, kernel = make_dataset(kfilename, ifilename)
	train_descriptors, valid_descriptors, train_kernel, valid_kernel = divide_dataset(index_set, descriptors, kernel, 1)
	#parameters for training

	hidden_dim = 50
	out_dim = 1000
	feat_dim = 300
	max_degree = 4
	n_atom_type = 12
	radius = 6
	n_epoch = 50
	batchsize = 100


	numpy.save('atom2id.npy', atom2id)
	print('out file:', ofilename, ', model file: ', model_name, ', kernel file: ', kfilename)
	print('hidden dim:', hidden_dim, ', out dim:', out_dim, ', feat dim:', feat_dim, ', radius: ', radius)
	trained_model = train_model(model_name, train_descriptors, valid_descriptors, train_kernel, valid_kernel, hidden_dim, out_dim, feat_dim, max_degree, n_atom_type, radius, n_epoch, batchsize)
	print('converting file...')
	
	convert_file1(trained_model, atom2id, ifilename, ofilename)
	#convert_folder(trained_model, atom2id, candidate_inchi, candidate_fp, batchsize)

if __name__ == '__main__':
	main()











