ステートフルなアプリケーションの例

この記事には、カスタム ステートフル アプリケーションのコード例が含まれています。 Databricks では、集計や結合などの一般的な操作に組み込みのステートフル メソッドを使用することをお勧めします。

この記事のパターンでは、Databricks Runtime 16.2 以降のパブリック プレビューで使用できる transformWithState 演算子と関連クラスを使用します。「カスタム ステートフル アプリケーションの構築」を参照してください。

注:

Python は transformWithStateInPandas 演算子を使用して同じ機能を提供します。 以下の例は、Python と Scala のコードを示しています。

要件

transformWithState演算子と関連する APIs およびクラスには、次の要件があります。

  • Databricks Runtime 16.2 以降で使用できます。

  • コンピュートは、専用アクセスモードまたは非分離アクセスモードを使用する必要があります。

  • RocksDB状態ストア プロバイダーを使用する必要があります。Databricks では、コンピュート構成の一部として RocksDB を有効にすることをお勧めします。

注:

現在のセッションで RocksDB 状態ストア プロバイダーを有効にするには、次のコマンドを実行します。

spark.conf.set("spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")

緩やかに変化する寸法 (SCD) タイプ1

次のコードは、 transformWithStateを使用して SCD タイプ 1 を実装する例です。 SCD タイプ 1 は、特定のフィールドの最新の値のみを追跡します。

注:

ストリーミング テーブルとストリーミング APPLY CHANGES INTO を使用して、Delta Lake でサポートされるテーブルを使用して SCD タイプ 1 またはタイプ 2 を実装できます。 この例では SCD 状態ストアにタイプ 1 を実装し、リアルタイムに近いアプリケーションの待機時間を短縮します。

import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, LongType, StringType
from typing import Iterator
spark.conf.set("spark.sql.streaming.stateStore.providerClass","org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")


output_schema = StructType([
    StructField("user", StringType(), True),
    StructField("time", LongType(), True),
    StructField("location", StringType(), True)
])

class SCDType1StatefulProcessor(StatefulProcessor):
  def init(self, handle: StatefulProcessorHandle) -> None:
    value_state_schema = StructType([
        StructField("user", StringType(), True),
        StructField("time", LongType(), True),
        StructField("location", StringType(), True)
    ])
    self.latest_location = handle.getValueState("latestLocation", value_state_schema)

  def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]:
    max_row = None
    max_time = float('-inf')
    for pdf in rows:
      for _, pd_row in pdf.iterrows():
        time_value = pd_row["time"]
        if time_value > max_time:
            max_time = time_value
            max_row = tuple(pd_row)
    exists = self.latest_location.exists()
    if not exists or max_row[1] > self.latest_location.get()[1]:
      self.latest_location.update(max_row)
      yield pd.DataFrame(
              {"user": (max_row[0],), "time": (max_row[1],), "location": (max_row[2],)}
          )
    yield pd.DataFrame()

  def close(self) -> None:
    pass


(df.groupBy("user")
  .transformWithStateInPandas(
      statefulProcessor=SCDType1StatefulProcessor(),
      outputStructType=output_schema,
      outputMode="Update",
      timeMode="None",
  )
  .writeStream...
)
case class UserLocation(
    user: String,
    time: Long,
    location: String)

class SCDType1StatefulProcessor extends StatefulProcessor[String, UserLocation, UserLocation] {
  import org.apache.spark.sql.{Encoders}

  @transient private var _latestLocation: ValueState[UserLocation] = _

  override def init(
      outputMode: OutputMode,
      timeMode: TimeMode): Unit = {
    _latestLocation = getHandle.getValueState[UserLocation]("locationState",
      Encoders.product[UserLocation], TTLConfig.NONE)
  }

  override def handleInputRows(
      key: String,
      inputRows: Iterator[UserLocation],
      timerValues: TimerValues): Iterator[UserLocation] = {
    val maxNewLocation = inputRows.maxBy(_.time)
    if (_latestLocation.getOption().isEmpty || maxNewLocation.time > _latestLocation.get().time) {
      _latestLocation.update(maxNewLocation)
      Iterator.single(maxNewLocation)
    } else {
      Iterator.empty
    }
  }
}

ダウンタイム検出器

transformWithState タイマーを実装して、特定のキーのレコードがマイクロバッチで処理されない場合でも、経過時間に基づいてアクションを実行できるようにします。

次の例では、ダウンタイム検出機能のパターンを実装します。 特定のキーに新しい値が表示されるたびに、 lastSeen 状態値が更新され、既存のタイマーがクリアされ、将来のタイマーがリセットされます。

タイマーの期限が切れると、アプリケーションはキーについて最後に観測されたイベントからの経過時間を出力します。 その後、新しいタイマーを設定して、10 秒後に更新を出力します。

import datetime
import time

class DownTimeDetectorStatefulProcessor(StatefulProcessor):
    def init(self, handle: StatefulProcessorHandle) -> None:
        state_schema = StructType([StructField("value", TimestampType(), True)])
        self.handle = handle
        self.last_seen = handle.getValueState("last_seen", state_schema)

    def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]:
        latest_from_existing = self.last_seen.get()
        downtime_duration = timerValues.getCurrentProcessingTimeInMs() - int(time.time() * 1000)
        self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000)
        yield pd.DataFrame(
            {
                "id": key,
                "timeValues": str(downtime_duration),
            }
        )

    def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
        max_row = max((tuple(pdf.iloc[0]) for pdf in rows), key=lambda row: row[1])
        if self.last_seen.exists():
            latest_from_existing = self.last_seen.get()
        else:
            latest_from_existing = datetime.fromtimestamp(0)

        if latest_from_existing < max_row[1]:
            for timer in self.handle.listTimers():
                self.handle.deleteTimer(timer)
            self.last_seen.update((max_row[1],))

        self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 5000)

        timestamp_in_millis = str(timerValues.getCurrentProcessingTimeInMs())

        yield pd.DataFrame({"id": key, "timeValues": timestamp_in_millis})

    def close(self) -> None:
        pass
import java.sql.Timestamp
import org.apache.spark.sql.Encoders

// The (String, Timestamp) schema represents an (id, time). We want to do downtime
// detection on every single unique sensor, where each sensor has a sensor ID.
class DowntimeDetector(duration: Duration) extends
  StatefulProcessor[String, (String, Timestamp), (String, Duration)] {

  @transient private var _lastSeen: ValueState[Timestamp] = _

  override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
    _lastSeen = getHandle.getValueState[Timestamp]("lastSeen", Encoders.TIMESTAMP, TTLConfig.NONE)
  }

  // The logic here is as follows: find the largest timestamp seen so far. Set a timer for
  // the duration later.
  override def handleInputRows(
      key: String,
      inputRows: Iterator[(String, Timestamp)],
      timerValues: TimerValues): Iterator[(String, Duration)] = {
    val latestRecordFromNewRows = inputRows.maxBy(_._2.getTime)

    // Use getOrElse to initiate state variable if it doesn't exist
    val latestTimestampFromExistingRows = _lastSeen.getOption().getOrElse(new Timestamp(0))
    val latestTimestampFromNewRows = latestRecordFromNewRows._2

    if (latestTimestampFromNewRows.after(latestTimestampFromExistingRows)) {
      // Cancel the one existing timer, since we have a new latest timestamp.
      // We call "listTimers()" just because we don't know ahead of time what
      // the timestamp of the existing timer is.
      getHandle.listTimers().foreach(timer => getHandle.deleteTimer(timer))

      _lastSeen.update(latestTimestampFromNewRows)
      // Use timerValues to schedule a timer using processing time.
      getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + duration.toMillis)
    } else {
      // No new latest timestamp, so no need to update state or set a timer.
    }

    Iterator.empty
  }

  override def handleExpiredTimer(
    key: String,
    timerValues: TimerValues,
    expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, Duration)] = {
      val latestTimestamp = _lastSeen.get()
      val downtimeDuration = new Duration(
        timerValues.getCurrentProcessingTimeInMs() - latestTimestamp.getTime)

      // Register another timer that will fire in 10 seconds.
      // Timers can be registered anywhere but init()
      getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000)

      Iterator((key, downtimeDuration))
  }
}

既存の状態情報を移行する

次の例は、初期状態を受け入れるステートフル アプリケーションを実装する方法を示しています。 初期状態処理は任意のステートフル アプリケーションに追加できますが、初期状態はアプリケーションを最初に初期化するときにのみ設定できます。

この例では、 statestore リーダーを使用して、チェックポイント パスから既存の状態情報を読み込みます。 このパターンの使用例として、従来のステートフルアプリケーションから transformWithStateへの移行があります。

import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, LongType, StringType
from typing import Iterator
spark.conf.set("spark.sql.streaming.stateStore.providerClass","org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")

"""
Input schema is as below

input_schema = StructType(
    [StructField("id", StringType(), True)],
    [StructField("value", StringType(), True)]
)
"""

output_schema = StructType([
    StructField("id", StringType(), True),
    StructField("accumulated", StringType(), True)
])

class AccumulatedCounterStatefulProcessorWithInitialState(StatefulProcessor):

    def init(self, handle: StatefulProcessorHandle) -> None:
        state_schema = StructType([StructField("value", IntegerType(), True)])
        self.counter_state = handle.getValueState("counter_state", state_schema)
        self.handle = handle

    def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
        exists = self.counter_state.exists()
        if exists:
            value_row = self.counter_state.get()
            existing_value = value_row[0]
        else:
            existing_value = 0

        accumulated_value = existing_value

        for pdf in rows:
            value = pdf["value"].astype(int).sum()
            accumulated_value += value

        self.counter_state.update((accumulated_value,))

        yield pd.DataFrame({"id": key, "accumulated": str(accumulated_value)})

    def handleInitialState(self, key, initialState, timerValues) -> None:
        init_val = initialState.at[0, "initVal"]
        self.counter_state.update((init_val,))

    def close(self) -> None:
        pass

initial_state = spark.read.format("statestore")
  .option("path", "$checkpointsDir")
  .load()

df.groupBy("id")
  .transformWithStateInPandas(
      statefulProcessor=AccumulatedCounterStatefulProcessorWithInitialState(),
      outputStructType=output_schema,
      outputMode="Update",
      timeMode="None",
      initialState=initial_state,
  )
  .writeStream...
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.{Dataset, Encoder, Encoders , DataFrame}
import org.apache.spark.sql.types._

class InitialStateStatefulProcessor extends StatefulProcessorWithInitialState[String, (String, String, String), (String, String), (String, Int)] {
  @transient protected var valueState: ValueState[Int] = _

  override def init(
      outputMode: OutputMode,
      timeMode: TimeMode): Unit = {
    valueState = getHandle.getValueState[Int]("valueState",
      Encoders.scalaInt, TTLConfig.NONE)
  }

  override def handleInputRows(
      key: String,
      inputRows: Iterator[(String, String, String)],
      timerValues: TimerValues): Iterator[(String, String)] = {
    var existingValue = 0
    if (valueState.exists()) {
      existingValue += valueState.get()
    }
    var accumulatedValue = existingValue
    for (row <- inputRows) {
      accumulatedValue += row._2.toInt
    }
    valueState.update(accumulatedValue)
    Iterator((key, accumulatedValue.toString))
  }

  override def handleInitialState(
      key: String, initialState: (String, Int), timerValues: TimerValues): Unit = {
    valueState.update(initialState._2)
  }
}