Noisy nodes in Java

package com.bayesserver.examples;

// --------------------------------------------------------------------------------------------------------------------
// <copyright file="NoisyNodesExample.java" company="Bayes Server">
//   Copyright (C) Bayes Server.  All rights reserved.
// </copyright>
// --------------------------------------------------------------------------------------------------------------------

import com.bayesserver.*;
import com.bayesserver.inference.*;

public class NoisyNodesExample {

    public static void main(String[] args) throws InconsistentEvidenceException {

        // In this example we programatically create a Bayesian network with Noisy nodes.
        // The network created can also be found within the example network (called Noisy Or) installed with Bayes Server.

        // Note that you can automatically define nodes from data using
        // classes in BayesServer.Data.Discovery,
        // and you can automatically learn the parameters using classes in
        // BayesServer.Learning.Parameters,
        // however here we build the 'Noisy OR' network manually.

        Network network = new Network("Noisy Or");

        // add the nodes (variables)

        State infectionFalse = new State("False");
        State infectionTrue = new State("True");
        Variable infection = new Variable("Infection", infectionFalse, infectionTrue);
        Node infectionNode = new Node(infection);

        State osteoFalse = new State("False");
        State osteoTrue = new State("True");
        Variable osteo = new Variable("Osteoarthritis", osteoFalse, osteoTrue);
        Node osteoNode = new Node(osteo);

        State rheumatoidFalse = new State("False");
        State rheumatoidTrue = new State("True");
        Variable rheumatoid = new Variable("Rheumatoid arthritis", rheumatoidFalse, rheumatoidTrue);
        Node rheumatoidNode = new Node(rheumatoid);

        State temperaturesFalse = new State("False");
        State temperaturesTrue = new State("True");
        Variable temperatures = new Variable("Temperatures", temperaturesFalse, temperaturesTrue);
        Node temperaturesNode = new Node(temperatures);
        temperaturesNode.getDistributionOptions().setNoisyType(NoisyType.NOISY_OR_MAX);

        State soreJointsFalse = new State("False");
        State soreJointsTrue = new State("True");
        Variable soreJoints = new Variable("Sore joints", soreJointsFalse, soreJointsTrue);
        Node soreJointsNode = new Node(soreJoints);
        soreJointsNode.getDistributionOptions().setNoisyType(NoisyType.NOISY_OR_MAX);

        State nailProblemsFalse = new State("False");
        State nailProblemsTrue = new State("True");
        Variable nailProblems = new Variable("Nail problems", nailProblemsFalse, nailProblemsTrue);
        Node nailProblemsNode = new Node(nailProblems);
        nailProblemsNode.getDistributionOptions().setNoisyType(NoisyType.NOISY_OR_MAX);

        NetworkNodeCollection nodes = network.getNodes();

        nodes.add(infectionNode);
        nodes.add(osteoNode);
        nodes.add(rheumatoidNode);
        nodes.add(temperaturesNode);
        nodes.add(soreJointsNode);
        nodes.add(nailProblemsNode);


        // add some directed links

        NetworkLinkCollection links = network.getLinks();

        links.add(new Link(infectionNode, temperaturesNode));
        links.add(new Link(infectionNode, soreJointsNode));
        links.add(new Link(infectionNode, nailProblemsNode));
        links.add(new Link(osteoNode, temperaturesNode));
        links.add(new Link(osteoNode, soreJointsNode));
        links.add(new Link(osteoNode, nailProblemsNode));
        links.add(new Link(rheumatoidNode, temperaturesNode));
        links.add(new Link(rheumatoidNode, soreJointsNode));
        links.add(new Link(rheumatoidNode, nailProblemsNode));

        // All the links in this particular network have ascending causal effects
        for(Link link : links)
        {
            link.setNoisyOrder(NoisyOrder.ASCENDING);
        }

        // at this point we have fully specified the structural (graphical) specification of the Bayesian Network.

        // We must define the necessary probability distributions for each node.

        // We will setup the nodes which are not 'Noisy' first in the usual way

        {
            Table tableInfection = infectionNode.newDistribution().getTable();
            tableInfection.set(0.95, infectionFalse);
            tableInfection.set(0.05, infectionTrue);
            infectionNode.setDistribution(tableInfection);
        }

        {
            Table tableOsteoarthritis = osteoNode.newDistribution().getTable();
            tableOsteoarthritis.set(0.99, osteoFalse);
            tableOsteoarthritis.set(0.01, osteoTrue);
            osteoNode.setDistribution(tableOsteoarthritis);
        }

        {
            Table tableRheumatoidArthritis = rheumatoidNode.newDistribution().getTable();
            tableRheumatoidArthritis.set(0.9999, rheumatoidFalse);
            tableRheumatoidArthritis.set(0.0001, rheumatoidTrue);
            rheumatoidNode.setDistribution(tableRheumatoidArthritis);

        }

        // For noisy nodes, we require a distribution given each parent and an additional leak distribution.

        // Define the distribution for the 'Temperatures' node
        {
            {
                NodeDistributionKey keyInfection = new NodeDistributionKey(infectionNode);
                Table tableInfectionTemperatures = temperaturesNode.newDistribution(keyInfection).getTable();
                tableInfectionTemperatures.set(0.4, infectionTrue, temperaturesFalse);
                tableInfectionTemperatures.set(0.6, infectionTrue, temperaturesTrue);
                temperaturesNode.getDistributions().set(keyInfection, tableInfectionTemperatures);
            }

            {
                NodeDistributionKey keyOsteo = new NodeDistributionKey(osteoNode);
                Table tableOsteoTemperatures = temperaturesNode.newDistribution(keyOsteo).getTable();
                tableOsteoTemperatures.set(1.0, osteoTrue, temperaturesFalse);
                tableOsteoTemperatures.set(0.0, osteoFalse, temperaturesTrue);
                temperaturesNode.getDistributions().set(keyOsteo, tableOsteoTemperatures);
            }

            {
                NodeDistributionKey keyRheumatoid = new NodeDistributionKey(rheumatoidNode);
                Table tableRheumatoidTemperatures = temperaturesNode.newDistribution(keyRheumatoid).getTable();
                tableRheumatoidTemperatures.set(0.3, rheumatoidTrue, temperaturesFalse);
                tableRheumatoidTemperatures.set(0.7, rheumatoidTrue, temperaturesTrue);
                temperaturesNode.getDistributions().set(keyRheumatoid, tableRheumatoidTemperatures);
            }

            {
                NodeDistributionKey keyLeak = new NodeDistributionKey(temperaturesNode);
                Table tableLeakTemperatures = temperaturesNode.newDistribution(keyLeak).getTable();
                tableLeakTemperatures.set(0.9, temperaturesFalse);
                tableLeakTemperatures.set(0.1, temperaturesTrue);
                temperaturesNode.getDistributions().set(keyLeak, tableLeakTemperatures);
            }
        }

        // Define the distribution for the 'Sore Joints' node
        {
            {
                NodeDistributionKey keyInfection = new NodeDistributionKey(infectionNode);
                Table tableInfectionSoreJoints = soreJointsNode.newDistribution(keyInfection).getTable();
                tableInfectionSoreJoints.set(0.9, infectionTrue, soreJointsFalse);
                tableInfectionSoreJoints.set(0.1, infectionTrue, soreJointsTrue);
                soreJointsNode.getDistributions().set(keyInfection, tableInfectionSoreJoints);
            }

            {
                NodeDistributionKey keyOsteo = new NodeDistributionKey(osteoNode);
                Table tableOsteoSoreJoints = soreJointsNode.newDistribution(keyOsteo).getTable();
                tableOsteoSoreJoints.set(0.01, osteoTrue, soreJointsFalse);
                tableOsteoSoreJoints.set(0.99, osteoTrue, soreJointsTrue);
                soreJointsNode.getDistributions().set(keyOsteo, tableOsteoSoreJoints);
            }

            {
                NodeDistributionKey keyRheumatoid = new NodeDistributionKey(rheumatoidNode);
                Table tableRheumatoidSoreJoints = soreJointsNode.newDistribution(keyRheumatoid).getTable();
                tableRheumatoidSoreJoints.set(0.01, rheumatoidTrue, soreJointsFalse);
                tableRheumatoidSoreJoints.set(0.99, rheumatoidTrue, soreJointsTrue);
                soreJointsNode.getDistributions().set(keyRheumatoid, tableRheumatoidSoreJoints);
            }

            {
                NodeDistributionKey keyLeak = new NodeDistributionKey(soreJointsNode);
                Table tableLeakSoreJoints = soreJointsNode.newDistribution(keyLeak).getTable();
                tableLeakSoreJoints.set(0.9, soreJointsFalse);
                tableLeakSoreJoints.set(0.1, soreJointsTrue);
                soreJointsNode.getDistributions().set(keyLeak, tableLeakSoreJoints);
            }
        }


        // Define the distribution for the 'Nail Problems' node
        {
            {
                NodeDistributionKey keyInfection = new NodeDistributionKey(infectionNode);
                Table tableInfectionNailProblems = nailProblemsNode.newDistribution(keyInfection).getTable();
                tableInfectionNailProblems.set(0.85, infectionTrue, nailProblemsFalse);
                tableInfectionNailProblems.set(0.15, infectionTrue, nailProblemsTrue);
                nailProblemsNode.getDistributions().set(keyInfection, tableInfectionNailProblems);
            }

            {
                NodeDistributionKey keyOsteo = new NodeDistributionKey(osteoNode);
                Table tableOsteoNailProblems = nailProblemsNode.newDistribution(keyOsteo).getTable();
                tableOsteoNailProblems.set(1.0, osteoTrue, nailProblemsFalse);
                tableOsteoNailProblems.set(0.0, osteoFalse, nailProblemsTrue);
                nailProblemsNode.getDistributions().set(keyOsteo, tableOsteoNailProblems);
            }

            {
                NodeDistributionKey keyRheumatoid = new NodeDistributionKey(rheumatoidNode);
                Table tableRheumatoidNailProblems = nailProblemsNode.newDistribution(keyRheumatoid).getTable();
                tableRheumatoidNailProblems.set(0.5, rheumatoidTrue, nailProblemsFalse);
                tableRheumatoidNailProblems.set(0.5, rheumatoidTrue, nailProblemsTrue);
                nailProblemsNode.getDistributions().set(keyRheumatoid, tableRheumatoidNailProblems);
            }

            {
                NodeDistributionKey keyLeak = new NodeDistributionKey(nailProblemsNode);
                Table tableLeakNailProblems = nailProblemsNode.newDistribution(keyLeak).getTable();
                tableLeakNailProblems.set(0.9, nailProblemsFalse);
                tableLeakNailProblems.set(0.1, nailProblemsTrue);
                nailProblemsNode.getDistributions().set(keyLeak, tableLeakNailProblems);
            }
        }

        network.validate(new ValidationOptions());


        // Setup an inference engine to allow us to make predictions

        InferenceFactory factory = new RelevanceTreeInferenceFactory();
        Inference inference = factory.createInferenceEngine(network);
        QueryOptions queryOptions = factory.createQueryOptions();
        QueryOutput queryOutputs = factory.createQueryOutput();

        // Now lets make some predictions (queries).

        QueryDistributionCollection queries = inference.getQueryDistributions();

        // Scenario A - Predict disease from Symptoms (not all symptoms have been accounted for)

        Table queryInfection = new Table(infection);
        queries.add(queryInfection);

        Table queryOsteo = new Table(osteo);
        queries.add(queryOsteo);

        Table queryRheumatoid = new Table(rheumatoid);
        queries.add(queryRheumatoid);

        Evidence evidence = inference.getEvidence();

        evidence.setState(temperaturesFalse);
        evidence.setState(soreJointsTrue);

        inference.query(queryOptions, queryOutputs);

        System.out.println("Scenario A");
        System.out.println("-----------");
        System.out.println("P(Infection=True | evidence) = " + queryInfection.get(infectionTrue));
        System.out.println("P(Osteoarthritis=True | evidence) = " + queryOsteo.get(osteoTrue));
        System.out.println("P(Rheumatoid arthritis=True | evidence) = " + queryRheumatoid.get(rheumatoidTrue));
        System.out.println();

        // Scenario B - Predict disease from Symptoms having tested and ruled out certain diseases

        evidence.setState(infectionFalse);

        inference.query(queryOptions, queryOutputs);

        System.out.println("Scenario B");
        System.out.println("-----------");
        System.out.println("P(Osteoarthritis=True | evidence) = " + queryOsteo.get(osteoTrue));
        System.out.println("P(Rheumatoid arthritis=True | evidence) = " + queryRheumatoid.get(rheumatoidTrue));
        System.out.println();

        // Scenario C - Also predict Symptom given other Symptoms having ruled out certain diseases

        Table queryNailProblems = new Table(nailProblems);
        queries.add(queryNailProblems);

        inference.query(queryOptions, queryOutputs);

        System.out.println("Scenario C");
        System.out.println("-----------");
        System.out.println("P(Nail Problem=True | evidence) = " + queryNailProblems.get(nailProblemsTrue));
        System.out.println();


    }
}