MuJoCo XLA (MJX)#

从 3.0.0 版本开始,MuJoCo 在 mjx 目录下包含了 MuJoCo XLA (MJX)。MJX 允许 MuJoCo 通过 JAX 框架运行在 XLA 编译器支持的计算硬件上。MJX 支持 JAX 支持的所有平台:Nvidia 和 AMD GPU、Apple Silicon 以及 Google Cloud TPU

MJX API 与 MuJoCo API 中的主要模拟函数保持一致,尽管目前缺少一些功能。虽然 API 文档 适用于这两个库,但我们在下面的 注意事项 中指出了 MJX 不支持的功能。

MJX 在 PyPI 上以名为 mujoco-mjx 的独立包分发。尽管它依赖主 mujoco 包进行模型编译和可视化,但它是 MuJoCo 的一个重新实现,使用与 MuJoCo 实现相同的算法。然而,为了充分利用 JAX,MJX 在一些地方有意地偏离了 MuJoCo API,详见下文。

MJX 是 Google Brax 物理和强化学习库中 通用物理管线 的继任者。MJX 由 MuJoCo 和 Brax 的核心贡献者共同构建,他们将一起继续支持 Brax(用于其强化学习算法和包含的环境)和 MJX(用于其物理算法)。Brax 的未来版本将依赖于 mujoco-mjx 包,并且 Brax 现有的 通用管线 将被弃用。这一变化对 Brax 用户来说基本是透明的。

教程 Notebook#

以下 IPython notebook 演示了如何使用 MJX 和强化学习来训练人形和四足机器人运动:colab

安装#

安装此包的推荐方法是通过 PyPI

pip install mujoco-mjx

MuJoCo 库的一个副本作为此包的依赖项提供,无需单独下载或安装。

基本用法#

安装后,可以通过 from mujoco import mjx 导入此包。结构体 (Structs)、函数 (functions) 和枚举 (enums) 可直接从顶层 mjx 模块获取。

结构体 (Structs)#

在加速器设备上运行 MJX 函数之前,必须通过 mjx.put_modelmjx.put_data 函数将结构体复制到设备上。将 mjModel 放置到设备上会得到一个 mjx.Model。将 mjData 放置到设备上会得到一个 mjx.Data

model = mujoco.MjModel.from_xml_string("...")
data = mujoco.MjData(model)
mjx_model = mjx.put_model(model)
mjx_data = mjx.put_data(model, data)

这些 MJX 变体与其 MuJoCo 对应项相似,但有一些主要区别:

  1. mjx.Modelmjx.Data 包含复制到设备上的 JAX 数组。

  2. 对于 MJX 中 不支持 的功能,mjx.Modelmjx.Data 中缺少一些字段。

  3. mjx.Modelmjx.Data 中的 JAX 数组支持添加批次维度 (batch dimensions)。批次维度是表达领域随机化(针对 mjx.Model)或强化学习高吞吐量模拟(针对 mjx.Data)的自然方式。

  4. mjx.Modelmjx.Data 中的 Numpy 数组是控制 JIT 编译输出的结构字段。修改这些数组将强制 JAX 重新编译 MJX 函数。例如,jnt_limited 是从 mjModel 引用传递的 numpy 数组,它决定是否应用关节限制约束。如果修改了 jnt_limited,JAX 将重新编译 MJX 函数。另一方面,jnt_range 是一个可以在运行时修改的 JAX 数组,它只会应用于 jnt_limited 字段指定的有限制关节。

mjx.Modelmjx.Data 都不应手动构造。可以通过调用 mjx.make_data 创建 mjx.Data,这与 MuJoCo 中的 mj_makeData 函数类似。

model = mujoco.MjModel.from_xml_string("...")
mjx_model = mjx.put_model(model)
mjx_data = mjx.make_data(model)

vmap 内部构造批次化 mjx.Data 结构时,使用 mjx.make_data 可能更可取。

函数 (Functions)#

MuJoCo 函数以同名的 MJX 函数暴露,但遵循 PEP 8 兼容命名。大部分 主要模拟 函数以及一些用于正向模拟的 子组件 可从顶层 mjx 模块获取。

MJX 函数默认不进行 JIT 编译——我们将 JIT MJX 函数或 JIT 引用 MJX 函数的用户自己的函数的任务留给用户。参见下面的 最小示例

枚举 (Enums) 和常量 (constants)#

MJX 枚举可用 mjx.EnumType.ENUM_VALUE 访问,例如 mjx.JointType.FREE。MJX 枚举声明中省略了 MJX 不支持功能的枚举。MJX 不声明常量,而是直接引用 MuJoCo 常量。

最小示例#

# Throw a ball at 100 different velocities.

import jax
import mujoco
from mujoco import mjx

XML=r"""
<mujoco>
  <worldbody>
    <body>
      <freejoint/>
      <geom size=".15" mass="1" type="sphere"/>
    </body>
  </worldbody>
</mujoco>
"""

model = mujoco.MjModel.from_xml_string(XML)
mjx_model = mjx.put_model(model)

@jax.vmap
def batched_step(vel):
  mjx_data = mjx.make_data(mjx_model)
  qvel = mjx_data.qvel.at[0].set(vel)
  mjx_data = mjx_data.replace(qvel=qvel)
  pos = mjx.step(mjx_model, mjx_data).qpos[0]
  return pos

vel = jax.numpy.arange(0.0, 1.0, 0.01)
pos = jax.jit(batched_step)(vel)
print(pos)

有用的命令行脚本#

我们在 mujoco-mjx 包中提供了两个命令行脚本:

mjx-testspeed --mjcf=/PATH/TO/MJCF/ --base_path=.

此命令接收 MJCF 文件的路径以及可选参数(使用 --help 获取更多信息),并计算有助于性能调优的指标。该命令将输出总模拟时间、总步数每秒以及总实时因子等信息(这里的“总”是指所有可用设备的总计)。

mjx-viewer --help

此命令在模拟查看器中启动 MJX 模型,允许您可视化模型并与其交互。请注意,这使用 MJX 物理(而非 C MuJoCo)进行模拟步进,因此例如在调试求解器参数时非常有用。

功能对等性#

MJX 支持 MuJoCo 的大部分主要模拟功能,但有少数例外。如果要求 MJX 将包含引用不支持功能字段值的 mjModel 复制到设备上,MJX 将抛出异常。

MJX 完全支持以下功能:

类别

功能

动力学

正向

关节

FREE, BALL, SLIDE, HINGE

传动

JOINT, JOINTINPARENT, SITE, TENDON

致动器动力学

NONE, INTEGRATOR, FILTER, FILTEREXACT, MUSCLE

致动器增益

FIXED, AFFINE, MUSCLE

致动器偏置

NONE, AFFINE, MUSCLE

肌腱缠绕

JOINT, SITE, PULLEY, SPHERE, CYLINDER

几何体 (Geom)

PLANE, HFIELD, SPHERE, CAPSULE, BOX, MESH 已完全实现。ELLIPSOIDCYLINDER 已实现,但仅与其他基本体碰撞,请注意 BOX 被实现为网格。

约束

EQUALITY, LIMIT_JOINT, CONTACT_FRICTIONLESS, CONTACT_PYRAMIDAL, CONTACT_ELLIPTIC, FRICTION_DOF, FRICTION_TENDON

等式约束 (Equality)

CONNECT, WELD, JOINT, TENDON

积分器

EULER, RK4, IMPLICITFAST (流体阻力 不支持 IMPLICITFAST)

锥形 (Cone)

PYRAMIDAL, ELLIPTIC

Condim

1, 3, 4, 6 (ELLIPTIC 不支持 1)

求解器 (Solver)

CG, NEWTON

动力学

逆向

流体模型

惯性模型

肌腱

固定 (Fixed), 空间 (Spatial)

传感器 (Sensors)

MAGNETOMETER, CAMPROJECTION, RANGEFINDER, JOINTPOS, TENDONPOS, ACTUATORPOS, BALLQUAT, FRAMEPOS, FRAMEXAXIS, FRAMEYAXIS, FRAMEZAXIS, FRAMEQUAT, SUBTREECOM, CLOCK, VELOCIMETER, GYRO, JOINTVEL, TENDONVEL, ACTUATORVEL, BALLANGVEL, FRAMELINVEL, FRAMEANGVEL, SUBTREELINVEL, SUBTREEANGMOM, TOUCH, ACCELEROMETER, FORCE, TORQUE, ACTUATORFRC, JOINTACTFRC, TENDONACTFRC, FRAMELINACC, FRAMEANGACC (连接或焊接等式约束不支持 ACCELEROMETER, FORCE, TORQUE)

以下功能正在开发中,即将推出:

类别

功能

几何体 (Geom)

SDF。(SPHERE, BOX, MESH, HFIELD)与 CYLINDER 之间的碰撞。(BOX, MESH, HFIELD)与 ELLIPSOID 之间的碰撞。

积分器

IMPLICIT

流体模型

椭球体模型

传感器 (Sensors)

PLUGIN, USER 外的所有功能

光源

光源的位置和方向

以下功能不支持

类别

功能

margingap

未实现与 Mesh 几何体 (Geom) 的碰撞。

传动

SLIDERCRANK, BODY

致动器动力学

USER

致动器增益

USER

致动器偏置

USER

求解器 (Solver)

PGS

传感器 (Sensors)

PLUGIN, USER

🔪 MJX - 注意事项 🔪#

GPU 和 TPU 具有 MJX 所受到的独特性能权衡。MJX 专门用于使用可在 SIMD 硬件 上高效矢量化的算法模拟大量并行相同的物理场景。这种专业化对于需要海量数据吞吐量的机器学习工作负载(例如 强化学习)非常有用。

MJX 不太适合某些工作流程:

单场景模拟

模拟单个场景(1 个 mjData 实例)时,MJX 可能比 MuJoCo 慢 10 倍,因为 MuJoCo 针对 CPU 进行了精心优化。MJX 在并行模拟数千或数万个场景时表现最佳。

大型网格之间的碰撞

MJX 支持凸网格几何体之间的碰撞。然而,MJX 中的凸碰撞算法实现方式与 MuJoCo 不同。MJX 使用 分离轴测试 (Separating Axis Test, SAT) 的无分支版本来确定几何体是否与凸网格发生碰撞,而 MuJoCo 使用 MPR 或 GJK/EPA,有关详细信息,请参阅 碰撞检测。SAT 对于小型网格效果很好,但对于大型网格来说在运行时和内存方面表现不佳。

对于与凸网格和基本体的碰撞,为了获得合理的性能,网格的凸分解应包含大约 200 个或更少的顶点。对于凸-凸碰撞,凸网格应包含大约 少于 32 个顶点。我们建议在 MuJoCo 编译器中使用 maxhullvert 来获得所需的凸网格属性。经过仔细调优,MJX 可以模拟包含网格碰撞的场景——有关示例,请参见 MJX 的 shadow hand 配置。加速网格碰撞检测是 MJX 当前活跃的开发领域。

具有许多接触点的大型复杂场景

加速器在处理 分支代码 时表现不佳。在宽阶段碰撞检测中会使用分支,即识别场景中大量物体之间的潜在碰撞。MJX 附带了一个简单的无分支宽阶段算法(参见性能调优),但它的功能不如 MuJoCo 中的强大。

为了了解这对模拟的影响,我们考虑一个包含不同数量人形身体(从 1 到 10)的物理场景。我们在 Apple M3 Max 和 64 核 AMD 3995WX 上使用 CPU MuJoCo 模拟该场景,并使用 testspeed 进行计时,使用 2 x numcore 线程。我们在批量大小为 8192 的 Nvidia A100 GPU 和批量大小为 16384 的 8 芯片 v5 TPU 机器上对 MJX 模拟进行计时。注意垂直刻度是对数刻度。

_images/SPS.svg

四种计时架构中单个人形(最左侧数据点)的值分别为每秒 65 万180 万95 万270 万 步。请注意,随着人形数量的增加(这会增加场景中潜在接触点的数量),MJX 的吞吐量下降速度比 MuJoCo 快。

性能调优#

为了让 MJX 表现良好,应根据默认的 MuJoCo 值调整一些配置参数:

option/iterationsoption/ls_iterations

iterationsls_iterations 属性分别控制求解器迭代和线搜索迭代次数,应将其降低到足以使模拟保持稳定的程度。在强化学习中,准确的求解器力并不那么重要,因为通常会使用领域随机化为物理添加噪声以实现 sim-to-real。 NEWTON 求解器 (Solver) 具有非常少的(通常只有一次)求解器迭代次数,并且在 GPU 上表现良好。CG 目前是 TPU 的更好选择。

contact/pair

考虑显式标记用于碰撞检测的几何体,以减少 MJX 在每一步中必须考虑的接触点数量。仅启用有效接触点的显式列表可以显着提高 MJX 的模拟性能。要做好这一点,通常需要对任务有所了解——例如,OpenAI Gym 人形 任务在人形开始跌倒时会重置,因此不需要与地面完全接触。

maxhullvert

maxhullvert 设置为 64 或更小,以获得更好的凸网格碰撞性能。

option/flag/eulerdamp

禁用 eulerdamp 可以提高性能,并且通常为了稳定性不需要它。阅读 数值积分 部分了解此标志的语义详情。

option/jacobian

明确设置“dense”或“sparse”可能会根据您的设备加快模拟速度。现代 TPU 拥有专门的硬件用于快速处理稀疏矩阵,而 GPU 在处理密集矩阵方面往往更快,只要它们能够装载到设备上。因此,在 MJX 中,“auto”默认设置的行为是:如果 nv >= 60(60 个或更多自由度),或者 MJX 检测到 TPU 作为默认后端,则为“sparse”,否则为“dense”。对于 TPU,将“sparse”与牛顿求解器一起使用可以将模拟速度提高 2 到 3 倍。对于 GPU,选择“dense”可能会带来更适度的 10% 到 20% 的加速,只要密集矩阵能够装载到设备上即可。

Broadphase (宽阶段)

虽然 MuJoCo 直接处理宽阶段剔除,但 MJX 需要额外的参数。对于宽阶段的近似版本,请使用实验性自定义数字参数 max_contact_pointsmax_geom_pairsmax_contact_points 限制发送给求解器处理每种 condim 类型的接触点数量。max_geom_pairs 限制发送给每个几何体类型对的相应碰撞函数的几何体对总数。例如,shadow hand 环境就使用了这些参数。

GPU 性能#

应设置以下环境变量:

XLA_FLAGS=--xla_gpu_triton_gemm_any=true

这会启用基于 Triton 的 GEMM (matmul) 发射器,用于它支持的任何 GEMM。这可以在 NVIDIA GPU 上带来 30% 的加速。如果您有多个 GPU,您还可以通过启用与 GPU 间通信 相关的标志而获益。