MuJoCo Warp (MJWarp)#

MuJoCo Warp (MJWarp) 是一个用 Warp 编写的 MuJoCo 实现,并针对 NVIDIA 硬件和并行仿真进行了优化。MJWarp 位于 google-deepmind/mujoco_warp GitHub 仓库中,目前处于测试(beta)阶段。

MJWarp 由 NVIDIAGoogle DeepMind 联合开发和维护。

测试版软件

  • MJWarp 是测试版软件,正在积极开发中。

  • MJWarp 开发者将对 错误报告和功能请求 进行分类和回应。

  • MJWarp 在功能上已基本完备,但仍需要进行性能优化、文档编写和测试。

  • 测试阶段的目标受众是物理引擎爱好者和学习框架集成者。

安装#

MuJoCo Warp 的测试版从 GitHub 安装。请注意,MuJoCo Warp 的测试版并不支持所有版本的 MuJoCo、Warp、CUDA、NVIDIA 驱动程序等。

git clone https://github.com/google-deepmind/mujoco_warp.git
cd mujoco_warp
python3 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install uv
uv pip install -e .[dev,cuda]

测试安装

pytest

基本用法#

安装后,可以通过 import mujoco_warp as mjw 导入该包。结构体、函数和枚举可直接从顶层的 mjw 模块中获取。

结构体#

在 NVIDIA GPU 上运行 MJWarp 函数之前,必须通过 mjw.put_modelmjw.make_datamjw.put_data 函数将结构体复制到设备上。将 mjModel 放置在设备上会生成一个 mjw.Model。将 mjData 放置在设备上会生成一个 mjw.Data

mjm = mujoco.MjModel.from_xml_string("...")
mjd = mujoco.MjData(mjm)
m = mjw.put_model(mjm)
d = mjw.put_data(mjm, mjd)

这些 MJWarp 变体与其对应的 MuJoCo 版本类似,但有几个关键区别:

  1. mjw.Modelmjw.Data 包含复制到设备上的 Warp 数组。

  2. 由于某些功能不受支持,mjw.Modelmjw.Data 中缺少一些字段。

批量大小#

MJWarp 针对并行仿真进行了优化。一批仿真可以通过三个参数来指定:

  • nworld:要仿真的世界数量。

  • nconmax:每个世界的预期接触点数量。所有世界的最大接触点总数为 nconmax * nworld

  • naconmaxnconmax 的替代方案,表示所有世界的最大接触点总数。如果同时设置了 nconmaxnaconmax,则忽略 nconmax

  • njmax:每个世界的最大约束数量。

nconmaxnjmax 的语义差异。

如果所有世界的总接触点数量不超过 nworld x nconmax,那么单个世界的接触点数量可能超过 nconmax。但是,每个世界的约束数量严格受 njmax 的限制。

XML 解析

nconmaxnjmax 的值不会从 size/nconmaxsize/njmax 中解析(这些参数已弃用)。这些参数的值必须提供给 mjw.make_datamjw.put_data

函数#

MuJoCo 函数以同名的 MJWarp 函数形式暴露,但遵循 PEP 8 兼容的命名规范。大部分主仿真函数和一些用于正向仿真的子组件函数可从顶层 mjw 模块中获取。

最小示例#

# Throw a ball at 100 different velocities.

import mujoco
import mujoco_warp as mjw
import warp as wp

_MJCF=r"""
<mujoco>
  <worldbody>
    <body>
      <freejoint/>
      <geom size=".15" mass="1" type="sphere"/>
    </body>
  </worldbody>
</mujoco>
"""

mjm = mujoco.MjModel.from_xml_string(_MJCF)
m = mjw.put_model(mjm)
d = mjw.make_data(mjm, nworld=100)

# initialize velocities
wp.copy(d.qvel, wp.array([[float(i) / 100, 0, 0, 0, 0, 0] for i in range(100)], dtype=float))

# simulate physics
mjw.step(m, d)

print(f'qpos:\n{d.qpos.numpy()}')

命令行脚本#

使用 testspeed 对环境进行基准测试

mjwarp-testspeed benchmark/humanoid/humanoid.xml

使用 MJWarp 进行交互式环境仿真

mjwarp-viewer benchmark/humanoid/humanoid.xml

功能对等性#

MJWarp 支持 MuJoCo 的大部分主仿真功能,但有少数例外。如果要求 MJWarp 将一个字段值引用了不支持功能的 mjModel 复制到设备上,它将引发异常。

MJWarp 不支持以下功能

类别

功能

等式约束

FLEX

积分器

在有流体阻力时不支持 IMPLICITIMPLICITFAST

求解器

PGSnoslip岛屿(islands)

流体模型

椭球体模型

传感器

GEOMDIST, GEOMNORMAL, GEOMFROMTO

Flex

VERTCOLLIDE=false, INTERNAL=true, nflex > 1

雅可比矩阵格式

SPARSE

选项

接触覆盖

插件

SDF 外的所有

用户参数

所有

性能调优#

以下是优化 MJWarp 性能时需要考虑的因素。

图捕捉#

MJWarp 函数,例如 mjw.step,通常包含一系列的内核启动。如果直接调用该函数,Warp 将逐个启动这些内核。为了提高性能,特别是当函数需要被多次调用时,建议将构成该函数的操作捕捉为 CUDA 图:

with wp.ScopedCapture() as capture:
  mjw.step(m, d)

然后可以启动或重新启动该图:

wp.capture_launch(capture.graph)

这通常会比直接调用函数快得多。详情请参阅 Warp 图 API 参考

批量大小#

最大接触点数量和最大约束数量,即 nconmax / naconmaxnjmax,是在使用 mjw.make_datamjw.put_data 创建 mjw.Data 时指定的。内存和计算量会随着这些参数的值而变化。为获得最佳性能,应将这些参数的值设置得尽可能小,同时确保仿真不会超过这些限制。

通常,这些限制的合适值是特定于环境的。在实践中,选择合适的值通常需要反复试验。使用带有 --measure_alloc 标志的 mjwarp-testspeed 来打印每个仿真步骤的接触点和约束数量,以及通过 mjwarp-viewer 与仿真交互并检查溢出错误,都是迭代测试这些参数值的有用技巧。

求解器迭代次数#

MuJoCo 关于最大求解器迭代次数线搜索迭代次数的默认求解器设置预计能提供合理的性能。降低 MJWarp 的设置 Option.iterations 和/或 Option.ls_iterations 的限制可能会提高性能,但这应在调整 nconmax / naconmaxnjmax 之后再作考虑。

将这些限制设置得过低可能会阻止约束求解器收敛,并导致仿真不准确或不稳定。

对性能的影响:MJX (JAX) 与 MJWarp

MJX 中,这些求解器参数是控制仿真性能的关键。相比之下,对于 MJWarp,一旦所有世界都已收敛,求解器可以提前退出以避免不必要的计算。因此,这些设置的值对性能的影响相对较小。

接触传感器匹配#

包含接触传感器的场景有一个参数,用于指定每个传感器匹配的最大接触点数量,即 Option.contact_sensor_max_match。为获得最佳性能,该参数的值应尽可能小,同时确保仿真不会超过此限制。超过此限制的已匹配接触点将被忽略。

该参数的值可以直接设置,例如 model.opt.contact_sensor_maxmatch = 16,或者通过 XML 自定义数字字段设置:

<custom>
  <numeric name="contact_sensor_maxmatch" data="16"/>
</custom>

与最大接触点和约束数量类似,此设置的合适值预计也是特定于环境的。mjwarp-testspeedmjwarp-viewer 可能有助于调整该参数的值。

并行线搜索#

除了约束求解器的迭代式线搜索外,MJWarp 还提供了一个并行线搜索例程,该例程可并行评估一组步长并选择最佳步长。这些步长在从 Model.opt.ls_parallel_min_step 到 1 的范围内按对数间隔分布,要评估的步长数量通过 Model.opt.ls_iterations 设置。

在某些情况下,与约束求解器的默认迭代式线搜索相比,并行例程可能会提供更高的性能。

要启用此例程,请设置 Model.opt.ls_parallel=True 或向 XML 添加一个自定义数字字段:

<custom>
  <numeric name="ls_parallel" data="1"/>
</custom>

实验性功能

并行线搜索目前是一项实验性功能。

批处理的 Model 字段#

为了能够使用不同的模型参数值进行批处理仿真,许多 mjw.Model 字段都有一个前导的批处理维度。默认情况下,前导维度为 1(即 field.shape[0] == 1),相同的值将应用于所有世界。可以用一个前导维度大于 1 的 wp.array 来覆盖其中一个字段。该字段将通过世界 ID 和批处理维度的模运算进行索引:field[worldid % field.shape[0]]。重要的是,字段的形状应在图捕捉(即 wp.ScopedCapture)之前被覆盖:

# override shape and values
m.dof_damping = wp.array([[0.1], [0.2]], dtype=float)

with wp.ScopedCapture() as capture:
  mjw.step(m, d)

也可以在图捕捉之后覆盖字段形状并设置字段值:

# override shape
m.dof_damping = wp.empty((2, 1), dtype=float)

with wp.ScopedCapture() as capture:
  mjw.step(m, d)

# set batched values
dof_damping_batch = wp.array([[0.1], [0.2]], dtype=float)
wp.copy(m.dof_damping, dof_damping_batch)  # m.dof = dof_damping_batch will not update the captured graph

异构世界

异构世界,例如:每个世界有不同的网格或自由度数量,目前尚不可用。

常见问题#

学习框架#

MJWarp 是否能与 JAX 配合使用?

是的。MJWarp 与 JAX 具有互操作性。详情请参阅 Warp 互操作性 文档。

此外,MJX 为 MJWarp 的API 的一个子集提供了 JAX API。后端通过 impl='warp' 指定。

MJWarp 是否能与 PyTorch 配合使用?

是的。MJWarp 与 PyTorch 具有互操作性。详情请参阅 Warp 互操作性 文档。

如何使用 MJWarp 物理引擎训练策略?

有关使用 MJWarp 物理引擎训练策略的示例,请参阅:

功能#

MJWarp 是否可微分?

否。MJWarp 目前无法通过 Warp 的自动微分功能进行微分。团队关于为 MJWarp 启用自动微分的更新情况,请在此 GitHub issue 中跟踪。

MJWarp 是否支持多 GPU?

是的。Warp 的 wp.ScopedDevice 支持多 GPU 计算:

# create a graph for each device
graph = {}
for device in wp.get_cuda_devices():
  with wp.ScopedDevice(device):
    m = mjw.put_model(mjm)
    d = mjw.make_data(mjm)
    with wp.ScopedCapture(device) as capture:
      mjw.step(m, d)
    graph[device] = capture.graph

# launch a graph on each device
for device in wp.get_cuda_devices():
  wp.capture_launch(graph[device])

详情请参阅 Warp 文档,以及一个强化学习示例 mjlab 分布式训练

MJWarp 在 GPU 上是确定性的吗?

否。相同代码的不同执行所计算的结果之间可能存在顺序或*微小*的数值差异。这是 GPU 上非确定性原子操作的特点。要获得确定性结果,请使用 wp.set_device("cpu") 将设备设置为 CPU。

关于在 GPU 上实现确定性结果的进展,请在此 GitHub issue 中跟踪。

方向是如何表示的?

方向由单位四元数表示,并遵循 MuJoCo 的约定w, x, y, z标量, 矢量

wp.quaternion

MJWarp 使用 Warp 的内置类型 wp.quaternion。但重要的是,MJWarp 并不使用 Warp 的 x, y, z, w 四元数约定或操作,而是实现了遵循 MuJoCo 约定的四元数例程。实现细节请参阅 math.py