Skip to main content

Causal inference in Python

# __author__ = 'Bayes Server'
# __version__= '0.3'

import jpype # pip install jpype1 (version 1.2.1 or later)
import jpype.imports
from jpype.types import *

classpath = "lib/bayesserver-10.8.jar" # TODO download the Bayes Server Java API, and adjust the path

# Launch the JVM
jpype.startJVM(classpath=[classpath])

# import the Java modules

from com.bayesserver import *
from com.bayesserver.inference import *

network = Network()

# TODO download network from the Bayes Server User Interface (or Bayes Server Online)
# and adjust the following path
network.load("networks/Causal Inference Simple.bayes")

drug = network.getVariables().get('Drug', True)
drugTrue = drug.getStates().get('True', True)
drugFalse = drug.getStates().get('False', True)

recovered = network.getVariables().get('Recovered', True)
recoveredTrue = recovered.getStates().get('True', True)

factory = RelevanceTreeInferenceFactory()
inference = factory.createInferenceEngine(network)
queryOptions = factory.createQueryOptions()
queryOutput = factory.createQueryOutput()
queryRecovered = Table(recovered)
inference.getQueryDistributions().add(queryRecovered)

print('Non-causal version (incorrect)...')

# First lets calculate P(Recovered=True|Drug=True) - P(Recovered=True|Drug=False)
# without an intervention. i.e. non-causally which will give us the wrong result

inference.getEvidence().setState(drugTrue)
inference.query(queryOptions, queryOutput)

pRecoveredGivenDrugTrue = queryRecovered.get(recoveredTrue)
inference.getEvidence().setState(drugFalse)
inference.query(queryOptions, queryOutput)
pRecoveredGivenDrugFalse = queryRecovered.get(recoveredTrue)

effectivenessNonCausal = pRecoveredGivenDrugTrue - pRecoveredGivenDrugFalse
print(f'Effectiveness of drug (non-causal) = {effectivenessNonCausal:.2%}')

print()

print('Causal version...')

# Now lets calculate P(Recovered=True|Do(Drug=True)) - P(Recovered=True|Do(Drug=False))
# with an intervention. i.e. causally which will give us the correct result

inference.getEvidence().setState(drugTrue, None, InterventionType.DO)
inference.query(queryOptions, queryOutput)
pRecoveredGivenDoDrugTrue = queryRecovered.get(recoveredTrue)

inference.getEvidence().setState(drugFalse, None, InterventionType.DO)
inference.query(queryOptions, queryOutput)
pRecoveredGivenDoDrugFalse = queryRecovered.get(recoveredTrue)

effectivenessCausal = pRecoveredGivenDoDrugTrue - pRecoveredGivenDoDrugFalse
print(f"Effectiveness of drug (causal) = {effectivenessCausal:.2%}")

# Expected output:
#
# Non-causal version (incorrect)...
# Effectiveness of drug(non-causal) = -4.91 %
#
# Causal version...
# Effectiveness of drug(causal) = 5.02 %