Skip to main content

Parameter learning with a Bayesian network in Matlab


% Parameter learning example
%
% Copyright (C) Bayes Server. All rights reserved.


% clear java;

% TODO change to the version you are using
bayesVersion = '7.24';

% TODO update to the path of the Bayes Server jar
jarPath = ['/Users/xxx/Downloads/bayesserver-', bayesVersion, '/Java/BayesServer-', bayesVersion, '.jar'];
javaaddpath(jarPath);

import com.bayesserver.*;
import com.bayesserver.data.*;
import com.bayesserver.inference.*;
import com.bayesserver.learning.parameters.*;

%
% Example code that learns the parameters of a Bayesian network from data.
%

network = createNetworkStructure(); % we manually construct the network here, but it could be loaded from a file
x = network.getVariables().get('X', true);
y = network.getVariables().get('Y', true);

% now learn the parameters from the data in Tutorial 2 - Mixture model

T = readtable('mixture_model.csv'); % Note: Saved as csv from the Tutorial data installed with Bayes Server

% Note that data does not have to be loaded from Csv. Any Matlab table can
% be used here, or you connect to a database

dt = toDataTable(T);

% We will use the RelevanceTree algorithm here, as it is optimized for parameter learning
learning = ParameterLearning(network, RelevanceTreeInferenceFactory());
learningOptions = ParameterLearningOptions();

dataReaderCommand = DataTableDataReaderCommand(dt); % In memory, but you could also connect to a database with DatabaseDataReaderCommand

readerOptions = ReaderOptions(); % we do not have a case column in this example

% here we map variables to database columns
% in this case the variables and database columns have the same name
refX = VariableReference(x, ColumnValueType.VALUE, x.getName());
refY = VariableReference(y, ColumnValueType.VALUE, y.getName());

variableReferences = [refX, refY];

% note that although this example only has non temporal data
% we could have included additional temporal variables and data

evidenceReaderCommand = DefaultEvidenceReaderCommand(dataReaderCommand, toJavaList(variableReferences), readerOptions);

result = learning.learn(evidenceReaderCommand, learningOptions);

disp(['Log likelihood = ', num2str(double(result.getLogLikelihood()))]);


function [network] = createNetworkStructure()

import com.bayesserver.*;

network = Network();

cluster1 = State('Cluster1');
cluster2 = State('Cluster2');
cluster3 = State('Cluster3');

nodeCluster = Node('Cluster', toJavaArray([cluster1, cluster2, cluster3]));
network.getNodes().add(nodeCluster);

x = Variable('X', VariableValueType.CONTINUOUS);
y = Variable('Y', VariableValueType.CONTINUOUS);

nodePosition = Node('Position', toJavaArray([x, y]));
network.getNodes().add(nodePosition);

network.getLinks().add(Link(nodeCluster, nodePosition));

% at this point the Bayesian network structure is fully specified

end