Common code for using Bayes Server and Apache Spark

// --------------------------------------------------------------------------------------------------------------------
// <copyright file="BayesSparkDistributer.scala" company="Bayes Server">
//   Copyright (C) Bayes Server.  All rights reserved.
// </copyright>
// <version>0.13</version>
// <dependencies>
//  <dependency>
//    <name>bayes-server</name>
//    <version>7.x</version>
//  </dependency>
//  <dependency>
//    <name>spark-core</name>
//    <version>1.x.x</version>
//  </dependency>
//  <dependency>
//    <name>scala</name>
//    <version>2.10.4</version>
//  </dependency>
// </dependencies>
// --------------------------------------------------------------------------------------------------------------------

package com.bayesserver.spark.core


import java.util.zip.{GZIPInputStream, GZIPOutputStream}

import com.bayesserver._
import com.bayesserver.inference._
import com.bayesserver.data.DefaultReadOptions
import org.apache.spark.SparkContext
import org.apache.spark.broadcast.Broadcast
import org.slf4j.LoggerFactory
import scala.util.Try

import com.bayesserver.learning.parameters._
import com.bayesserver.Network
import com.bayesserver.Table
import com.bayesserver.VariableValueType
import org.apache.spark.rdd.RDD
import com.bayesserver.data._
import com.bayesserver.data.distributed._ // Remove this line if using Bayes Server 6.x
import com.bayesserver.{WriteStreamAction, NameValuesReader, NameValuesWriter, Distributer}
import scala.collection.JavaConversions
import java.io._
import scala.language.implicitConversions

/**
 * Implements the Bayes Server Distributer interface for distributed parameter learning of Bayesian networks.
 *
 * To use, call com.bayesserver.learning.parameters.ParameterLearning.learnDistributed(...)
 *
 * @param data The spark RDD.
 * @param driverToWorker The configuration name value pairs to be passed to the mappers and reducers
 * @param newEvidenceReader A factory method for creating evidence readers when required
 */
class BayesSparkDistributer[T]
(
  val data: RDD[T],
  val driverToWorker: NameValuesReader with NameValuesWriter,
  val newEvidenceReader: (DistributedMapperContext, Iterator[T]) => EvidenceReader,
  val licenseKey: Option[String] = None,
  val workerToDriver: () => NameValuesReader with NameValuesWriter = () => new MemoryNameValues()

  ) extends Distributer[DistributerContext] with Serializable {

    private val logger = LoggerFactory.getLogger(classOf[BayesSparkDistributer[T]])

  /**
   * @inheritdoc
   */
  override def getConfiguration: NameValuesWriter = this.driverToWorker

  /**
   * @inheritdoc
   */
  override def distribute(ctx: DistributerContext): NameValuesReader = {

    logger.info("Distributer stage: " + ctx.getName)

    // We use mapPartitions as the Bayes Server library reads/pulls the data via its EvidenceReader interface
    // The calls to Bayes Server are not inline as the code is not serializable which is required by Spark

    data
      .mapPartitions(iterator => {

      // this code executes on the worker nodes, so a license will not yet have been validated
      licenseKey.foreach(s => License.validate(s))

      new Mapper(driverToWorker, workerToDriver).call(
        iterator,
        newEvidenceReader
      )
    })
      .reduce((a, b) => {

      // this code executes on the worker nodes, so a license will not yet have been validated
      licenseKey.foreach(s => License.validate(s))

      new Reducer(driverToWorker).call(a, b, workerToDriver()) // this calls the Bayes Server ParameterLearning.learnDistributedReducer method
    })
  }

  /**
   * Calls the mapper phase of Bayes Server distributed parameter learning.
   * @param driverToWorker Configuration information which is required on all nodes.
   * @param workerToDriver Creates a new name value store, which Bayes Server uses to pass information.
   */
  class Mapper(val driverToWorker: NameValuesReader, val workerToDriver: () => NameValuesReader with NameValuesWriter) extends Serializable {

    /**
     * Performs the Bayes Server map operation on an iterator generated from the RDD.mapPartitions.
     * @param iterator The iterator generated from RDD.mapPartitions.
     * @param newEvidenceReader A factory for creating a new EvidenceReader.
     * @return An iterator of outputs from the map operation.
     */
    def call(
              iterator: Iterator[T],
              newEvidenceReader: (DistributedMapperContext, Iterator[T]) => EvidenceReader
              ): Iterator[NameValuesReader] = {

      if (iterator.isEmpty)
        return Iterator.empty // Spark RDD.mapPartitions can pass in an empty Iterator

      val output = workerToDriver()

      ParameterLearning.learnDistributedMapper(
        new EvidencePartition[DistributedMapperContext] {
          override def createEvidenceReader(ctx: DistributedMapperContext): EvidenceReader = {
            newEvidenceReader(ctx, iterator)
          }
        },
        driverToWorker,
        output,
        new RelevanceTreeInferenceFactory)

      Iterator(output)
    }
  }

  /**
   * Calls the reducer phase of Bayes Server distributed parameter learning.
   *
   * @param configuration Configuration information which is required on all nodes.
   */
  class Reducer(val configuration: NameValuesReader) extends Serializable {

    def call(a: NameValuesReader, b: NameValuesReader, output: NameValuesReader with NameValuesWriter): NameValuesReader = {

      val inputs = Iterable(a, b)

      ParameterLearning.learnDistributedReducer(JavaConversions.asJavaIterable(inputs), configuration, output)

      output
    }
  }

}

/**
 * Methods to compress and decompress values from a store such as MemoryNameValues or BroadcastNameValues.
 */
object CompressedNameValues {

  def write(writeStreamAction: WriteStreamAction, output: OutputStream) = {

    val zipped = new GZIPOutputStream(output)
    writeStreamAction.write(zipped)
    zipped.finish() // required to write final bytes
  }

  def read(input: InputStream): InputStream = {
    new GZIPInputStream(input)
  }
}

/**
 * An adapter than can be used to add compression to an existing store.
 * MemoryNameValues and BroadcastNameValues both have compression options already.
 * @param wrapped The underlying store
 */
class CompressedNameValues(val wrapped: NameValuesReader with NameValuesWriter with Serializable) extends NameValuesReader with NameValuesWriter with Serializable {

  /**
   * @inheritdoc
   */
  def write(name: String, writeStreamAction: WriteStreamAction) = {

    this.wrapped.write(name, new WriteStreamAction {
      override def write(output: OutputStream) = CompressedNameValues.write(writeStreamAction, output)
    })
  }

  /**
   * @inheritdoc
   */
  def contains(name: String): Boolean = this.wrapped.contains(name)

  /**
   * @inheritdoc
   */
  override def read(name: String): InputStream = CompressedNameValues.read(this.wrapped.read(name))

}

/**
 * A Spark broadcast variable based implementation of both NameValuesReader and NameValuesWriter.
 * Note that values can only be written in driver code, not on the workers.
 */
class BroadcastNameValues(@transient val sc: SparkContext, val compress: Boolean = true) extends NameValuesReader with NameValuesWriter with Serializable {

  val map = new scala.collection.mutable.HashMap[String, Broadcast[Array[Byte]]]

  /**
   * @inheritdoc
   */
  def write(name: String, writeStreamAction: WriteStreamAction) {

    if (sc == null)
      throw new IllegalStateException("SparkContext is null.  BroadcastNameValues.write is not supported on worker nodes")

    val output: ByteArrayOutputStream = new ByteArrayOutputStream

    compress match {
      case true => CompressedNameValues.write(writeStreamAction, output)
      case false => writeStreamAction.write(output)
    }

    this.map.put(name, sc.broadcast(output.toByteArray))
  }

  /**
   * @inheritdoc
   */
  def contains(name: String): Boolean = {
    this.map.contains(name)

  }

  /**
   * @inheritdoc
   */
  override def read(name: String): InputStream = {

    val value = this.map.get(name).get

    compress match {
      case true => CompressedNameValues.read(new ByteArrayInputStream(value.value))
      case false => new ByteArrayInputStream(value.value)
    }
  }
}

/**
 * An in memory implementation of both NameValuesReader and NameValuesWriter.
 *
 */
class MemoryNameValues(val compress: Boolean = true) extends NameValuesReader with NameValuesWriter with Serializable {

  val map = new scala.collection.mutable.HashMap[String, Array[Byte]]

  /**
   * @inheritdoc
   */
  def write(name: String, writeStreamAction: WriteStreamAction) {

    val output: ByteArrayOutputStream = new ByteArrayOutputStream

    compress match {
      case true => CompressedNameValues.write(writeStreamAction, output)
      case false => writeStreamAction.write(output)
    }

    this.map.put(name, output.toByteArray)
  }

  /**
   * @inheritdoc
   */
  def contains(name: String): Boolean = {
    this.map.contains(name)

  }

  /**
   * @inheritdoc
   */
  override def read(name: String): InputStream = {

    val value = this.map.get(name).get

    compress match {
      case true => CompressedNameValues.read(new ByteArrayInputStream(value))
      case false => new ByteArrayInputStream(value)
    }
  }
}

/**
 * Trait which can be used to help implement the Bayes Server EvidenceReader interface.
 * @tparam T The type of data contained in the RDD.  This could be a class, an Array[Double] or anything else which is convenient.
 */
trait IteratorEvidenceReader[T] extends EvidenceReader {

  val iterator: Iterator[T]
  require(iterator != null)

  if (iterator.isEmpty)
    throw new UnsupportedOperationException("Iterator is empty.")

  /**
   * Maps information from the current RDD element to Bayes Server variables using the Bayes Server evidence instance.
   * @param item The RDD element.
   * @param evidence The evidence to be updated.
   */
  def setEvidence(item: T, evidence: Evidence)

  /**
   * @inheritdoc
   */
  override def read(evidence: Evidence, readOptions: ReadOptions): Boolean = {

    readOption(evidence, readOptions).isDefined

  }

  /**
   * Converts the next RDD element into evidence.
   * @param evidence The evidence instance which is to be updated.
   * @param readOptions Options affecting the read.
   * @return The RDD element, or None if no more records.
   */
  def readOption(evidence: Evidence, readOptions: ReadOptions): Option[T] = {

    if (!iterator.hasNext)
      return None

    if (!readOptions.getCleared) {
      evidence.clear()
    }

    val current = iterator.next()

    setEvidence(current, evidence)

    Some(current)

  }

  /**
   * @inheritdoc
   */
  override def close(): Unit = {}
}


object TimeMode extends Enumeration {
  type TimeMode = Value

  /**
   * Query times are absolute and zero based.
   */
  val Absolute = Value

  /**
   * Query times are zero based but relative to the maximum evidence time.
   */
  val Relative = Value
}

import TimeMode._

case class PredictTime(time: Int, timeMode: TimeMode)

/**
 * Base class for predictions.
 */
sealed abstract class PredictValue

/**
 * The predicted probability of a discrete variable state.  When state is not specified, the probability of the most likely state (modal) is returned
 * @param variable The variable to predict.
 * @param state The state to predict or null to return the probability of the most likely state (modal).
 * @param time The time at which to predict.  Only required for temporal variables (time series predictions).
 */
final case class PredictState(variable: String, state: Option[String] = None, time: Option[PredictTime] = None) extends PredictValue


/**
 * Predicts the most likely state (modal) for discrete variables or the predicted mean for continuous variables.
 * @param name The variable name.
 * @param time The time at which to predict.  Only required for temporal variables (time series predictions).
 */
final case class PredictVariable(name: String, time: Option[PredictTime] = None) extends PredictValue

/**
 * Predicts the variance of a continuous variable.
 * @param name The variable name.
 * @param time The time at which to predict.  Only required for temporal variables (time series predictions).
 */
final case class PredictVariance(name: String, time: Option[PredictTime] = None) extends PredictValue

/**
 * Predicts the log-likelihood of the case.
 */
final case class PredictLogLikelihood() extends PredictValue

/**
 * Helper to make predictions easier.
 */
object Prediction {

  private final class Reader[T](
                                 val network: Network,
                                 val factory: InferenceFactory,
                                 val reader: IteratorEvidenceReader[T],
                                 val predictions: Iterable[PredictValue]
                                 ) extends Iterator[(T, Try[Seq[Double]])] {

    private val inference = factory.createInferenceEngine(network)
    private val queryOptions = factory.createQueryOptions()
    private val queryOutput = factory.createQueryOutput()
    private val readOptions = new DefaultReadOptions()

    private val predictionQueries: Seq[(PredictValue, Option[Distribution], Option[PredictTime])] = addQueries(predictions.toSeq, inference, queryOptions)
    private val relativePredictions = predictionQueries.filter(pv => pv._3 match {
      case Some(x) => x.timeMode == Relative
      case None => false
    })

    private var current: Option[T] = read()

    /**
     * Adjusts the time for queries that use relative times.
     * @param shift The amount to shift the query times.
     */
    private def adjustRelativeTimes(shift: Integer, plus: Boolean) = {

      if (shift != null) {
        this.relativePredictions.foreach(p => {
          p._2.get.timeShift(if (plus) shift else -shift)
        })
      }
    }


    override def hasNext: Boolean = current.isDefined

    /**
     * @inheritdoc
     */
    override def next(): (T, Try[Seq[Double]]) = {

      require(current.isDefined)

      val result = (this.current.get, Try({

        val maxEvidenceTime = this.inference.getEvidence.getMaxTime

        adjustRelativeTimes(maxEvidenceTime, plus = true)

        this.inference.query(this.queryOptions, this.queryOutput)

        adjustRelativeTimes(maxEvidenceTime, plus = false) // reset

        this.predictionQueries.map({
          case (prediction, query, _) =>

            prediction match {

              case PredictState(variableName, stateName, time) =>
                val variable = network.getVariables.get(variableName, true)

                stateName match {
                  case Some(name) => query.get.getTable.get(variable.getStates.get(name, true))
                  case None => query.get.getTable.getMaxValue.getValue
                }

              case PredictVariable(variableName, time) =>
                val variable = network.getVariables.get(variableName, true)
                variable.getValueType match {
                  case VariableValueType.DISCRETE => query.get.getTable.getMaxValue.getIndex
                  case VariableValueType.CONTINUOUS => query.get.asInstanceOf[CLGaussian].getMean(0, 0)
                }
              case PredictVariance(variable, time) => query.get.asInstanceOf[CLGaussian].getVariance(0, 0)
              case PredictLogLikelihood() =>
                this.queryOutput.getLogLikelihood.doubleValue()
            }
        })
      }))



      this.current = read()

      result

    }

    /**
     * Reads the next element in the underlying evidence reader.
     * @return The original RDD element or None if no further elements are available in this partition.
     */
    private def read(): Option[T] = reader.readOption(this.inference.getEvidence, this.readOptions)

    /**
     * Add queries to the inference engine to cover all predictions, but do not duplicate
     * @param predictions The predictions
     * @param inference The inference engine
     * @return The predictions and their associated queries.  Note that the same query may be used for multiple predictions.
     */
    private def addQueries(predictions: Seq[PredictValue], inference: Inference, queryOptions: QueryOptions): Seq[(PredictValue, Option[Distribution], Option[PredictTime])] = {

      case class VariableTime(variable: String, time: Option[PredictTime])

      val variableTimes: Seq[(PredictValue, Option[VariableTime])] = for (prediction <- predictions) yield {

        prediction match {

          case PredictState(variableName, state, time) => (prediction, Some(VariableTime(variableName, time)))
          case PredictVariance(variableName, time) => (prediction, Some(VariableTime(variableName, time)))
          case PredictVariable(variableName, time) => (prediction, Some(VariableTime(variableName, time)))
          case PredictLogLikelihood() => queryOptions.setLogLikelihood(true); (prediction, None)
        }
      }

      val distinctVariableTimes = variableTimes.filter(_._2.isDefined).map(_._2.get).distinct

      implicit def toJavaInt(time: Option[PredictTime]): Integer = time match {
        case Some(PredictTime(t, _)) => t // Relative times will be adjusted per case
        case None => null
        case _ => throw new MatchError()
      }

      val queries = distinctVariableTimes.map(vt => {
        val variable = network.getVariables.get(vt.variable, true)

        variable.getValueType match {
          case VariableValueType.DISCRETE => (vt, new Table(variable, vt.time).asInstanceOf[Distribution])
          case VariableValueType.CONTINUOUS => (vt, new CLGaussian(variable, vt.time).asInstanceOf[Distribution])
        }
      })

      for (query <- queries)
        inference.getQueryDistributions.add(query._2)

      val queryMap: Map[VariableTime, Distribution] = queries.toMap


      variableTimes.map(pvt => (pvt._1, pvt._2.map(queryMap(_)), pvt._2.map(vt => vt.time).flatten))

    }

  }

  /**
   * Make predictions from data.
   * @param network The trained network.
   * @param data The data to make predictions on.  Typically test data.
   * @param predictions The predictions to be made.
   * @param newReader Creates a reader for a partition.  Typically the variable you are trying to predict is not mapped.
   * @tparam T The RDD element type.
   * @return An RDD of pairs containing the original RDD element and the predictions as Double values.
   */
  def predict[T](
                  network: Network,
                  data: RDD[T],
                  predictions: Iterable[PredictValue],
                  newReader: (Network, Iterator[T]) => IteratorEvidenceReader[T],
                  licenseKey: Option[String] = None): RDD[(T, Try[Seq[Double]])] = {

    // save the network to a string, as the Network class does not support serialization which is required by mapPartitions
    val networkString = data.sparkContext.broadcast(network.saveToString())


    data.mapPartitions(iterator => {

      licenseKey.foreach(s => License.validate(s))

      if (iterator.isEmpty)
        Iterator.empty // as of Spark 1.1 mapPartitions can pass empty iterators
      else {

        val networkPartition = new Network
        networkPartition.loadFromString(networkString.value)

        // TODO allow configuration of Inference engine
        new Reader(networkPartition, new RelevanceTreeInferenceFactory, newReader(networkPartition, iterator), predictions)
      }
    })
  }
}