Skip to main content

Parameter tuning in Python

# __author__ = 'Bayes Server'
# __version__= '0.2'

import jpype # pip install jpype1 (version 1.2.1 or later)
import jpype.imports
from jpype.types import *

classpath = "lib/bayesserver-10.8.jar" # TODO download the Bayes Server Java API, and adjust the path

# Launch the JVM
jpype.startJVM(classpath=[classpath])

# import the Java modules
from com.bayesserver import *
from com.bayesserver.inference import *
from com.bayesserver.analysis import *
from jpype import java

# Uncomment the following line and change the license key, if you are using a licensed version
# License.validate("xxx")

# TODO download network from the Bayes Server User Interface (or Bayes Server Online)
# and adjust the following path
network_path = 'networks/Asia.bayes'

network = Network()
network.load(network_path)

variables = network.getVariables()

visit_to_asia = variables.get('Visit to Asia', True)
has_lung_cancer = variables.get('Has Lung Cancer', True)
tuberculosis_or_cancer = variables.get('Tuberculosis or Cancer', True)
smoker = variables.get('Smoker', True)
has_tuberculosis = variables.get('Has Tuberculosis', True)
dyspnea = variables.get('Dyspnea', True)
xray_result = variables.get('XRay Result', True)
has_bronchitis = variables.get('Has Bronchitis', True)

xRayResultAbnormal = xray_result.getStates().get('Abnormal', True)
smokerFalse = smoker.getStates().get('False', True)
hasLungCancerFalse = has_lung_cancer.getStates().get('False', True)

evidence = DefaultEvidence(network)

# TODO set any evidence here if you need to...

sensitivity = SensitivityToParameters(network, RelevanceTreeInferenceFactory())

parameters_to_test = []

parameters_to_test.append(
ParameterReference(has_lung_cancer.getNode(), [smokerFalse, hasLungCancerFalse]))

# TODO add more parameters to test here if necessary

print('Node\tParameter\tMin\tMax')

for parameter in parameters_to_test:

oneWay = sensitivity.oneWay(
evidence,
xRayResultAbnormal,
parameter)

try:

output = ParameterTuning.oneWaySimple(
oneWay,
Interval(
java.lang.Double(0.2),
java.lang.Double(0.25),
IntervalEndPoint.CLOSED,
IntervalEndPoint.CLOSED))

param_states_text = '[' + ','.join([str(s.getVariable().getName()) + ' = ' + str(s.getName()) for s in parameter.getStates()]) + ']'
print('{}\t{}\t{}\t{}'.format(
parameter.getNode().getName(),
param_states_text,
output.getInterval().getMinimum(),
output.getInterval().getMaximum()
))

except ConstraintNotSatisfiedException:
print("Ignoring here as solution not found for this parameter.")

# Expected output...

# Node Parameter Min Max
# Has Lung Cancer[Has Lung Cancer = False, Smoker = False] 0.686390938882659 0.795047852504759