Infer from which of 2 distributions the data was drawn
-
28. října 2011 10:52
After working thought the new 101 document, I am trying to implement some basic models. I ran into a problem trying to implement a quadratic discriminant classifier:
I have 2 VectorGaussian distributions (representing 2 classes of data) with parameters that I have learned from training data and want to infer the posterior probability that a data point was drawn from the first of the 2 distributions.
I am trying to do this using a Variable.If/IfNot pair. If my boolean variable is true the data was drawn from the first distribution / belongs to class 1.
This is the code I was hoping to model this with:
class PredictionModel { public InferenceEngine inferenceEngine; protected VariableArray<bool> isClass1; protected VectorGaussian class1Dist; protected VectorGaussian class2Dist; protected Variable<int> nSamples; // number of data samples protected VariableArray<Vector> data; // input data protected VariableArray<Bernoulli> predictions; // predictions protected Variable<double> isClass1Prior; public PredictionModel(int nDims) { nSamples = Variable.New<int>().Named("nSamples"); Range sampleRange = new Range(nSamples); isClass1 = Variable.Array<bool>(sampleRange).Named("isClass1"); isClass1Prior = Variable.Beta(1, 1).Named("isClass1Prior"); class1Dist = new VectorGaussian(nDims); class2Dist = new VectorGaussian(nDims); data = Variable.Array<Vector>(sampleRange).Named("Data"); using (Variable.ForEach(sampleRange)) { isClass1[sampleRange] = Variable.Bernoulli(isClass1Prior); using (Variable.If(isClass1[sampleRange])) data[sampleRange] = Variable.Random<Vector>(class1Dist); using (Variable.IfNot(isClass1[sampleRange])) data[sampleRange] = Variable.Random<Vector>(class2Dist); } if (inferenceEngine == null) { inferenceEngine = new InferenceEngine(new VariationalMessagePassing()); } } // Set model parameters public void SetModelData(ModelData class1Para, ModelData class2Para) { Vector class1Mean = class1Para.mean.GetMean(); PositiveDefiniteMatrix class1Precision = class1Para.precision.GetMean(); class1Dist.SetMeanAndPrecision(class1Mean, class1Precision); Vector class2Mean = class2Para.mean.GetMean(); PositiveDefiniteMatrix class2Precision = class2Para.precision.GetMean(); class2Dist.SetMeanAndPrecision(class2Mean, class2Precision); } // Infer class of input data public void InferClasses(Vector[] inputData) { nSamples.ObservedValue = inputData.Length; Range sampleRange = new Range(nSamples); data.ObservedValue = inputData; inferenceEngine.ShowFactorGraph = true; predictions = Variable.Array<Bernoulli>(sampleRange).Named("Predictions"); predictions = inferenceEngine.Infer<VariableArray<Bernoulli>>(isClass1); } }
The Main() is this:// Create an instance of the prediction model PredictionModel predModel = new PredictionModel(nDims); // Set the model parameters predModel.SetModelData(class1Posteriors, class2Posteriors); // Collect data for Inference double[] feature1 = DataReader.readFromCSV(".\\data\\feature1.csv"); double[] feature2 = DataReader.readFromCSV(".\\data\\feature2.csv"); double[] feature3 = DataReader.readFromCSV(".\\data\\feature3.csv"); int nSamples = feature1.Length; Vector[] inputData = new Vector[nSamples]; for (int i = 0; i < nSamples; i++) { inputData[i] = Vector.FromArray(feature2[i], feature3[i], feature1[i]); } // Infer classes given the input data predModel.InferClasses(inputData);
Running this I get the following error:"An unhandled exception of type 'System.ArgumentException' occurred in Infer.Runtime.dll
Additional information: Cannot convert distribution type DistributionStructArray<Bernoulli,bool> to type VariableArray<Bernoulli>"Would you please tell me what I am doing wrong here?
Many thanks!
- Upravený M.Arnold 28. října 2011 10:55
Všechny reakce
-
28. října 2011 14:18Vlastník
The results of doing inference are marginal posterior distributions. The Variable types are for defining the model. In your case, the variables you want to infer are Variable<bool> but the results of the inference are Bernoulli distributions. What you need in your code is:
Bernoulli[] predictions = inferenceEngine.Infer<Bernoulli[]>(isClass1);
John
- Upravený John GuiverMicrosoft Employee, Owner 28. října 2011 14:18
- Označen jako odpověď M.Arnold 28. října 2011 14:25
-
28. října 2011 14:34
That was it. Many thanks, John!
Mirko