MuJoCo XLA (MJX)#
从 3.0.0 版本开始,MuJoCo 在 mjx 目录下包含了 MuJoCo XLA (MJX)。MJX 允许 MuJoCo 通过 JAX 框架运行在 XLA 编译器支持的计算硬件上。MJX 支持 JAX 支持的所有平台:Nvidia 和 AMD GPU、Apple Silicon 以及 Google Cloud TPU。
MJX API 与 MuJoCo API 中的主要模拟函数保持一致,尽管目前缺少一些功能。虽然 API 文档 适用于这两个库,但我们在下面的 注意事项 中指出了 MJX 不支持的功能。
MJX 在 PyPI 上以名为 mujoco-mjx
的独立包分发。尽管它依赖主 mujoco
包进行模型编译和可视化,但它是 MuJoCo 的一个重新实现,使用与 MuJoCo 实现相同的算法。然而,为了充分利用 JAX,MJX 在一些地方有意地偏离了 MuJoCo API,详见下文。
MJX 是 Google Brax 物理和强化学习库中 通用物理管线 的继任者。MJX 由 MuJoCo 和 Brax 的核心贡献者共同构建,他们将一起继续支持 Brax(用于其强化学习算法和包含的环境)和 MJX(用于其物理算法)。Brax 的未来版本将依赖于 mujoco-mjx
包,并且 Brax 现有的 通用管线 将被弃用。这一变化对 Brax 用户来说基本是透明的。
教程 Notebook#
安装#
安装此包的推荐方法是通过 PyPI
pip install mujoco-mjx
MuJoCo 库的一个副本作为此包的依赖项提供,无需单独下载或安装。
基本用法#
安装后,可以通过 from mujoco import mjx
导入此包。结构体 (Structs)、函数 (functions) 和枚举 (enums) 可直接从顶层 mjx
模块获取。
结构体 (Structs)#
在加速器设备上运行 MJX 函数之前,必须通过 mjx.put_model
和 mjx.put_data
函数将结构体复制到设备上。将 mjModel 放置到设备上会得到一个 mjx.Model
。将 mjData 放置到设备上会得到一个 mjx.Data
。
model = mujoco.MjModel.from_xml_string("...")
data = mujoco.MjData(model)
mjx_model = mjx.put_model(model)
mjx_data = mjx.put_data(model, data)
这些 MJX 变体与其 MuJoCo 对应项相似,但有一些主要区别:
mjx.Model
和mjx.Data
包含复制到设备上的 JAX 数组。对于 MJX 中 不支持 的功能,
mjx.Model
和mjx.Data
中缺少一些字段。mjx.Model
和mjx.Data
中的 JAX 数组支持添加批次维度 (batch dimensions)。批次维度是表达领域随机化(针对mjx.Model
)或强化学习高吞吐量模拟(针对mjx.Data
)的自然方式。mjx.Model
和mjx.Data
中的 Numpy 数组是控制 JIT 编译输出的结构字段。修改这些数组将强制 JAX 重新编译 MJX 函数。例如,jnt_limited
是从 mjModel 引用传递的 numpy 数组,它决定是否应用关节限制约束。如果修改了jnt_limited
,JAX 将重新编译 MJX 函数。另一方面,jnt_range
是一个可以在运行时修改的 JAX 数组,它只会应用于jnt_limited
字段指定的有限制关节。
mjx.Model
和 mjx.Data
都不应手动构造。可以通过调用 mjx.make_data
创建 mjx.Data
,这与 MuJoCo 中的 mj_makeData 函数类似。
model = mujoco.MjModel.from_xml_string("...")
mjx_model = mjx.put_model(model)
mjx_data = mjx.make_data(model)
在 vmap
内部构造批次化 mjx.Data
结构时,使用 mjx.make_data
可能更可取。
函数 (Functions)#
MuJoCo 函数以同名的 MJX 函数暴露,但遵循 PEP 8 兼容命名。大部分 主要模拟 函数以及一些用于正向模拟的 子组件 可从顶层 mjx
模块获取。
MJX 函数默认不进行 JIT 编译——我们将 JIT MJX 函数或 JIT 引用 MJX 函数的用户自己的函数的任务留给用户。参见下面的 最小示例。
枚举 (Enums) 和常量 (constants)#
MJX 枚举可用 mjx.EnumType.ENUM_VALUE
访问,例如 mjx.JointType.FREE
。MJX 枚举声明中省略了 MJX 不支持功能的枚举。MJX 不声明常量,而是直接引用 MuJoCo 常量。
最小示例#
# Throw a ball at 100 different velocities.
import jax
import mujoco
from mujoco import mjx
XML=r"""
<mujoco>
<worldbody>
<body>
<freejoint/>
<geom size=".15" mass="1" type="sphere"/>
</body>
</worldbody>
</mujoco>
"""
model = mujoco.MjModel.from_xml_string(XML)
mjx_model = mjx.put_model(model)
@jax.vmap
def batched_step(vel):
mjx_data = mjx.make_data(mjx_model)
qvel = mjx_data.qvel.at[0].set(vel)
mjx_data = mjx_data.replace(qvel=qvel)
pos = mjx.step(mjx_model, mjx_data).qpos[0]
return pos
vel = jax.numpy.arange(0.0, 1.0, 0.01)
pos = jax.jit(batched_step)(vel)
print(pos)
有用的命令行脚本#
我们在 mujoco-mjx
包中提供了两个命令行脚本:
mjx-testspeed --mjcf=/PATH/TO/MJCF/ --base_path=.
此命令接收 MJCF 文件的路径以及可选参数(使用 --help
获取更多信息),并计算有助于性能调优的指标。该命令将输出总模拟时间、总步数每秒以及总实时因子等信息(这里的“总”是指所有可用设备的总计)。
mjx-viewer --help
此命令在模拟查看器中启动 MJX 模型,允许您可视化模型并与其交互。请注意,这使用 MJX 物理(而非 C MuJoCo)进行模拟步进,因此例如在调试求解器参数时非常有用。
功能对等性#
MJX 支持 MuJoCo 的大部分主要模拟功能,但有少数例外。如果要求 MJX 将包含引用不支持功能字段值的 mjModel 复制到设备上,MJX 将抛出异常。
MJX 完全支持以下功能:
类别 |
功能 |
---|---|
动力学 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1, 3, 4, 6 ( |
|
|
|
动力学 |
|
流体模型 |
|
|
以下功能正在开发中,即将推出:
类别 |
功能 |
---|---|
|
|
|
|
流体模型 |
|
除 |
|
光源 |
光源的位置和方向 |
以下功能不支持:
类别 |
功能 |
---|---|
未实现与 |
|
|
|
|
|
|
|
|
|
|
|
|
🔪 MJX - 注意事项 🔪#
GPU 和 TPU 具有 MJX 所受到的独特性能权衡。MJX 专门用于使用可在 SIMD 硬件 上高效矢量化的算法模拟大量并行相同的物理场景。这种专业化对于需要海量数据吞吐量的机器学习工作负载(例如 强化学习)非常有用。
MJX 不太适合某些工作流程:
- 单场景模拟
模拟单个场景(1 个 mjData 实例)时,MJX 可能比 MuJoCo 慢 10 倍,因为 MuJoCo 针对 CPU 进行了精心优化。MJX 在并行模拟数千或数万个场景时表现最佳。
- 大型网格之间的碰撞
MJX 支持凸网格几何体之间的碰撞。然而,MJX 中的凸碰撞算法实现方式与 MuJoCo 不同。MJX 使用 分离轴测试 (Separating Axis Test, SAT) 的无分支版本来确定几何体是否与凸网格发生碰撞,而 MuJoCo 使用 MPR 或 GJK/EPA,有关详细信息,请参阅 碰撞检测。SAT 对于小型网格效果很好,但对于大型网格来说在运行时和内存方面表现不佳。
对于与凸网格和基本体的碰撞,为了获得合理的性能,网格的凸分解应包含大约 200 个或更少的顶点。对于凸-凸碰撞,凸网格应包含大约 少于 32 个顶点。我们建议在 MuJoCo 编译器中使用 maxhullvert 来获得所需的凸网格属性。经过仔细调优,MJX 可以模拟包含网格碰撞的场景——有关示例,请参见 MJX 的 shadow hand 配置。加速网格碰撞检测是 MJX 当前活跃的开发领域。
- 具有许多接触点的大型复杂场景
加速器在处理 分支代码 时表现不佳。在宽阶段碰撞检测中会使用分支,即识别场景中大量物体之间的潜在碰撞。MJX 附带了一个简单的无分支宽阶段算法(参见性能调优),但它的功能不如 MuJoCo 中的强大。
为了了解这对模拟的影响,我们考虑一个包含不同数量人形身体(从 1 到 10)的物理场景。我们在 Apple M3 Max 和 64 核 AMD 3995WX 上使用 CPU MuJoCo 模拟该场景,并使用 testspeed 进行计时,使用
2 x numcore
线程。我们在批量大小为 8192 的 Nvidia A100 GPU 和批量大小为 16384 的 8 芯片 v5 TPU 机器上对 MJX 模拟进行计时。注意垂直刻度是对数刻度。四种计时架构中单个人形(最左侧数据点)的值分别为每秒 65 万、180 万、95 万 和 270 万 步。请注意,随着人形数量的增加(这会增加场景中潜在接触点的数量),MJX 的吞吐量下降速度比 MuJoCo 快。
性能调优#
为了让 MJX 表现良好,应根据默认的 MuJoCo 值调整一些配置参数:
- option/iterations 和 option/ls_iterations
iterations 和 ls_iterations 属性分别控制求解器迭代和线搜索迭代次数,应将其降低到足以使模拟保持稳定的程度。在强化学习中,准确的求解器力并不那么重要,因为通常会使用领域随机化为物理添加噪声以实现 sim-to-real。
NEWTON
求解器 (Solver) 具有非常少的(通常只有一次)求解器迭代次数,并且在 GPU 上表现良好。CG
目前是 TPU 的更好选择。- contact/pair
考虑显式标记用于碰撞检测的几何体,以减少 MJX 在每一步中必须考虑的接触点数量。仅启用有效接触点的显式列表可以显着提高 MJX 的模拟性能。要做好这一点,通常需要对任务有所了解——例如,OpenAI Gym 人形 任务在人形开始跌倒时会重置,因此不需要与地面完全接触。
- maxhullvert
将 maxhullvert 设置为 64 或更小,以获得更好的凸网格碰撞性能。
- option/flag/eulerdamp
禁用
eulerdamp
可以提高性能,并且通常为了稳定性不需要它。阅读 数值积分 部分了解此标志的语义详情。- option/jacobian
明确设置“dense”或“sparse”可能会根据您的设备加快模拟速度。现代 TPU 拥有专门的硬件用于快速处理稀疏矩阵,而 GPU 在处理密集矩阵方面往往更快,只要它们能够装载到设备上。因此,在 MJX 中,“auto”默认设置的行为是:如果
nv >= 60
(60 个或更多自由度),或者 MJX 检测到 TPU 作为默认后端,则为“sparse”,否则为“dense”。对于 TPU,将“sparse”与牛顿求解器一起使用可以将模拟速度提高 2 到 3 倍。对于 GPU,选择“dense”可能会带来更适度的 10% 到 20% 的加速,只要密集矩阵能够装载到设备上即可。- Broadphase (宽阶段)
虽然 MuJoCo 直接处理宽阶段剔除,但 MJX 需要额外的参数。对于宽阶段的近似版本,请使用实验性自定义数字参数
max_contact_points
和max_geom_pairs
。max_contact_points
限制发送给求解器处理每种 condim 类型的接触点数量。max_geom_pairs
限制发送给每个几何体类型对的相应碰撞函数的几何体对总数。例如,shadow hand 环境就使用了这些参数。
GPU 性能#
应设置以下环境变量:
XLA_FLAGS=--xla_gpu_triton_gemm_any=true
这会启用基于 Triton 的 GEMM (matmul) 发射器,用于它支持的任何 GEMM。这可以在 NVIDIA GPU 上带来 30% 的加速。如果您有多个 GPU,您还可以通过启用与 GPU 间通信 相关的标志而获益。