Common code for using Bayes Server and Apache Spark
package com.bayesserver.spark.core
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import com.bayesserver._
import com.bayesserver.inference._
import com.bayesserver.data.DefaultReadOptions
import org.apache.spark.SparkContext
import org.apache.spark.broadcast.Broadcast
import org.slf4j.LoggerFactory
import scala.util.Try
import com.bayesserver.learning.parameters._
import com.bayesserver.Network
import com.bayesserver.Table
import com.bayesserver.VariableValueType
import org.apache.spark.rdd.RDD
import com.bayesserver.data._
import com.bayesserver.data.distributed._
import com.bayesserver.{WriteStreamAction, NameValuesReader, NameValuesWriter, Distributer}
import scala.collection.JavaConversions
import java.io._
import scala.language.implicitConversions
class BayesSparkDistributer[T]
(
val data: RDD[T],
val driverToWorker: NameValuesReader with NameValuesWriter,
val newEvidenceReader: (DistributedMapperContext, Iterator[T]) => EvidenceReader,
val licenseKey: Option[String] = None,
val workerToDriver: () => NameValuesReader with NameValuesWriter = () => new MemoryNameValues()
) extends Distributer[DistributerContext] with Serializable {
private val logger = LoggerFactory.getLogger(classOf[BayesSparkDistributer[T]])
override def getConfiguration: NameValuesWriter = this.driverToWorker
override def distribute(ctx: DistributerContext): NameValuesReader = {
logger.info("Distributer stage: " + ctx.getName)
data
.mapPartitions(iterator => {
licenseKey.foreach(s => License.validate(s))
new Mapper(driverToWorker, workerToDriver).call(
iterator,
newEvidenceReader
)
})
.reduce((a, b) => {
licenseKey.foreach(s => License.validate(s))
new Reducer(driverToWorker).call(a, b, workerToDriver())
})
}
class Mapper(val driverToWorker: NameValuesReader, val workerToDriver: () => NameValuesReader with NameValuesWriter) extends Serializable {
def call(
iterator: Iterator[T],
newEvidenceReader: (DistributedMapperContext, Iterator[T]) => EvidenceReader
): Iterator[NameValuesReader] = {
if (iterator.isEmpty)
return Iterator.empty
val output = workerToDriver()
ParameterLearning.learnDistributedMapper(
new EvidencePartition[DistributedMapperContext] {
override def createEvidenceReader(ctx: DistributedMapperContext): EvidenceReader = {
newEvidenceReader(ctx, iterator)
}
},
driverToWorker,
output,
new RelevanceTreeInferenceFactory)
Iterator(output)
}
}
class Reducer(val configuration: NameValuesReader) extends Serializable {
def call(a: NameValuesReader, b: NameValuesReader, output: NameValuesReader with NameValuesWriter): NameValuesReader = {
val inputs = Iterable(a, b)
ParameterLearning.learnDistributedReducer(JavaConversions.asJavaIterable(inputs), configuration, output)
output
}
}
}
object CompressedNameValues {
def write(writeStreamAction: WriteStreamAction, output: OutputStream) = {
val zipped = new GZIPOutputStream(output)
writeStreamAction.write(zipped)
zipped.finish()
}
def read(input: InputStream): InputStream = {
new GZIPInputStream(input)
}
}
class CompressedNameValues(val wrapped: NameValuesReader with NameValuesWriter with Serializable) extends NameValuesReader with NameValuesWriter with Serializable {
def write(name: String, writeStreamAction: WriteStreamAction) = {
this.wrapped.write(name, new WriteStreamAction {
override def write(output: OutputStream) = CompressedNameValues.write(writeStreamAction, output)
})
}
def contains(name: String): Boolean = this.wrapped.contains(name)
override def read(name: String): InputStream = CompressedNameValues.read(this.wrapped.read(name))
}
class BroadcastNameValues(@transient val sc: SparkContext, val compress: Boolean = true) extends NameValuesReader with NameValuesWriter with Serializable {
val map = new scala.collection.mutable.HashMap[String, Broadcast[Array[Byte]]]
def write(name: String, writeStreamAction: WriteStreamAction) {
if (sc == null)
throw new IllegalStateException("SparkContext is null. BroadcastNameValues.write is not supported on worker nodes")
val output: ByteArrayOutputStream = new ByteArrayOutputStream
compress match {
case true => CompressedNameValues.write(writeStreamAction, output)
case false => writeStreamAction.write(output)
}
this.map.put(name, sc.broadcast(output.toByteArray))
}
def contains(name: String): Boolean = {
this.map.contains(name)
}
override def read(name: String): InputStream = {
val value = this.map.get(name).get
compress match {
case true => CompressedNameValues.read(new ByteArrayInputStream(value.value))
case false => new ByteArrayInputStream(value.value)
}
}
}
class MemoryNameValues(val compress: Boolean = true) extends NameValuesReader with NameValuesWriter with Serializable {
val map = new scala.collection.mutable.HashMap[String, Array[Byte]]
def write(name: String, writeStreamAction: WriteStreamAction) {
val output: ByteArrayOutputStream = new ByteArrayOutputStream
compress match {
case true => CompressedNameValues.write(writeStreamAction, output)
case false => writeStreamAction.write(output)
}
this.map.put(name, output.toByteArray)
}
def contains(name: String): Boolean = {
this.map.contains(name)
}
override def read(name: String): InputStream = {
val value = this.map.get(name).get
compress match {
case true => CompressedNameValues.read(new ByteArrayInputStream(value))
case false => new ByteArrayInputStream(value)
}
}
}
trait IteratorEvidenceReader[T] extends EvidenceReader {
val iterator: Iterator[T]
require(iterator != null)
if (iterator.isEmpty)
throw new UnsupportedOperationException("Iterator is empty.")
def setEvidence(item: T, evidence: Evidence)
override def read(evidence: Evidence, readOptions: ReadOptions): Boolean = {
readOption(evidence, readOptions).isDefined
}
def readOption(evidence: Evidence, readOptions: ReadOptions): Option[T] = {
if (!iterator.hasNext)
return None
if (!readOptions.getCleared) {
evidence.clear()
}
val current = iterator.next()
setEvidence(current, evidence)
Some(current)
}
override def close(): Unit = {}
}
object TimeMode extends Enumeration {
type TimeMode = Value
val Absolute = Value
val Relative = Value
}
import TimeMode._
case class PredictTime(time: Int, timeMode: TimeMode)
sealed abstract class PredictValue
final case class PredictState(variable: String, state: Option[String] = None, time: Option[PredictTime] = None) extends PredictValue
final case class PredictVariable(name: String, time: Option[PredictTime] = None) extends PredictValue
final case class PredictVariance(name: String, time: Option[PredictTime] = None) extends PredictValue
final case class PredictLogLikelihood() extends PredictValue
object Prediction {
private final class Reader[T](
val network: Network,
val factory: InferenceFactory,
val reader: IteratorEvidenceReader[T],
val predictions: Iterable[PredictValue]
) extends Iterator[(T, Try[Seq[Double]])] {
private val inference = factory.createInferenceEngine(network)
private val queryOptions = factory.createQueryOptions()
private val queryOutput = factory.createQueryOutput()
private val readOptions = new DefaultReadOptions()
private val predictionQueries: Seq[(PredictValue, Option[Distribution], Option[PredictTime])] = addQueries(predictions.toSeq, inference, queryOptions)
private val relativePredictions = predictionQueries.filter(pv => pv._3 match {
case Some(x) => x.timeMode == Relative
case None => false
})
private var current: Option[T] = read()
private def adjustRelativeTimes(shift: Integer, plus: Boolean) = {
if (shift != null) {
this.relativePredictions.foreach(p => {
p._2.get.timeShift(if (plus) shift else -shift)
})
}
}
override def hasNext: Boolean = current.isDefined
override def next(): (T, Try[Seq[Double]]) = {
require(current.isDefined)
val result = (this.current.get, Try({
val maxEvidenceTime = this.inference.getEvidence.getMaxTime
adjustRelativeTimes(maxEvidenceTime, plus = true)
this.inference.query(this.queryOptions, this.queryOutput)
adjustRelativeTimes(maxEvidenceTime, plus = false)
this.predictionQueries.map({
case (prediction, query, _) =>
prediction match {
case PredictState(variableName, stateName, time) =>
val variable = network.getVariables.get(variableName, true)
stateName match {
case Some(name) => query.get.getTable.get(variable.getStates.get(name, true))
case None => query.get.getTable.getMaxValue.getValue
}
case PredictVariable(variableName, time) =>
val variable = network.getVariables.get(variableName, true)
variable.getValueType match {
case VariableValueType.DISCRETE => query.get.getTable.getMaxValue.getIndex
case VariableValueType.CONTINUOUS => query.get.asInstanceOf[CLGaussian].getMean(0, 0)
}
case PredictVariance(variable, time) => query.get.asInstanceOf[CLGaussian].getVariance(0, 0)
case PredictLogLikelihood() =>
this.queryOutput.getLogLikelihood.doubleValue()
}
})
}))
this.current = read()
result
}
private def read(): Option[T] = reader.readOption(this.inference.getEvidence, this.readOptions)
private def addQueries(predictions: Seq[PredictValue], inference: Inference, queryOptions: QueryOptions): Seq[(PredictValue, Option[Distribution], Option[PredictTime])] = {
case class VariableTime(variable: String, time: Option[PredictTime])
val variableTimes: Seq[(PredictValue, Option[VariableTime])] = for (prediction <- predictions) yield {
prediction match {
case PredictState(variableName, state, time) => (prediction, Some(VariableTime(variableName, time)))
case PredictVariance(variableName, time) => (prediction, Some(VariableTime(variableName, time)))
case PredictVariable(variableName, time) => (prediction, Some(VariableTime(variableName, time)))
case PredictLogLikelihood() => queryOptions.setLogLikelihood(true); (prediction, None)
}
}
val distinctVariableTimes = variableTimes.filter(_._2.isDefined).map(_._2.get).distinct
implicit def toJavaInt(time: Option[PredictTime]): Integer = time match {
case Some(PredictTime(t, _)) => t
case None => null
case _ => throw new MatchError()
}
val queries = distinctVariableTimes.map(vt => {
val variable = network.getVariables.get(vt.variable, true)
variable.getValueType match {
case VariableValueType.DISCRETE => (vt, new Table(variable, vt.time).asInstanceOf[Distribution])
case VariableValueType.CONTINUOUS => (vt, new CLGaussian(variable, vt.time).asInstanceOf[Distribution])
}
})
for (query <- queries)
inference.getQueryDistributions.add(query._2)
val queryMap: Map[VariableTime, Distribution] = queries.toMap
variableTimes.map(pvt => (pvt._1, pvt._2.map(queryMap(_)), pvt._2.map(vt => vt.time).flatten))
}
}
def predict[T](
network: Network,
data: RDD[T],
predictions: Iterable[PredictValue],
newReader: (Network, Iterator[T]) => IteratorEvidenceReader[T],
licenseKey: Option[String] = None): RDD[(T, Try[Seq[Double]])] = {
val networkString = data.sparkContext.broadcast(network.saveToString())
data.mapPartitions(iterator => {
licenseKey.foreach(s => License.validate(s))
if (iterator.isEmpty)
Iterator.empty
else {
val networkPartition = new Network
networkPartition.loadFromString(networkString.value)
new Reader(networkPartition, new RelevanceTreeInferenceFactory, newReader(networkPartition, iterator), predictions)
}
})
}
}