multiple Infer calls (Migrated from community.research.microsoft.com) RRS feed

  • Question

  • laura posted on 02-09-2009 3:41 PM

    In search for the problem described in the thread about "Learning a Beta", I figured that I did not understand the effects of multiple calls to InferenceEngine.Infer.


     My code looks like this

    // model specification

    for iter = 0 to maxOuterIter
       do inferenceEngine.NumberOfIterations <- maxInnerIter
       inferenceEngine.Infer> (edgeLabel) |> ignore
       let output = inferenceEngine.GetOutputMessage>(edgeLabel)
       // cross messages on output and feed back to edgeLabelSync

    // Solution: inferenceEngine.NumberOfIterations <- 50

    let ePost = inferenceEngine.Infer,bool[>>(e)
    let epsilonPost = inferenceEngine.Infer>(epsilon)
    // used generate output described in "Learning a Beta"

    I was naively assuming that whenever I call "Infer" the global state of the model would be advanced, i.e. after the iteration finished, no further message passing iterations are carried out. Since I am not using InferAll, I assumed that all variables are always updated. I used a setting with maxInnerIter=1 (within the iter-loop and outside). This yielded wrong posteriors on epsilon (drawn from Beta). When I set inferenceEndgine.NumberOfIterations to 50 after the loop (indicated by the comment), the posterior on epsilion looks good.

    This insight leaves me with some questions...

    Under which circumstances is a message passing iteration carried out? It seems that in multiple calls to infer(variable), the result is cached.

    Which variables are affected? Obviously some variables (such es "e" were updated to get a result for "edgeLabel")

    Is there a way to enforce all variables are kept up to date, knowing that I have to infer them later anyway?



    Friday, June 3, 2011 4:44 PM


  • jwinn replied on 02-09-2009 5:13 PM

    In general, when you call Infer(v) the inference engine may opportunistically compute marginals for variables other than v.  These marginals will then be cached and so future calls to Infer() may return immediately with one of these cached values.  To see when inference actually occurs, just switch on the transform browser - the browser only appears when inference is being performed.

    Changing the model in some way invalidates all cached values e.g. changing ObservedValue of a variable, such as edgeLabelSync.  So when you infer ePost and epsilonPost inference will run for just one iteration from initial values, which is unlikely to lead to convergence.  This is why to get good results you need to increase the number of iterations.

    If you prefer to enforce explicitly which variables you want to compute marginals for, then use InferAll(v1,v2,...).  Marginals will be computed (and cached) for exactly this list of variables and no others.  You can then use Infer() to retrieve the marginals, as normal.

    However, in your example, you are in fact trying to perform operations in the middle of the inference iterations i.e. when you use GetOutputMessage() you really want to be crossing over message inside the inference loop.  You cannot do this using Infer() because this will always reinitialise the message passing each time it is called and then perform NumberOfIterations iterations of inference. You must instead take fine-grained control of the inference procedure using a CompiledAlgorithm object, retrieved using GetCompiledInferenceAlgorithm(v1,v2,...) as described on this page.  You need to Reset() and Initialise() the algorithm object, and then do the message crossover after each call to Update().  It should also be faster than your current method since it will not waste time reallocating and reinitialising the messages inside the loop (also convergence should be faster since you are not reinitialising the messages each time).

    Let me know if you need clarification of any of these points.

    John W.

    Friday, June 3, 2011 4:44 PM