Skip to main content

Mixture model (learning) with Apache Spark

// --------------------------------------------------------------------------------------------------------------------
// <copyright file="MixtureModel.scala" company="Bayes Server">
// Copyright (C) Bayes Server. All rights reserved.
// </copyright>
// --------------------------------------------------------------------------------------------------------------------

package com.bayesserver.spark.examples.parameterlearning

import org.apache.spark.{SparkContext}
import com.bayesserver._
import com.bayesserver.learning.parameters.{InitializationMethod, ParameterLearningOptions, ParameterLearning}
import com.bayesserver.inference.Evidence
import com.bayesserver.spark.core.{MemoryNameValues, BroadcastNameValues, BayesSparkDistributer, IteratorEvidenceReader}

/**
* Example that learns the parameters of a mixture model, also known as a probabilistic cluster model.
*/
object MixtureModel {

/**
* @param sc The SparkContext (see Apache Spark user guide for help on this)
*/
def apply(sc: SparkContext, licenseKey: Option[String] = None) = {

licenseKey.foreach(key => License.validate(key))

// hard code some test data. Normally you would read data from your cluster.
val data = createRDD(sc).cache()

// A network could be loaded from a file or stream
// we create it manually here to keep the example self contained
val network = createNetwork

val parameterLearningOptions = new ParameterLearningOptions

// Bayes Server supports multi-threaded learning
// which we want to turn off as Spark takes care of this
parameterLearningOptions.setMaximumConcurrency(1)
parameterLearningOptions.getInitialization.setMethod(InitializationMethod.CLUSTERING)
/// parameterLearningOptions.setMaximumIterations(...) // this can be useful to limit the number of iterations

val driverToWorker = new BroadcastNameValues(sc)

val output = ParameterLearning.learnDistributed(network, parameterLearningOptions,
new BayesSparkDistributer[(Double, Double)](
data,
driverToWorker,
(ctx, iterator) => new MixtureModelEvidenceReader(ctx.getNetwork, iterator),
licenseKey
))

// we could now call network.save(...) to a file or stream
// and the file could be opened in the Bayes Server User Interface

println("Mixture model parameter learning complete")
println("Case count = " + output.getCaseCount)
println("Log-likelihood = " + output.getLogLikelihood)
println("Converged = " + output.getConverged)
println("Iterations = " + output.getIterationCount)
}


/**
* Some test data. Normally you would load the data from the cluster.
*
* We have hard coded it here to keep the example self contained.
* @return An RDD
*/
def createRDD(sc: SparkContext) = {

sc.parallelize(Seq(
(0.176502224, 7.640580199),
(1.308020831, 8.321963251),
(7.841271129, 3.34044587),
(2.623799516, 6.667664279),
(8.617288623, 3.319091539),
(0.292639161, 9.070469416),
(1.717525934, 6.509707265),
(0.347388367, 9.144193334),
(4.332228381, 0.129103276),
(0.550570479, 9.925610034),
(10.18819907, 3.414009144),
(9.796154937, 4.335498562),
(4.492011746, 0.527572356),
(8.793496377, 3.811848391),
(0.479689038, 8.041976487),
(0.460045193, 10.74481444),
(3.249955813, 5.58667984),
(1.677468832, 8.742639202),
(2.567398263, 3.338528008),
(8.507535409, 3.358378353),
(8.863647208, 3.533757566),
(-0.612339597, 11.27289689),
(10.38075113, 3.657256133),
(9.443691262, 3.561824026),
(1.589644185, 7.936062309),
(7.680055137, 2.541577306),
(1.047477704, 6.382052946),
(0.735659679, 8.029083014),
(0.489446685, 11.40715477),
(3.258072314, 1.451124598),
(0.140278917, 7.78885888),
(9.237538442, 2.647543473),
(2.28453948, 5.836716478),
(7.22011534, 1.51979264),
(1.474811913, 1.942052919),
(1.674889251, 5.601765101),
(1.30742068, 6.137114076),
(6.957133145, 3.957540541),
(10.87472856, 5.598949484),
(1.110499364, 9.241584372),
(7.233905739, 2.322237847),
(7.474329505, 2.920099189),
(0.455631413, 7.356350266),
(1.234318558, 6.592203772),
(10.72837103, 5.371838788),
(0.655168407, 6.713544957),
(2.001307579, 5.30283356),
(0.061834893, 2.071499561),
(1.86460938, 6.013710897)
))
}

/**
* Create a network in code. An existing network could also be read from file or stream using Network.load.
* @return A Bayes Server network.
*/
def createNetwork = {

val network = new Network

val mixture = new Node("Mixture", 2)
network.getNodes.add(mixture)

val gaussian = new Node()
gaussian.setName("Gaussian")
val x = new Variable("X", VariableValueType.CONTINUOUS)
gaussian.getVariables.add(x)
val y = new Variable("Y", VariableValueType.CONTINUOUS)
gaussian.getVariables.add(y)
network.getNodes.add(gaussian)

network.getLinks.add(new Link(mixture, gaussian))

network
}

/**
* Implements the Bayes Server EvidenceReader interface, for reading our data.
* @param network The network
* @param iterator The iterator, which will be generated by RDD.mapPartitions.
*/
class MixtureModelEvidenceReader(val network: Network, val iterator: Iterator[(Double, Double)])
extends IteratorEvidenceReader[(Double, Double)] {

val x = network.getVariables.get("X")
val y = network.getVariables.get("Y")

override def setEvidence(item: (Double, Double), evidence: Evidence): Unit = {

evidence.set(x, item._1)
evidence.set(y, item._2)
}
}
}