Skip to main content

Cross validation C#

// --------------------------------------------------------------------------------------------------------------------
// <copyright file="Node.cs" company="Bayes Server">
// Copyright (C) Bayes Server. All rights reserved.
// </copyright>
// --------------------------------------------------------------------------------------------------------------------

namespace BayesServer.HelpSamples
{
using BayesServer.Analysis;
using BayesServer.Data;
using BayesServer.Inference;
using BayesServer.Inference.RelevanceTree;
using BayesServer.Learning.Parameters;
using System;
using System.Collections.Generic;
using System.Data;
using System.Linq;
using System.Text;

public static class CrossValidationExample
{
public static void Main()
{
var network = LoadNetwork();
var factory = new RelevanceTreeInferenceFactory();

var result = Score(
network,
3,
factory,
(networkCopy, dataPartitioning) => NewEvidenceReaderCommand(networkCopy, dataPartitioning),
(networkCopy, evidenceReaderCommand) => learn(networkCopy, evidenceReaderCommand, factory));

Console.WriteLine("Cross validated Log-likelihood: " + result.Score);


}

private static void learn(Network network, IEvidenceReaderCommand evidenceReaderCommand, IInferenceFactory inferenceFactory)
{
var options = new ParameterLearningOptions();
var parameterLearning = new ParameterLearning(network, inferenceFactory);

parameterLearning.Learn(evidenceReaderCommand, options);
}

/// <summary>
/// Uses Cross Validation to calculate the log-likelihood over the entire data set.
/// </summary>
/// <param name="network">The Bayesian network.</param>
/// <param name="partitions">The number of cross validation partitions to use.</param>
/// <param name="inferenceFactory">A factory which can create inference engines.</param>
/// <param name="createEvidenceReaderCommand">A method which creates an evidence reader command</param>
/// <param name="learn">A method which learns the parameters of a Bayesian network.</param>
/// <returns>The combined log-likelihood score.</returns>
private static ICrossValidationScore Score(
Network network,
int partitions,
IInferenceFactory inferenceFactory,
Func<Network, DataPartitioning, IEvidenceReaderCommand> createEvidenceReaderCommand,
Action<Network, IEvidenceReaderCommand> learn)
{
// Note: in this example we are only calculating a single metric (log-likelihood)
// but you can calculate multiple metrics also

var metricScores = CrossValidation.kFold(
partitionCount: partitions,
testMetricCount: 1, // log-likelihood
learn: (trainingPartitioning) =>
{
var networkCopy = network.Copy();
var evidenceReaderCommand = createEvidenceReaderCommand(networkCopy, trainingPartitioning);
learn(networkCopy, evidenceReaderCommand);
return new CrossValidationNetwork(networkCopy);
}
,
test: (testPartitioning, crossValidationNetwork) =>
{
var inference = inferenceFactory.CreateInferenceEngine(crossValidationNetwork.Network);
var queryOptions = inferenceFactory.CreateQueryOptions();
queryOptions.LogLikelihood = true;
var queryOutput = inferenceFactory.CreateQueryOutput();

var sumLogLikelihood = 0.0;
var weightedCaseCount = 0.0;

var evidenceReaderCommand = createEvidenceReaderCommand(crossValidationNetwork.Network, testPartitioning);
var evidenceReader = evidenceReaderCommand.ExecuteReader();

try
{
var readOptions = new ReadOptions();

while (evidenceReader.Read(inference.Evidence, readOptions))
{
inference.Query(queryOptions, queryOutput);
weightedCaseCount += inference.Evidence.Weight;
sumLogLikelihood += queryOutput.LogLikelihood.Value;
}
}
finally
{
evidenceReader.Dispose();
}


var testResults = new ICrossValidationTestResult[1];
testResults[0] = new CrossValidationTestResult(weightedCaseCount, sumLogLikelihood, sumLogLikelihood);
return testResults;
}
,
combine: (metric, testResults) =>
{
// IMPORTANT: CombineMethod will depend on the metric.
// In this example we are using log-likelihoods which already incorporate the weight of each case,
// so we can simple use UnweightedSum. However for most metrics (e.g. MAE) WeightedSum is
// appropriate. For R Squared, you can use the RSquared combine method.

var combineMethod = CrossValidationCombineMethod.UnweightedSum;

double score = CrossValidation.Combine(
testResults,
combineMethod
);

return new CrossValidationScore(score);
}
);

// In this example we are only calculating log-likelihood, so
// we should have a single metric returned
if (metricScores.Length != 1)
throw new InvalidOperationException();

return metricScores[0];
}

/// <summary>
/// Creates a trivial network to use in the example.
/// </summary>
/// <remarks>
/// Instead of creating a network manually, you could simple use network.Load instead.</remarks>
/// <returns>A simple Bayesian networks.</returns>
private static Network LoadNetwork()
{
var network = new Network();

// network.Load("...");

var nodeA = new Node("A", new string[] { "False", "True" });
network.Nodes.Add(nodeA);
var nodeB = new Node("B", new string[] { "False", "True" });
network.Nodes.Add(nodeB);

network.Links.Add(new Link(nodeA, nodeB));

return network;
}

private static IEvidenceReaderCommand NewEvidenceReaderCommand(Network network, DataPartitioning dataPartitioning)
{
var dataReaderCommand = NewDataReaderCommand(dataPartitioning);

// all of our variables in this example are discrete and text based so map by name
var variableReferences = network.Variables.Select(v => new VariableReference(v, ColumnValueType.Name, v.Name)).ToArray();

return new EvidenceReaderCommand(
dataReaderCommand,
variableReferences,
new ReaderOptions()
);

}

/// <summary>
/// Loads data for a given data partitioning.
/// </summary>
/// <param name="dataPartitioning"></param>
/// <returns>A command that can read data multiple times</returns>
private static IDataReaderCommand NewDataReaderCommand(DataPartitioning dataPartitioning)
{
// IMPORTANT: Here we are manually mocking some data to keep the example self contained,
// but normally you would construct a Sql statement based on the data-partitioning and use a DatabaseDataReaderCommand,
// or in R, Python, or Spark you could filter a data frame based on the partitioning.

var table = new DataTable();
table.Columns.Add("A", typeof(string));
table.Columns.Add("B", typeof(string));

var method = dataPartitioning.Method;
var partition = dataPartitioning.PartitionNumber;

if (IncludeData(0, partition, method))
{
table.Rows.Add("False", "True");
table.Rows.Add("True", "True");
table.Rows.Add("True", "True");
table.Rows.Add("False", "False");
table.Rows.Add("False", "True");
table.Rows.Add("True", "False");
table.Rows.Add("True", "True");
}

if (IncludeData(1, partition, method))
{
table.Rows.Add("True", "False");
table.Rows.Add("True", "False");
table.Rows.Add("False", "False");
table.Rows.Add("True", "False");
table.Rows.Add("False", "True");
table.Rows.Add("False", "False");
}

if (IncludeData(2, partition, method))
{
table.Rows.Add("True", "True");
table.Rows.Add("True", "True");
table.Rows.Add("True", "True");
table.Rows.Add("False", "True");
table.Rows.Add("True", "True");
table.Rows.Add("False", "False");
}

Console.WriteLine("Method = {0}, partition = {1}, count = {2}", method, partition, table.Rows.Count);

return new DataTableDataReaderCommand(table);
}

private static bool IncludeData(int sourcePartition, int currentPartition, DataPartitionMethod method)
{
switch (method)
{
case DataPartitionMethod.IncludePartitionData:
return sourcePartition == currentPartition;
case DataPartitionMethod.ExcludePartitionData:
return sourcePartition != currentPartition;
default:
throw new InvalidOperationException();
}

}

}
}