ステートフルなアプリケーションの例
この記事には、カスタム ステートフル アプリケーションのコード例が含まれています。 Databricks では、集計や結合などの一般的な操作に組み込みのステートフル メソッドを使用することをお勧めします。
この記事のパターンでは、Databricks Runtime 16.2 以降のパブリック プレビューで使用できる transformWithState
演算子と関連クラスを使用します。「カスタム ステートフル アプリケーションの構築」を参照してください。
注:
Python は transformWithStateInPandas
演算子を使用して同じ機能を提供します。 以下の例は、Python と Scala のコードを示しています。
要件
transformWithState
演算子と関連する APIs およびクラスには、次の要件があります。
Databricks Runtime 16.2 以降で使用できます。
コンピュートは、専用アクセスモードまたは非分離アクセスモードを使用する必要があります。
RocksDB状態ストア プロバイダーを使用する必要があります。Databricks では、コンピュート構成の一部として RocksDB を有効にすることをお勧めします。
注:
現在のセッションで RocksDB 状態ストア プロバイダーを有効にするには、次のコマンドを実行します。
spark.conf.set("spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
緩やかに変化する寸法 (SCD) タイプ1
次のコードは、 transformWithState
を使用して SCD タイプ 1 を実装する例です。 SCD タイプ 1 は、特定のフィールドの最新の値のみを追跡します。
注:
ストリーミング テーブルとストリーミング APPLY CHANGES INTO
を使用して、Delta Lake でサポートされるテーブルを使用して SCD タイプ 1 またはタイプ 2 を実装できます。 この例では SCD 状態ストアにタイプ 1 を実装し、リアルタイムに近いアプリケーションの待機時間を短縮します。
import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, LongType, StringType
from typing import Iterator
spark.conf.set("spark.sql.streaming.stateStore.providerClass","org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
output_schema = StructType([
StructField("user", StringType(), True),
StructField("time", LongType(), True),
StructField("location", StringType(), True)
])
class SCDType1StatefulProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
value_state_schema = StructType([
StructField("user", StringType(), True),
StructField("time", LongType(), True),
StructField("location", StringType(), True)
])
self.latest_location = handle.getValueState("latestLocation", value_state_schema)
def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]:
max_row = None
max_time = float('-inf')
for pdf in rows:
for _, pd_row in pdf.iterrows():
time_value = pd_row["time"]
if time_value > max_time:
max_time = time_value
max_row = tuple(pd_row)
exists = self.latest_location.exists()
if not exists or max_row[1] > self.latest_location.get()[1]:
self.latest_location.update(max_row)
yield pd.DataFrame(
{"user": (max_row[0],), "time": (max_row[1],), "location": (max_row[2],)}
)
yield pd.DataFrame()
def close(self) -> None:
pass
(df.groupBy("user")
.transformWithStateInPandas(
statefulProcessor=SCDType1StatefulProcessor(),
outputStructType=output_schema,
outputMode="Update",
timeMode="None",
)
.writeStream...
)
case class UserLocation(
user: String,
time: Long,
location: String)
class SCDType1StatefulProcessor extends StatefulProcessor[String, UserLocation, UserLocation] {
import org.apache.spark.sql.{Encoders}
@transient private var _latestLocation: ValueState[UserLocation] = _
override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_latestLocation = getHandle.getValueState[UserLocation]("locationState",
Encoders.product[UserLocation], TTLConfig.NONE)
}
override def handleInputRows(
key: String,
inputRows: Iterator[UserLocation],
timerValues: TimerValues): Iterator[UserLocation] = {
val maxNewLocation = inputRows.maxBy(_.time)
if (_latestLocation.getOption().isEmpty || maxNewLocation.time > _latestLocation.get().time) {
_latestLocation.update(maxNewLocation)
Iterator.single(maxNewLocation)
} else {
Iterator.empty
}
}
}
ダウンタイム検出器
transformWithState
タイマーを実装して、特定のキーのレコードがマイクロバッチで処理されない場合でも、経過時間に基づいてアクションを実行できるようにします。
次の例では、ダウンタイム検出機能のパターンを実装します。 特定のキーに新しい値が表示されるたびに、 lastSeen
状態値が更新され、既存のタイマーがクリアされ、将来のタイマーがリセットされます。
タイマーの期限が切れると、アプリケーションはキーについて最後に観測されたイベントからの経過時間を出力します。 その後、新しいタイマーを設定して、10 秒後に更新を出力します。
import datetime
import time
class DownTimeDetectorStatefulProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
state_schema = StructType([StructField("value", TimestampType(), True)])
self.handle = handle
self.last_seen = handle.getValueState("last_seen", state_schema)
def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]:
latest_from_existing = self.last_seen.get()
downtime_duration = timerValues.getCurrentProcessingTimeInMs() - int(time.time() * 1000)
self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000)
yield pd.DataFrame(
{
"id": key,
"timeValues": str(downtime_duration),
}
)
def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
max_row = max((tuple(pdf.iloc[0]) for pdf in rows), key=lambda row: row[1])
if self.last_seen.exists():
latest_from_existing = self.last_seen.get()
else:
latest_from_existing = datetime.fromtimestamp(0)
if latest_from_existing < max_row[1]:
for timer in self.handle.listTimers():
self.handle.deleteTimer(timer)
self.last_seen.update((max_row[1],))
self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 5000)
timestamp_in_millis = str(timerValues.getCurrentProcessingTimeInMs())
yield pd.DataFrame({"id": key, "timeValues": timestamp_in_millis})
def close(self) -> None:
pass
import java.sql.Timestamp
import org.apache.spark.sql.Encoders
// The (String, Timestamp) schema represents an (id, time). We want to do downtime
// detection on every single unique sensor, where each sensor has a sensor ID.
class DowntimeDetector(duration: Duration) extends
StatefulProcessor[String, (String, Timestamp), (String, Duration)] {
@transient private var _lastSeen: ValueState[Timestamp] = _
override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
_lastSeen = getHandle.getValueState[Timestamp]("lastSeen", Encoders.TIMESTAMP, TTLConfig.NONE)
}
// The logic here is as follows: find the largest timestamp seen so far. Set a timer for
// the duration later.
override def handleInputRows(
key: String,
inputRows: Iterator[(String, Timestamp)],
timerValues: TimerValues): Iterator[(String, Duration)] = {
val latestRecordFromNewRows = inputRows.maxBy(_._2.getTime)
// Use getOrElse to initiate state variable if it doesn't exist
val latestTimestampFromExistingRows = _lastSeen.getOption().getOrElse(new Timestamp(0))
val latestTimestampFromNewRows = latestRecordFromNewRows._2
if (latestTimestampFromNewRows.after(latestTimestampFromExistingRows)) {
// Cancel the one existing timer, since we have a new latest timestamp.
// We call "listTimers()" just because we don't know ahead of time what
// the timestamp of the existing timer is.
getHandle.listTimers().foreach(timer => getHandle.deleteTimer(timer))
_lastSeen.update(latestTimestampFromNewRows)
// Use timerValues to schedule a timer using processing time.
getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + duration.toMillis)
} else {
// No new latest timestamp, so no need to update state or set a timer.
}
Iterator.empty
}
override def handleExpiredTimer(
key: String,
timerValues: TimerValues,
expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, Duration)] = {
val latestTimestamp = _lastSeen.get()
val downtimeDuration = new Duration(
timerValues.getCurrentProcessingTimeInMs() - latestTimestamp.getTime)
// Register another timer that will fire in 10 seconds.
// Timers can be registered anywhere but init()
getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000)
Iterator((key, downtimeDuration))
}
}
既存の状態情報を移行する
次の例は、初期状態を受け入れるステートフル アプリケーションを実装する方法を示しています。 初期状態処理は任意のステートフル アプリケーションに追加できますが、初期状態はアプリケーションを最初に初期化するときにのみ設定できます。
この例では、 statestore
リーダーを使用して、チェックポイント パスから既存の状態情報を読み込みます。 このパターンの使用例として、従来のステートフルアプリケーションから transformWithState
への移行があります。
import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, LongType, StringType
from typing import Iterator
spark.conf.set("spark.sql.streaming.stateStore.providerClass","org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
"""
Input schema is as below
input_schema = StructType(
[StructField("id", StringType(), True)],
[StructField("value", StringType(), True)]
)
"""
output_schema = StructType([
StructField("id", StringType(), True),
StructField("accumulated", StringType(), True)
])
class AccumulatedCounterStatefulProcessorWithInitialState(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
state_schema = StructType([StructField("value", IntegerType(), True)])
self.counter_state = handle.getValueState("counter_state", state_schema)
self.handle = handle
def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
exists = self.counter_state.exists()
if exists:
value_row = self.counter_state.get()
existing_value = value_row[0]
else:
existing_value = 0
accumulated_value = existing_value
for pdf in rows:
value = pdf["value"].astype(int).sum()
accumulated_value += value
self.counter_state.update((accumulated_value,))
yield pd.DataFrame({"id": key, "accumulated": str(accumulated_value)})
def handleInitialState(self, key, initialState, timerValues) -> None:
init_val = initialState.at[0, "initVal"]
self.counter_state.update((init_val,))
def close(self) -> None:
pass
initial_state = spark.read.format("statestore")
.option("path", "$checkpointsDir")
.load()
df.groupBy("id")
.transformWithStateInPandas(
statefulProcessor=AccumulatedCounterStatefulProcessorWithInitialState(),
outputStructType=output_schema,
outputMode="Update",
timeMode="None",
initialState=initial_state,
)
.writeStream...
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.{Dataset, Encoder, Encoders , DataFrame}
import org.apache.spark.sql.types._
class InitialStateStatefulProcessor extends StatefulProcessorWithInitialState[String, (String, String, String), (String, String), (String, Int)] {
@transient protected var valueState: ValueState[Int] = _
override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
valueState = getHandle.getValueState[Int]("valueState",
Encoders.scalaInt, TTLConfig.NONE)
}
override def handleInputRows(
key: String,
inputRows: Iterator[(String, String, String)],
timerValues: TimerValues): Iterator[(String, String)] = {
var existingValue = 0
if (valueState.exists()) {
existingValue += valueState.get()
}
var accumulatedValue = existingValue
for (row <- inputRows) {
accumulatedValue += row._2.toInt
}
valueState.update(accumulatedValue)
Iterator((key, accumulatedValue.toString))
}
override def handleInitialState(
key: String, initialState: (String, Int), timerValues: TimerValues): Unit = {
valueState.update(initialState._2)
}
}