import jpype
import jpype.imports
from jpype.types import *
classpath = "lib/bayesserver-10.8.jar"
jpype.startJVM(classpath=[classpath])
from com.bayesserver import *
from com.bayesserver.inference import *
from jpype import java
def nullable_int(x):
""" Helper function to convert an integer to a java nullable integer (java.lang.Integer) """
return java.lang.Integer(x)
network = Network('DBN')
cluster1 = State('Cluster1')
cluster2 = State('Cluster2')
cluster3 = State('Cluster3')
varTransition = Variable('Transition', cluster1, cluster2, cluster3)
nodeTransition = Node(varTransition)
nodeTransition.setTemporalType(TemporalType.TEMPORAL)
varObs1 = Variable('Obs1', VariableValueType.CONTINUOUS)
varObs2 = Variable('Obs2', VariableValueType.CONTINUOUS)
varObs3 = Variable('Obs3', VariableValueType.CONTINUOUS)
varObs4 = Variable('Obs4', VariableValueType.CONTINUOUS)
nodeObservation = Node('Observation', [varObs1, varObs2, varObs3, varObs4])
nodeObservation.setTemporalType(TemporalType.TEMPORAL)
network.getNodes().add(nodeTransition)
network.getNodes().add(nodeObservation)
network.getLinks().add(Link(nodeTransition, nodeObservation))
network.getLinks().add(Link(nodeTransition, nodeTransition, 1))
t0 = nullable_int(0)
cluster1Time0 = StateContext(cluster1, t0)
cluster2Time0 = StateContext(cluster2, t0)
cluster3Time0 = StateContext(cluster3, t0)
prior = nodeTransition.newDistribution(0).getTable()
prior.set(0.2, cluster1Time0)
prior.set(0.3, cluster2Time0)
prior.set(0.5, cluster3Time0)
nodeTransition.setDistribution(prior)
transition = nodeTransition.newDistribution(1).getTable()
tMinus1 = nullable_int(-1)
cluster1TimeM1 = StateContext(cluster1, tMinus1)
cluster2TimeM1 = StateContext(cluster2, tMinus1)
cluster3TimeM1 = StateContext(cluster3, tMinus1)
transition.set(0.2, cluster1TimeM1, cluster1Time0)
transition.set(0.3, cluster1TimeM1, cluster2Time0)
transition.set(0.5, cluster1TimeM1, cluster3Time0)
transition.set(0.4, cluster2TimeM1, cluster1Time0)
transition.set(0.4, cluster2TimeM1, cluster2Time0)
transition.set(0.2, cluster2TimeM1, cluster3Time0)
transition.set(0.9, cluster3TimeM1, cluster1Time0)
transition.set(0.09, cluster3TimeM1, cluster2Time0)
transition.set(0.01, cluster3TimeM1, cluster3Time0)
nodeTransition.getDistributions().set(1, transition)
gaussian = nodeObservation.newDistribution()
varObs1Time0 = VariableContext(varObs1, t0, HeadTail.HEAD)
varObs2Time0 = VariableContext(varObs2, t0, HeadTail.HEAD)
varObs3Time0 = VariableContext(varObs3, t0, HeadTail.HEAD)
varObs4Time0 = VariableContext(varObs4, t0, HeadTail.HEAD)
gaussian.setMean(varObs1Time0, 3.2, cluster1Time0)
gaussian.setMean(varObs2Time0, 2.4, cluster1Time0)
gaussian.setMean(varObs3Time0, -1.7, cluster1Time0)
gaussian.setMean(varObs4Time0, 6.2, cluster1Time0)
gaussian.setVariance(varObs1Time0, 2.3, cluster1Time0)
gaussian.setVariance(varObs2Time0, 2.1, cluster1Time0)
gaussian.setVariance(varObs3Time0, 3.2, cluster1Time0)
gaussian.setVariance(varObs4Time0, 1.4, cluster1Time0)
gaussian.setCovariance(varObs1Time0, varObs2Time0, -0.3, cluster1Time0)
gaussian.setCovariance(varObs1Time0, varObs3Time0, 0.5, cluster1Time0)
gaussian.setCovariance(varObs1Time0, varObs4Time0, 0.35, cluster1Time0)
gaussian.setCovariance(varObs2Time0, varObs3Time0, 0.12, cluster1Time0)
gaussian.setCovariance(varObs2Time0, varObs4Time0, 0.1, cluster1Time0)
gaussian.setCovariance(varObs3Time0, varObs4Time0, 0.23, cluster1Time0)
gaussian.setMean(varObs1Time0, 3.0, cluster2Time0)
gaussian.setMean(varObs2Time0, 2.8, cluster2Time0)
gaussian.setMean(varObs3Time0, -2.5, cluster2Time0)
gaussian.setMean(varObs4Time0, 6.9, cluster2Time0)
gaussian.setVariance(varObs1Time0, 2.1, cluster2Time0)
gaussian.setVariance(varObs2Time0, 2.2, cluster2Time0)
gaussian.setVariance(varObs3Time0, 3.3, cluster2Time0)
gaussian.setVariance(varObs4Time0, 1.5, cluster2Time0)
gaussian.setCovariance(varObs1Time0, varObs2Time0, -0.4, cluster2Time0)
gaussian.setCovariance(varObs1Time0, varObs3Time0, 0.5, cluster2Time0)
gaussian.setCovariance(varObs1Time0, varObs4Time0, 0.45, cluster2Time0)
gaussian.setCovariance(varObs2Time0, varObs3Time0, 0.22, cluster2Time0)
gaussian.setCovariance(varObs2Time0, varObs4Time0, 0.15, cluster2Time0)
gaussian.setCovariance(varObs3Time0, varObs4Time0, 0.24, cluster2Time0)
gaussian.setMean(varObs1Time0, 3.8, cluster3Time0)
gaussian.setMean(varObs2Time0, 2.0, cluster3Time0)
gaussian.setMean(varObs3Time0, -1.9, cluster3Time0)
gaussian.setMean(varObs4Time0, 6.25, cluster3Time0)
gaussian.setVariance(varObs1Time0, 2.34, cluster3Time0)
gaussian.setVariance(varObs2Time0, 2.11, cluster3Time0)
gaussian.setVariance(varObs3Time0, 3.22, cluster3Time0)
gaussian.setVariance(varObs4Time0, 1.43, cluster3Time0)
gaussian.setCovariance(varObs1Time0, varObs2Time0, -0.31, cluster3Time0)
gaussian.setCovariance(varObs1Time0, varObs3Time0, 0.52, cluster3Time0)
gaussian.setCovariance(varObs1Time0, varObs4Time0, 0.353, cluster3Time0)
gaussian.setCovariance(varObs2Time0, varObs3Time0, 0.124, cluster3Time0)
gaussian.setCovariance(varObs2Time0, varObs4Time0, 0.15, cluster3Time0)
gaussian.setCovariance(varObs3Time0, varObs4Time0, 0.236, cluster3Time0)
nodeObservation.setDistribution(gaussian)
network.validate(ValidationOptions())
inference = RelevanceTreeInference(network)
queryOptions = RelevanceTreeQueryOptions()
queryOutput = RelevanceTreeQueryOutput()
inference.getEvidence().set(varObs1, [2.2, 2.4, 2.6, 2.9], 0, 0, 4)
inference.getEvidence().set(varObs2, [None, 4.0, 4.1, 4.88], 0, 0, 4)
inference.getEvidence().set(varObs3, [-2.5, -2.3, None, -4.0], 0, 0, 4)
inference.getEvidence().set(varObs4, [4.0, 6.5, 4.9, 4.4], 0, 0, 4)
queryOptions.setLogLikelihood(True)
predict_time = nullable_int(4)
gaussian_future = []
for v in nodeObservation.getVariables():
gaussian_future_v = CLGaussian(v, predict_time)
gaussian_future.append(gaussian_future_v)
inference.getQueryDistributions().add(gaussian_future_v)
jointFuture = CLGaussian(java.util.Arrays.asList([varObs1, varObs2]), predict_time)
inference.getQueryDistributions().add(jointFuture)
inference.query(queryOptions, queryOutput)
print('LogLikelihood: ', queryOutput.getLogLikelihood())
print()
for ix, gaussian in enumerate(gaussian_future):
variableH = nodeObservation.getVariables().get(ix)
print('P({}(t=4)|evidence)={}'.format(variableH.getName(), gaussian.getMean(variableH, predict_time)))
print()
print('P({},{}|evidence)='.format(varObs1.getName(), varObs2.getName()))
print('{}\t{}'.format(jointFuture.getMean(varObs1, predict_time), jointFuture.getMean(varObs2, predict_time)))
print('{}\t{}'.format(jointFuture.getVariance(varObs1, predict_time), jointFuture.getCovariance(varObs1, predict_time, varObs2, predict_time)))
print('{}\t{}'.format(jointFuture.getCovariance(varObs2, predict_time, varObs1, predict_time), jointFuture.getVariance(varObs2, predict_time)))