Prediction with Apache Spark
package com.bayesserver.spark.examples.inference
import com.bayesserver.{License, Network}
import org.apache.spark.SparkContext
import com.bayesserver.inference.Evidence
import com.bayesserver.spark.core._
import com.bayesserver.spark.core.PredictVariable
import com.bayesserver.spark.core.PredictVariance
import com.bayesserver.spark.core.PredictState
object WastePredictions {
def apply(path: String, sc: SparkContext, licenseKey: Option[String] = None) = {
licenseKey.foreach(key => License.validate(key))
val network = new Network()
network.load(path)
case class Waste(
burningRegimen: String, cO2Concentration: Double, dustEmission: Double,
filterEfficiency: Double, filterState: String, lightPenetrability: Double,
metalsEmission: Double, metalsInWaste: Double, wasteType: String)
val testData = sc.parallelize(Seq(
Waste("Stable", -1.93, 2.54, -3.9, "Intact", 1.5, 3.2, 0.665, "Industrial"),
Waste("Stable", -1.78, 2.99, -3.2, "Intact", 2.17, 2.53, -0.496, "Household"),
Waste("Stable", -2.26, 5.59, -0.5, "Defect", 0.609, 5.1, -0.411, "Household")
))
class WasteReader(val network: Network, val iterator: Iterator[Waste]) extends IteratorEvidenceReader[Waste] {
val wasteType = network.getVariables.get("Waste type", true)
val lightPenetrability = network.getVariables.get("Light penetrability", true)
override def setEvidence(item: Waste, evidence: Evidence): Unit = {
evidence.setState(wasteType.getStates.get(item.wasteType, true))
evidence.set(lightPenetrability, item.lightPenetrability)
}
}
val predictions = Prediction.predict[Waste](
network,
testData,
Seq(
PredictVariable("Burning Regimen"),
PredictState("Burning Regimen"),
PredictVariable("CO2 concentration"),
PredictVariance("CO2 concentration"),
PredictLogLikelihood()
),
(network, iterator) => new WasteReader(network, iterator),
licenseKey)
predictions.foreach(println)
}
}