Skip to main content

Causal inference with a Bayesian network in Java

package com.bayesserver.examples;

import com.bayesserver.*;
import com.bayesserver.inference.*;

import javax.xml.stream.XMLStreamException;
import java.io.IOException;
import java.text.NumberFormat;

public final class CausalInferenceExample {

public static void main(String[] args) throws XMLStreamException, IOException, InconsistentEvidenceException {

Network network = new Network();

// TODO download the network from the Bayes Server User Interface (or Bayes Server Online)
// and adjust the following path
network.load("Causal Inference Simple.bayes");

Variable drug = network.getVariables().get("Drug", true);
State drugTrue = drug.getStates().get("True", true);
State drugFalse = drug.getStates().get("False", true);

Variable recovered = network.getVariables().get("Recovered", true);
State recoveredTrue = recovered.getStates().get("True", true);

InferenceFactory factory = new RelevanceTreeInferenceFactory();
Inference inference = factory.createInferenceEngine(network);
QueryOptions queryOptions = factory.createQueryOptions();
QueryOutput queryOutput = factory.createQueryOutput();
Table queryRecovered = new Table(recovered);
inference.getQueryDistributions().add(queryRecovered);

NumberFormat percentFormat = NumberFormat.getPercentInstance();
percentFormat.setMinimumFractionDigits(2);
percentFormat.setMaximumFractionDigits(2);

{
System.out.println("Non-causal version (incorrect)...");

// First lets calculate P(Recovered=True|Drug=True) - P(Recovered=True|Drug=False)
// without an intervention. i.e. non-causally which will give us the wrong result

inference.getEvidence().setState(drugTrue);
inference.query(queryOptions, queryOutput);

double pRecoveredGivenDrugTrue = queryRecovered.get(recoveredTrue);
inference.getEvidence().setState(drugFalse);
inference.query(queryOptions, queryOutput);
double pRecoveredGivenDrugFalse = queryRecovered.get(recoveredTrue);

double effectivenessNonCausal = pRecoveredGivenDrugTrue - pRecoveredGivenDrugFalse;
System.out.printf("Effectiveness of drug (non-causal) = %s%n", percentFormat.format(effectivenessNonCausal));

}

System.out.println();

{
System.out.println("Causal version...");

// Now lets calculate P(Recovered=True|Do(Drug=True)) - P(Recovered=True|Do(Drug=False))
// with an intervention. i.e. causally which will give us the correct result

inference.getEvidence().setState(drugTrue, null, InterventionType.DO);
inference.query(queryOptions, queryOutput);
double pRecoveredGivenDoDrugTrue = queryRecovered.get(recoveredTrue);

inference.getEvidence().setState(drugFalse, null, InterventionType.DO);
inference.query(queryOptions, queryOutput);
double pRecoveredGivenDoDrugFalse = queryRecovered.get(recoveredTrue);

double effectivenessCausal = pRecoveredGivenDoDrugTrue - pRecoveredGivenDoDrugFalse;
System.out.printf("Effectiveness of drug (causal) = %s%n", percentFormat.format(effectivenessCausal));

}

// Expected output:

// Non-causal version (incorrect)...
// Effectiveness of drug(non-causal) = -4.91 %

// Causal version...
// Effectiveness of drug(causal) = 5.02 %

}
}