function [ mse ] = IOKR_eval_mse( KX_train, Y_train, ky_param, select_param, iokr_param )
%======================================================
% DESCRIPTION:
% Computation of the mean squared errors (mse) for different values of the 
% regularization parameter
%
% INPUTS:
% KX_train:     training input Gram matrix
% Y_train:      matrix of size d*n_train containing the training fingerprint vectors
% ky_param:     1*1 struct array containing the information related to the output kernel
% select_param: 1*1 struct array containing information related to the parameter selection 
% iokr_param:   1*1 struct array containing information relative to
%               centering and multiple kernel learning
%
% OUTPUT:
% mse:          vector containing the mse obtained for each regularization parameter
%               in select_param.lambda
%
%====================================================== 

    % Vector containing the possible values for the regularization parameter
    val_lambda = select_param.lambda;
    
    n_train = size(KX_train,1);
    
    switch ky_param.type
        case 'linear'
            % Output feature vectors processing
            %Psi_train = norma(Y_train, mean(Y_train,2), iokr_param.center);
            %Y_train_n = norm_vec(Y_train);
            mean_Y = mean(Y_train, 2);
            %Psi_train = norma(Y_train_n, mean_Y, ker_center);
            Psi_train = norma( Y_train , mean_Y, iokr_param.center );
        otherwise
            % Output kernel processing
            KY_train = output_kernel_preprocessing_train(Y_train, ky_param, iokr_param.center);
    end
    
    switch select_param.cv_type
        
        % Parameter selection using inner cross-validation
        case 'cv'
        
            c = select_param.cv_partition; % CV partition
            n_folds = select_param.num_folds; % number of folds

            mse_cv = zeros(length(val_lambda),n_folds);

            for j = 1:n_folds % Cross-validation
                train_set_cv = find(training(c,j));
                test_set_cv = find(test(c,j));
                
                n_train_cv = length(train_set_cv);
                n_test_cv = length(test_set_cv);

                KX_train_cv = KX_train(train_set_cv, train_set_cv);
                KX_train_test_cv = KX_train(train_set_cv, test_set_cv);
                
                for il = 1:length(val_lambda)
                    
                    % Training                    
                    C_cv = val_lambda(il)*eye(n_train_cv) + KX_train_cv;
               
                    % Prediction
                    B_cv = C_cv \ KX_train_test_cv;
                    
                    % Computation of the mean squared error
                    switch ky_param.type
                        case 'linear'
                            
                            Psi_pred_cv = Psi_train(:,train_set_cv)*B_cv;
                            mse_cv(il,j) = 1/n_test_cv*norm(Psi_pred_cv - Psi_train(:,test_set_cv),'fro')^2;
                            
                        otherwise
                            
                            mse_cv(il,j) = 1 + 1/n_test_cv*trace(B_cv'*KY_train(train_set_cv,train_set_cv)*B_cv ...
                                            - 2*B_cv'*KY_train(train_set_cv,test_set_cv));
                    end
                end
            end
            mse = mean(mse_cv,2); % mean over the different folds
            
     
        % Parameter selection using leave-one-out cross-validation
        case 'loocv'
        
            mse = zeros(length(val_lambda),1);
            for il = 1:length(val_lambda)

                B = (val_lambda(il)*eye(n_train) + KX_train) \ KX_train;

                % Computation of the mean squared error
                LOOE = (eye(n_train)-B) / diag(diag(eye(n_train)-B));
                
                switch ky_param.type
                    case 'linear'
                        mse(il) = 1/n_train * norm(Psi_train * LOOE,'fro')^2;
                    otherwise
                        mse(il) = 1/n_train * trace(LOOE' * KY_train * LOOE);
                end
            end
    end

end

