package csl.tools.information;


import java.util.Arrays;
import java.util.Collections;

public class MutualInformation {

	private static final int splineOrder = 3;
	private static final int numberOfBins = 5;

	/**
	 * Estimates the distribution function of a random variable X using binning.
	 * 
	 * @param X
	 * @return
	 */

	public static double[] distributionEstimatorNaive(Double[] X) {
		double[] p = new double[numberOfBins];
		for (int i = 0; i < p.length; i++) {
			p[i] = 0;
		}
		Double Xmax = Collections.max(Arrays.asList(X));
		Double Xmin = Collections.min(Arrays.asList(X));

		if (Xmax.equals(Xmin)) {
			Xmax = Xmax+1;
			Xmin = Xmin -1;
		}

		for (Double Xmu : X) {
			int k = 0;
			if (!Xmu.equals(Xmax)) {
				k = (int) (numberOfBins * ((Xmu - Xmin) / (Xmax - Xmin)));
			} else {
				k = numberOfBins - 1;
			}
			p[k]++;
		}

		for (int i = 0; i < p.length; i++) {
			p[i] = p[i] / X.length;
		}

		return p;
	}

	public static double[] distributionEstimatorNaive(Double[] X, int nOfBins) {
		double[] p = new double[nOfBins];
		for (int i = 0; i < p.length; i++) {
			p[i] = 0;
		}
		Double Xmax = Collections.max(Arrays.asList(X));
		Double Xmin = Collections.min(Arrays.asList(X));

		if (Xmax.equals(Xmin)) {
			++Xmax;
			--Xmin;
		}

		for (Double Xmu : X) {
			int k = 0;
			if (!Xmu.equals(Xmax)) {
				k = (int) (nOfBins * ((Xmu - Xmin) / (Xmax - Xmin)));
			} else {
				k = nOfBins - 1;
			}
			p[k]++;
		}

		for (int i = 0; i < p.length; i++) {
			p[i] = p[i] / X.length;
		}

		return p;
	}
	
	public static double[] distributionEstimatorNaive(Double[] X, int nOfBins, Double Xmin, Double Xmax) {
		double[] p = new double[nOfBins];
		for (int i = 0; i < p.length; i++) {
			p[i] = 0;
		}
		

		if (Xmax.equals(Xmin)) {
			++Xmax;
			--Xmin;
		}

		for (Double Xmu : X) {
			int k = 0;
			if (!(Xmu.equals(Xmax))) {
				k = (int) (nOfBins * ((Xmu - Xmin) / (Xmax - Xmin)));
			} else {
				k = nOfBins - 1;
			}
			p[k]++;
		}

		for (int i = 0; i < p.length; i++) {
			p[i] = p[i] / X.length;
		}

		return p;
	}

	/**
	 * Estimates the joint distribution function of a couple of random variables
	 * X and Y using binning.
	 * 
	 * @param X
	 * @param Y
	 * @return
	 */
	public static double[][] distributionEstimatorNaive(Double[] X, Double[] Y)
			throws IllegalArgumentException {
		if (X.length != Y.length)
			throw new IllegalArgumentException();

		double[][] p = new double[numberOfBins][numberOfBins];

		for (int i = 0; i < numberOfBins; i++) {
			for (int j = 0; j < numberOfBins; j++) {
				p[i][j] = 0;
			}
		}

		Double Xmax = Collections.max(Arrays.asList(X));
		Double Xmin = Collections.min(Arrays.asList(X));
		Double Ymax = Collections.max(Arrays.asList(Y));
		Double Ymin = Collections.min(Arrays.asList(Y));

		if (Xmax.equals(Xmin)) {
			++Xmax;
			--Xmin;
		}

		if (Ymax.equals(Ymin)) {
			++Ymax;
			--Ymin;
		}

		for (int i = 0; i < X.length; i++) {
			int k = 0;
			if (!X[i].equals(Xmax)) {
				k = (int) (numberOfBins * ((X[i] - Xmin) / (Xmax - Xmin)));

			} else {
				k = numberOfBins - 1;
			}
			int k2 = 0;
			if (!Y[i].equals(Ymax)) {
				k2 = (int) (numberOfBins * ((Y[i] - Ymin) / (Ymax - Ymin)));

			} else {
				k2 = numberOfBins - 1;
			}
			p[k][k2]++;
		}

		for (int i = 0; i < numberOfBins; i++) {
			for (int j = 0; j < numberOfBins; j++) {
				p[i][j] = p[i][j] / X.length;
			}
		}
		return p;
	}

	public static double[][] distributionEstimatorNaive(Double[] X, Double[] Y,
			int nOfBins) throws IllegalArgumentException {
		if (X.length != Y.length)
			throw new IllegalArgumentException();

		double[][] p = new double[nOfBins][nOfBins];

		for (int i = 0; i < nOfBins; i++) {
			for (int j = 0; j < nOfBins; j++) {
				p[i][j] = 0;
			}
		}

		Double Xmax = Collections.max(Arrays.asList(X));
		Double Xmin = Collections.min(Arrays.asList(X));
		Double Ymax = Collections.max(Arrays.asList(Y));
		Double Ymin = Collections.min(Arrays.asList(Y));

		if (Xmax.equals(Xmin)) {
			++Xmax;
			--Xmin;
		}

		if (Ymax.equals(Ymin)) {
			++Ymax;
			--Ymin;
		}

		for (int i = 0; i < X.length; i++) {
			int k = 0;
			if (!X[i].equals(Xmax)) {
				k = (int) (nOfBins * ((X[i] - Xmin) / (Xmax - Xmin)));

			} else {
				k = nOfBins - 1;
			}
			int k2 = 0;
			if (!Y[i].equals(Ymax)) {
				k2 = (int) (nOfBins * ((Y[i] - Ymin) / (Ymax - Ymin)));

			} else {
				k2 = nOfBins - 1;
			}
			p[k][k2]++;
		}

		for (int i = 0; i < nOfBins; i++) {
			for (int j = 0; j < nOfBins; j++) {
				p[i][j] = p[i][j] / X.length;
			}
		}
		return p;
	}

	/**
	 * Estimation of the mutual information between X and Y based on binning
	 * 
	 * @param X
	 * @param Y
	 * @return
	 */
	public static double estimatorNaive(Double[] X, Double[] Y) {
		double pX[] = distributionEstimatorNaive(X);
		double pY[] = distributionEstimatorNaive(Y); // Side probability
														// distributions
		double pXY[][] = distributionEstimatorNaive(X, Y); // Joint probability
															// distribution

		double hX = entropy(pX);
		double hY = entropy(pY);
		double hXY = jointentropy(pXY);

		// B-spline estimation of Mutual Information
		double MI = hX + hY - hXY;
		
		// MI-based distance
		double dist = MI;
		
//		if (hXY == 0) {
//			dist = 0;
//		} else {
//			dist = 1 - MI / hXY;
//		}
	
		return dist;

	}

	public static double estimatorNaive(Double[] X, Double[] Y, int nOfBins) {
		double pX[] = distributionEstimatorNaive(X, nOfBins);
		double pY[] = distributionEstimatorNaive(Y, nOfBins); // Side
																// probability
																// distributions
		double pXY[][] = distributionEstimatorNaive(X, Y, nOfBins); // Joint
																	// probability
																	// distribution

		double hX = entropy(pX);
		double hY = entropy(pY);
		double hXY = jointentropy(pXY);

		// B-spline estimation of Mutual Information
		double MI = hX + hY - hXY;

		// MI-based distance

		double dist = 1 - MI / hXY;

		return dist;

	}

	public static double estimator1NN(Double[] X, Double[] Y) {
		assert Y.length != X.length;
		int k = 1;
		int nX[] = new int[X.length];
		int nY[] = new int[Y.length];
		double epsilon[] = new double[X.length];
		Double Xmax = Collections.max(Arrays.asList(X));
		Double Xmin = Collections.min(Arrays.asList(X));
		Double Ymax = Collections.max(Arrays.asList(Y));
		Double Ymin = Collections.min(Arrays.asList(Y));

		// Evaluating epsilon
		for (int i = 0; i < X.length; i++) {
			Double[] distfromXiYi = new Double[X.length];
			for (int j = 0; j < distfromXiYi.length; j++) {
				if (i != j) {
					distfromXiYi[j] = Math.max(Math.abs(X[i] - X[j]), Math
							.abs(Y[i] - Y[j]));
				} else {
					distfromXiYi[j] = Math.max(Xmax - Xmin, Ymax - Ymin);
				}

			}
			epsilon[i] = Collections.min(Arrays.asList(distfromXiYi));
		}

		// Initialization of nX, nY
		for (int i = 0; i < nX.length; i++) {
			nX[i] = 0;
			nY[i] = 0;
		}

		// Evaluating nX, nY
		for (int i = 0; i < X.length; i++) {
			for (int j = 0; j < Y.length; j++) {
				if (Math.abs(X[i] - X[j]) < epsilon[i]) {
					nX[i]++;
				}
				if (Math.abs(Y[i] - Y[j]) < epsilon[i]) {
					nY[i]++;
				}
			}
		}

		// Evaluating Square Bracket

		double bracket = 0;
		for (int i = 0; i < X.length; i++) {
			bracket = bracket + (digamma(nX[i]) + digamma(nY[i]));
		}
		bracket = (1.0 / X.length) * bracket;

		// Evaluating MI1NN

		double MI1NN = digamma(k) + digamma(X.length) - bracket;
		return MI1NN;
	}

	private static double digamma(int n) {
		double psi = -eulerMascheroniConst();
		if (n > 1) {
			for (int k = 1; k < n; k++) {
				psi = psi + (1.0 / k);
			}
		}
		return psi;
	}

	private static double eulerMascheroniConst() {
		double e = 0.57721566490153286060651209008240243104215933593992;
		return e;
	}

	private static int knot(int i, int numberOfBins) {

		int knot;
		if (i < splineOrder) {
			knot = 0;
		} else {
			if (i < numberOfBins) {
				knot = i - splineOrder + 1;
			} else {
				knot = numberOfBins - 1 - splineOrder + 2;
			}
		}
		return knot;
	}

	/**
	 * Recursive definition of B-spline basis, see "Estimating mutual
	 * information using B-spline functions" by Daub
	 * 
	 * @param z
	 * @return
	 */

	private static double[] spline(double z, int nOfBins) throws IllegalArgumentException {

		if (!(0 <= z && z <= nOfBins - splineOrder + 1)){
			System.out.println(z);
			throw new IllegalArgumentException();
		}
		double b[] = new double[nOfBins];
		for (int i = 0; i < b.length; i++) {
			b[i] = 0;
		}
		int i = 0;
		if (z == nOfBins - splineOrder + 1) {
			i = nOfBins;
		} else {
			while (!(z < knot(i, nOfBins))) {
				i++;
			}
		}
		i = i - 1;
		b[i] = 1;

		for (int k = 2; k <= splineOrder; k++) {
			for (int j = 0; j <= i; j++) {
				double a1 = 1;
				double a2 = 1;
				if (b[j] != 0) {
					a1 = (z - knot(j,nOfBins)) / (knot(j + k - 1,nOfBins) - knot(j,nOfBins));
				}
				if (j == nOfBins - 1) {
					b[j] = b[j] * a1;
				} else {
					if (b[j + 1] != 0) {
						a2 = (knot(j + k,nOfBins) - z) / (knot(j + k,nOfBins) - knot(j + 1,nOfBins));
					}
					b[j] = b[j] * a1 + b[j + 1] * a2;
				}

			}
		}
		return b;
	}

	/**
	 * Estimates the distribution function of a random variable X using B-spline
	 * functions.
	 * 
	 * @param X
	 * @return
	 */

	public static double[] distributionEstimatorBSpline(Double[] X) {
		double[] p = new double[numberOfBins];
		for (int i = 0; i < p.length; i++) {
			p[i] = 0;
		}
		Double Xmax = Collections.max(Arrays.asList(X));
		Double Xmin = Collections.min(Arrays.asList(X));

		if (Xmax.equals(Xmin)) {
			++Xmax;
			--Xmin;
		}

		for (Double Xmu : X) {
			double z = (Xmu - Xmin) * (numberOfBins - splineOrder + 1)
					/ (Xmax - Xmin);
			double[] b = spline(z, numberOfBins);
			for (int i = 0; i < b.length; i++) {
				p[i] = p[i] + b[i];
			}
		}

		for (int i = 0; i < p.length; i++) {
			p[i] = p[i] / X.length;
		}

		return p;
	}
	
	
	/**
	 * Estimates the distribution function of a random variable X using B-spline
	 * functions. 
	 * 
	 * @param X
	 * @param nOfBins
	 * @return
	 */
	public static double[] distributionEstimatorBSpline(Double[] X, int nOfBins) {
		double[] p = new double[nOfBins];
		for (int i = 0; i < p.length; i++) {
			p[i] = 0;
		}
		Double Xmax = Collections.max(Arrays.asList(X));
		Double Xmin = Collections.min(Arrays.asList(X));

		if (Xmax.equals(Xmin)) {
			++Xmax;
			--Xmin;
		}

		for (Double Xmu : X) {
			double z = (Xmu - Xmin) * (nOfBins - splineOrder + 1)
					/ (Xmax - Xmin);
			double[] b = spline(z, nOfBins);
			for (int i = 0; i < b.length; i++) {
				p[i] = p[i] + b[i];
			}
		}

		for (int i = 0; i < p.length; i++) {
			p[i] = p[i] / X.length;
		}

		return p;
	}
	
	
	/**
	 * Estimates the distribution function of a random variable X using B-spline
	 * functions. The range of all possible values of X is specified. 
	 * 
	 * @param X
	 * @param nOfBins
	 * @param Xmin
	 * @param Xmax
	 * @return
	 */
	public static double[] distributionEstimatorBSpline(Double[] X, int nOfBins, Double Xmin, Double Xmax) {
		double[] p = new double[nOfBins];
		for (int i = 0; i < p.length; i++) {
			p[i] = 0;
		}

		if (Xmax.equals(Xmin)) 
			throw new IllegalArgumentException();
		

		for (Double Xmu : X) {
			double z = (Xmu - Xmin) * (nOfBins - splineOrder + 1)
					/ (Xmax - Xmin);
			
			
			
			double[] b = spline(z, nOfBins);
			for (int i = 0; i < b.length; i++) {
				p[i] = p[i] + b[i];
			}
		}

		for (int i = 0; i < p.length; i++) {
			p[i] = p[i] / X.length;
		}

		return p;
	}

	/**
	 * Estimates the joint distribution function of a couple of random variables
	 * X and Y using B-spline functions.
	 * 
	 * @param X
	 * @param Y
	 * @return
	 */
	public static double[][] distributionEstimatorBSpline(Double[] X, Double[] Y)
			throws IllegalArgumentException {
		if (X.length != Y.length)
			throw new IllegalArgumentException();

		double[][] p = new double[numberOfBins][numberOfBins];

		for (int i = 0; i < numberOfBins; i++) {
			for (int j = 0; j < numberOfBins; j++) {
				p[i][j] = 0;
			}
		}

		Double Xmax = Collections.max(Arrays.asList(X));
		Double Xmin = Collections.min(Arrays.asList(X));
		Double Ymax = Collections.max(Arrays.asList(Y));
		Double Ymin = Collections.min(Arrays.asList(Y));

		if (Xmax.equals(Xmin)) {
			++Xmax;
			--Xmin;
		}

		if (Ymax.equals(Ymin)) {
			++Ymax;
			--Ymin;
		}

		for (int k = 0; k < X.length; k++) {
			double Xmu = (X[k] - Xmin) * (numberOfBins - splineOrder + 1)
					/ (Xmax - Xmin);
			double Ymu = (Y[k] - Ymin) * (numberOfBins - splineOrder + 1)
					/ (Ymax - Ymin);
			double[] bX = spline(Xmu, numberOfBins);
			double[] bY = spline(Ymu, numberOfBins);
			for (int i = 0; i < numberOfBins; i++) {
				for (int j = 0; j < numberOfBins; j++) {
					p[i][j] = p[i][j] + bX[i] * bY[j];
				}
			}
		}
		for (int i = 0; i < numberOfBins; i++) {
			for (int j = 0; j < numberOfBins; j++) {
				p[i][j] = p[i][j] / X.length;
			}
		}
		return p;
	}

	public static double[][] distributionEstimatorBSpline(Double[] X,
			Double[] Y, int nOfBins) throws IllegalArgumentException {
		if (X.length != Y.length)
			throw new IllegalArgumentException();

		double[][] p = new double[nOfBins][nOfBins];

		for (int i = 0; i < nOfBins; i++) {
			for (int j = 0; j < nOfBins; j++) {
				p[i][j] = 0;
			}
		}

		Double Xmax = Collections.max(Arrays.asList(X));
		Double Xmin = Collections.min(Arrays.asList(X));
		Double Ymax = Collections.max(Arrays.asList(Y));
		Double Ymin = Collections.min(Arrays.asList(Y));

		if (Xmax.equals(Xmin)) {
			++Xmax;
			--Xmin;
		}

		if (Ymax.equals(Ymin)) {
			++Ymax;
			--Ymin;
		}

		for (int k = 0; k < X.length; k++) {
			double Xmu = (X[k] - Xmin) * (nOfBins - splineOrder + 1)
					/ (Xmax - Xmin);
			double Ymu = (Y[k] - Ymin) * (nOfBins - splineOrder + 1)
					/ (Ymax - Ymin);
			double[] bX = spline(Xmu,nOfBins);
			double[] bY = spline(Ymu,nOfBins);
			for (int i = 0; i < nOfBins; i++) {
				for (int j = 0; j < nOfBins; j++) {
					p[i][j] = p[i][j] + bX[i] * bY[j];
				}
			}
		}
		for (int i = 0; i < nOfBins; i++) {
			for (int j = 0; j < nOfBins; j++) {
				p[i][j] = p[i][j] / X.length;
			}
		}
		return p;
	}

	/**
	 * Estimates the mutual information between two random variables X and Y
	 * using B-spline functions.
	 * 
	 * @param X
	 * @param Y
	 * @return
	 */
	public static double estimatorBSpline(Double[] X, Double[] Y) {

		double pX[] = distributionEstimatorBSpline(X);
		double pY[] = distributionEstimatorBSpline(Y); // Side probability
														// distributions
		double pXY[][] = distributionEstimatorBSpline(X, Y); // Joint
																// probability
																// distribution

		double hX = entropy(pX);
		double hY = entropy(pY);
		double hXY = jointentropy(pXY);

		// B-spline estimation of Mutual Information
		double MI = hX + hY - hXY;

		// MI-based distance

		double dist = 1 - MI / hXY;

		return dist;

	}

	public static double estimatorBSpline(Double[] X, Double[] Y, int nOfBins) {

		double pX[] = distributionEstimatorBSpline(X, nOfBins);
		double pY[] = distributionEstimatorBSpline(Y, nOfBins); // Side
																// probability
																// distributions
		double pXY[][] = distributionEstimatorBSpline(X, Y, nOfBins); // Joint
																		// probability
																		// distribution

		double hX = entropy(pX);
		double hY = entropy(pY);
		double hXY = jointentropy(pXY);

		// B-spline estimation of Mutual Information
		double MI = hX + hY - hXY;

		// MI-based distance

		double dist = 1 - MI / hXY;

		return dist;

	}

	private static double entropy(double[] distribution) {
		double h = 0;
		for (double d : distribution) {
			if (d != 0) {
				h = h + d * Math.log(d);
			}
		}
		h = -h;
		return h;
	}

	private static double jointentropy(double[][] distribution) {
		double h = 0;
		for (double[] ds : distribution) {
			for (double d : ds) {
				if (d != 0) 
					h = h + d * Math.log(d);
			}
		}
		h = -h;
		return h;
	}

}
