Inference (discrete & continuous) with a Bayesian network from Excel functions

This page contains examples of how to embed predictions in Microsoft Excel.

Please also see the Setup page for Excel functions before using this example.

For a long time Bayes Server has supported reading data from Excel and writing data to Excel, but you can also call Bayes Server directly from Excel functions.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using ExcelDna.Integration;
using BayesServer;
using BayesServer.Inference.RelevanceTree;

namespace BayesServer.VBA
{
    /// <summary>
    /// Functions to perform inference using a pre-built Bayes Server model in Excel.
    /// </summary>
    public static class BayesServerInferenceFunctions
    {
        /// <summary>
        /// Basic prediction function for Excel.
        /// </summary>
        /// <param name="networkPath">The path of the Bayes Server network used to make the predictions (.bayes file).</param>
        /// <param name="data">One column for each variable, with the name of each variable in the header (first row).  Blank cells can be used for missing data.</param>
        /// <param name="targetVariable">The variable to be predicted.</param>
        /// <returns>Predictions including a header.</returns>
        [ExcelFunction(Description = "PredictVariable")]
        public static object[,] PredictVariable(
            [ExcelArgument(Description="The path to a Bayes Server network (.bayes file).")]
            string networkPath,
            [ExcelArgument(Description="Data arranged in rows and columns.  Each column is a variable, each row is a record to make predictions from.  A header row is used to specifiy variable names.")]
            object[,] data,
            [ExcelArgument(Description="The variable being predicted.")]
            string targetVariable)
        {
            if (data == null)
                throw new ArgumentNullException("data");

            if (string.IsNullOrEmpty(networkPath))
                throw new ArgumentException("network is null or empty");

            if (string.IsNullOrEmpty(targetVariable))
                throw new ArgumentException("targetVariable is null or empty");

            var rowsWithHeader = data.GetLength(0);
            var cols = data.GetLength(1);

            // We load the network and create an inference engine once.
            // these are then reused for all records in 'data'.

            var network = new Network();
            network.Load(networkPath);

            var factory = new RelevanceTreeInferenceFactory();
            var inference = factory.CreateInferenceEngine(network);
            var queryOptions = factory.CreateQueryOptions();
            var queryOutput = factory.CreateQueryOutput();

            var target = network.Variables[targetVariable, true];

            var queryTarget = target.ValueType == VariableValueType.Discrete ? (IDistribution)new Table(target) : new CLGaussian(target);
            inference.QueryDistributions.Add(queryTarget);

            var inputs = new Variable[cols];

            for(int i = 0; i < cols; i++)
            {
                inputs[i] = network.Variables[data[0, i].ToString(), true];
            }

            // data could contain:
            // Double
            // String
            // Boolean
            // ExcelDna.Integration.ExcelError
            // ExcelDna.Integration.ExcelMissing
            // ExcelDna.Integration.ExcelEmpty

            var result = new object[rowsWithHeader, 2];

            result[0, 0] = string.Format("Predict({0})", targetVariable);
            result[0, 1] = string.Format("{0}({1})",
                target.ValueType == VariableValueType.Discrete ? "PredictProbability" : "PredictVariance",
                targetVariable);


            for (int r = 1; r < rowsWithHeader; r++)
            {
                try
                {

                    // set the evidence

                    for (int i = 0; i < cols; i++)
                    {
                        var input = inputs[i];
                        var value = data[r, i];

                        if (input.StateValueType != StateValueType.None)
                        {
                            throw new NotImplementedException();
                        }

                        if (value is double)
                        {
                            inference.Evidence.Set(input, (double)value);
                        }
                        else if (value is string)
                        {
                            var state = input.States[(string)value, true];
                            inference.Evidence.SetState(state);
                        }
                        else if(value is ExcelMissing)
                        {
                            inference.Evidence.Clear(target);
                        }
                        else if(value is ExcelEmpty)
                        {
                            inference.Evidence.Clear(target);
                        }
                        else
                        {
                            throw new NotImplementedException();
                        }
                    }

                    // perform the query

                    inference.Query(queryOptions, queryOutput);

                    if (target.ValueType == VariableValueType.Discrete)
                    {
                        int stateIndex;
                        result[r, 1] = queryTarget.Table.GetMaxValue(out stateIndex);
                        result[r, 0] = target.States[stateIndex].Name;
                    }
                    else
                    {
                        var gaussian = (CLGaussian)queryTarget;
                        result[r, 0] = gaussian.GetMean(target);
                        result[r, 1] = gaussian.GetVariance(target);
                    }
                }

                catch (Exception)
                {
                    result[r, 0] = ExcelError.ExcelErrorValue;
                    result[r, 1] = ExcelError.ExcelErrorValue;
                }


            }

            return result;
        }

    }
}