Asked by:
approximate logistic function with a linear function

Question
-
Hi
I have a logistic factor node in my graph. Since the output of logistic is itself multiplied by another variable with Gaussian distribution, the model is not supported by EP (multiplication of beta, logistic output, and gaussian is not supported).
Now I want to approximate my logistic with just a series of linear functions.
For example in this case (for logistic(x)) I would define intervals for x:
if ( x<= -1.4) then y=0 else { if(x <= 0.5) y = 0.3; else{ if(x<=0.4) y= 0.4; else{ if(x<=1.3) y= 0.7; else y=1; } } }
I have this variable x and I use something like "using(Variable.If(x <= 0.5)) ... "
But the problem is that these intervals are fixed and there is no room for the error in here.
I just get the error "probability zero" for the model.
How can I make this problem solved? Multiplication of beta and Gaussian, or this linearization?
Any help is appreciated. Many Thanks.
Wednesday, September 10, 2014 11:11 PM
All replies
-
What you describe should work as long as you are not directly observing the output of the logistic. Try posting your Infer.NET code.Thursday, September 11, 2014 3:40 PMOwner
-
Thanks Tom.
I will work more on it and then post it. I thought this implementation is not correct because there is not any noise considered for the model.
I mean for example for (y = a.x ), we always model it considering a noise (Y = A.X + e). I thought this is the reason that the model does not converge.
Now I check my code again to see what's the problem and will send it if I couldn't solve the problem.
Friday, September 12, 2014 1:30 PM -
I've attempted to write a function that replicates the logistic function as a combination of linear functions, my code is:
class InfernetTools { public static Variable<double> pseudo_logistic(Variable<double> logit) { Variable<double> prob_abs = Variable.New<double>(); Variable<double> logit_abs = Variable.New<double>(); using (Variable.If(logit > 0)) { logit_abs.SetTo(logit); } using (Variable.IfNot(logit > 0)) { logit_abs.SetTo(-logit); } using (Variable.If(logit_abs > 8)) { prob_abs.SetTo(0.9999); } using (Variable.IfNot(logit_abs > 8)) { using (Variable.If(logit_abs > 6)) { prob_abs.SetTo(logit_abs * 0.001 + 0.9919); } using (Variable.IfNot(logit_abs > 6)) { using (Variable.If(logit_abs > 5)) { prob_abs.SetTo(logit_abs * 0.0042 + 0.9729); } using (Variable.IfNot(logit_abs > 5)) { using (Variable.If(logit_abs > 4.5)) { prob_abs.SetTo(logit_abs * 0.0086 + 0.9505); } using (Variable.IfNot(logit_abs > 4.5)) { using (Variable.If(logit_abs > 4)) { prob_abs.SetTo(logit_abs * 0.014 + 0.9262); } using (Variable.IfNot(logit_abs > 4)) { using (Variable.If(logit_abs > 3.5)) { prob_abs.SetTo(logit_abs * 0.0227 + 0.8918); } using (Variable.IfNot(logit_abs > 3.5)) { using (Variable.If(logit_abs > 3)) { prob_abs.SetTo(logit_abs * 0.0363 + 0.8444); } using (Variable.IfNot(logit_abs > 3)) { using (Variable.If(logit_abs > 2.5)) { prob_abs.SetTo(logit_abs * 0.057 + 0.7828); } using (Variable.IfNot(logit_abs > 2.5)) { using (Variable.If(logit_abs > 2)) { prob_abs.SetTo(logit_abs * 0.0869 + 0.7085); } using (Variable.IfNot(logit_abs > 2)) { using (Variable.If(logit_abs > 1.5)) { prob_abs.SetTo(logit_abs * 0.1268 + 0.6292); } using (Variable.IfNot(logit_abs > 1.5)) { using (Variable.If(logit_abs > 1)) { prob_abs.SetTo(logit_abs * 0.1735 + 0.5594); } using (Variable.IfNot(logit_abs > 1)) { using (Variable.If(logit_abs > 0.5)) { prob_abs.SetTo(logit_abs * 0.2179 + 0.515); } using (Variable.IfNot(logit_abs > 0.5)) { using (Variable.If(logit_abs > 0)) { prob_abs.SetTo(logit_abs * 0.2456 + 0.5005); } using (Variable.IfNot(logit_abs > 0)) { //Do nothing, logit_abs can't be negative } } } } } } } } } } } } } Variable<double> output = Variable.New<double>(); using (Variable.If(logit > 0)) { output.SetTo(prob_abs); } using (Variable.IfNot(logit > 0)) { output.SetTo(1 - prob_abs); } return prob_abs; } }
However when I try run my code with Variable.Logistic replaced with InfernetTools.pseudo_logistic I get a long error beginning:
GateTransform failed with 6 error(s) and 0 warning(s):
Error 0: 'vdouble77' is not defined in all cases. It is only defined for (vbool8=true)(vbool9=false) in ...What am I doing wrong in my function? I thought I'd covered off all cases with the If/IfNot structure.
Wednesday, April 22, 2015 3:28 PM -
The innermost case is missing. Also, you should make intermediate variables for the comparisons so that they are not recomputed. Finally, this method of using logit_abs is not recommended since it leads to a difficult inference problem. Try to branch only on 'logit'.Wednesday, April 22, 2015 3:45 PMOwner
-
Thanks Tom, I'll give those changes a try. The only thing I don't understand is what you mean by intermediate variables to avoid recomputation - which variable should I be making intermediates of?Wednesday, April 22, 2015 3:58 PM
-
Make an intermediate for 'logit > 5', etc.Wednesday, April 22, 2015 4:03 PMOwner
-
Thanks Tom, I made those changes and it seems to work. For the benefit of others, here's the fixed function:
class InfernetTools { public static Variable<double> pseudo_logistic(Variable<double> logit) { Variable<double> prob = Variable.New<double>(); Variable<bool> logit_8 = logit > 8; using (Variable.If(logit_8)) { prob.SetTo(0.9999); } using (Variable.IfNot(logit_8)) { Variable<bool> logit_6 = logit > 6; using (Variable.If(logit_6)) { prob.SetTo(logit * 0.001 + 0.9919); } using (Variable.IfNot(logit_6)) { Variable<bool> logit_5 = logit > 5; using (Variable.If(logit_5)) { prob.SetTo(logit * 0.0042 + 0.9729); } using (Variable.IfNot(logit_5)) { Variable<bool> logit_45 = logit > 4.5; using (Variable.If(logit_45)) { prob.SetTo(logit * 0.0086 + 0.9505); } using (Variable.IfNot(logit_45)) { Variable<bool> logit_4 = logit > 4; using (Variable.If(logit_4)) { prob.SetTo(logit * 0.014 + 0.9262); } using (Variable.IfNot(logit_4)) { Variable<bool> logit_35 = logit > 3.5; using (Variable.If(logit_35)) { prob.SetTo(logit * 0.0227 + 0.8918); } using (Variable.IfNot(logit_35)) { Variable<bool> logit_3 = logit > 3; using (Variable.If(logit_3)) { prob.SetTo(logit * 0.0363 + 0.8444); } using (Variable.IfNot(logit_3)) { Variable<bool> logit_25 = logit > 25; using (Variable.If(logit_25)) { prob.SetTo(logit * 0.057 + 0.7828); } using (Variable.IfNot(logit_25)) { Variable<bool> logit_2 = logit > 2; using (Variable.If(logit_2)) { prob.SetTo(logit * 0.0869 + 0.7085); } using (Variable.IfNot(logit_2)) { Variable<bool> logit_15 = logit > 1.5; using (Variable.If(logit_15)) { prob.SetTo(logit * 0.1268 + 0.6292); } using (Variable.IfNot(logit_15)) { Variable<bool> logit_1 = logit > 1; using (Variable.If(logit_1)) { prob.SetTo(logit * 0.1735 + 0.5594); } using (Variable.IfNot(logit_1)) { Variable<bool> logit_05 = logit > 0.5; using (Variable.If(logit_05)) { prob.SetTo(logit * 0.2179 + 0.515); } using (Variable.IfNot(logit_05)) { Variable<bool> logit_0 = logit > 0; using (Variable.If(logit_0)) { prob.SetTo(logit * 0.2456 + 0.5005); } using (Variable.IfNot(logit_0)) { Variable<bool> logit_m0 = logit > -0.5; using (Variable.If(logit_m0)) { prob.SetTo(logit * 0.2456 + 0.4995); } using (Variable.IfNot(logit_m0)) { Variable<bool> logit_m1 = logit > -1; using (Variable.If(logit_m1)) { prob.SetTo(logit * 0.2179 + 0.485); } using (Variable.IfNot(logit_m1)) { Variable<bool> logit_m15 = logit > -1.5; using (Variable.If(logit_m15)) { prob.SetTo(logit * 0.1735 + 0.4406); } using (Variable.IfNot(logit_m15)) { Variable<bool> logit_m2 = logit > -2; using (Variable.If(logit_m2)) { prob.SetTo(logit * 0.1268 + 0.3708); } using (Variable.IfNot(logit_m2)) { Variable<bool> logit_m25 = logit > -2.5; using (Variable.If(logit_m25)) { prob.SetTo(logit * 0.0869 + 0.2915); } using (Variable.IfNot(logit_m25)) { Variable<bool> logit_m3 = logit > -3; using (Variable.If(logit_m3)) { prob.SetTo(logit * 0.057 + 0.2172); } using (Variable.IfNot(logit_m3)) { Variable<bool> logit_m35 = logit > -3.5; using (Variable.If(logit_m35)) { prob.SetTo(logit * 0.0363 + 0.1556); } using (Variable.IfNot(logit_m35)) { Variable<bool> logit_m4 = logit > -4; using (Variable.If(logit_m4)) { prob.SetTo(logit * 0.0227 + 0.1082); } using (Variable.IfNot(logit_m4)) { Variable<bool> logit_m45 = logit > -4.5; using (Variable.If(logit_m45)) { prob.SetTo(logit * 0.014 + 0.0738); } using (Variable.IfNot(logit_m45)) { Variable<bool> logit_m5 = logit > -5; using (Variable.If(logit_m5)) { prob.SetTo(logit * 0.0086 + 0.0495); } using (Variable.IfNot(logit_m5)) { Variable<bool> logit_m6 = logit > -6; using (Variable.If(logit_m6)) { prob.SetTo(logit * 0.0042 + 0.0271); } using (Variable.IfNot(logit_m6)) { Variable<bool> logit_m8 = logit > -8; using (Variable.If(logit_m8)) { prob.SetTo(logit * 0.001 + 0.0081); } using (Variable.IfNot(logit_m8)) { prob.SetTo(0.0001); } } } } } } } } } } } } } } } } } } } } } } } } } return prob; } }
Wednesday, April 22, 2015 4:44 PM