Copy fragment in Python

# __author__ = 'Bayes Server'
# __version__= '0.1'


from jpype import *  # pip install jpype1==0.7.5

# TODO change path to Bayes Server jar file
classpath = "C:\\Program Files\\Bayes Server\\Bayes Server 9.2\\API\\Java\\bayesserver-9.2.jar"

startJVM(getDefaultJVMPath(), '-Djava.class.path=%s' % classpath, convertStrings=False)

bayes = JPackage('com.bayesserver')

kinds = [
    bayes.NodeDistributionKind.PROBABILITY,
    bayes.NodeDistributionKind.EXPERIENCE,
    bayes.NodeDistributionKind.FADING
]


class CacheItem:
    def __init__(self, key, kind, distribution):
        self.key = key
        self.kind = kind
        self.distribution = distribution


def map_variable_contexts(source, destination, source_to_dest):
    """
    Maps variable contexts from one distribution to another
    """

    if len(source) != len(destination):
        raise ValueError

    mapped = []

    for v in range(len(source)):
        source_v = source[v]
        source_variable = source_v.getVariable()
        dest_variable = source_to_dest[source_variable]

        mapped.append(destination.indexOf(dest_variable, source_v.getTime()))

    return mapped


def copy_distribution(source, dest, source_to_dest):
    """
    Copies one distribution into another, using a mapping between source and
    destination variables.
    """
    source_to_dest_discrete = map_variable_contexts(
        source.getTable().getSortedVariables(),
        dest.getTable().getSortedVariables(),
        source_to_dest)

    dest_order_discrete = []

    for i in range(dest.getTable().getSortedVariables().size()):
        dest_order_discrete.append(dest.getTable().getSortedVariables().get(source_to_dest_discrete[i]))

    iterator_dest = bayes.TableIterator(dest.getTable(), java.util.Arrays.asList(dest_order_discrete))

    order_head = None
    order_tail = None

    copy_gaussian = isinstance(source, bayes.CLGaussian)

    if copy_gaussian:
        order_head = map_variable_contexts(
            source.getSortedContinuousHead(),
            dest.getSortedContinuousHead(),
            source_to_dest)

        order_tail = map_variable_contexts(
            source.getSortedContinuousTail(),
            dest.getSortedContinuousTail(),
            source_to_dest)

    for p in range(source.getTable().size()):

        source_prob = source.getTable().get(p)
        iterator_dest.setValue(source_prob)

        if copy_gaussian:

            for h1 in range(order_head.Length):

                h1_dest = order_head[h1]
                dest.setMean(p, h1_dest, source.getMean(p, h1))

                for h2 in range(order_head.Length):
                    h2_dest = order_head[h2]
                    dest.setCovariance(p, h1_dest, h2_dest, source.getCovariance(p, h1, h2))

                for t in range(order_tail.Length):
                    t_dest = order_tail[t]
                    dest.SetWeight(p, h1_dest, t_dest, source.getWeight(p, h1, t))

        iterator_dest.increment()


def copy_fragment(source_network, source_nodes, dest_network, suffix, migrate_distributions):
    """
    Copies nodes from a source network into a destination network.
    The source and destination can be the same.
    """
    source_nodes_lookup = set(source_nodes)
    source_nodes_and_parents = []
    source_nodes_and_parents_lookup = set()
    source_node_groups_lookup = set()
    source_node_groups_to_add = []

    # We initially copy parents as well,
    # so that we can make it easier to migrate distributions
    # if necessary.  These are then removed later

    for source_node in source_nodes:

        if source_node not in source_nodes_and_parents_lookup:
            source_nodes_and_parents.append(source_node)
            source_nodes_and_parents_lookup.add(source_node)

        for link in source_node.getLinksIn():

            if link.getFrom() not in source_nodes_and_parents_lookup:
                source_nodes_and_parents.append(link.getFrom())
                source_nodes_and_parents_lookup.add(link.getFrom())

        for groupName in source_node.getGroups():

            group = source_network.NodeGroups[groupName]

            if source_node_groups_lookup.add(group):
                source_node_groups_to_add.append(group)

    source_links_to_add = []
    source_links_to_add_lookup = set()

    for source_node in source_nodes_and_parents:

        for link in source_node.getLinks():

            if link.getFrom() in source_nodes_and_parents_lookup and link.getTo() in source_nodes_and_parents_lookup:

                if link not in source_links_to_add_lookup:
                    source_links_to_add.append(link)
                    source_links_to_add_lookup.add(link)

    # Add in the same order as in the original network
    source_nodes_and_parents.sort(key=lambda x: x.getIndex())
    source_links_to_add.sort(key=lambda x: x.getIndex())

    source_node_to_copy_node = {}

    variable_map = {}

    for source_group in source_node_groups_to_add:

        if dest_network.getNodeGroups().get(source_group.getName()) is None:
            dest_network.getNodeGroups().add(source_group.copy())

    for source_node in source_nodes_and_parents:

        dest_node = source_node.copy()
        dest_node.setName(dest_node.getName() + suffix)

        if source_network == dest_network:
            source_bounds = source_node.getBounds()
            dest_node.setBounds(bayes.Bounds(
                source_bounds.getX() + 50,
                source_bounds.getY() + 50,
                source_bounds.getWidth(),
                source_bounds.getHeight()))

        for v in range(source_node.getVariables().size()):
            source_variable = source_node.getVariables().get(v)
            dest_variable = dest_node.getVariables().get(v)
            dest_variable.setName(dest_variable.getName() + suffix)

            variable_map[source_variable] = dest_variable
            variable_map[dest_variable] = source_variable

        if source_node not in source_nodes_lookup:
            dest_node.getGroups().clear()  # as we may not have added groups for this node, which will later be deleted

        dest_network.getNodes().add(dest_node)

        source_node_to_copy_node[source_node] = dest_node

    for source_link in source_links_to_add:
        from_dest = source_node_to_copy_node[source_link.getFrom()]
        to_dest = source_node_to_copy_node[source_link.getTo()]

        link_dest = source_link.copy(from_dest, to_dest, source_link.getTemporalOrder())

        dest_network.getLinks().add(link_dest)

    # copy the distributions

    for source_node, node_dest in source_node_to_copy_node.items():

        if source_node not in source_nodes_lookup:
            continue

        for source_key in source_node.getDistributions().getKeys():
            related_node_dest = None

            if source_key.getRelatedNode() is not None:
                related_node_dest = source_node_to_copy_node[source_key.getRelatedNode()]

            key_dest = bayes.NodeDistributionKey(source_key.getOrder(), related_node_dest)

            if not node_dest.getDistributions().canUpdate(key_dest):
                continue

            for kind in kinds:

                source_distribution = source_node.getDistributions().get(source_key, kind)

                if source_distribution is None:
                    continue

                distribution_dest = node_dest.newDistribution(key_dest, kind)
                copy_distribution(source_distribution, distribution_dest, variable_map)
                node_dest.getDistributions().set(key_dest, kind, distribution_dest)

    for source_parent in source_nodes_and_parents:

        if source_parent in source_nodes_lookup:
            continue

        parent_dest = source_node_to_copy_node[source_parent]

        old_distributions_dest = None

        if migrate_distributions:

            old_distributions_dest = []

            for source_link in source_parent.getLinksOut():

                source_child = source_link.getTo()

                if source_child not in source_nodes_lookup:
                    continue

                assert (source_child in source_nodes_lookup)

                child_dest = source_node_to_copy_node[source_child]

                for key_dest in child_dest.getDistributions().getKeys():

                    for kind in kinds:

                        distribution_dest = child_dest.getDistributions().get(key_dest, kind)

                        if distribution_dest is None:
                            continue

                        cache_item = CacheItem(key_dest, kind, distribution_dest)

                        old_distributions_dest.append((child_dest, cache_item))

        dest_network.getNodes().remove(parent_dest)

        if migrate_distributions:

            for node_dest, cache_item in old_distributions_dest:

                old_distribution_dest = cache_item.distribution

                if cache_item.key not in node_dest.getDistributions().getKeys():
                    continue

                new_distribution_dest = node_dest.newDistribution(cache_item.key, cache_item.kind)

                # TODO if you want to convert the old wider distribution to the new
                # you would do that here and then set as follows
                # node_dest.getDistributions().set(cache_item.key, cache_item.kind, new_distribution_dest)

    return dest_network


# // TODO change path to Asia network
network_path = 'C:\\ProgramData\\Bayes Server 8.19\\Sample Networks\\Asia.bayes'

network = bayes.Network()
network.load(network_path)

nodes = network.getNodes()

visit_to_asia = nodes.get('Visit to Asia', True)
has_lung_cancer = nodes.get('Has Lung Cancer', True)
tuberculosis_or_cancer = nodes.get('Tuberculosis or Cancer', True)
smoker = nodes.get('Smoker', True)
has_tuberculosis = nodes.get('Has Tuberculosis', True)
dyspnea = nodes.get('Dyspnea', True)
xray_result = nodes.get('XRay Result', True)
has_bronchitis = nodes.get('Has Bronchitis', True)

dest_network = network  # copy into the same network
source_nodes = [has_lung_cancer, tuberculosis_or_cancer, dyspnea, has_bronchitis]
copy_fragment(network, source_nodes, dest_network, '_copy', True)

print(network.saveToString())