Skip to content

Approximate inference in aGrUM (pyAgrum)

Creative Commons LicenseaGrUMinteractive online version

There are several approximate inference for BN in aGrUM (pyAgrum). They share the same API than exact inference.

  • Loopy Belief Propagation : LBP is an approximate inference that uses exact calculous methods (when the BN os a tree) even if the BN is not a tree. LBP is a special case of inference : the algorithm may not converge and even if it converges, it may converge to anything (but the exact posterior). LBP however is fast and usually gives not so bad results.
  • Sampling inference : Sampling inference use sampling to compute the posterior. The sampling may be (very) slow but those algorithms converge to the exac distribution. aGrUM implements :
    • Montecarlo sampling,
    • Weighted sampling,
    • Importance sampling,
    • Gibbs sampling.
  • Finally, aGrUM propose the so-called ‘loopy version’ of the sampling algorithms : the idea is to use LBP as a Dirichlet prior for the sampling algorithm. A loopy version of each sampling algorithm is proposed.
%matplotlib inline
from pylab import *
import matplotlib.pyplot as plt
def unsharpen(bn):
"""
Force the parameters of the BN not to be a bit more far from 0 or 1
"""
for nod in bn.nodes():
bn.cpt(nod).translate(bn.maxParam() / 10).normalizeAsCPT()
def compareInference(ie, ie2, ax=None):
"""
compare 2 inference by plotting all the points from (posterior(ie),posterior(ie2))
"""
exact = []
appro = []
errmax = 0
for node in bn.nodes():
# Tensors as list
exact += ie.posterior(node).tolist()
appro += ie2.posterior(node).tolist()
errmax = max(errmax, (ie.posterior(node) - ie2.posterior(node)).abs().max())
if errmax < 1e-10:
errmax = 0
if ax == None:
fig = plt.Figure(figsize=(4, 4))
ax = plt.gca() # default axis for plt
ax.plot(exact, appro, "ro")
ax.set_title(
"{} vs {}\n {}\nMax error {:2.4} in {:2.4} seconds".format(
str(type(ie)).split(".")[2].split("_")[0][0:-2], # name of first inference
str(type(ie2)).split(".")[2].split("_")[0][0:-2], # name of second inference
ie2.messageApproximationScheme(),
errmax,
ie2.currentTime(),
)
)
import pyagrum as gum
import pyagrum.lib.notebook as gnb
bn = gum.loadBN("res/alarm.dsl")
unsharpen(bn)
ie = gum.LazyPropagation(bn)
ie.makeInference()
gnb.showBN(bn, size="8")

svg

gnb.sideBySide(gnb.getJunctionTreeMap(bn), gnb.getInference(bn, size="8")) # using LazyPropagation by default
print(ie.posterior("KINKEDTUBE"))
0 0~16 0--0~16 1 1~32 1--1~32 2 2~33 2--2~33 3 3~4 3--3~4 4 4~22 4--4~22 5 5~22 5--5~22 6 6~23 6--6~23 7 7~26 7--7~26 8 8~17 8--8~17 10 10~14 10--10~14 11 11~16 11--11~16 12 12~13 12--12~13 13 13~30 13--13~30 14 14~26 14--14~26 16 16~17 16--16~17 17 17~24 17--17~24 19 19~27 19--19~27 20 20~33 20--20~33 22 22~33 22--22~33 23 23~27 23--23~27 23~31 23--23~31 24 24~26 24--24~26 26 26~27 26--26~27 27 30 30~31 30--30~31 31 31~32 31--31~32 32 32~33 32--32~33 33 19~27--27 12~13--13 2~33--33 23~27--27 22~33--33 11~16--16 24~26--26 31~32--32 10~14--14 26~27--27 13~30--30 5~22--22 7~26--26 20~33--33 16~17--17 32~33--33 23~31--31 8~17--17 1~32--32 3~4--4 4~22--22 17~24--24 14~26--26 30~31--31 6~23--23 0~16--16
structs Inference in  16.44ms KINKEDTUBE 2025-10-29T14:06:14.778300 image/svg+xml Matplotlib v3.10.7, VENTLUNG 2025-10-29T14:06:15.193439 image/svg+xml Matplotlib v3.10.7, KINKEDTUBE->VENTLUNG PRESS 2025-10-29T14:06:15.242099 image/svg+xml Matplotlib v3.10.7, KINKEDTUBE->PRESS HYPOVOLEMIA 2025-10-29T14:06:14.795719 image/svg+xml Matplotlib v3.10.7, STROKEVOLUME 2025-10-29T14:06:15.011039 image/svg+xml Matplotlib v3.10.7, HYPOVOLEMIA->STROKEVOLUME LVEDVOLUME 2025-10-29T14:06:15.048698 image/svg+xml Matplotlib v3.10.7, HYPOVOLEMIA->LVEDVOLUME INTUBATION 2025-10-29T14:06:14.816481 image/svg+xml Matplotlib v3.10.7, SHUNT 2025-10-29T14:06:15.109500 image/svg+xml Matplotlib v3.10.7, INTUBATION->SHUNT INTUBATION->VENTLUNG MINVOL 2025-10-29T14:06:15.216670 image/svg+xml Matplotlib v3.10.7, INTUBATION->MINVOL INTUBATION->PRESS VENTALV 2025-10-29T14:06:15.265928 image/svg+xml Matplotlib v3.10.7, INTUBATION->VENTALV MINVOLSET 2025-10-29T14:06:14.837139 image/svg+xml Matplotlib v3.10.7, VENTMACH 2025-10-29T14:06:15.070290 image/svg+xml Matplotlib v3.10.7, MINVOLSET->VENTMACH PULMEMBOLUS 2025-10-29T14:06:14.855641 image/svg+xml Matplotlib v3.10.7, PAP 2025-10-29T14:06:14.991892 image/svg+xml Matplotlib v3.10.7, PULMEMBOLUS->PAP PULMEMBOLUS->SHUNT INSUFFANESTH 2025-10-29T14:06:14.874147 image/svg+xml Matplotlib v3.10.7, CATECHOL 2025-10-29T14:06:15.392311 image/svg+xml Matplotlib v3.10.7, INSUFFANESTH->CATECHOL ERRLOWOUTPUT 2025-10-29T14:06:14.891242 image/svg+xml Matplotlib v3.10.7, HRBP 2025-10-29T14:06:15.429657 image/svg+xml Matplotlib v3.10.7, ERRLOWOUTPUT->HRBP ERRCAUTER 2025-10-29T14:06:14.907481 image/svg+xml Matplotlib v3.10.7, HRSAT 2025-10-29T14:06:15.449429 image/svg+xml Matplotlib v3.10.7, ERRCAUTER->HRSAT HREKG 2025-10-29T14:06:15.487626 image/svg+xml Matplotlib v3.10.7, ERRCAUTER->HREKG FIO2 2025-10-29T14:06:14.923978 image/svg+xml Matplotlib v3.10.7, PVSAT 2025-10-29T14:06:15.329528 image/svg+xml Matplotlib v3.10.7, FIO2->PVSAT LVFAILURE 2025-10-29T14:06:14.941007 image/svg+xml Matplotlib v3.10.7, LVFAILURE->STROKEVOLUME LVFAILURE->LVEDVOLUME HISTORY 2025-10-29T14:06:15.128055 image/svg+xml Matplotlib v3.10.7, LVFAILURE->HISTORY DISCONNECT 2025-10-29T14:06:14.957782 image/svg+xml Matplotlib v3.10.7, VENTTUBE 2025-10-29T14:06:15.147866 image/svg+xml Matplotlib v3.10.7, DISCONNECT->VENTTUBE ANAPHYLAXIS 2025-10-29T14:06:14.973922 image/svg+xml Matplotlib v3.10.7, TPR 2025-10-29T14:06:15.030357 image/svg+xml Matplotlib v3.10.7, ANAPHYLAXIS->TPR CO 2025-10-29T14:06:15.468172 image/svg+xml Matplotlib v3.10.7, STROKEVOLUME->CO TPR->CATECHOL BP 2025-10-29T14:06:15.506853 image/svg+xml Matplotlib v3.10.7, TPR->BP PCWP 2025-10-29T14:06:15.091541 image/svg+xml Matplotlib v3.10.7, LVEDVOLUME->PCWP CVP 2025-10-29T14:06:15.170478 image/svg+xml Matplotlib v3.10.7, LVEDVOLUME->CVP VENTMACH->VENTTUBE SAO2 2025-10-29T14:06:15.349879 image/svg+xml Matplotlib v3.10.7, SHUNT->SAO2 VENTTUBE->VENTLUNG VENTTUBE->PRESS VENTLUNG->MINVOL VENTLUNG->VENTALV EXPCO2 2025-10-29T14:06:15.371306 image/svg+xml Matplotlib v3.10.7, VENTLUNG->EXPCO2 ARTCO2 2025-10-29T14:06:15.289224 image/svg+xml Matplotlib v3.10.7, VENTALV->ARTCO2 VENTALV->PVSAT ARTCO2->EXPCO2 ARTCO2->CATECHOL PVSAT->SAO2 SAO2->CATECHOL HR 2025-10-29T14:06:15.410572 image/svg+xml Matplotlib v3.10.7, CATECHOL->HR HR->HRBP HR->HRSAT HR->CO HR->HREKG CO->BP
KINKEDTUBE |
TRUE |FALSE |
---------|---------|
0.1167 | 0.8833 |

Gibbs inference iterations can be stopped :

  • by the value of error (epsilon)
  • by the rate of change of epsilon (MinEpsilonRate)
  • by the number of iteration (MaxIteration)
  • by the duration of the algorithm (MaxTime)
ie2 = gum.GibbsSampling(bn)
ie2.setEpsilon(1e-2)
gnb.showInference(bn, engine=ie2, size="8")
print(ie2.posterior("KINKEDTUBE"))
print(ie2.messageApproximationScheme())
compareInference(ie, ie2)

svg

KINKEDTUBE |
TRUE |FALSE |
---------|---------|
0.0968 | 0.9032 |
stopped with rate=0.00673795

svg

With default parameters, this inference has been stopped by a low value of rate.

ie2 = gum.GibbsSampling(bn)
ie2.setMaxIter(1000)
ie2.setEpsilon(5e-3)
ie2.makeInference()
print(ie2.posterior(2))
print(ie2.messageApproximationScheme())
INTUBATION |
NORMAL |ESOPHAGEA|ONESIDED |
---------|---------|---------|
0.8736 | 0.0664 | 0.0600 |
stopped with max iteration=1000
compareInference(ie, ie2)

svg

ie2 = gum.GibbsSampling(bn)
ie2.setMaxTime(3)
ie2.makeInference()
print(ie2.posterior(2))
print(ie2.messageApproximationScheme())
compareInference(ie, ie2)
INTUBATION |
NORMAL |ESOPHAGEA|ONESIDED |
---------|---------|---------|
0.6433 | 0.1900 | 0.1667 |
stopped with epsilon=0.201897

svg

ie2 = gum.GibbsSampling(bn)
ie2.setEpsilon(10**-1.8)
ie2.setBurnIn(300)
ie2.setPeriodSize(300)
ie2.setDrawnAtRandom(True)
gnb.animApproximationScheme(ie2)
ie2.makeInference()

svg

compareInference(ie, ie2)

svg

ie4 = gum.ImportanceSampling(bn)
ie4.setEpsilon(10**-1.8)
ie4.setMaxTime(10) # 10 seconds for inference
ie4.setPeriodSize(300)
ie4.makeInference()
compareInference(ie, ie4)

svg

Every sampling inference has a ‘hybrid’ version which consists in using a first loopy belief inference as a prior for the probability estimations by sampling.

ie3 = gum.LoopyGibbsSampling(bn)
ie3.setEpsilon(10**-1.8)
ie3.setMaxTime(10) # 10 seconds for inference
ie3.setPeriodSize(300)
ie3.makeInference()
compareInference(ie, ie3)

svg

These computations may be a bit long

def compareAllInference(bn, evs={}, epsilon=10**-1.6, epsilonRate=1e-8, maxTime=20):
ies = [
gum.LazyPropagation(bn),
gum.LoopyBeliefPropagation(bn),
gum.GibbsSampling(bn),
gum.LoopyGibbsSampling(bn),
gum.WeightedSampling(bn),
gum.LoopyWeightedSampling(bn),
gum.ImportanceSampling(bn),
gum.LoopyImportanceSampling(bn),
]
# burn in for Gibbs samplings
for i in [2, 3]:
ies[i].setBurnIn(300)
ies[i].setDrawnAtRandom(True)
for i in range(2, len(ies)):
ies[i].setEpsilon(epsilon)
ies[i].setMinEpsilonRate(epsilonRate)
ies[i].setPeriodSize(300)
ies[i].setMaxTime(maxTime)
for i in range(len(ies)):
ies[i].setEvidence(evs)
ies[i].makeInference()
fig, axes = plt.subplots(1, len(ies) - 1, figsize=(35, 3), num="gpplot")
for i in range(len(ies) - 1):
compareInference(ies[0], ies[i + 1], axes[i])
compareAllInference(bn, epsilon=1e-1)

svg

compareAllInference(bn, epsilon=1e-2)

svg

compareAllInference(bn, maxTime=1, epsilon=1e-8)

svg

compareAllInference(bn, maxTime=2, epsilon=1e-8)

svg

funny = {"BP": 1, "PCWP": 2, "EXPCO2": 0, "HISTORY": 0}
compareAllInference(bn, maxTime=1, evs=funny, epsilon=1e-8)

svg

compareAllInference(bn, maxTime=4, evs=funny, epsilon=1e-8)

svg

compareAllInference(bn, maxTime=10, evs=funny, epsilon=1e-8)

svg