Time Series Prediction with Apache Spark

// --------------------------------------------------------------------------------------------------------------------
// <copyright file="TimeSeriesPredictions.scala" company="Bayes Server">
//   Copyright (C) Bayes Server.  All rights reserved.
// </copyright>
// --------------------------------------------------------------------------------------------------------------------

package com.bayesserver.spark.examples.inference

import com.bayesserver.{License, Network}
import org.apache.spark.SparkContext
import com.bayesserver.inference.Evidence
import com.bayesserver.spark.core._
import com.bayesserver.spark.core.PredictVariable
import com.bayesserver.spark.core.PredictVariance
import scala.Some
import TimeMode._

/**
 * Demonstrates Time Series predictions using Bayes Server on Apache Spark.
 */
object TimeSeriesPredictions {

  /**
   * Demonstrate making time series predictions.
   * @param path The path to the sample time series network included with Bayes Server called 'Walkthrough 3 - Time series network.bayes'
   * @param sc The SparkContext  (see Apache Spark user guide for help on this)
   */
  def apply(path: String, sc: SparkContext, licenseKey: Option[String] = None) = {

    licenseKey.foreach(key => License.validate(key))

    val network = new Network()
    network.load(path)

    // Defining case classes in RDDs is optional, but makes the example clearer
    case class Point(x1: Double, x2: Double)

    case class TimeSeries(values: Seq[Point])

    // here we define some test data inline.  Normally you would use data on the cluster.
    // You can use Option[...] for data types if we want to represent missing data
    val testData = sc.parallelize(Seq(
      TimeSeries(
        Seq(
          Point(-0.40524692794093076, 23.776762100584833),
          Point(4.7122694935148735, 19.77962797774531),
          Point(2.9686471004503421, 17.903935430472625),
          Point(5.520077143425917, 15.741589336983196),
          Point(10.566675366294405, 14.71857740776503)
        )
      ),
      TimeSeries(
        Seq(
          Point(5.65151101171709, 16.250092854775264),
          Point(7.8396884451200757, 14.25030395410486),
          Point(9.1972154914500184, 11.07714258018922),
          Point(11.242423644030387, 10.199447302688682),
          Point(12.372335188920115, 10.155607834998676)
        )
      )))

    class TimeSeriesReader(val network: Network, val iterator: Iterator[TimeSeries]) extends IteratorEvidenceReader[TimeSeries] {

      val x1 = network.getVariables.get("X1", true)
      val x2 = network.getVariables.get("X2", true)

      override def setEvidence(item: TimeSeries, evidence: Evidence): Unit = {

        item.values.view.zipWithIndex.foreach({
          case (point, time) =>
            evidence.set(x1, point.x1, time)
            evidence.set(x2, point.x2, time)
        })
      }
    }

    // make some time series predictions into the future

    val predictions = Prediction.predict[TimeSeries](
      network,
      testData,
      Seq(
        PredictVariable("X1", Some(PredictTime(5, Absolute))), PredictVariance("X1", Some(PredictTime(5, Absolute))),
        PredictVariable("X2", Some(PredictTime(5, Absolute))), PredictVariance("X2", Some(PredictTime(5, Absolute))),
        PredictVariable("X1", Some(PredictTime(6, Absolute))), PredictVariance("X1", Some(PredictTime(6, Absolute))),
        PredictVariable("X2", Some(PredictTime(6, Absolute))), PredictVariance("X2", Some(PredictTime(6, Absolute))),
        PredictLogLikelihood() // this value can be used for Time Series anomaly detection
      ),
      (network, iterator) => new TimeSeriesReader(network, iterator),
      licenseKey)

    predictions.foreach(println)
  }

}