Skip to main content

Copy fragment in Python

# __author__ = 'Bayes Server'
# __version__= '0.2'

import jpype # pip install jpype1 (version 1.2.1 or later)
import jpype.imports
from jpype.types import *

classpath = "lib/bayesserver-10.8.jar" # TODO download the Bayes Server Java API, and adjust the path

# Launch the JVM
jpype.startJVM(classpath=[classpath])

# import the Java modules
from com.bayesserver import *
from com.bayesserver.inference import *
from jpype import java

# Uncomment the following line and change the license key, if you are using a licensed version
# License.validate("xxx")


kinds = [
NodeDistributionKind.PROBABILITY,
NodeDistributionKind.EXPERIENCE,
NodeDistributionKind.FADING
]


class CacheItem:
def __init__(self, key, kind, distribution):
self.key = key
self.kind = kind
self.distribution = distribution


def map_variable_contexts(source, destination, source_to_dest):
"""
Maps variable contexts from one distribution to another
"""

if len(source) != len(destination):
raise ValueError

mapped = []

for v in range(len(source)):
source_v = source[v]
source_variable = source_v.getVariable()
dest_variable = source_to_dest[source_variable]

mapped.append(destination.indexOf(dest_variable, source_v.getTime()))

return mapped


def copy_distribution(source, dest, source_to_dest):
"""
Copies one distribution into another, using a mapping between source and
destination variables.
"""
source_to_dest_discrete = map_variable_contexts(
source.getTable().getSortedVariables(),
dest.getTable().getSortedVariables(),
source_to_dest)

dest_order_discrete = []

for i in range(dest.getTable().getSortedVariables().size()):
dest_order_discrete.append(dest.getTable().getSortedVariables().get(source_to_dest_discrete[i]))

iterator_dest = TableIterator(dest.getTable(), java.util.Arrays.asList(dest_order_discrete))

order_head = None
order_tail = None

copy_gaussian = isinstance(source, CLGaussian)

if copy_gaussian:
order_head = map_variable_contexts(
source.getSortedContinuousHead(),
dest.getSortedContinuousHead(),
source_to_dest)

order_tail = map_variable_contexts(
source.getSortedContinuousTail(),
dest.getSortedContinuousTail(),
source_to_dest)

for p in range(source.getTable().size()):

source_prob = source.getTable().get(p)
iterator_dest.setValue(source_prob)

if copy_gaussian:

for h1 in range(order_head.Length):

h1_dest = order_head[h1]
dest.setMean(p, h1_dest, source.getMean(p, h1))

for h2 in range(order_head.Length):
h2_dest = order_head[h2]
dest.setCovariance(p, h1_dest, h2_dest, source.getCovariance(p, h1, h2))

for t in range(order_tail.Length):
t_dest = order_tail[t]
dest.SetWeight(p, h1_dest, t_dest, source.getWeight(p, h1, t))

iterator_dest.increment()


def copy_fragment(source_network, source_nodes, dest_network, suffix, migrate_distributions):
"""
Copies nodes from a source network into a destination network.
The source and destination can be the same.
"""
source_nodes_lookup = set(source_nodes)
source_nodes_and_parents = []
source_nodes_and_parents_lookup = set()
source_node_groups_lookup = set()
source_node_groups_to_add = []

# We initially copy parents as well,
# so that we can make it easier to migrate distributions
# if necessary. These are then removed later

for source_node in source_nodes:

if source_node not in source_nodes_and_parents_lookup:
source_nodes_and_parents.append(source_node)
source_nodes_and_parents_lookup.add(source_node)

for link in source_node.getLinksIn():

if link.getFrom() not in source_nodes_and_parents_lookup:
source_nodes_and_parents.append(link.getFrom())
source_nodes_and_parents_lookup.add(link.getFrom())

for groupName in source_node.getGroups():

group = source_network.NodeGroups[groupName]

if source_node_groups_lookup.add(group):
source_node_groups_to_add.append(group)

source_links_to_add = []
source_links_to_add_lookup = set()

for source_node in source_nodes_and_parents:

for link in source_node.getLinks():

if link.getFrom() in source_nodes_and_parents_lookup and link.getTo() in source_nodes_and_parents_lookup:

if link not in source_links_to_add_lookup:
source_links_to_add.append(link)
source_links_to_add_lookup.add(link)

# Add in the same order as in the original network
source_nodes_and_parents.sort(key=lambda x: x.getIndex())
source_links_to_add.sort(key=lambda x: x.getIndex())

source_node_to_copy_node = {}

variable_map = {}

for source_group in source_node_groups_to_add:

if dest_network.getNodeGroups().get(source_group.getName()) is None:
dest_network.getNodeGroups().add(source_group.copy())

for source_node in source_nodes_and_parents:

dest_node = source_node.copy()
dest_node.setName(dest_node.getName() + suffix)

if source_network == dest_network:
source_bounds = source_node.getBounds()
dest_node.setBounds(Bounds(
source_bounds.getX() + 50,
source_bounds.getY() + 50,
source_bounds.getWidth(),
source_bounds.getHeight()))

for v in range(source_node.getVariables().size()):
source_variable = source_node.getVariables().get(v)
dest_variable = dest_node.getVariables().get(v)
dest_variable.setName(dest_variable.getName() + suffix)

variable_map[source_variable] = dest_variable
variable_map[dest_variable] = source_variable

if source_node not in source_nodes_lookup:
dest_node.getGroups().clear() # as we may not have added groups for this node, which will later be deleted

dest_network.getNodes().add(dest_node)

source_node_to_copy_node[source_node] = dest_node

for source_link in source_links_to_add:
from_dest = source_node_to_copy_node[source_link.getFrom()]
to_dest = source_node_to_copy_node[source_link.getTo()]

link_dest = source_link.copy(from_dest, to_dest, source_link.getTemporalOrder())

dest_network.getLinks().add(link_dest)

# copy the distributions

for source_node, node_dest in source_node_to_copy_node.items():

if source_node not in source_nodes_lookup:
continue

for source_key in source_node.getDistributions().getKeys():
related_node_dest = None

if source_key.getRelatedNode() is not None:
related_node_dest = source_node_to_copy_node[source_key.getRelatedNode()]

key_dest = NodeDistributionKey(source_key.getOrder(), related_node_dest)

if not node_dest.getDistributions().canUpdate(key_dest):
continue

for kind in kinds:

source_distribution = source_node.getDistributions().get(source_key, kind)

if source_distribution is None:
continue

distribution_dest = node_dest.newDistribution(key_dest, kind)
copy_distribution(source_distribution, distribution_dest, variable_map)
node_dest.getDistributions().set(key_dest, kind, distribution_dest)

for source_parent in source_nodes_and_parents:

if source_parent in source_nodes_lookup:
continue

parent_dest = source_node_to_copy_node[source_parent]

old_distributions_dest = None

if migrate_distributions:

old_distributions_dest = []

for source_link in source_parent.getLinksOut():

source_child = source_link.getTo()

if source_child not in source_nodes_lookup:
continue

assert (source_child in source_nodes_lookup)

child_dest = source_node_to_copy_node[source_child]

for key_dest in child_dest.getDistributions().getKeys():

for kind in kinds:

distribution_dest = child_dest.getDistributions().get(key_dest, kind)

if distribution_dest is None:
continue

cache_item = CacheItem(key_dest, kind, distribution_dest)

old_distributions_dest.append((child_dest, cache_item))

dest_network.getNodes().remove(parent_dest)

if migrate_distributions:

for node_dest, cache_item in old_distributions_dest:

old_distribution_dest = cache_item.distribution

if cache_item.key not in node_dest.getDistributions().getKeys():
continue

new_distribution_dest = node_dest.newDistribution(cache_item.key, cache_item.kind)

# TODO if you want to convert the old wider distribution to the new
# you would do that here and then set as follows
# node_dest.getDistributions().set(cache_item.key, cache_item.kind, new_distribution_dest)

return dest_network


# TODO download network from the Bayes Server User Interface (or Bayes Server Online)
# and adjust the following path
network_path = 'networks/Asia.bayes'

network = Network()
network.load(network_path)

nodes = network.getNodes()

visit_to_asia = nodes.get('Visit to Asia', True)
has_lung_cancer = nodes.get('Has Lung Cancer', True)
tuberculosis_or_cancer = nodes.get('Tuberculosis or Cancer', True)
smoker = nodes.get('Smoker', True)
has_tuberculosis = nodes.get('Has Tuberculosis', True)
dyspnea = nodes.get('Dyspnea', True)
xray_result = nodes.get('XRay Result', True)
has_bronchitis = nodes.get('Has Bronchitis', True)

dest_network = network # copy into the same network
source_nodes = [has_lung_cancer, tuberculosis_or_cancer, dyspnea, has_bronchitis]
copy_fragment(network, source_nodes, dest_network, '_copy', True)

print(network.saveToString())