public class EM {
public class CONSTANTS {
public static final int NFRAMES_COUNT = 3;
public static final int NFRAMES_SETS = 2;
public static final float THRESH = 1.0f/ 6.0f;
}
private float[] theta;
private float[][] ml;
private boolean[][] inputOn, inputNeutral, inputOff;
private int n, nsets, ncount;
//default constructor used
//interfaced a create function (if needed by other files)
public void create() {
this.n = 3; // this number is somewhat redundant.
this.ncount = CONSTANTS.NFRAMES_COUNT;
this.nsets = CONSTANTS.NFRAMES_SETS;
this.theta = new float[this.n];
this.ml = new float[this.nsets][this.n];
//three states of input; must match ncount, n
this.inputOn = new boolean[this.nsets][this.ncount];
this.inputNeutral = new boolean[this.nsets][this.ncount];
this.inputOff = new boolean[this.nsets][this.ncount];
//initialize all theta (probabilities) to lower bound
this.theta[0] = 0.001f;
this.theta[1] = 0.001f;
this.theta[2] = 0.001f;
}
//boolean input length (ncount * nsets)
public void setInput(boolean[] notesOn, boolean[] notesNeutral, boolean[] notesOff) {
int itr, aitr, bitr;
bitr = 0;
for (itr = 0; itr < this.nsets; itr += 1) {
for (aitr = 0; aitr < this.ncount; aitr += 1) {
this.inputOn[itr][aitr] = notesOn[bitr];
this.inputNeutral[itr][aitr] = notesNeutral[bitr];
this.inputOff[itr][aitr] = notesOff[bitr];
bitr += 1;
}
}
return;
}
private void EStep() {
float mult;
int itr, a,b,c, aitr;
for (aitr = 0; aitr < this.nsets; aitr += 1) {
//always count; somewhat less efficent; could be place in create() function
a = 0;
for (itr = 0; itr < this.ncount; itr += 1) {
if (this.inputOff[aitr][itr]) a += 1;
}
b = 0;
for (itr = 0; itr < this.ncount; itr += 1) {
if (this.inputNeutral[aitr][itr]) b += 1;
}
c = 0;
for (itr = 0; itr < this.ncount; itr += 1) {
if (this.inputOn[aitr][itr]) c += 1;
}
//this and not that (my own modus ponens addition); typically just this; could edit to (2 * this.ncount) to clamp the multiplier to 1.0f
mult = ((float)((a + (this.ncount - (b + c)))) / (float)(this.ncount));
if (mult < 0.01f) mult = 0.01f;
this.ml[aitr][0] = this.theta[0] * mult;
mult = ((float)((b + (this.ncount - (a + c)))) / (float)(this.ncount));
if (mult < 0.01f) mult = 0.01f;
this.ml[aitr][1] = this.theta[1] * mult;
mult = ((float)((c + (this.ncount - (b + a)))) / (float)(this.ncount));
if (mult < 0.01f) mult = 0.01f;
this.ml[aitr][2] = this.theta[2] * mult;
}
return;
}
private void MStep() {
float[][] ntheta = new float[this.nsets][this.n];
int itr, c, aitr;
for (aitr = 0; aitr < this.nsets; aitr += 1) {
c = 0;
for (itr = 0; itr < this.ncount; itr += 1) {
if (this.inputOff[aitr][itr]) c += 1;
}
ntheta[aitr][0] = (float)c * this.ml[aitr][0] + 0.001f;
c = 0;
for (itr = 0; itr < this.ncount; itr += 1) {
if (this.inputNeutral[aitr][itr]) c += 1;
}
ntheta[aitr][1] = (float)c * this.ml[aitr][1] + 0.001f;
c = 0;
for (itr = 0; itr < this.ncount; itr += 1) {
if (this.inputOn[aitr][itr]) c += 1;
}
ntheta[aitr][2] = (float)c * this.ml[aitr][2] + 0.001f;
}
float[] emp = new float[3];
for (itr = 0; itr < this.nsets; itr += 1) {
emp[0] += ntheta[itr][0] / (ntheta[itr][0] + ntheta[itr][1] + ntheta[itr][2]) + 0.001f;
if (Float.isNaN(emp[0])) emp[0] = 0.001f;
if (Float.isInfinite(emp[0])) emp[0] = 0.001f;
emp[1] += ntheta[itr][1] / (ntheta[itr][0] + ntheta[itr][1] + ntheta[itr][2]) + 0.001f;
if (Float.isNaN(emp[1])) emp[1] = 0.001f;
if (Float.isInfinite(emp[1])) emp[1] = 0.001f;
emp[2] += ntheta[itr][2] / (ntheta[itr][0] + ntheta[itr][1] + ntheta[itr][2]) + 0.001f;
if (Float.isNaN(emp[2])) emp[2] = 0.001f;
if (Float.isInfinite(emp[2])) emp[2] = 0.001f;
}
emp[0] /= (float)this.nsets;
emp[1] /= (float)this.nsets;
emp[2] /= (float)this.nsets;
this.theta[0] = (float)Math.sqrt((0.001f + this.theta[0]*emp[0])*2.0f);
if (Float.isNaN(this.theta[0])) this.theta[0] = 0.001f;
if (Float.isInfinite(this.theta[0])) this.theta[0] = 0.001f;
this.theta[1] = (float)Math.sqrt((0.001f + this.theta[1]*emp[1])*2.0f);
if (Float.isNaN(this.theta[1])) this.theta[1] = 0.001f;
if (Float.isInfinite(this.theta[1])) this.theta[1] = 0.001f;
this.theta[2] = (float)Math.sqrt((0.001f + this.theta[2]*emp[2])*2.0f);
if (Float.isNaN(this.theta[2])) this.theta[2] = 0.001f;
if (Float.isInfinite(this.theta[2])) this.theta[2] = 0.001f;
return;
}
public void iterate() {
int itr;
for (itr = 0; itr < 27; itr += 1) {
this.EStep();
this.MStep();
}
return;
}
public boolean isOn() {
boolean ret = ( (this.theta[2] >= CONSTANTS.THRESH && this.theta[0] <= 1.0f-CONSTANTS.THRESH) );
return ret;
}
public float getOnValue() {
return this.theta[2];
}
}
//All iterations of EStep and MStep are O(N), as opposed to a NN's O(N^3)
//my MStep deviates from the actual algorithm