Prediction with Apache Spark

// --------------------------------------------------------------------------------------------------------------------
// <copyright file="WastePredictions.scala" company="Bayes Server">
//   Copyright (C) Bayes Server.  All rights reserved.
// </copyright>
// --------------------------------------------------------------------------------------------------------------------

package com.bayesserver.spark.examples.inference

import com.bayesserver.{License, Network}
import org.apache.spark.SparkContext
import com.bayesserver.inference.Evidence
import com.bayesserver.spark.core._
import com.bayesserver.spark.core.PredictVariable
import com.bayesserver.spark.core.PredictVariance
import com.bayesserver.spark.core.PredictState

/**
 * Demonstrates discrete, continuous and log-likelihood predictions on the Waste sample network.
 */
object WastePredictions  {

  /**
   * Demonstrate making various predictions on the sample Waste network.
   * @param path The path to the sample Waste network included with Bayes Server called 'Waste.bayes'
   * @param sc The SparkContext  (see Apache Spark user guide for help on this)
   */
  def apply(path: String, sc: SparkContext, licenseKey: Option[String] = None) = {

    licenseKey.foreach(key => License.validate(key))

    val network = new Network()
    network.load(path)

    // make some predictions about P(Burning Regimen) and P(CO2 concentration), given Waste type and Metals in waste

    // Defining case classes in RDDs is optional, but makes the example clearer
    case class Waste(
                      burningRegimen: String, cO2Concentration: Double, dustEmission: Double,
                      filterEfficiency: Double, filterState: String, lightPenetrability: Double,
                      metalsEmission: Double, metalsInWaste: Double, wasteType: String)

    // here we define some test data inline.  Normally you would use data on the cluster.
    // You can use Option[...] for data types if we want to represent missing data
    val testData = sc.parallelize(Seq(
      Waste("Stable", -1.93, 2.54, -3.9, "Intact", 1.5, 3.2, 0.665, "Industrial"),
      Waste("Stable", -1.78, 2.99, -3.2, "Intact", 2.17, 2.53, -0.496, "Household"),
      Waste("Stable", -2.26, 5.59, -0.5, "Defect", 0.609, 5.1, -0.411, "Household")
    ))

    class WasteReader(val network: Network, val iterator: Iterator[Waste]) extends IteratorEvidenceReader[Waste] {

      val wasteType = network.getVariables.get("Waste type", true)
      val lightPenetrability = network.getVariables.get("Light penetrability", true)

      override def setEvidence(item: Waste, evidence: Evidence): Unit = {

        evidence.setState(wasteType.getStates.get(item.wasteType, true))
        evidence.set(lightPenetrability, item.lightPenetrability)
      }
    }

    val predictions = Prediction.predict[Waste](
      network,
      testData,
      Seq(
        PredictVariable("Burning Regimen"),
        PredictState("Burning Regimen"),
        PredictVariable("CO2 concentration"),
        PredictVariance("CO2 concentration"),
        PredictLogLikelihood()
      ),
      (network, iterator) => new WasteReader(network, iterator),
      licenseKey)

    predictions.foreach(println)
  }


}