PySpark 中实现累积递归计算(如复利式列更新)

本文介绍如何在 pyspark 中高效实现基于前一行结果的累积递归计算(如 aₙ = aₙ₋₁ × (1 + bₙ₋₁/100)),避开低效的逐行处理,利用 udf + 序号映射 + 预加载系数列表完成向量化模拟递归。

在 PySpark 中直接实现“依赖上一行输出”的累积逻辑(如复利更新、滚动衰减等)是一个经典难点:lag() 仅支持单步偏移,无法表达 A[i] = A[i-1] * (1 + B[i-1]/100) 这类链式依赖;而 foreach() 或 toPandas() 等逐行/本地处理方式在大数据量下极易 OOM 或超时。

核心思路是「解耦递归」:将递归公式 Aₙ = A₀ × ∏ᵢ₌₀ⁿ⁻¹

(1 + Bᵢ/100) 显式展开为累乘形式。只要能按顺序获取 B 列全部值,并为每行分配其对应乘积长度 n(即从第 0 行到当前行前的所有 B 元素索引),即可通过 Python 函数预计算每个 n 对应的 Aₙ。

以下是完整可运行的解决方案:

from pyspark.sql import Window
from pyspark.sql.functions import col, udf, row_number, lit
from pyspark.sql.types import FloatType
from functools import reduce

# 假设原始 DataFrame 名为 df,含列 "A" 和 "B"
# Step 1: 提取 B 列为 Python 列表(注意:仅适用于中等规模数据;超大表需改用广播变量+分段处理)
B_list = df.select("B").rdd.map(lambda r: float(r.B)).collect()

# Step 2: 定义高效累乘 UDF(避免递归调用栈,使用迭代+缓存中间结果更稳定)
def compute_cumulative_a(a0, n):
    if n < 0:
        return float(a0)
    result = float(a0)
    for i in range(n):  # 计算 A0 → A1 → ... → An,共 n 次乘法
        if i < len(B_list):
            result *= (1 + B_list[i] / 100.0)
        else:
            break
    return result

compute_udf = udf(compute_cumulative_a, FloatType())

# Step 3: 构建有序序号列(关键!确保 B_list 索引与行顺序严格一致)
window_spec = Window.orderBy("A")  # 若原始顺序重要,请改用带时间戳/ID的稳定排序字段
df_with_index = df.withColumn("row_idx", row_number().over(window_spec) - lit(1))

# Step 4: 应用 UDF,将每行的 row_idx 作为 n,计算对应 A_n
result_df = df_with_index.withColumn(
    "A_updated",
    compute_udf(col("A"), col("row_idx"))
).drop("row_idx")

result_df.select("A_updated", "B").show(truncate=False)

输出示例

+---------+-----+
|A_updated|  B  |
+---------+-----+
|   3740.0|-15.0|
|   3179.0| -5.0|
| 3020.05 | -10.0|
+---------+-----+

⚠️ 重要注意事项

  • 顺序一致性:Window.orderBy(...) 必须保证与 B_list 的提取顺序完全一致(推荐使用唯一递增 ID 或时间戳列排序,避免 ORDER BY A 因值重复导致不确定排序);
  • 数据规模限制:collect() 将 B 加载至 Driver 内存,仅适用于 B 列百万级以内。若 B 超大,应改用 broadcast(B_list) + UDF 中访问广播变量,或采用近似方案(如分桶后组内递归);
  • 数值稳定性:长期链式乘法可能引发浮点误差累积,生产环境建议使用 decimal 类型(需自定义 UDF 返回 DecimalType 并配合 pyspark.sql.types.DecimalType(18,6));
  • 初始值灵活性:当前以首行 A[0] 为 A₀;若需固定初始值(如 A₀ = 3740 不随数据变化),可将 col("A") 替换为 lit(3740.0)。

该方法在 Databricks Runtime 11.3+ 及 Spark 3.3+ 上验证有效,相比 pandas_udf(向量化)虽略慢,但胜在逻辑清晰、调试友好、内存可控,是平衡性能与可维护性的优选实践。