import rdkit
from rdkit import Chem
from rdkit.Chem import rdmolops
from rdkit.Chem import MolFromSmiles
from rdkit.Chem import inchi
from rdkit.Chem.inchi import MolFromInchi
from rdkit.Chem import Draw

import os
from os import listdir
from os.path import isfile, join
import numpy
MAX_NUMBER_ATOM = 75

def check_candidate_set(folder):
	count = 0
	filenames = [file for file in listdir(folder) if isfile(join(folder, file))]
	atom_list = []
	bond_list = []
	for filename in filenames:
		fname = join(folder, filename)
		print(fname)
		fhandle = open(fname)
		file = fhandle.read()
		for inchi in file.split('\n')[:-1]:
			#print(inchi)
			mol = MolFromInchi(inchi)
			if mol == None:
				print('count=', count)
				count+=1

				continue
			N = mol.GetNumAtoms()
			for i in range(N):
				atom = mol.GetAtomWithIdx(i).GetSymbol()
				if atom not in atom_list:
					atom_list.append(atom)
		fhandle.close()
	print(atom_list)
	print(count)



	


def load_kernel(kfilename):
	print('loading kernel file...')
	file = open(kfilename).read()
	kernel = [list(map(float, item.split())) for item in file.split('\n')[:-1]]
	kernel = numpy.array(kernel, dtype = numpy.float32)
	return kernel

def load_inchi(ifilename):
	print('loading inchi file...')
	file = open(ifilename).read()
	index_set = []
	mol_set = []
	maxatom = 0;
	for idx, item in enumerate(file.split('\n')[:-1]):
		index = int(item.split()[0])
		inchi = item.split()[1]
		mol = MolFromInchi(inchi)
		natoms = mol.GetNumAtoms()
		if natoms > maxatom:
			maxatom = natoms
		#print(index, mol.GetNumAtoms())
		#if idx < 10:
		#Draw.MolToFile(mol, 'images/'+ str(idx) + '.png')
		index_set.append(index)
		mol_set.append(mol)
	print('MAX NUMBER of ATOMS is ', maxatom)
	return numpy.array(index_set, dtype=numpy.int32), mol_set
def check_atom_bond(mol_set):
	atom_list = []
	bond_list = []
	for mol in mol_set:
		N = mol.GetNumAtoms()
		for i in range(N - 1):
			atom = mol.GetAtomWithIdx(i).GetSymbol()
			if atom not in atom_list:
				atom_list.append(atom)
			for j in range(i+1, N):
				print(i, j)
				bond = mol.GetBondBetweenAtoms(i, j)
				if bond is None:
					continue
				bondType = str(bond.GetBondType())
				print(bondType)
				if bondType not in bond_list:
					bond_list.append(bondType)
	return atom_list, bond_list


def get_atom_list(mol):
	if mol is None:
		return None
	atom_list = [a.GetSymbol() for a in mol.GetAtoms()]
	return atom_list

def get_edge_matrix(mol):
	if mol is None:
		return None
	N = mol.GetNumAtoms()
	size = MAX_NUMBER_ATOM
	adj = numpy.zeros((4, size, size), dtype=numpy.float32)

	for i in range(N):
		for j in range(N):
			bond = mol.GetBondBetweenAtoms(i, j)
			if bond is not None:
				bondType = str(bond.GetBondType())
				if bondType == 'SINGLE':
					adj[0, i, j] = 1.0
				elif bondType == 'DOUBLE':
					adj[1, i, j] = 1.0
				elif bondType == 'TRIPLE':
					adj[2, i, j] = 1.0
				elif bondType == 'AROMATIC':
					adj[3, i, j] = 1.0
				else:
					print('[ERROR] Unknown bond type', bondType)
					assert False
	return adj

def getMols(filename):
	descriptors = []
	index_set, mol_set = load_inchi(filename)
	for index, mol in zip(index_set, mol_set):
		
		adj = get_edge_matrix(mol)
		atom_list = get_atom_list(mol)
		descriptors.append((adj, atom_list))
		#print(index, atom_list)
	return index_set, descriptors
def getAtom2id(descriptors):

	atom2id = {'empty':0}

	for descriptor in descriptors:
		atom_list = descriptor[1]
		for a in atom_list:
			if a not in atom2id:
				atom2id[a] = len(atom2id)
	# converting dataset
	ret = []
	for descriptor in descriptors:
		adj = descriptor[0]
		atom_list = descriptor[1]
		atom_ids = [atom2id[a] for a in atom_list]
		n_atom = len(atom_list)
		atom_array = numpy.zeros(MAX_NUMBER_ATOM, dtype=numpy.int32)
		atom_array[:n_atom] = numpy.array(atom_ids)
		ret.append((adj, atom_array))
	return atom2id, ret

def make_dataset(kfilename, ifilename):
	kernel = load_kernel(kfilename)
	index_set, descriptors = getMols(ifilename)
	atom2id, descriptors = getAtom2id(descriptors)
	return atom2id, index_set, descriptors, kernel
def divide_dataset(index_set, descriptors, kernel, valid_id):
	print('Diving dataset into valid and training set')
	train_set = []
	valid_set = []
	train_index = []
	valid_index = []

	for i in range(len(index_set)):
		if index_set[i] == valid_id:
			valid_set.append(descriptors[i])
			valid_index.append(i)
		else:
			train_set.append(descriptors[i])
			train_index.append(i)
	train_index = numpy.array(train_index, dtype=numpy.int)
	valid_index = numpy.array(valid_index, dtype=numpy.int)
	valid_K = kernel[:, valid_index]
	valid_K = valid_K[valid_index, :]
	train_K = kernel[:, train_index]
	train_K = train_K[train_index, :]

	print('train size: ', len(train_set))
	print('valid size: ', len(valid_set))
	print('train kernel size:', train_K.shape)
	print('valid kernel size:', valid_K.shape)

	return train_set, valid_set, train_K, valid_K






#check_candidate_set('./candidates_inchi/')

#atom2id, index_set, descriptors, kernel = make_dataset('wholekernel.txt', 'wholeinchi.txt')
#print(atom2id)


#des = descriptors[0]
#print('list atom id : ', des[1])

#print('num of descriptrs:', len(descriptors))
#print('size of kernel: ', kernel.shape)
#print('size of index set :', len(index_set))











