package csl.tools.weka;

import java.io.Serializable;
import java.util.Formatter;

public class ConfusionMatrix implements Serializable {

	private static final long	serialVersionUID	= 1706385750197448847L;

	final String[]				classes;
	final int[][]				matrix;
	final int					n;

	public ConfusionMatrix(String[] classes, int[][] matrix) {

		this.matrix = matrix;
		this.n = matrix.length;
		this.classes = classes;
	}

	public ConfusionMatrix(String[] classes) {

		this(classes, new int[classes.length][classes.length]);
	}

	public void setValue(int val, int row, int col) {

		matrix[row][col] = val;
	}

	public int getValue(int row, int col) {

		return matrix[row][col];
	}

	public int sumCol(int col) {

		int sum = 0;
		for (int i = 0; i < n; i++)
			sum += getValue(i, col);
		return sum;
	}

	public int sumRow(int row) {

		int sum = 0;
		for (int i = 0; i < n; i++)
			sum += getValue(row, i);
		return sum;
	}

	public double recall(int classIndex) {

		int sum = sumRow(classIndex);
		if (sum == 0)
			return 0;
		return ((double) getValue(classIndex, classIndex)) / sum;
	}

	public double precision(int classIndex) {

		int sum = sumCol(classIndex);
		if (sum == 0)
			return 0;
		return ((double) getValue(classIndex, classIndex)) / sum;
	}

	public double fmeasure(int classIndex) {

		double p_plus_r = precision(classIndex) + recall(classIndex);
		if (p_plus_r == 0)
			return 0;
		return 2 * precision(classIndex) * recall(classIndex) / p_plus_r;
	}

	public double minFmeasure() {

		double m = fmeasure(0);
		for (int i = 0; i < n; i++) {
			if (m < fmeasure(i))
				continue;
			m = fmeasure(i);
		}
		return m;
	}

	public String perfString() {

		StringBuffer sb = new StringBuffer();
		sb.append("Prec\tRecall\tF-meas").append("\r\n");
		for (int i = 0; i < n; i++) {
			Formatter f = new Formatter();
			sb.append(
					f.format("%1.3f\t%1.3f\t%1.3f", precision(i), recall(i),
							fmeasure(i))).append("\t").append(
					classes[i].substring(0, Math.min(5, classes[i].length())))
					.append("\r\n");
		}
		return sb.toString();
	}

	@Override
	public String toString() {

		StringBuffer sb = new StringBuffer();
		for (int i = 0; i < n; i++) {
			sb
					.append(
							classes[i].substring(0, Math.min(5, classes[i]
									.length()))).append("\t");
		}
		sb.append("<--- as");
		sb.append("\r\n");
		for (int i = 0; i < n; i++) {
			for (int j = 0; j < n; j++) {
				int val = matrix[i][j];
				sb.append(val + "\t");
			}
			sb.append(classes[i]);
			sb.append("\r\n");
		}
		return sb.toString();
	}

	public static void main(String[] args) {

		String[] classes = { "true", "false", "maybe", "dunno", "doubt it",
				"don't give a damn" };
		int[][] matrix = { { 10, 1, 0, 2, 1, 3 }, { 5, 22, 0, 0, 3, 1 },
				{ 1, 3, 27, 2, 2, 1 }, { 0, 0, 0, 12, 1, 0 },
				{ 0, 0, 0, 0, 1, 0 }, { 0, 0, 0, 0, 0, 1 } };
		ConfusionMatrix mat = new ConfusionMatrix(classes, matrix);
		System.out.println(mat);
		System.out.println(mat.perfString());
	}
}
