Cross validation C#

// --------------------------------------------------------------------------------------------------------------------
// <copyright file="Node.cs" company="Bayes Server">
//   Copyright (C) Bayes Server.  All rights reserved.
// </copyright>
// --------------------------------------------------------------------------------------------------------------------

namespace BayesServer.HelpSamples
{
    using BayesServer.Analysis;
    using BayesServer.Data;
    using BayesServer.Inference;
    using BayesServer.Inference.RelevanceTree;
    using BayesServer.Learning.Parameters;
    using System;
    using System.Collections.Generic;
    using System.Data;
    using System.Linq;
    using System.Text;

    public static class CrossValidationExample
    {
        public static void Main()
        {
            var network = LoadNetwork();
            var factory = new RelevanceTreeInferenceFactory();

            var result = Score(
                network,
                3,
                factory,
                (networkCopy, dataPartitioning) => NewEvidenceReaderCommand(networkCopy, dataPartitioning),
                (networkCopy, evidenceReaderCommand) => learn(networkCopy, evidenceReaderCommand, factory));

            Console.WriteLine("Cross validated Log-likelihood: " + result.Score);


        }

        private static void learn(Network network, IEvidenceReaderCommand evidenceReaderCommand, IInferenceFactory inferenceFactory)
        {
            var options = new ParameterLearningOptions();
            var parameterLearning = new ParameterLearning(network, inferenceFactory);

            parameterLearning.Learn(evidenceReaderCommand, options);
        }

        /// <summary>
        /// Uses Cross Validation to calculate the log-likelihood over the entire data set.
        /// </summary>
        /// <param name="network">The Bayesian network.</param>
        /// <param name="partitions">The number of cross validation partitions to use.</param>
        /// <param name="inferenceFactory">A factory which can create inference engines.</param>
        /// <param name="createEvidenceReaderCommand">A method which creates an evidence reader command</param>
        /// <param name="learn">A method which learns the parameters of a Bayesian network.</param>
        /// <returns>The combined log-likelihood score.</returns>
        private static ICrossValidationScore Score(
            Network network,
            int partitions,
            IInferenceFactory inferenceFactory,
            Func<Network, DataPartitioning, IEvidenceReaderCommand> createEvidenceReaderCommand,
            Action<Network, IEvidenceReaderCommand> learn)
        {
            // Note: in this example we are only calculating a single metric (log-likelihood)
            // but you can calculate multiple metrics also

            var metricScores = CrossValidation.kFold(
                partitionCount: partitions,
                testMetricCount: 1,  // log-likelihood
                learn: (trainingPartitioning) =>
                {
                    var networkCopy = network.Copy();
                    var evidenceReaderCommand = createEvidenceReaderCommand(networkCopy, trainingPartitioning);
                    learn(networkCopy, evidenceReaderCommand);
                    return new CrossValidationNetwork(networkCopy);
                }
                ,
                test: (testPartitioning, crossValidationNetwork) =>
                {
                    var inference = inferenceFactory.CreateInferenceEngine(crossValidationNetwork.Network);
                    var queryOptions = inferenceFactory.CreateQueryOptions();
                    queryOptions.LogLikelihood = true;
                    var queryOutput = inferenceFactory.CreateQueryOutput();

                    var sumLogLikelihood = 0.0;
                    var weightedCaseCount = 0.0;

                    var evidenceReaderCommand = createEvidenceReaderCommand(crossValidationNetwork.Network, testPartitioning);
                    var evidenceReader = evidenceReaderCommand.ExecuteReader();

                    try
                    {
                        var readOptions = new ReadOptions();

                        while (evidenceReader.Read(inference.Evidence, readOptions))
                        {
                            inference.Query(queryOptions, queryOutput);
                            weightedCaseCount += inference.Evidence.Weight;
                            sumLogLikelihood += queryOutput.LogLikelihood.Value;
                        }
                    }
                    finally
                    {
                        evidenceReader.Dispose();
                    }


                    var testResults = new ICrossValidationTestResult[1];
                    testResults[0] = new CrossValidationTestResult(weightedCaseCount, sumLogLikelihood, sumLogLikelihood);
                    return testResults;
                }
                ,
                combine: (metric, testResults) =>
                {
                    // IMPORTANT: CombineMethod will depend on the metric.
                    // In this example we are using log-likelihoods which already incorporate the weight of each case,
                    // so we can simple use UnweightedSum.  However for most metrics (e.g. MAE) WeightedSum is
                    // appropriate.  For R Squared, you can use the RSquared combine method.

                    var combineMethod = CrossValidationCombineMethod.UnweightedSum;

                    double score = CrossValidation.Combine(
                        testResults,
                        combineMethod
                        );    

                    return new CrossValidationScore(score);
                }
                );

            // In this example we are only calculating log-likelihood, so
            // we should have a single metric returned
            if (metricScores.Length != 1)
                throw new InvalidOperationException();

            return metricScores[0];
        }

        /// <summary>
        /// Creates a trivial network to use in the example.
        /// </summary>
        /// <remarks>
        /// Instead of creating a network manually, you could simple use network.Load instead.</remarks>
        /// <returns>A simple Bayesian networks.</returns>
        private static Network LoadNetwork()
        {
            var network = new Network();

            // network.Load("...");

            var nodeA = new Node("A", new string[] { "False", "True" });
            network.Nodes.Add(nodeA);
            var nodeB = new Node("B", new string[] { "False", "True" });
            network.Nodes.Add(nodeB);

            network.Links.Add(new Link(nodeA, nodeB));

            return network;
        }

        private static IEvidenceReaderCommand NewEvidenceReaderCommand(Network network, DataPartitioning dataPartitioning)
        {
            var dataReaderCommand = NewDataReaderCommand(dataPartitioning);

            // all of our variables in this example are discrete and text based so map by name
            var variableReferences = network.Variables.Select(v => new VariableReference(v, ColumnValueType.Name, v.Name)).ToArray();

            return new EvidenceReaderCommand(
                dataReaderCommand,
                variableReferences,
                new ReaderOptions()
                );

        }

        /// <summary>
        /// Loads data for a given data partitioning.
        /// </summary>
        /// <param name="dataPartitioning"></param>
        /// <returns>A command that can read data multiple times</returns>
        private static IDataReaderCommand NewDataReaderCommand(DataPartitioning dataPartitioning)
        {
            // IMPORTANT: Here we are manually mocking some data to keep the example self contained,
            // but normally you would construct a Sql statement based on the data-partitioning and use a DatabaseDataReaderCommand,
            // or in R, Python, or Spark you could filter a data frame based on the partitioning.

            var table = new DataTable();
            table.Columns.Add("A", typeof(string));
            table.Columns.Add("B", typeof(string));

            var method = dataPartitioning.Method;
            var partition = dataPartitioning.PartitionNumber;

            if (IncludeData(0, partition, method))
            {
                table.Rows.Add("False", "True");
                table.Rows.Add("True", "True");
                table.Rows.Add("True", "True");
                table.Rows.Add("False", "False");
                table.Rows.Add("False", "True");
                table.Rows.Add("True", "False");
                table.Rows.Add("True", "True");
            }

            if (IncludeData(1, partition, method))
            {
                table.Rows.Add("True", "False");
                table.Rows.Add("True", "False");
                table.Rows.Add("False", "False");
                table.Rows.Add("True", "False");
                table.Rows.Add("False", "True");
                table.Rows.Add("False", "False");
            }

            if (IncludeData(2, partition, method))
            {
                table.Rows.Add("True", "True");
                table.Rows.Add("True", "True");
                table.Rows.Add("True", "True");
                table.Rows.Add("False", "True");
                table.Rows.Add("True", "True");
                table.Rows.Add("False", "False");
            }

            Console.WriteLine("Method = {0}, partition = {1}, count = {2}", method, partition, table.Rows.Count);

            return new DataTableDataReaderCommand(table);
        }

        private static bool IncludeData(int sourcePartition, int currentPartition, DataPartitionMethod method)
        {
            switch (method)
            {
                case DataPartitionMethod.IncludePartitionData:
                    return sourcePartition == currentPartition;
                case DataPartitionMethod.ExcludePartitionData:
                    return sourcePartition != currentPartition;
                default:
                    throw new InvalidOperationException();
            }

        }

    }
}