Skip to main content

Time Series Prediction with Apache Spark

// --------------------------------------------------------------------------------------------------------------------
// <copyright file="TimeSeriesPredictions.scala" company="Bayes Server">
// Copyright (C) Bayes Server. All rights reserved.
// </copyright>
// --------------------------------------------------------------------------------------------------------------------

package com.bayesserver.spark.examples.inference

import com.bayesserver.{License, Network}
import org.apache.spark.SparkContext
import com.bayesserver.inference.Evidence
import com.bayesserver.spark.core._
import com.bayesserver.spark.core.PredictVariable
import com.bayesserver.spark.core.PredictVariance
import scala.Some
import TimeMode._

/**
* Demonstrates Time Series predictions using Bayes Server on Apache Spark.
*/
object TimeSeriesPredictions {

/**
* Demonstrate making time series predictions.
* @param path The path to the sample time series network included with Bayes Server called 'Walkthrough 3 - Time series network.bayes'
* @param sc The SparkContext (see Apache Spark user guide for help on this)
*/
def apply(path: String, sc: SparkContext, licenseKey: Option[String] = None) = {

licenseKey.foreach(key => License.validate(key))

val network = new Network()
network.load(path)

// Defining case classes in RDDs is optional, but makes the example clearer
case class Point(x1: Double, x2: Double)

case class TimeSeries(values: Seq[Point])

// here we define some test data inline. Normally you would use data on the cluster.
// You can use Option[...] for data types if we want to represent missing data
val testData = sc.parallelize(Seq(
TimeSeries(
Seq(
Point(-0.40524692794093076, 23.776762100584833),
Point(4.7122694935148735, 19.77962797774531),
Point(2.9686471004503421, 17.903935430472625),
Point(5.520077143425917, 15.741589336983196),
Point(10.566675366294405, 14.71857740776503)
)
),
TimeSeries(
Seq(
Point(5.65151101171709, 16.250092854775264),
Point(7.8396884451200757, 14.25030395410486),
Point(9.1972154914500184, 11.07714258018922),
Point(11.242423644030387, 10.199447302688682),
Point(12.372335188920115, 10.155607834998676)
)
)))

class TimeSeriesReader(val network: Network, val iterator: Iterator[TimeSeries]) extends IteratorEvidenceReader[TimeSeries] {

val x1 = network.getVariables.get("X1", true)
val x2 = network.getVariables.get("X2", true)

override def setEvidence(item: TimeSeries, evidence: Evidence): Unit = {

item.values.view.zipWithIndex.foreach({
case (point, time) =>
evidence.set(x1, point.x1, time)
evidence.set(x2, point.x2, time)
})
}
}

// make some time series predictions into the future

val predictions = Prediction.predict[TimeSeries](
network,
testData,
Seq(
PredictVariable("X1", Some(PredictTime(5, Absolute))), PredictVariance("X1", Some(PredictTime(5, Absolute))),
PredictVariable("X2", Some(PredictTime(5, Absolute))), PredictVariance("X2", Some(PredictTime(5, Absolute))),
PredictVariable("X1", Some(PredictTime(6, Absolute))), PredictVariance("X1", Some(PredictTime(6, Absolute))),
PredictVariable("X2", Some(PredictTime(6, Absolute))), PredictVariance("X2", Some(PredictTime(6, Absolute))),
PredictLogLikelihood() // this value can be used for Time Series anomaly detection
),
(network, iterator) => new TimeSeriesReader(network, iterator),
licenseKey)

predictions.foreach(println)
}

}