论文解读:Stellaris — 基于 Serverless 的 Staleness-Aware 分布式强化学习
解读 SC 2024 论文,Stellaris 通过 Staleness-Aware 策略结合 Serverless 计算,优化分布式强化学习的训练效率。
论文: Stellaris: Staleness-Aware Distributed Reinforcement Learning with Serverless Computing (SC 2024, Hanfei Yu et al.) 机构: Stevens Institute of Technology, Northeastern University, Stony Brook University, Missouri S&T 代码: https://github.com/IntelliSys-Lab/Stellaris-SC24
1. 一句话总结
Stellaris 是第一个将 异步多 Learner 与 Serverless 计算 结合起来的通用分布式深度强化学习 (DRL) 训练范式 -- 它通过 importance sampling 截断、staleness-aware 梯度聚合和按需 Learner 编排三大核心技术, 在 AWS EC2 与 HPC 集群上实现了最高 2.2 倍的最终奖励提升和 41% 的训练成本降低.
2. 研究动机与要解决的问题
2.1 分布式 DRL 的现状痛点
分布式 DRL 普遍采用 Actor-Learner 架构: Actor 与环境交互收集轨迹, Learner 计算梯度更新策略. 现有方案 (RLlib, MSRL, SEED RL, IMPALA 等) 几乎都依赖 同步 Learner + serverful 基础设施, 存在以下问题:
- 资源利用率低: VM/物理机长期占用, idle 资源浪费严重;
- 扩展性差: Learner/Actor 数量固定, 无法动态适配训练负载;
- 成本高: 无法按需付费.
2.2 为什么 Serverless + 异步 Learner 是好的方向?
Serverless 天然支持:
- 弹性伸缩: 按需启停函数实例, 资源利用率高;
- 按需付费: 只为实际计算时间付费, 成本更低;
- 事件驱动: 天然适合异步学习范式.
论文通过 PPO + Hopper 的实验 (Fig. 2) 直观展示: 同时启用异步学习和 serverless 后, 训练速度更快、成本更低.
2.3 Serverless 异步 DRL 的三大挑战
但是直接把 serverful 多 Learner 或异步学习搬到 serverless 上会遇到三个新问题:
| 挑战 | 具体描述 |
|---|---|
| (1) Dynamic Learner Orchestration | Actor 数量动态变化导致 Learner 需按需伸缩, 传统固定 Learner 数方案无法兼顾 GPU 利用率与训练速度 (Fig. 3a) |
| (2) Dynamic Staleness | 异步 Learner 各自持有不同版本的策略, 产生的梯度存在 staleness, 且 staleness 分布随 Learner 数量增多而加剧 (Fig. 3b) |
| (3) Unstable Policy Updates | 多个 Learner 各自用本地策略计算 importance sampling ratio, 但在聚合时缺乏全局视角, 导致跨 Learner 策略漂移 (cross-learner policy drift), KL 散度剧烈波动 (Fig. 3c) |
3. 系统架构总览
Stellaris 包含 四大组件, 如论文 Fig. 4 所示:
┌─────────────────────────────────────────────────┐
│ Distributed Cache (Redis) │
│ 存储: 轨迹数据, 梯度, 策略模型权重 │
└──────┬──────────────┬──────────────┬─────────────┘
│ │ │
┌────▼────┐ ┌─────▼──────┐ ┌───▼────┐
│ Actors │ │ Learner │ │Parameter│
│(serverful│ │ Functions │ │Function │
│/serverless)│ │(serverless)│ │(serverless)│
└─────────┘ └────────────┘ └──────────┘
- Parameter Function (参数函数): 控制 staleness, 聚合梯度, 更新全局策略;
- Learner Functions (学习器函数): 按需启停的 serverless 函数, 从 Cache 拉取轨迹并计算梯度;
- Actors (执行器): 与 RL 环境交互, 采集轨迹 (支持 serverful 和 serverless 两种模式);
- Distributed Cache (Redis): 内存键值存储, 作为各组件间的数据中枢.
工作流程 (三步循环)
Step 1 -- Importance Sampling 驱动的轨迹采集: Actor 从 Cache 拉取最新策略 -> 与环境交互 -> 轨迹提交回 Cache.
Step 2 -- 按需梯度计算: 当 Cache 中出现新 batch 时, Stellaris 按需启动一批 Learner 函数并发计算梯度, 完成后提交回 Cache.
Step 3 -- Staleness-aware 梯度聚合: Parameter Function 监控 Cache 中的梯度, 评估 staleness, 当 staleness 低于阈值时才执行聚合和策略更新.
4. 核心技术详解
4.1 Importance Sampling Truncated Trajectory Processing (IS 截断轨迹处理)
问题本质: 在同步 Learner 设定下, 所有 Learner 共享同一策略 $\pi_\theta$, 只需限制 learner-actor 之间的 importance sampling ratio $|\frac{\pi_\theta}{\mu_\theta}|$ 即可. 但在异步设定下, 每个 Learner 持有不同策略 $\pi_{\theta_1} \neq \pi_{\theta_2} \neq \ldots \neq \pi_{\theta_n}$, 即使每个 Learner 本地 clip 了自己的 ratio, 聚合时跨 Learner 的 ratio 集合 ${\frac{\pi_{\theta_1}}{\mu_\theta}, \ldots, \frac{\pi_{\theta_n}}{\mu_\theta}}$ 仍然可能爆炸 (Fig. 5a).
解决方案 -- 全局 IS 截断:
$$R' := \min\left(\left|\min_i\left(\frac{\pi_{\theta_i}}{\mu_\theta}\right)\right|, \rho\right), \quad i \in {1, \ldots, n}$$
其中 $\rho$ 是 clip 阈值 (实验中设为 1.0). 这个截断取的是所有 Learner 策略与 Actor 策略之比的最小值, 再与 $\rho$ 取 min, 从而:
- 当任意一个跨 Learner ratio 偏离过大时, 全局 ratio 会被拉回;
- 有效防止 cross-learner policy drift.
对应的策略梯度变为:
$$\nabla J(\pi_\theta) = \mathbb{E}t\left[\mathbb{E}{\tau_t \sim \mu_\theta}\left[\min\left(\left|\min_i\left(\frac{\pi_{\theta_i}}{\mu_\theta}\right)\right|, \rho\right) A_t\right]\right]$$
论文证明了这个截断保持了单调奖励提升的下界 (Theorem 2).
4.2 Staleness-Aware Gradient Aggregation (staleness 感知梯度聚合)
问题本质: 传统异步方法 (如 Softsync, SSP) 假设 worker 数量固定, 用静态 staleness bound. 但 serverless Learner 数量是动态的, staleness 分布随之变化.
Stellaris 的动态阈值方案:
$$\beta_k := \delta_{\max} \times d^k, \quad d \in (0, 1]$$
- $\delta_{\max}$: 第一轮观察到的最大 staleness;
- $d$: 衰减因子 (实验中设为 0.96);
- $k$: 当前训练轮次.
直觉: 训练初期放宽阈值 (允许更多异步梯度参与聚合, 加速探索); 训练后期收紧阈值 (减少 stale 梯度, 保证收敛质量).
当梯度入队时, 检查队列的平均 staleness $\delta$ 是否低于 $\beta_k$:
- 若 $\delta < \beta_k$: 立即聚合;
- 若 $\delta \geq \beta_k$: 延迟等待更新鲜的梯度.
自适应学习率: 同时根据梯度的 staleness 调整学习率:
$$\alpha_c := \frac{\alpha_0}{\sqrt{\delta_c}}, \quad \text{if } \delta_c > 0$$
staleness 越大的梯度, 学习率越小, 减少其对策略更新的影响.
最终聚合公式:
$$g_c := \frac{1}{H_c} \sum_{i=1}^{H_c} \frac{\alpha_0}{\sqrt{\delta_j}} g_{i,j}, \quad \theta_{c+1} := \theta_c - g_c$$
4.3 On-Demand Gradient Calculation (按需梯度计算)
Stellaris 解决了三个 serverless Learner 的工程效率问题:
(1) GPU Data Loader: 一个轻量后台程序, 从 Redis 预取轨迹并加载到 GPU 显存, 解耦数据加载和梯度计算, 类似 serverless function pre-warming 思想.
(2) Learner Schemes: 当 Data Loader 发现新轨迹就绪时, 立即启动 Learner 函数, 传入策略权重指针和轨迹指针. Learner 计算完梯度后提交回 Cache 即终止 -- 完全事件驱动.
(3) Hierarchical Data Passing: 三级通信机制:
- Shared Memory: 同一物理机上的 Learner 间交换梯度和轨迹;
- RPC: 跨机器 Learner 远程通信;
- Distributed Cache (Redis): 持久化存储, 供所有组件访问.
5. 理论保证
5.1 收敛性 (Theorem 1)
使用 SGD 优化器, 在标准假设 (无偏梯度、有界方差、Lipschitz 平滑) 下:
$$\frac{1}{T}\sum_{t=1}^{T} \mathbb{E}(|\nabla J(\theta_t)|^2) \leq 2\sqrt{\frac{2C_1 C_2}{Tb}}$$
其中 $T$ 是策略更新步数, $b$ 是 mini-batch 大小. 收敛率为 $O(1/\sqrt{Tb})$, 与标准 SGD 一致 -- 说明 Stellaris 的 staleness-aware 聚合没有损害收敛速度, 且通过增大 $b$ 可获得近线性加速.
5.2 奖励提升下界 (Theorem 2)
对于任意 Learner 策略 $\pi_i$ 和 Actor 策略 $\mu$, 在 IS 截断下:
$$J(\pi_i) - J(\mu) \geq -\frac{\gamma \epsilon^{\pi_i} \sqrt{2\log\rho}}{(1-\gamma)^2}$$
其中 $\epsilon^{\pi_i} = \max_s |\mathbb{E}_{a \sim \pi}[A^\mu]|$, $\gamma$ 是折扣因子, $\rho$ 是截断阈值. 该下界保证了训练过程中的单调性能提升.
6. 实现细节
- 语言: Python, 约 5000 行代码;
- Serverless 实现: 在 AWS EC2 GPU 实例上, 使用 Docker 容器 + NVIDIA Container Runtime 实现自建 serverless 平台 (因主流 serverless 平台尚不支持 GPU);
- DRL 框架: 核心逻辑基于 PyTorch, 策略网络、staleness-aware 聚合、IS 截断均在其中实现;
- Cache: Redis 内存键值存储, Actor 用 Pickle 序列化轨迹上传, Learner/Parameter Function 从中读写;
- 兼容性: 已集成到 RLlib (替换默认 learner group) 和 MinionsRL (替换同步 learner 为异步 serverless learner);
- 算法支持: On-policy (PPO) 和 Off-policy (IMPACT) 均可, Actor 支持 serverful 和 serverless.
7. 实验评估
7.1 实验设置
| 项目 | 配置 |
|---|---|
| AWS EC2 集群 | 2x p3.2xlarge + 1x c6a.32xlarge, 共 2 块 V100 GPU, 128 AMD EPYC CPU |
| HPC 集群 | 2x p3.16xlarge (16 V100 GPU) + 5x hpc7a.96xlarge (960 CPU cores), Singularity 容器 |
| 环境 | 6 个 OpenAI Gym: MuJoCo (Hopper, Humanoid, Walker2d) + Atari (SpaceInvaders, Qbert, Gravitar) |
| 算法 | PPO (on-policy), IMPACT (off-policy) |
| Stellaris 参数 | $d=0.96$, $v=3$, $\rho=1.0$, 每 V100 最多 4 个 Learner, 每 CPU core 1 个 Actor |
7.2 主要结果
与原始算法对比 (PPO, IMPACT):
- Stellaris + PPO: 最终奖励提升最高 2.2x (Fig. 6);
- Stellaris + IMPACT: 最终奖励提升最高 1.3x (Fig. 7);
- 训练成本降低: PPO 最高 31%, IMPACT 最高 30% (Fig. 8).
与 DRL 框架对比 (RLlib, MinionsRL):
- Stellaris + RLlib: 最终奖励提升最高 1.3x, 成本降低最高 38% (Fig. 9);
- Stellaris + MinionsRL: 最终奖励提升最高 1.6x, 成本降低最高 41% (Fig. 10).
HPC 集群 (PAR-RL):
- Hopper: 最终奖励提升 2.4x, 成本降低 19%;
- Qbert: 最终奖励提升 1.1x, 成本降低 34% (Fig. 12).
7.3 Ablation Study
(a) 梯度聚合策略对比 (Fig. 11a): Stellaris vs Softsync vs SSP vs Pure Async:
- Pure Async 训练最快但收敛差;
- Stellaris 的动态阈值策略在累积奖励上全面领先.
(b) IS 截断的效果 (Fig. 11b):
- 去掉 IS 截断后训练出现明显的性能振荡和不稳定;
- 证实跨 Learner 的全局 IS 截断是稳定训练的关键.
7.4 参数敏感性 (Fig. 13)
| 参数 | 最优值 | 趋势 |
|---|---|---|
| Decay factor $d$ | 0.96 | 增大 $d$ -> 更高奖励但更高成本, 0.96 后奖励饱和 |
| LR smoothness $v$ | 3 | 太大导致学习率对 staleness 不敏感, 太小则过度抑制 |
| IS threshold $\rho$ | 1.0 | 太大允许策略偏离, 太小过于保守 |
7.5 延迟开销 (Fig. 14)
Stellaris 各组件的额外开销不超过单轮训练延迟的 5%, 证明其设计是轻量高效的.
8. 与现有工作的对比
| 框架 | 异步 Learner | 可扩展 Actor | On/Off-policy | Serverless |
|---|---|---|---|---|
| RLlib | x | x | Yes | x |
| MSRL | x | x | Yes | x |
| SEED RL | x | x | Yes | x |
| SRL | x | x | Yes | x |
| PQL | x | x | x | x |
| MinionsRL | x | Yes | x | Yes |
| Stellaris | Yes | Yes | Yes | Yes |
Stellaris 是唯一同时支持以上四项特性的框架. 与最相关的 MinionsRL 相比:
- MinionsRL 只有 serverless Actor + 同步单 Learner, 只支持 on-policy;
- Stellaris 支持异步多 Learner + serverful/serverless Actor + on/off-policy.
9. 个人点评
优势:
- 问题定义清晰: 三个挑战 (动态 Learner 编排、动态 staleness、不稳定策略更新) 层层递进, 解决方案一一对应;
- 理论+系统双重贡献: 不仅有收敛性和奖励下界的理论证明, 还有完整的系统实现和工程优化 (GPU data loader, hierarchical data passing);
- 全局 IS 截断思路巧妙: 通过取所有 Learner ratio 的最小值来做全局截断, 想法简单但非常有效;
- 广泛兼容: 能无缝集成到 RLlib 和 MinionsRL, 不需要大改已有框架;
- 实验充分: 6 个环境、4 个 baseline 框架、2 类算法、EC2 + HPC 两种平台.
可以进一步思考的方向:
- 当前 serverless GPU 平台 (如 AWS Lambda 尚不支持 GPU) 导致作者需要自建容器化 serverless 环境, 未来主流云平台支持 GPU serverless 后, Stellaris 的价值会更大;
- 论文主要测试了 PPO 和 IMPACT, 对更复杂的算法 (如 SAC, TD3, multi-agent RL) 的适用性值得探索;
- 动态 Learner 数量的调度策略目前依赖简单的事件驱动, 未来可以结合负载预测做更智能的 auto-scaling;
- Staleness 阈值的衰减因子 $d$ 是手动设定的超参, 有没有可能根据训练状态自适应调整.
10. 关键公式速查
| 公式 | 含义 |
|---|---|
| $R' = \min(|\min_i(\pi_{\theta_i}/\mu_\theta)|, \rho)$ | 全局 IS 截断, 防止跨 Learner 策略漂移 |
| $\beta_k = \delta_{\max} \cdot d^k$ | 动态 staleness 阈值, 随训练衰减 |
| $\alpha_c = \alpha_0 / \sqrt{\delta_c}$ | staleness 自适应学习率 |
| $g_c = \frac{1}{H_c}\sum_{i=1}^{H_c}\frac{\alpha_0}{\sqrt{\delta_j}}g_{i,j}$ | 加权梯度聚合 |
| $O(1/\sqrt{Tb})$ | 收敛率, 与标准 SGD 一致 |
| $J(\pi_i)-J(\mu) \geq -\frac{\gamma\epsilon^{\pi_i}\sqrt{2\log\rho}}{(1-\gamma)^2}$ | 奖励提升下界 |