Skip to main content

Prediction with Apache Spark

// --------------------------------------------------------------------------------------------------------------------
// <copyright file="WastePredictions.scala" company="Bayes Server">
// Copyright (C) Bayes Server. All rights reserved.
// </copyright>
// --------------------------------------------------------------------------------------------------------------------

package com.bayesserver.spark.examples.inference

import com.bayesserver.{License, Network}
import org.apache.spark.SparkContext
import com.bayesserver.inference.Evidence
import com.bayesserver.spark.core._
import com.bayesserver.spark.core.PredictVariable
import com.bayesserver.spark.core.PredictVariance
import com.bayesserver.spark.core.PredictState

/**
* Demonstrates discrete, continuous and log-likelihood predictions on the Waste sample network.
*/
object WastePredictions {

/**
* Demonstrate making various predictions on the sample Waste network.
* @param path The path to the sample Waste network included with Bayes Server called 'Waste.bayes'
* @param sc The SparkContext (see Apache Spark user guide for help on this)
*/
def apply(path: String, sc: SparkContext, licenseKey: Option[String] = None) = {

licenseKey.foreach(key => License.validate(key))

val network = new Network()
network.load(path)

// make some predictions about P(Burning Regimen) and P(CO2 concentration), given Waste type and Metals in waste

// Defining case classes in RDDs is optional, but makes the example clearer
case class Waste(
burningRegimen: String, cO2Concentration: Double, dustEmission: Double,
filterEfficiency: Double, filterState: String, lightPenetrability: Double,
metalsEmission: Double, metalsInWaste: Double, wasteType: String)

// here we define some test data inline. Normally you would use data on the cluster.
// You can use Option[...] for data types if we want to represent missing data
val testData = sc.parallelize(Seq(
Waste("Stable", -1.93, 2.54, -3.9, "Intact", 1.5, 3.2, 0.665, "Industrial"),
Waste("Stable", -1.78, 2.99, -3.2, "Intact", 2.17, 2.53, -0.496, "Household"),
Waste("Stable", -2.26, 5.59, -0.5, "Defect", 0.609, 5.1, -0.411, "Household")
))

class WasteReader(val network: Network, val iterator: Iterator[Waste]) extends IteratorEvidenceReader[Waste] {

val wasteType = network.getVariables.get("Waste type", true)
val lightPenetrability = network.getVariables.get("Light penetrability", true)

override def setEvidence(item: Waste, evidence: Evidence): Unit = {

evidence.setState(wasteType.getStates.get(item.wasteType, true))
evidence.set(lightPenetrability, item.lightPenetrability)
}
}

val predictions = Prediction.predict[Waste](
network,
testData,
Seq(
PredictVariable("Burning Regimen"),
PredictState("Burning Regimen"),
PredictVariable("CO2 concentration"),
PredictVariance("CO2 concentration"),
PredictLogLikelihood()
),
(network, iterator) => new WasteReader(network, iterator),
licenseKey)

predictions.foreach(println)
}


}