package csl.tools.weka;

import java.io.File;
import java.io.FileFilter;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Random;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.trees.J48;
import weka.core.Instance;
import weka.core.Instances;
import csl.tools.files.FileTools;

public class CrossSplit {

	private static final String CSV_SEPARATOR = ";";

	private Random g = new Random();
	/**
	 * The ARFF file containing the Weka Instances to use for the evaluation
	 */
	private File arff;
	/**
	 * The text output file
	 */
	private File log;
	/**
	 * The CSV output file
	 */
	private File csv;

	public File getArff() {
		return arff;
	}

	public void setArff(File arff) {
		this.arff = arff;
	}

	public File getCSV() {
		return csv;
	}

	public void setCSV(File csv) {
		this.csv = csv;
	}

	public File getLog() {
		return log;
	}

	public void setLog(File log) {
		this.log = log;
	}

	public void run(int n, float splitPercent) throws Exception {

		runEvaluations(n, loadInstances(), splitPercent);
	}

	private void runEvaluations(int evalCount, Instances instances,
			float splitPercent) throws Exception {

		if (log == null)
			throw new FileNotFoundException();
		PrintWriter logWriter = FileTools.printWriterOn(log);
		PrintWriter csvWriter = FileTools.printWriterOn(csv);
		csvWriter.print("exp" + CSV_SEPARATOR + "Correct" + CSV_SEPARATOR
				+ "Incorrect" + CSV_SEPARATOR);
		for (int i = 0; i < instances.numClasses(); i++) {
			csvWriter.print("FM: " + instances.classAttribute().value(i)
					+ CSV_SEPARATOR);
		}

		csvWriter.print("Mean FM" + CSV_SEPARATOR);

		for (int i = 0; i < instances.numClasses(); i++) {
			for (int j = 0; j < instances.numClasses(); j++) {
				csvWriter.print(instances.classAttribute().value(i) + " as "
						+ instances.classAttribute().value(j) + CSV_SEPARATOR);
			}
		}
		csvWriter.println();
		for (int i = 0; i < evalCount; i++) {
			logWriter.println("Evaluation #" + (i + 1)
					+ "\r\n----------------\r\n");
			csvWriter.print((i + 1) + CSV_SEPARATOR);
			runOneEvaluation(logWriter, csvWriter, instances, splitPercent);
		}
		logWriter.close();
		csvWriter.close();
	}

	private void runOneEvaluation(PrintWriter logWriter, PrintWriter csvWriter,
			Instances instances, float splitPercent) throws Exception {

		Classifier classifier = getNewClassifier();
		Instances[] split = split(instances, splitPercent);
		Instances train = split[0];
		Instances test = split[1];
		classifier.buildClassifier(train);
		Evaluation eval = new Evaluation(train);
		eval.evaluateModel(classifier, test);
		logWriter.println(eval.toSummaryString());
		logWriter.println(eval.toClassDetailsString());
		logWriter.println(eval.toMatrixString());

		csvWriter.print(eval.correct() + CSV_SEPARATOR);
		csvWriter.print(eval.incorrect() + CSV_SEPARATOR);
		double meanFM = 0d;
		for (int i = 0; i < eval.confusionMatrix().length; i++) {
			csvWriter.print(eval.fMeasure(i) + CSV_SEPARATOR);
			meanFM += eval.fMeasure(i);
		}
		csvWriter.print((meanFM / eval.confusionMatrix().length)
				+ CSV_SEPARATOR);
		for (int i = 0; i < eval.confusionMatrix().length; i++)
			for (int j = 0; j < eval.confusionMatrix().length; j++)
				csvWriter.print(eval.confusionMatrix()[i][j] + CSV_SEPARATOR);
		csvWriter.print("\r\n");
	}

	private Instances[] split(Instances instances, float splitPercent) {

		Instances[] split = new Instances[2];
		/*
		 * Creates train and test instances to return eventually
		 */
		split[0] = new Instances(instances);
		split[1] = new Instances(instances);
		split[0].delete();
		split[1].delete();
		/*
		 * Collects instances for each class in a hash-map
		 */
		HashMap<String, List<Instance>> instancesMap = new HashMap<String, List<Instance>>();
		for (int i = 0; i < instances.numInstances(); i++) {
			Instance inst = instances.instance(i);
			String value = inst.classAttribute().value((int) inst.classValue());
			List<Instance> list = instancesMap.get(value);
			if (list == null) {
				instancesMap.put(value, new ArrayList<Instance>());
				list = instancesMap.get(value);
			}
			list.add(inst);
		}
		/*
		 * Shuffles the list of instances for each class
		 */
		for (List<Instance> l : instancesMap.values()) {
			Collections.shuffle(l, g);
		}
		/*
		 * Dispatches instances in train and test (i.e. split[0] and split[1]
		 */
		for (List<Instance> l : instancesMap.values()) {
			int s = l.size();
			int trainSize = (int) (s * splitPercent);
			for (int i = 0; i < s; i++)
				split[i < trainSize ? 0 : 1].add(l.get(i));
		}

		return split;
	}

	private Classifier getNewClassifier() {

		return new J48();
	}

	private Instances loadInstances() throws IOException {

		if (!arff.exists())
			throw new FileNotFoundException();
		return Weka.instancesFromARFF(arff);
	}

	public static void main(String[] args) throws Exception {
		CrossSplit split = new CrossSplit();
		File root = new File(
				"D:/users/roy/sony/papers/submitted/Animal Cognition -- Parrot Vocalizations/Experiment #4 (IB1 100 Samples)/statistical tests");
		for (File arff : root.listFiles(new FileFilter() {

			public boolean accept(File pathname) {
				return pathname.getName().toLowerCase().endsWith(".arff");
			}
		})) {

			split.setArff(arff);
			int n = 5;
			float percent = 0.1f; // 10%
			split
					.setLog(new File(arff.getParentFile(), "N=" + n
							+ "; Split=" + percent + "; "
							+ arff.getName().replace(".arff", ".log")));
			split
					.setCSV(new File(arff.getParentFile(), "N=" + n
							+ "; Split=" + percent + "; "
							+ arff.getName().replace(".arff", ".csv")));
			split.run(n, percent);
		}
	}
}
