Skip to main content

Common code for using Bayes Server and Apache Spark

// --------------------------------------------------------------------------------------------------------------------
// <copyright file="BayesSparkDistributer.scala" company="Bayes Server">
// Copyright (C) Bayes Server. All rights reserved.
// </copyright>
// <version>0.13</version>
// <dependencies>
// <dependency>
// <name>bayes-server</name>
// <version>7.x</version>
// </dependency>
// <dependency>
// <name>spark-core</name>
// <version>1.x.x</version>
// </dependency>
// <dependency>
// <name>scala</name>
// <version>2.10.4</version>
// </dependency>
// </dependencies>
// --------------------------------------------------------------------------------------------------------------------

package com.bayesserver.spark.core


import java.util.zip.{GZIPInputStream, GZIPOutputStream}

import com.bayesserver._
import com.bayesserver.inference._
import com.bayesserver.data.DefaultReadOptions
import org.apache.spark.SparkContext
import org.apache.spark.broadcast.Broadcast
import org.slf4j.LoggerFactory
import scala.util.Try

import com.bayesserver.learning.parameters._
import com.bayesserver.Network
import com.bayesserver.Table
import com.bayesserver.VariableValueType
import org.apache.spark.rdd.RDD
import com.bayesserver.data._
import com.bayesserver.data.distributed._ // Remove this line if using Bayes Server 6.x
import com.bayesserver.{WriteStreamAction, NameValuesReader, NameValuesWriter, Distributer}
import scala.collection.JavaConversions
import java.io._
import scala.language.implicitConversions

/**
* Implements the Bayes Server Distributer interface for distributed parameter learning of Bayesian networks.
*
* To use, call com.bayesserver.learning.parameters.ParameterLearning.learnDistributed(...)
*
* @param data The spark RDD.
* @param driverToWorker The configuration name value pairs to be passed to the mappers and reducers
* @param newEvidenceReader A factory method for creating evidence readers when required
*/
class BayesSparkDistributer[T]
(
val data: RDD[T],
val driverToWorker: NameValuesReader with NameValuesWriter,
val newEvidenceReader: (DistributedMapperContext, Iterator[T]) => EvidenceReader,
val licenseKey: Option[String] = None,
val workerToDriver: () => NameValuesReader with NameValuesWriter = () => new MemoryNameValues()

) extends Distributer[DistributerContext] with Serializable {

private val logger = LoggerFactory.getLogger(classOf[BayesSparkDistributer[T]])

/**
* @inheritdoc
*/
override def getConfiguration: NameValuesWriter = this.driverToWorker

/**
* @inheritdoc
*/
override def distribute(ctx: DistributerContext): NameValuesReader = {

logger.info("Distributer stage: " + ctx.getName)

// We use mapPartitions as the Bayes Server library reads/pulls the data via its EvidenceReader interface
// The calls to Bayes Server are not inline as the code is not serializable which is required by Spark

data
.mapPartitions(iterator => {

// this code executes on the worker nodes, so a license will not yet have been validated
licenseKey.foreach(s => License.validate(s))

new Mapper(driverToWorker, workerToDriver).call(
iterator,
newEvidenceReader
)
})
.reduce((a, b) => {

// this code executes on the worker nodes, so a license will not yet have been validated
licenseKey.foreach(s => License.validate(s))

new Reducer(driverToWorker).call(a, b, workerToDriver()) // this calls the Bayes Server ParameterLearning.learnDistributedReducer method
})
}

/**
* Calls the mapper phase of Bayes Server distributed parameter learning.
* @param driverToWorker Configuration information which is required on all nodes.
* @param workerToDriver Creates a new name value store, which Bayes Server uses to pass information.
*/
class Mapper(val driverToWorker: NameValuesReader, val workerToDriver: () => NameValuesReader with NameValuesWriter) extends Serializable {

/**
* Performs the Bayes Server map operation on an iterator generated from the RDD.mapPartitions.
* @param iterator The iterator generated from RDD.mapPartitions.
* @param newEvidenceReader A factory for creating a new EvidenceReader.
* @return An iterator of outputs from the map operation.
*/
def call(
iterator: Iterator[T],
newEvidenceReader: (DistributedMapperContext, Iterator[T]) => EvidenceReader
): Iterator[NameValuesReader] = {

if (iterator.isEmpty)
return Iterator.empty // Spark RDD.mapPartitions can pass in an empty Iterator

val output = workerToDriver()

ParameterLearning.learnDistributedMapper(
new EvidencePartition[DistributedMapperContext] {
override def createEvidenceReader(ctx: DistributedMapperContext): EvidenceReader = {
newEvidenceReader(ctx, iterator)
}
},
driverToWorker,
output,
new RelevanceTreeInferenceFactory)

Iterator(output)
}
}

/**
* Calls the reducer phase of Bayes Server distributed parameter learning.
*
* @param configuration Configuration information which is required on all nodes.
*/
class Reducer(val configuration: NameValuesReader) extends Serializable {

def call(a: NameValuesReader, b: NameValuesReader, output: NameValuesReader with NameValuesWriter): NameValuesReader = {

val inputs = Iterable(a, b)

ParameterLearning.learnDistributedReducer(JavaConversions.asJavaIterable(inputs), configuration, output)

output
}
}

}

/**
* Methods to compress and decompress values from a store such as MemoryNameValues or BroadcastNameValues.
*/
object CompressedNameValues {

def write(writeStreamAction: WriteStreamAction, output: OutputStream) = {

val zipped = new GZIPOutputStream(output)
writeStreamAction.write(zipped)
zipped.finish() // required to write final bytes
}

def read(input: InputStream): InputStream = {
new GZIPInputStream(input)
}
}

/**
* An adapter than can be used to add compression to an existing store.
* MemoryNameValues and BroadcastNameValues both have compression options already.
* @param wrapped The underlying store
*/
class CompressedNameValues(val wrapped: NameValuesReader with NameValuesWriter with Serializable) extends NameValuesReader with NameValuesWriter with Serializable {

/**
* @inheritdoc
*/
def write(name: String, writeStreamAction: WriteStreamAction) = {

this.wrapped.write(name, new WriteStreamAction {
override def write(output: OutputStream) = CompressedNameValues.write(writeStreamAction, output)
})
}

/**
* @inheritdoc
*/
def contains(name: String): Boolean = this.wrapped.contains(name)

/**
* @inheritdoc
*/
override def read(name: String): InputStream = CompressedNameValues.read(this.wrapped.read(name))

}

/**
* A Spark broadcast variable based implementation of both NameValuesReader and NameValuesWriter.
* Note that values can only be written in driver code, not on the workers.
*/
class BroadcastNameValues(@transient val sc: SparkContext, val compress: Boolean = true) extends NameValuesReader with NameValuesWriter with Serializable {

val map = new scala.collection.mutable.HashMap[String, Broadcast[Array[Byte]]]

/**
* @inheritdoc
*/
def write(name: String, writeStreamAction: WriteStreamAction) {

if (sc == null)
throw new IllegalStateException("SparkContext is null. BroadcastNameValues.write is not supported on worker nodes")

val output: ByteArrayOutputStream = new ByteArrayOutputStream

compress match {
case true => CompressedNameValues.write(writeStreamAction, output)
case false => writeStreamAction.write(output)
}

this.map.put(name, sc.broadcast(output.toByteArray))
}

/**
* @inheritdoc
*/
def contains(name: String): Boolean = {
this.map.contains(name)

}

/**
* @inheritdoc
*/
override def read(name: String): InputStream = {

val value = this.map.get(name).get

compress match {
case true => CompressedNameValues.read(new ByteArrayInputStream(value.value))
case false => new ByteArrayInputStream(value.value)
}
}
}

/**
* An in memory implementation of both NameValuesReader and NameValuesWriter.
*
*/
class MemoryNameValues(val compress: Boolean = true) extends NameValuesReader with NameValuesWriter with Serializable {

val map = new scala.collection.mutable.HashMap[String, Array[Byte]]

/**
* @inheritdoc
*/
def write(name: String, writeStreamAction: WriteStreamAction) {

val output: ByteArrayOutputStream = new ByteArrayOutputStream

compress match {
case true => CompressedNameValues.write(writeStreamAction, output)
case false => writeStreamAction.write(output)
}

this.map.put(name, output.toByteArray)
}

/**
* @inheritdoc
*/
def contains(name: String): Boolean = {
this.map.contains(name)

}

/**
* @inheritdoc
*/
override def read(name: String): InputStream = {

val value = this.map.get(name).get

compress match {
case true => CompressedNameValues.read(new ByteArrayInputStream(value))
case false => new ByteArrayInputStream(value)
}
}
}

/**
* Trait which can be used to help implement the Bayes Server EvidenceReader interface.
* @tparam T The type of data contained in the RDD. This could be a class, an Array[Double] or anything else which is convenient.
*/
trait IteratorEvidenceReader[T] extends EvidenceReader {

val iterator: Iterator[T]
require(iterator != null)

if (iterator.isEmpty)
throw new UnsupportedOperationException("Iterator is empty.")

/**
* Maps information from the current RDD element to Bayes Server variables using the Bayes Server evidence instance.
* @param item The RDD element.
* @param evidence The evidence to be updated.
*/
def setEvidence(item: T, evidence: Evidence)

/**
* @inheritdoc
*/
override def read(evidence: Evidence, readOptions: ReadOptions): Boolean = {

readOption(evidence, readOptions).isDefined

}

/**
* Converts the next RDD element into evidence.
* @param evidence The evidence instance which is to be updated.
* @param readOptions Options affecting the read.
* @return The RDD element, or None if no more records.
*/
def readOption(evidence: Evidence, readOptions: ReadOptions): Option[T] = {

if (!iterator.hasNext)
return None

if (!readOptions.getCleared) {
evidence.clear()
}

val current = iterator.next()

setEvidence(current, evidence)

Some(current)

}

/**
* @inheritdoc
*/
override def close(): Unit = {}
}


object TimeMode extends Enumeration {
type TimeMode = Value

/**
* Query times are absolute and zero based.
*/
val Absolute = Value

/**
* Query times are zero based but relative to the maximum evidence time.
*/
val Relative = Value
}

import TimeMode._

case class PredictTime(time: Int, timeMode: TimeMode)

/**
* Base class for predictions.
*/
sealed abstract class PredictValue

/**
* The predicted probability of a discrete variable state. When state is not specified, the probability of the most likely state (modal) is returned
* @param variable The variable to predict.
* @param state The state to predict or null to return the probability of the most likely state (modal).
* @param time The time at which to predict. Only required for temporal variables (time series predictions).
*/
final case class PredictState(variable: String, state: Option[String] = None, time: Option[PredictTime] = None) extends PredictValue


/**
* Predicts the most likely state (modal) for discrete variables or the predicted mean for continuous variables.
* @param name The variable name.
* @param time The time at which to predict. Only required for temporal variables (time series predictions).
*/
final case class PredictVariable(name: String, time: Option[PredictTime] = None) extends PredictValue

/**
* Predicts the variance of a continuous variable.
* @param name The variable name.
* @param time The time at which to predict. Only required for temporal variables (time series predictions).
*/
final case class PredictVariance(name: String, time: Option[PredictTime] = None) extends PredictValue

/**
* Predicts the log-likelihood of the case.
*/
final case class PredictLogLikelihood() extends PredictValue

/**
* Helper to make predictions easier.
*/
object Prediction {

private final class Reader[T](
val network: Network,
val factory: InferenceFactory,
val reader: IteratorEvidenceReader[T],
val predictions: Iterable[PredictValue]
) extends Iterator[(T, Try[Seq[Double]])] {

private val inference = factory.createInferenceEngine(network)
private val queryOptions = factory.createQueryOptions()
private val queryOutput = factory.createQueryOutput()
private val readOptions = new DefaultReadOptions()

private val predictionQueries: Seq[(PredictValue, Option[Distribution], Option[PredictTime])] = addQueries(predictions.toSeq, inference, queryOptions)
private val relativePredictions = predictionQueries.filter(pv => pv._3 match {
case Some(x) => x.timeMode == Relative
case None => false
})

private var current: Option[T] = read()

/**
* Adjusts the time for queries that use relative times.
* @param shift The amount to shift the query times.
*/
private def adjustRelativeTimes(shift: Integer, plus: Boolean) = {

if (shift != null) {
this.relativePredictions.foreach(p => {
p._2.get.timeShift(if (plus) shift else -shift)
})
}
}


override def hasNext: Boolean = current.isDefined

/**
* @inheritdoc
*/
override def next(): (T, Try[Seq[Double]]) = {

require(current.isDefined)

val result = (this.current.get, Try({

val maxEvidenceTime = this.inference.getEvidence.getMaxTime

adjustRelativeTimes(maxEvidenceTime, plus = true)

this.inference.query(this.queryOptions, this.queryOutput)

adjustRelativeTimes(maxEvidenceTime, plus = false) // reset

this.predictionQueries.map({
case (prediction, query, _) =>

prediction match {

case PredictState(variableName, stateName, time) =>
val variable = network.getVariables.get(variableName, true)

stateName match {
case Some(name) => query.get.getTable.get(variable.getStates.get(name, true))
case None => query.get.getTable.getMaxValue.getValue
}

case PredictVariable(variableName, time) =>
val variable = network.getVariables.get(variableName, true)
variable.getValueType match {
case VariableValueType.DISCRETE => query.get.getTable.getMaxValue.getIndex
case VariableValueType.CONTINUOUS => query.get.asInstanceOf[CLGaussian].getMean(0, 0)
}
case PredictVariance(variable, time) => query.get.asInstanceOf[CLGaussian].getVariance(0, 0)
case PredictLogLikelihood() =>
this.queryOutput.getLogLikelihood.doubleValue()
}
})
}))



this.current = read()

result

}

/**
* Reads the next element in the underlying evidence reader.
* @return The original RDD element or None if no further elements are available in this partition.
*/
private def read(): Option[T] = reader.readOption(this.inference.getEvidence, this.readOptions)

/**
* Add queries to the inference engine to cover all predictions, but do not duplicate
* @param predictions The predictions
* @param inference The inference engine
* @return The predictions and their associated queries. Note that the same query may be used for multiple predictions.
*/
private def addQueries(predictions: Seq[PredictValue], inference: Inference, queryOptions: QueryOptions): Seq[(PredictValue, Option[Distribution], Option[PredictTime])] = {

case class VariableTime(variable: String, time: Option[PredictTime])

val variableTimes: Seq[(PredictValue, Option[VariableTime])] = for (prediction <- predictions) yield {

prediction match {

case PredictState(variableName, state, time) => (prediction, Some(VariableTime(variableName, time)))
case PredictVariance(variableName, time) => (prediction, Some(VariableTime(variableName, time)))
case PredictVariable(variableName, time) => (prediction, Some(VariableTime(variableName, time)))
case PredictLogLikelihood() => queryOptions.setLogLikelihood(true); (prediction, None)
}
}

val distinctVariableTimes = variableTimes.filter(_._2.isDefined).map(_._2.get).distinct

implicit def toJavaInt(time: Option[PredictTime]): Integer = time match {
case Some(PredictTime(t, _)) => t // Relative times will be adjusted per case
case None => null
case _ => throw new MatchError()
}

val queries = distinctVariableTimes.map(vt => {
val variable = network.getVariables.get(vt.variable, true)

variable.getValueType match {
case VariableValueType.DISCRETE => (vt, new Table(variable, vt.time).asInstanceOf[Distribution])
case VariableValueType.CONTINUOUS => (vt, new CLGaussian(variable, vt.time).asInstanceOf[Distribution])
}
})

for (query <- queries)
inference.getQueryDistributions.add(query._2)

val queryMap: Map[VariableTime, Distribution] = queries.toMap


variableTimes.map(pvt => (pvt._1, pvt._2.map(queryMap(_)), pvt._2.map(vt => vt.time).flatten))

}

}

/**
* Make predictions from data.
* @param network The trained network.
* @param data The data to make predictions on. Typically test data.
* @param predictions The predictions to be made.
* @param newReader Creates a reader for a partition. Typically the variable you are trying to predict is not mapped.
* @tparam T The RDD element type.
* @return An RDD of pairs containing the original RDD element and the predictions as Double values.
*/
def predict[T](
network: Network,
data: RDD[T],
predictions: Iterable[PredictValue],
newReader: (Network, Iterator[T]) => IteratorEvidenceReader[T],
licenseKey: Option[String] = None): RDD[(T, Try[Seq[Double]])] = {

// save the network to a string, as the Network class does not support serialization which is required by mapPartitions
val networkString = data.sparkContext.broadcast(network.saveToString())


data.mapPartitions(iterator => {

licenseKey.foreach(s => License.validate(s))

if (iterator.isEmpty)
Iterator.empty // as of Spark 1.1 mapPartitions can pass empty iterators
else {

val networkPartition = new Network
networkPartition.loadFromString(networkString.value)

// TODO allow configuration of Inference engine
new Reader(networkPartition, new RelevanceTreeInferenceFactory, newReader(networkPartition, iterator), predictions)
}
})
}
}