Learning a Bayesian network can be used to obtain a classifier for one of the nodes of the model. For more about classifier, see pyAgrum.skbn
.
import sys
import os
import numpy as np
import pyAgrum as gum
import pyAgrum.lib.notebook as gnb
%matplotlib inline
from pyAgrum.lib.bn2roc import showROC
from pyAgrum.lib.bn2roc import showPR
from pyAgrum.lib.bn2roc import showROC_PR
SIZE_LEARN=10000
SIZE_VALID=2000
bn=gum.loadBN("res/alarm.dsl")
bn
gum.generateSample(bn,SIZE_LEARN,"out/learn.csv",show_progress=True,with_labels=True)
gum.generateSample(bn,SIZE_VALID,"out/train.csv",show_progress=True,with_labels=True)
out/learn.csv: 100%|███████████████████████████████████████|
Log2-Likelihood : -151554.4230546606
out/train.csv: 100%|███████████████████████████████████████|
Log2-Likelihood : -30406.139733341093
-30406.139733341093
# Learning a BN from the database
learner=gum.BNLearner("out/train.csv")
bn2=learner.useMIIC().learnBN()
currentTime=learner.currentTime()
gnb.flow.add(gnb.getBN(bn2,size="9"),f"Learned with {SIZE_LEARN} lines in {currentTime:.3f}s")
gnb.flow.display()
import pyAgrum.lib.bn_vs_bn as bnvsbn
gnb.flow.add(gnb.getBNDiff(bn,bn2,size="8!"),"Diff with MIIC")
gnb.flow.add(bnvsbn.graphDiffLegend())
gnb.flow.display()
bn3=learner.useGreedyHillClimbing().useNMLCorrection().useScoreBDeu().learnBN()
gnb.flow.add(gnb.getBNDiff(bn,bn3,size="8!"),"Diff with GHC/NMD/BDEU")
gnb.flow.add(bnvsbn.graphDiffLegend())
gnb.flow.display()
bn4=learner.useGreedyHillClimbing().useNMLCorrection().useScoreBDeu().setInitialDAG(bn2.dag()).learnBN()
gnb.flow.add(gnb.getBNDiff(bn,bn4,size="8!"),"Diff with GHC/NMD/BDEU with intial DAG from MIIC")
gnb.flow.add(bnvsbn.graphDiffLegend())
gnb.flow.display()
print(bn2.names())
{'MINVOL', 'KINKEDTUBE', 'VENTTUBE', 'DISCONNECT', 'SAO2', 'ERRLOWOUTPUT', 'BP', 'INTUBATION', 'LVFAILURE', 'HYPOVOLEMIA', 'STROKEVOLUME', 'ERRCAUTER', 'ARTCO2', 'ANAPHYLAXIS', 'HRBP', 'LVEDVOLUME', 'PVSAT', 'VENTALV', 'INSUFFANESTH', 'MINVOLSET', 'PCWP', 'TPR', 'CVP', 'CATECHOL', 'VENTLUNG', 'CO', 'VENTMACH', 'PAP', 'PULMEMBOLUS', 'FIO2', 'HISTORY', 'PRESS', 'HREKG', 'HRSAT', 'HR', 'EXPCO2', 'SHUNT'}
gnb.showInference(bn2,evs={},size="10")