
import chainer
import chainer.functions as F
import chainer.links as L
import chainer.iterators as I
import chainer.optimizers as O
import chainer.computational_graph as c
from chainer import serializers
#rdkit
import rdkit
from rdkit import Chem
from rdkit.Chem import rdmolops
from rdkit.Chem import MolFromSmiles
from rdkit.Chem import inchi
from rdkit.Chem.inchi import MolFromInchi
from rdkit.Chem import Draw
from rdkit.Chem.Draw import DrawingOptions

import load_data_inchi
from load_data_inchi import make_dataset, divide_dataset, get_atom_list, get_edge_matrix, MAX_NUMBER_ATOM

import model
from model import Net
# numpy
import numpy
import math
import sys
import six
import os
from os import listdir
from os.path import isfile, join
MAX_NUMBER_CANDIDATES = 10000

def extract_feature(atom2id_file, model_file, inchi_file, batchsize = 100):
	hidden_dim = 50
	out_dim = 1000
	feat_dim = 300
	max_degree = 4
	n_atom_type = 12
	radius = 6
	n_epoch = 100
	batchsize = 100
	atom2id = numpy.load(atom2id_file).item()
	model = Net(hidden_dim, out_dim, feat_dim, n_atom_type, radius)



	serializers.load_npz(model_file, model)
	#print('inchi_file:', inchi_file)
	ifile = open(inchi_file).read()
	lines = ifile.split('\n')[:-1]
	nlines = len(lines)
	nlines = min(nlines, MAX_NUMBER_CANDIDATES)
	init = True
	PHI = None
	count = 0

	with chainer.using_config('train', False):
		for i in range(0, nlines, batchsize):
			atom_arrays = []
			adjs = []
			maxid = min(i + batchsize, nlines)
			batch = []
			for id in range(i, maxid):
				items = lines[id].split(' ')
				inchi = items[-1]
				mol = MolFromInchi(inchi)
				if mol is not None:
					adj = get_edge_matrix(mol)
					atom_list = get_atom_list(mol)
					atom_ids = [atom2id[a] for a in atom_list]
					n_atom = len(atom_list)
					atom_array = numpy.zeros(MAX_NUMBER_ATOM, dtype=numpy.int32)
					atom_array[:n_atom] = numpy.array(atom_ids)
					adjs.append(adj)
					atom_arrays.append(atom_array)
				else:
					size = MAX_NUMBER_ATOM
					adj = numpy.zeros((4, size, size), dtype=numpy.float32)
					atom_array = numpy.zeros(MAX_NUMBER_ATOM, dtype=numpy.int32)
					adjs.append(adj)
					atom_arrays.append(atom_array)
					count+=1
			adjs = numpy.asarray(adjs)
			atom_arrays = numpy.asarray(atom_arrays)
			phi = model(adjs, atom_arrays).data
			if init == True:
				PHI = phi
				init = False
			else:
				PHI = numpy.concatenate((PHI, phi), axis = 0)
	return PHI


if __name__ == '__main__':
	

	str1 = sys.argv[1]
	str2 = sys.argv[2]
	str3 = sys.argv[3]
	PHI = extract_feature(str1, str2, str3)
	numpy.savetxt('../Step2/temp.txt', PHI, delimiter=' ')
	
	




