locked
approximate logistic function with a linear function RRS feed

  • Question

  • Hi

    I have a logistic factor node in my graph. Since the output of logistic is itself multiplied by another variable with Gaussian distribution, the model is not supported by EP (multiplication of beta, logistic output, and gaussian is not supported).

    Now I want to approximate my logistic with just a series of linear functions.

    For example in this case (for logistic(x)) I would define intervals for x:

    if ( x<= -1.4) then y=0
    
    else {
       if(x <= 0.5) y = 0.3;
       else{
         if(x<=0.4) y= 0.4;
         else{
            if(x<=1.3) y= 0.7;
            else y=1;
         }
       }
    }

    I have this variable x and I use something like "using(Variable.If(x <= 0.5)) ... "

    But the problem is that these intervals are fixed and there is no room for the error in here.

    I just get the error "probability zero" for the model.

    How can I make this problem solved? Multiplication of beta and Gaussian, or this linearization?

    Any help is appreciated. Many Thanks.

    Wednesday, September 10, 2014 11:11 PM

All replies

  • What you describe should work as long as you are not directly observing the output of the logistic. Try posting your Infer.NET code.
    • Marked as answer by Capli19 Friday, September 12, 2014 1:30 PM
    • Unmarked as answer by Capli19 Friday, September 12, 2014 1:30 PM
    Thursday, September 11, 2014 3:40 PM
    Owner
  • Thanks Tom.

    I will work more on it and then post it. I thought this implementation is not correct because there is not any noise considered for the model.

    I mean for example for (y = a.x ), we always model it considering a noise (Y = A.X + e). I thought this is the reason that the model does not converge.

    Now I check my code again to see what's the problem and will send it if I couldn't solve the problem.

    Friday, September 12, 2014 1:30 PM
  • I've attempted to write a function that replicates the logistic function as a combination of linear functions, my code is:

    class InfernetTools
        {
            public static Variable<double> pseudo_logistic(Variable<double> logit)
            {
                Variable<double> prob_abs = Variable.New<double>();
                Variable<double> logit_abs = Variable.New<double>();
                using (Variable.If(logit > 0))
                {
                    logit_abs.SetTo(logit);
                }
                using (Variable.IfNot(logit > 0))
                {
                    logit_abs.SetTo(-logit);
                }
                using (Variable.If(logit_abs > 8))
                {
                    prob_abs.SetTo(0.9999);
                }
                using (Variable.IfNot(logit_abs > 8))
                {
                    using (Variable.If(logit_abs > 6))
                    {
                        prob_abs.SetTo(logit_abs * 0.001 + 0.9919);
                    }
                    using (Variable.IfNot(logit_abs > 6))
                    {
                        using (Variable.If(logit_abs > 5))
                        {
                            prob_abs.SetTo(logit_abs * 0.0042 + 0.9729);
                        }
                        using (Variable.IfNot(logit_abs > 5))
                        {
                            using (Variable.If(logit_abs > 4.5))
                            {
                                prob_abs.SetTo(logit_abs * 0.0086 + 0.9505);
                            }
                            using (Variable.IfNot(logit_abs > 4.5))
                            {
                                using (Variable.If(logit_abs > 4))
                                {
                                    prob_abs.SetTo(logit_abs * 0.014 + 0.9262);
                                }
                                using (Variable.IfNot(logit_abs > 4))
                                {
                                    using (Variable.If(logit_abs > 3.5))
                                    {
                                        prob_abs.SetTo(logit_abs * 0.0227 + 0.8918);
                                    }
                                    using (Variable.IfNot(logit_abs > 3.5))
                                    {
                                        using (Variable.If(logit_abs > 3))
                                        {
                                            prob_abs.SetTo(logit_abs * 0.0363 + 0.8444);
                                        }
                                        using (Variable.IfNot(logit_abs > 3))
                                        {
                                            using (Variable.If(logit_abs > 2.5))
                                            {
                                                prob_abs.SetTo(logit_abs * 0.057 + 0.7828);
                                            }
                                            using (Variable.IfNot(logit_abs > 2.5))
                                            {
                                                using (Variable.If(logit_abs > 2))
                                                {
                                                    prob_abs.SetTo(logit_abs * 0.0869 + 0.7085);
                                                }
                                                using (Variable.IfNot(logit_abs > 2))
                                                {
                                                    using (Variable.If(logit_abs > 1.5))
                                                    {
                                                        prob_abs.SetTo(logit_abs * 0.1268 + 0.6292);
                                                    }
                                                    using (Variable.IfNot(logit_abs > 1.5))
                                                    {
                                                        using (Variable.If(logit_abs > 1))
                                                        {
                                                            prob_abs.SetTo(logit_abs * 0.1735 + 0.5594);
                                                        }
                                                        using (Variable.IfNot(logit_abs > 1))
                                                        {
                                                            using (Variable.If(logit_abs > 0.5))
                                                            {
                                                                prob_abs.SetTo(logit_abs * 0.2179 + 0.515);
                                                            }
                                                            using (Variable.IfNot(logit_abs > 0.5))
                                                            {
                                                                using (Variable.If(logit_abs > 0))
                                                                {
                                                                    prob_abs.SetTo(logit_abs * 0.2456 + 0.5005);
                                                                }
                                                                using (Variable.IfNot(logit_abs > 0))
                                                                {
                                                                    //Do nothing, logit_abs can't be negative
                                                                }
                                                            }
                                                        }
                                                    }
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
                Variable<double> output = Variable.New<double>();
                using (Variable.If(logit > 0))
                {
                    output.SetTo(prob_abs);
                }
                using (Variable.IfNot(logit > 0))
                {
                    output.SetTo(1 - prob_abs);
                }
                return prob_abs;
            }
        }

    However when I try run my code with Variable.Logistic replaced with InfernetTools.pseudo_logistic I get a long error beginning:

    GateTransform failed with 6 error(s) and 0 warning(s):

    Error 0: 'vdouble77' is not defined in all cases.  It is only defined for (vbool8=true)(vbool9=false) in ...

    What am I doing wrong in my function? I thought I'd covered off all cases with the If/IfNot structure.

    Wednesday, April 22, 2015 3:28 PM
  • The innermost case is missing.  Also, you should make intermediate variables for the comparisons so that they are not recomputed.  Finally, this method of using logit_abs is not recommended since it leads to a difficult inference problem.  Try to branch only on 'logit'.
    Wednesday, April 22, 2015 3:45 PM
    Owner
  • Thanks Tom, I'll give those changes a try. The only thing I don't understand is what you mean by intermediate variables to avoid recomputation - which variable should I be making intermediates of?
    Wednesday, April 22, 2015 3:58 PM
  • Make an intermediate for 'logit > 5', etc.
    Wednesday, April 22, 2015 4:03 PM
    Owner
  • Thanks Tom, I made those changes and it seems to work. For the benefit of others, here's the fixed function:

    class InfernetTools
        {
            public static Variable<double> pseudo_logistic(Variable<double> logit)
            {
                Variable<double> prob = Variable.New<double>();
                Variable<bool> logit_8 = logit > 8;
                using (Variable.If(logit_8))
                {
                    prob.SetTo(0.9999);
                }
                using (Variable.IfNot(logit_8))
                {
                    Variable<bool> logit_6 = logit > 6;
                    using (Variable.If(logit_6))
                    {
                        prob.SetTo(logit * 0.001 + 0.9919);
                    }
                    using (Variable.IfNot(logit_6))
                    {
                        Variable<bool> logit_5 = logit > 5;
                        using (Variable.If(logit_5))
                        {
                            prob.SetTo(logit * 0.0042 + 0.9729);
                        }
                        using (Variable.IfNot(logit_5))
                        {
                            Variable<bool> logit_45 = logit > 4.5;
                            using (Variable.If(logit_45))
                            {
                                prob.SetTo(logit * 0.0086 + 0.9505);
                            }
                            using (Variable.IfNot(logit_45))
                            {
                                Variable<bool> logit_4 = logit > 4;
                                using (Variable.If(logit_4))
                                {
                                    prob.SetTo(logit * 0.014 + 0.9262);
                                }
                                using (Variable.IfNot(logit_4))
                                {
                                    Variable<bool> logit_35 = logit > 3.5;
                                    using (Variable.If(logit_35))
                                    {
                                        prob.SetTo(logit * 0.0227 + 0.8918);
                                    }
                                    using (Variable.IfNot(logit_35))
                                    {
                                        Variable<bool> logit_3 = logit > 3;
                                        using (Variable.If(logit_3))
                                        {
                                            prob.SetTo(logit * 0.0363 + 0.8444);
                                        }
                                        using (Variable.IfNot(logit_3))
                                        {
                                            Variable<bool> logit_25 = logit > 25;
                                            using (Variable.If(logit_25))
                                            {
                                                prob.SetTo(logit * 0.057 + 0.7828);
                                            }
                                            using (Variable.IfNot(logit_25))
                                            {
                                                Variable<bool> logit_2 = logit > 2;
                                                using (Variable.If(logit_2))
                                                {
                                                    prob.SetTo(logit * 0.0869 + 0.7085);
                                                }
                                                using (Variable.IfNot(logit_2))
                                                {
                                                    Variable<bool> logit_15 = logit > 1.5;
                                                    using (Variable.If(logit_15))
                                                    {
                                                        prob.SetTo(logit * 0.1268 + 0.6292);
                                                    }
                                                    using (Variable.IfNot(logit_15))
                                                    {
                                                        Variable<bool> logit_1 = logit > 1;
                                                        using (Variable.If(logit_1))
                                                        {
                                                            prob.SetTo(logit * 0.1735 + 0.5594);
                                                        }
                                                        using (Variable.IfNot(logit_1))
                                                        {
                                                            Variable<bool> logit_05 = logit > 0.5;
                                                            using (Variable.If(logit_05))
                                                            {
                                                                prob.SetTo(logit * 0.2179 + 0.515);
                                                            }
                                                            using (Variable.IfNot(logit_05))
                                                            {
                                                                Variable<bool> logit_0 = logit > 0;
                                                                using (Variable.If(logit_0))
                                                                {
                                                                    prob.SetTo(logit * 0.2456 + 0.5005);
                                                                }
                                                                using (Variable.IfNot(logit_0))
                                                                {
                                                                    Variable<bool> logit_m0 = logit > -0.5;
                                                                    using (Variable.If(logit_m0))
                                                                    {
                                                                        prob.SetTo(logit * 0.2456 + 0.4995);
                                                                    }
                                                                    using (Variable.IfNot(logit_m0))
                                                                    {
                                                                        Variable<bool> logit_m1 = logit > -1;
                                                                        using (Variable.If(logit_m1))
                                                                        {
                                                                            prob.SetTo(logit * 0.2179 + 0.485);
                                                                        }
                                                                        using (Variable.IfNot(logit_m1))
                                                                        {
                                                                            Variable<bool> logit_m15 = logit > -1.5;
                                                                            using (Variable.If(logit_m15))
                                                                            {
                                                                                prob.SetTo(logit * 0.1735 + 0.4406);
                                                                            }
                                                                            using (Variable.IfNot(logit_m15))
                                                                            {
                                                                                Variable<bool> logit_m2 = logit > -2;
                                                                                using (Variable.If(logit_m2))
                                                                                {
                                                                                    prob.SetTo(logit * 0.1268 + 0.3708);
                                                                                }
                                                                                using (Variable.IfNot(logit_m2))
                                                                                {
                                                                                    Variable<bool> logit_m25 = logit > -2.5;
                                                                                    using (Variable.If(logit_m25))
                                                                                    {
                                                                                        prob.SetTo(logit * 0.0869 + 0.2915);
                                                                                    }
                                                                                    using (Variable.IfNot(logit_m25))
                                                                                    {
                                                                                        Variable<bool> logit_m3 = logit > -3;
                                                                                        using (Variable.If(logit_m3))
                                                                                        {
                                                                                            prob.SetTo(logit * 0.057 + 0.2172);
                                                                                        }
                                                                                        using (Variable.IfNot(logit_m3))
                                                                                        {
                                                                                            Variable<bool> logit_m35 = logit > -3.5;
                                                                                            using (Variable.If(logit_m35))
                                                                                            {
                                                                                                prob.SetTo(logit * 0.0363 + 0.1556);
                                                                                            }
                                                                                            using (Variable.IfNot(logit_m35))
                                                                                            {
                                                                                                Variable<bool> logit_m4 = logit > -4;
                                                                                                using (Variable.If(logit_m4))
                                                                                                {
                                                                                                    prob.SetTo(logit * 0.0227 + 0.1082);
                                                                                                }
                                                                                                using (Variable.IfNot(logit_m4))
                                                                                                {
                                                                                                    Variable<bool> logit_m45 = logit > -4.5;
                                                                                                    using (Variable.If(logit_m45))
                                                                                                    {
                                                                                                        prob.SetTo(logit * 0.014 + 0.0738);
                                                                                                    }
                                                                                                    using (Variable.IfNot(logit_m45))
                                                                                                    {
                                                                                                        Variable<bool> logit_m5 = logit > -5;
                                                                                                        using (Variable.If(logit_m5))
                                                                                                        {
                                                                                                            prob.SetTo(logit * 0.0086 + 0.0495);
                                                                                                        }
                                                                                                        using (Variable.IfNot(logit_m5))
                                                                                                        {
                                                                                                            Variable<bool> logit_m6 = logit > -6;
                                                                                                            using (Variable.If(logit_m6))
                                                                                                            {
                                                                                                                prob.SetTo(logit * 0.0042 + 0.0271);
                                                                                                            }
                                                                                                            using (Variable.IfNot(logit_m6))
                                                                                                            {
                                                                                                                Variable<bool> logit_m8 = logit > -8;
                                                                                                                using (Variable.If(logit_m8))
                                                                                                                {
                                                                                                                    prob.SetTo(logit * 0.001 + 0.0081);
                                                                                                                }
                                                                                                                using (Variable.IfNot(logit_m8))
                                                                                                                {
                                                                                                                    prob.SetTo(0.0001);
                                                                                                                }
                                                                                                            }
                                                                                                        }
                                                                                                    }
                                                                                                }
                                                                                            }
                                                                                        }
                                                                                    }
                                                                                }
                                                                            }
                                                                        }
                                                                    }
                                                                }
                                                            }
                                                        }
                                                    }
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
                return prob;
            }
        }

    Wednesday, April 22, 2015 4:44 PM