function [ ] = run_IOKR_GNPS( datadir, results_file, iokr_param, select_param, ky_param)
%======================================================
% DESCRIPTION:
% This script is modified from the originl implementation
% of Brouard et al (Bioinformatics 2016).
% Script for running IOKR on the GNPS dataset
%
% INPUTS:
% datadir:      string corresponding to the directory that contains the data
% results_file: string of the file in which the results should be saved (should be a .txt file)
% iokr_param:   1*1 struct array containing information relative to
%               centering and multiple kernel learning
% select_param: 1*1 struct array containing information related to the parameter selection
% ky_param:     1*1 struct array containing the information related to the output kernel
%
%======================================================

    addpath('general_functions');
    % Load Data
    [KX_list, Y1, inchi, mf, eval] = load_data_GNPS(datadir);
    n_sample = size(Y1,2);
    KX_list1 = KX_list(1:24);
    Y = importdata('../Step1/wholefp243.txt', ' ')';
    model_name = '../Step1/my_model243';
    MAX_NUMBER_CANDIDATES = 10000;
    %rdim =size(Y1, 1);
    %Y = Y1(1:rdim, :);
    n = size(Y, 2);
    rank = zeros(n, 1) + 1000;
    %--------------------------------------------------------------
    % Cross-validation
    %--------------------------------------------------------------

    disp('Begin cross-validation')

    n_folds = 10; % number of folds
    ind_fold = load([datadir 'cv_ind.txt']); % indices of the different folds
    avg_time = 0;
    avg_time_ = 0;
    list_top1 = [];
    list_top10 = [];
    list_top20 = [];

    for i = 1:n_folds
        disp(['Now starting iteration ', int2str(i), ' out of ', int2str(n_folds)])

        % Create training and test sets
        test_set = find(ind_fold == i);
        n_test = length(test_set);
        train_set = setdiff(1:n_sample,test_set);
        test_set = intersect(test_set,eval); % restrict to the test examples used for evaluation

        % Training
        KX_list_train = cellfun(@(x) x(train_set,train_set), KX_list1, 'UniformOutput', false);
        Y_train = Y(:,train_set);
        %inchi_train = inchi(train_set);
        train_model = Train_IOKR(KX_list_train, Y_train, ky_param, select_param, iokr_param);

        % Candidate sets for the test examples
        n_test = length(test_set);
        Y_C_test = cell(n_test,1);
        inchi_C_test = cell(n_test,1);
        for j = 1:n_test
            display(num2str(j));
            % xxxxxxxxx
             cand_file_name = ['../DeepIOKR/candidates_inchi/' mf{test_set(j)} '_' num2str(test_set(j)) '.txt'];
             atom2id_name = '../Step1/atom2id.npy';
             
             python_path = '/Users/haidnguyen0909/miniconda3/envs/my-rdkit-env/bin/python';
             command_path = '../Step1/extract_feature.py';
             systemCommand = [python_path ' ' command_path ' ' atom2id_name ' ' model_name ' ' cand_file_name];
             system(systemCommand);
             tmp_file = 'temp.txt';
             cand_set = importdata(tmp_file, ' ');
             Y_C_test{j} = cand_set';
             cand_set = load([datadir 'candidates/candidate_set_' mf{test_set(j)} '.mat']);
             n_inchi = length(cand_set.inchi);
             n_inchi = min(n_inchi, MAX_NUMBER_CANDIDATES);
             inchi_C_test{j} = cand_set.inchi(1:n_inchi);
            % ++++++++
            
            
%             cand_set = load([datadir 'candidates/candidate_set_' mf{test_set(j)} '.mat']);
%             n_inchi = length(cand_set.inchi);
%             n_inchi = min(n_inchi, MAX_NUMBER_CANDIDATES);
%             inchi_C_test{j} = cand_set.inchi(1:n_inchi);   
%              tmp = full(cand_set.fp);
%              Y_C_test{j} = tmp(1:rdim, 1:n_inchi);
            % +++++++++
        end

        % Prediction 
        display('prediction');
        
        t1= clock;
        KX_list_train_test = cellfun(@(x) x(train_set,test_set), KX_list1, 'UniformOutput', false);
        KX_list_test = cellfun(@(x) x(test_set,test_set), KX_list1, 'UniformOutput', false);
        KX_train_test = input_kernel_preprocessing_test(KX_list_train_test, KX_list_test, train_model.process_input, iokr_param.center);
        t1_= clock;
        score = Test_IOKR(KX_train_test, train_model, Y_train, Y_C_test, iokr_param);
        % Rank computation
        top1 = 0;
        top10 = 0;
        top20 = 0;
        for j = 1:length(test_set)
            k = test_set(j);
            [~,IX] = sort(score{j},'descend');
            %display(IX);
            index = find(strcmp(inchi_C_test{j}(IX),inchi{k}));
            if length(index) == 0
                rank(k) = 10000;
            else
                rank(k) = index;
            end
            if rank(k) <= 1
                top1 = top1 + 1;
            end
            if rank(k) <= 10
                top10 = top10 + 1;
            end
            if rank(k) <=20
                top20 = top20 + 1;
            end
        end
        top1 = top1/length(test_set);
        top10 = top10/length(test_set);
        top20 = top20/length(test_set);
        list_top1 = [list_top1 top1];
        list_top10 = [list_top10 top10];
        list_top20 = [list_top20 top20];
        t2 = clock;
        t = etime(t2, t1);
        t_ = etime(t2, t1_);
        avg_time = avg_time + t;
        avg_time_ = avg_time_ + t_;
    end

    % Computation of the percentage of identified metabolites in the top k
    fprintf('computing and saving the results into files');
    topk = zeros(20,1);
    for k = 1:20
        topk(k) = sum(rank(eval) <= k) / length(eval)*100;
    end
    avg_time = avg_time / n_sample;
    avg_time_ = avg_time_ / n_sample;
    display(num2str(avg_time));
    display(num2str(avg_time_));
    mean_top1 = mean(list_top1); std_top1 = std(list_top1);
    mean_top10 = mean(list_top10); std_top10 = std(list_top10);
    mean_top20 = mean(list_top20); std_top20 = std(list_top20);
    display(list_top1);
    display(list_top10);
    display(list_top20);
    fprintf('top 1: %f, %f\n', mean_top1, std_top1);
    fprintf('top 10: %f, %f\n', mean_top10, std_top10);
    fprintf('top 20: %f, %f\n', mean_top20, std_top20);
    save(results_file,'topk','-ascii');
    %display(num2str(topk))
end


