import jpype
import jpype.imports
from jpype.types import *
classpath = "lib/bayesserver-10.8.jar"
jpype.startJVM(classpath=[classpath])
from com.bayesserver import *
from com.bayesserver.inference import *
network = Network()
network.load("networks/Causal Inference Simple.bayes")
drug = network.getVariables().get('Drug', True)
drugTrue = drug.getStates().get('True', True)
drugFalse = drug.getStates().get('False', True)
recovered = network.getVariables().get('Recovered', True)
recoveredTrue = recovered.getStates().get('True', True)
factory = RelevanceTreeInferenceFactory()
inference = factory.createInferenceEngine(network)
queryOptions = factory.createQueryOptions()
queryOutput = factory.createQueryOutput()
queryRecovered = Table(recovered)
inference.getQueryDistributions().add(queryRecovered)
print('Non-causal version (incorrect)...')
inference.getEvidence().setState(drugTrue)
inference.query(queryOptions, queryOutput)
pRecoveredGivenDrugTrue = queryRecovered.get(recoveredTrue)
inference.getEvidence().setState(drugFalse)
inference.query(queryOptions, queryOutput)
pRecoveredGivenDrugFalse = queryRecovered.get(recoveredTrue)
effectivenessNonCausal = pRecoveredGivenDrugTrue - pRecoveredGivenDrugFalse
print(f'Effectiveness of drug (non-causal) = {effectivenessNonCausal:.2%}')
print()
print('Causal version...')
inference.getEvidence().setState(drugTrue, None, InterventionType.DO)
inference.query(queryOptions, queryOutput)
pRecoveredGivenDoDrugTrue = queryRecovered.get(recoveredTrue)
inference.getEvidence().setState(drugFalse, None, InterventionType.DO)
inference.query(queryOptions, queryOutput)
pRecoveredGivenDoDrugFalse = queryRecovered.get(recoveredTrue)
effectivenessCausal = pRecoveredGivenDoDrugTrue - pRecoveredGivenDoDrugFalse
print(f"Effectiveness of drug (causal) = {effectivenessCausal:.2%}")