MuJoCo XLA (MJX)#
从 3.0.0 版本开始,MuJoCo 在 mjx 目录下包含了 MuJoCo XLA (MJX)。MJX 允许 MuJoCo 通过 JAX 框架在 XLA 编译器支持的计算硬件上运行。MJX 可以在 JAX 支持的所有平台上运行:Nvidia 和 AMD GPU、Apple Silicon 以及 Google Cloud TPU。
MJX API 与 MuJoCo API 中的主要仿真函数一致,但目前缺少一些功能。虽然 API 文档 适用于这两个库,但我们在下面的注释中指出了 MJX 不支持的功能。
MJX 是一个名为 mujoco-mjx 的独立软件包,发布在 PyPI 上。虽然它依赖于主要的 mujoco 软件包进行模型编译和可视化,但它是 MuJoCo 的重新实现,使用了与 MuJoCo 实现相同的算法。然而,为了更好地利用 JAX,MJX 在几个地方有意地与 MuJoCo API 有所不同,详见下文。
MJX 是 Google 的 Brax 物理和强化学习库中通用物理管线的继任者。MJX 由 MuJoCo 和 Brax 的核心贡献者共同构建,他们将继续共同支持 Brax(因其强化学习算法和内置环境)和 MJX(因其物理算法)。未来版本的 Brax 将依赖于 mujoco-mjx 软件包,而 Brax 现有的通用管线将被弃用。这一变化对 Brax 用户来说将基本是透明的。
教程笔记本#
安装#
推荐通过 PyPI 安装此软件包。
pip install mujoco-mjx
MuJoCo 库的副本已作为此软件包依赖项的一部分提供,不需要单独下载或安装。
基本用法#
安装后,可以通过 from mujoco import mjx 导入该包。结构体、函数和枚举可直接从顶层 mjx 模块中获取。
结构体#
在加速器设备上运行 MJX 函数之前,必须通过 mjx.put_model 和 mjx.put_data 函数将结构体复制到设备上。将一个 mjModel 放置在设备上会产生一个 mjx.Model。将一个 mjData 放置在设备上会产生一个 mjx.Data。
model = mujoco.MjModel.from_xml_string("...")
data = mujoco.MjData(model)
mjx_model = mjx.put_model(model)
mjx_data = mjx.put_data(model, data)
这些 MJX 变体与其 MuJoCo 对应物类似,但有几个关键区别:
mjx.Model和mjx.Data包含复制到设备上的 JAX 数组。由于某些功能在 MJX 中不受支持,
mjx.Model和mjx.Data中缺少一些字段。mjx.Model和mjx.Data中的 JAX 数组支持添加批处理维度。批处理维度是表达领域随机化(对于mjx.Model)或强化学习中的高吞吐量仿真(对于mjx.Data)的自然方式。mjx.Model和mjx.Data中的 Numpy 数组是控制 JIT 编译输出的结构性字段。修改这些数组将强制 JAX 重新编译 MJX 函数。例如,jnt_limited是一个从 mjModel 按引用传递的 numpy 数组,它决定了是否应用关节限位约束。如果jnt_limited被修改,JAX 将重新编译 MJX 函数。另一方面,jnt_range是一个可以在运行时修改的 JAX 数组,它只适用于由jnt_limited字段指定的有限位关节。
mjx.Model 和 mjx.Data 都不应手动构建。可以通过调用 mjx.make_data 创建一个 mjx.Data,这与 MuJoCo 中的 mj_makeData 函数类似。
model = mujoco.MjModel.from_xml_string("...")
mjx_model = mjx.put_model(model)
mjx_data = mjx.make_data(model)
在 vmap 内部构建批处理的 mjx.Data 结构时,使用 mjx.make_data 可能更可取。
函数#
MuJoCo 函数以 MJX 函数的形式公开,名称相同,但遵循 PEP 8 兼容的命名规范。大多数主要仿真函数和一些用于正向仿真的子组件函数可从顶层 mjx 模块中获取。
默认情况下,MJX 函数不会进行JIT 编译——我们让用户自行决定是否对 MJX 函数进行 JIT,或者对引用 MJX 函数的自定义函数进行 JIT。请参阅下面的最小示例。
枚举和常量#
MJX 枚举可通过 mjx.EnumType.ENUM_VALUE 访问,例如 mjx.JointType.FREE。对于不受支持的 MJX 功能,其枚举会从 MJX 枚举声明中省略。MJX 不声明任何常量,而是直接引用 MuJoCo 的常量。
最小示例#
# Throw a ball at 100 different velocities.
import jax
import mujoco
from mujoco import mjx
XML=r"""
<mujoco>
<worldbody>
<body>
<freejoint/>
<geom size=".15" mass="1" type="sphere"/>
</body>
</worldbody>
</mujoco>
"""
model = mujoco.MjModel.from_xml_string(XML)
mjx_model = mjx.put_model(model)
@jax.vmap
def batched_step(vel):
mjx_data = mjx.make_data(mjx_model)
qvel = mjx_data.qvel.at[0].set(vel)
mjx_data = mjx_data.replace(qvel=qvel)
pos = mjx.step(mjx_model, mjx_data).qpos[0]
return pos
vel = jax.numpy.arange(0.0, 1.0, 0.01)
pos = jax.jit(batched_step)(vel)
print(pos)
有用的命令行脚本#
我们随 mujoco-mjx 软件包提供了两个命令行脚本。
mjx-testspeed --mjcf=/PATH/TO/MJCF/ --base_path=.
此命令接收一个 MJCF 文件的路径以及可选参数(使用 --help 获取更多信息),并计算用于性能调优的有用指标。该命令将输出总仿真时间、每秒总步数和总实时因子(这里的总和是跨所有可用设备的总和)等信息。
mjx-viewer --help
此命令在 simulate 查看器中启动 MJX 模型,允许您可视化模型并与之交互。请注意,这是使用 MJX 物理(而非 C MuJoCo)进行仿真步进,因此例如对于调试求解器参数会很有帮助。
功能对等性#
MJX 支持 MuJoCo 的大部分主要仿真功能,但有少数例外。如果要求 MJX 将一个引用了不支持功能的字段值的 mjModel 复制到设备上,它将引发异常。
以下功能在 MJX 中得到完全支持:
类别 |
功能 |
|---|---|
动力学 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1、3、4、6(1 不支持与 |
|
|
|
流体模型 |
|
|
|
光源 |
光源的位置和方向 |
以下功能正在开发中,即将推出:
类别 |
功能 |
|---|---|
|
|
|
|
流体模型 |
|
除了 |
以下功能不受支持:
🔪 MJX - 尖锐之处 🔪#
GPU 和 TPU 具有独特的性能权衡,MJX 也受其影响。MJX 专门用于模拟大量并行的相同物理场景,其使用的算法可以在 SIMD 硬件上高效地进行矢量化。这种专业化对于需要海量数据吞吐量的机器学习工作负载(如强化学习)非常有用。
有些工作流程不适合使用 MJX:
- 单一场景仿真
模拟单个场景(1 个 mjData 实例)时,MJX 的速度可能比为 CPU 精心优化的 MuJoCo 慢 10 倍。当并行模拟数千或数万个场景时,MJX 的效果最佳。
- 大型网格之间的碰撞
MJX 支持凸面网格几何体之间的碰撞。然而,MJX 中的凸面碰撞算法与 MuJoCo 的实现方式不同。MJX 使用分离轴测试 (SAT) 的无分支版本来确定几何体是否与凸面网格发生碰撞,而 MuJoCo 使用 MPR 或 GJK/EPA,更多细节请参阅碰撞检测。SAT 对于较小的网格效果很好,但对于较大的网格,在运行时间和内存方面都会受到影响。
对于与凸面网格和基本体的碰撞,网格的凸分解应具有大约 200 个或更少的顶点才能获得合理的性能。对于凸-凸碰撞,凸面网格应具有大约少于 32 个顶点。我们建议在 MuJoCo 编译器中使用 maxhullvert 来实现所需的凸面网格属性。通过仔细调整,MJX 可以模拟带有网格碰撞的场景——请参阅 MJX shadow hand 配置作为示例。加速网格碰撞检测是 MJX 的一个活跃开发领域。
- 具有许多接触点的大型复杂场景
加速器对于分支代码的性能表现不佳。分支用于宽相碰撞检测,即在场景中大量物体之间识别潜在碰撞时。MJX 提供了一个简单的无分支宽相算法(见性能调优),但它不如 MuJoCo 中的强大。
为了了解这对仿真的影响,让我们考虑一个具有不断增加的人形物体数量的物理场景,从 1 个变化到 10 个。我们在 Apple M3 Max 和 64 核 AMD 3995WX 上使用 CPU MuJoCo 模拟此场景,并使用 testspeed 进行计时,使用
2 x numcore个线程。我们在 Nvidia A100 GPU 上计时 MJX 仿真,批处理大小为 8192;在 8 芯片 v5 TPU 机器上计时,批处理大小为 16384。注意纵坐标为对数刻度。对于单个(最左侧数据点)人形机器人,四种被测架构的值分别为每秒 65 万、180 万、95 万 和 270 万 步。请注意,随着我们增加人形机器人的数量(这会增加场景中潜在接触点的数量),MJX 的吞吐量下降速度比 MuJoCo 更快。
性能调优#
为了让 MJX 表现良好,应调整一些配置参数,使其偏离默认的 MuJoCo 值:
- option/iterations 和 option/ls_iterations
iterations 和 ls_iterations 属性分别控制求解器和线搜索迭代次数,应将其降低到刚好能保持仿真稳定的最低值。在强化学习中,精确的求解器力不是那么重要,因为通常会使用领域随机化向物理中添加噪声以实现模拟到现实的转换。
NEWTON求解器 在很少(通常只有一个)的求解器迭代次数下就能提供出色的收敛性,并且在 GPU 上表现良好。CG目前是 TPU 的更好选择。- contact/pair
考虑明确标记用于碰撞检测的几何体,以减少 MJX 在每一步中必须考虑的接触点数量。仅启用一个明确的有效接触点列表,可以对 MJX 的仿真性能产生显著影响。要做好这一点,通常需要了解任务——例如,OpenAI Gym Humanoid 任务在人形机器人开始摔倒时会重置,因此不需要与地面完全接触。
- maxhullvert
将 maxhullvert 设置为
64或更小,以获得更好的凸面网格碰撞性能。- option/flag/eulerdamp
禁用
eulerdamp有助于提高性能,而且通常对于稳定性而言不是必需的。有关此标志语义的详细信息,请阅读数值积分部分。- option/jacobian
明确设置为“dense”(密集)或“sparse”(稀疏)可能会根据您的设备加快仿真速度。现代 TPU 具有专门的硬件,可以快速操作稀疏矩阵,而 GPU 在处理密集矩阵方面速度更快,只要它们能装入设备内存。因此,MJX 中默认的“auto”设置的行为是:如果
nv >= 60(自由度为 60 或更多),或者如果 MJX 检测到 TPU 作为默认后端,则为稀疏;否则为“dense”。对于 TPU,将“sparse”与 Newton 求解器一起使用可以将仿真速度提高 2 倍到 3 倍。对于 GPU,选择“dense”可能会带来 10% 到 20% 的适度加速,只要密集矩阵可以装入设备内存。- 宽相碰撞检测
虽然 MuJoCo 开箱即用地处理宽相剔除,但 MJX 需要额外的参数。对于宽相的近似版本,请使用实验性的自定义数值参数
max_contact_points和max_geom_pairs。max_contact_points限制了发送到求解器的每种 condim 类型的接触点数量。max_geom_pairs限制了发送到相应碰撞函数的每种几何体类型对的总几何体对数量。例如,shadow hand 环境就利用了这些参数。
GPU 性能#
应设置以下环境变量:
XLA_FLAGS=--xla_gpu_triton_gemm_any=true这将为任何支持的 GEMM(矩阵乘法)启用基于 Triton 的 GEMM 发射器。这可以在 NVIDIA GPU 上带来 30% 的速度提升。如果您有多个 GPU,启用与GPU 间通信相关的标志也可能对您有益。
