Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions genesis/engine/solvers/rigid/abd/forward_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,18 +1429,19 @@ def func_update_acc(
if qd.static(static_rigid_sim_config.use_hibernation)
else i_0
)

link_start = entities_info.link_start[i_e]
link_end = entities_info.link_end[i_e]
for i_l_ in (
range(entities_info.link_start[i_e], entities_info.link_end[i_e])
range(link_start, link_end)
if qd.static(not BW)
else qd.static(range(static_rigid_sim_config.max_n_links_per_entity))
):
i_l = i_l_ if qd.static(not BW) else (i_l_ + entities_info.link_start[i_e])
i_l = i_l_ if qd.static(not BW) else (i_l_ + link_start)

if func_check_index_range(
i_l,
entities_info.link_start[i_e],
entities_info.link_end[i_e],
link_start,
link_end,
BW,
):
I_l = [i_l, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_l
Expand All @@ -1460,32 +1461,35 @@ def func_update_acc(
if qd.static(update_cacc):
links_state.cacc_lin[i_l, i_b] = links_state.cacc_lin[i_p, i_b]
links_state.cacc_ang[i_l, i_b] = links_state.cacc_ang[i_p, i_b]

dof_start = links_info.dof_start[I_l]
dof_end = links_info.dof_end[I_l]
for i_d_ in (
range(links_info.dof_start[I_l], links_info.dof_end[I_l])
range(dof_start, dof_end)
if qd.static(not BW)
else qd.static(range(static_rigid_sim_config.max_n_dofs_per_link))
):
i_d = i_d_ if qd.static(not BW) else (i_d_ + links_info.dof_start[I_l])
i_d = i_d_ if qd.static(not BW) else (i_d_ + dof_start)

if func_check_index_range(i_d, links_info.dof_start[I_l], links_info.dof_end[I_l], BW):
if func_check_index_range(i_d, dof_start, dof_end, BW):
# cacc = cacc_parent + cdofdot * qvel + cdof * qacc
local_cdd_vel = dofs_state.cdofd_vel[i_d, i_b] * dofs_state.vel[i_d, i_b]
local_cdd_ang = dofs_state.cdofd_ang[i_d, i_b] * dofs_state.vel[i_d, i_b]
vel = dofs_state.vel[i_d, i_b]
acc = dofs_state.acc[i_d, i_b]
local_cdd_vel = dofs_state.cdofd_vel[i_d, i_b] * vel
local_cdd_ang = dofs_state.cdofd_ang[i_d, i_b] * vel

func_add_safe_backward(links_state.cdd_vel, [i_l, i_b], local_cdd_vel, BW)
func_add_safe_backward(links_state.cdd_ang, [i_l, i_b], local_cdd_ang, BW)
if qd.static(update_cacc):
func_add_safe_backward(
links_state.cacc_lin,
[i_l, i_b],
local_cdd_vel + dofs_state.cdof_vel[i_d, i_b] * dofs_state.acc[i_d, i_b],
local_cdd_vel + dofs_state.cdof_vel[i_d, i_b] * acc,
BW,
)
func_add_safe_backward(
links_state.cacc_ang,
[i_l, i_b],
local_cdd_ang + dofs_state.cdof_ang[i_d, i_b] * dofs_state.acc[i_d, i_b],
local_cdd_ang + dofs_state.cdof_ang[i_d, i_b] * acc,
BW,
)

Expand Down
51 changes: 31 additions & 20 deletions genesis/engine/solvers/rigid/abd/forward_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,16 +975,19 @@ def func_forward_kinematics_entity(
R = qd.static(func_read_field_if)
WR = qd.static(func_write_and_read_field_if)
i_b = qd.cast(i_b, qd.i32)
link_start = entities_info.link_start[i_e]
link_end = entities_info.link_end[i_e]

# Becomes static loop in backward pass, because we assume this loop is an inner loop
qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL,block_dim=64)
for i_l_ in (
range(entities_info.link_start[i_e], entities_info.link_end[i_e])
range(link_start, link_end)
if qd.static(not BW)
else qd.static(range(static_rigid_sim_config.max_n_links_per_entity))
):
i_l = gs.qd_int(i_l_ if qd.static(not BW) else (i_l_ + entities_info.link_start[i_e]))
i_l = gs.qd_int(i_l_ if qd.static(not BW) else (i_l_ + link_start))

if func_check_index_range(i_l, entities_info.link_start[i_e], entities_info.link_end[i_e], BW):
if func_check_index_range(i_l, link_start, link_end, BW):
I_l = [i_l, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_l
I_l0 = (i_l, 0, i_b)

Expand Down Expand Up @@ -1386,15 +1389,17 @@ def func_forward_velocity_entity(
R = qd.static(func_read_field_if)
A = qd.static(func_atomic_add_if)
i_b = qd.cast(i_b, qd.i32)

link_start = entities_info.link_start[i_e]
link_end = entities_info.link_end[i_e]
qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL,block_dim=64)
for i_l_ in (
range(entities_info.link_start[i_e], entities_info.link_end[i_e])
range(link_start, link_end)
if qd.static(not BW)
else qd.static(range(static_rigid_sim_config.max_n_links_per_entity))
):
i_l = gs.qd_int(i_l_ if qd.static(not BW) else (i_l_ + entities_info.link_start[i_e]))
i_l = gs.qd_int(i_l_ if qd.static(not BW) else (i_l_ + link_start))

if func_check_index_range(i_l, entities_info.link_start[i_e], entities_info.link_end[i_e], BW):
if func_check_index_range(i_l, link_start, link_end, BW):
I_l = [i_l, i_b] if qd.static(static_rigid_sim_config.batch_links_info) else i_l
n_joints = links_info.joint_end[I_l] - links_info.joint_start[I_l]

Expand Down Expand Up @@ -1423,38 +1428,44 @@ def func_forward_velocity_entity(

if joint_type == gs.JOINT_TYPE.FREE:
for i_3 in qd.static(range(3)):
_vel = dofs_state.cdof_vel[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b]
_ang = dofs_state.cdof_ang[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b]

cvel_vel = cvel_vel + A(links_state.cd_vel_bw, curr_I, _vel, BW)
cvel_ang = cvel_ang + A(links_state.cd_ang_bw, curr_I, _ang, BW)

idx = dof_start + i_3

v = dofs_state.vel[idx, i_b]
tmp_vel = dofs_state.cdof_vel[idx, i_b] * v
tmp_ang = dofs_state.cdof_ang[idx, i_b] * v
cvel_vel = cvel_vel + A(links_state.cd_vel_bw, curr_I, tmp_vel, BW)
cvel_ang = cvel_ang + A(links_state.cd_ang_bw, curr_I, tmp_ang, BW)

ang_curr = R(links_state.cd_ang_bw, curr_I, cvel_ang, BW)
vel_curr = R(links_state.cd_vel_bw, curr_I, cvel_vel, BW)
for i_3 in qd.static(range(3)):
(
dofs_state.cdofd_ang[dof_start + i_3, i_b],
dofs_state.cdofd_vel[dof_start + i_3, i_b],
) = qd.Vector.zero(gs.qd_float, 3), qd.Vector.zero(gs.qd_float, 3)

(
dofs_state.cdofd_ang[dof_start + i_3 + 3, i_b],
dofs_state.cdofd_vel[dof_start + i_3 + 3, i_b],
) = gu.motion_cross_motion(
R(links_state.cd_ang_bw, curr_I, cvel_ang, BW),
R(links_state.cd_vel_bw, curr_I, cvel_vel, BW),
ang_curr,
vel_curr,
dofs_state.cdof_ang[dof_start + i_3 + 3, i_b],
dofs_state.cdof_vel[dof_start + i_3 + 3, i_b],
)

if qd.static(BW):
links_state.cd_vel_bw[next_I] = links_state.cd_vel_bw[curr_I]
links_state.cd_ang_bw[next_I] = links_state.cd_ang_bw[curr_I]

for i_3 in qd.static(range(3)):
idx = dof_start + i_3 + 3
v = dofs_state.vel[idx, i_b]
_vel = (
dofs_state.cdof_vel[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b]
dofs_state.cdof_vel[idx, i_b] * v
)
_ang = (
dofs_state.cdof_ang[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b]
dofs_state.cdof_ang[idx, i_b] * v
)
cvel_vel = cvel_vel + A(links_state.cd_vel_bw, next_I, _vel, BW)
cvel_ang = cvel_ang + A(links_state.cd_ang_bw, next_I, _ang, BW)
Expand Down
5 changes: 3 additions & 2 deletions genesis/engine/solvers/rigid/collider/broadphase.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def func_check_collision_valid(
i_lb = geoms_info.link_idx[i_gb]

# Filter out collision pairs that are involved in dynamically registered weld equality constraints
qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=64)
for i_eq in range(rigid_global_info.n_equalities[None], constraint_state.qd_n_equalities[i_b]):
if equalities_info.eq_type[i_eq, i_b] == gs.EQUALITY_TYPE.WELD:
i_leqa = equalities_info.eq_obj1id[i_eq, i_b]
Expand Down Expand Up @@ -79,7 +80,7 @@ def func_collision_clear(
):
_B = collider_state.n_contacts.shape[0]

qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=64)
for i_b in range(_B):
if qd.static(static_rigid_sim_config.use_hibernation):
collider_state.n_contacts_hibernated[i_b] = 0
Expand Down Expand Up @@ -164,7 +165,7 @@ def func_broad_phase(
# Clear collider state
func_collision_clear(links_state, links_info, collider_state, static_rigid_sim_config)

qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL,block_dim=64)
for i_b in range(_B):
axis = 0

Expand Down
108 changes: 85 additions & 23 deletions genesis/engine/solvers/rigid/constraint/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,28 +1559,41 @@ def func_hessian_direct_tiled(
qd.simt.block.sync()

# Compute `H += J.T @ D @ J` for a single Hessian block
pid = tid
numel = n_dofs_tile_row * n_dofs_tile_col
while pid < numel:
i_d1_ = pid // n_dofs_tile_col
i_d2_ = pid % n_dofs_tile_col
i_d1 = i_d1_ + i_d1_start
i_d2 = i_d2_ + i_d2_start
if i_d1 >= i_d2:
if is_diag_tile:
n_lower_tri_tile = n_dofs_tile_row * (n_dofs_tile_row + 1) // 2
pid = tid
while pid < n_lower_tri_tile:
i_d1_, i_d2_ = linear_to_lower_tri(pid)
i_d1 = i_d1_ + i_d1_start
i_d2 = i_d2_ + i_d2_start
coef = gs.qd_float(0.0)
if i_c_start == 0:
coef = rigid_global_info.mass_mat[i_d1, i_d2, i_b]
if is_diag_tile:
for j_c_ in range(n_conts_tile):
coef = coef + jac_row[j_c_, i_d1_] * jac_row[j_c_, i_d2_] * efc_D[j_c_]
for j_c_ in range(n_conts_tile):
coef = coef + jac_row[j_c_, i_d1_] * jac_row[j_c_, i_d2_] * efc_D[j_c_]
if i_c_start == 0:
constraint_state.nt_H[i_b, i_d1, i_d2] = coef
else:
for j_c_ in range(n_conts_tile):
coef = coef + jac_row[j_c_, i_d1_] * jac_col[j_c_, i_d2_] * efc_D[j_c_]
constraint_state.nt_H[i_b, i_d1, i_d2] = constraint_state.nt_H[i_b, i_d1, i_d2] + coef
pid = pid + BLOCK_DIM
else:
numel = n_dofs_tile_row * n_dofs_tile_col
pid = tid
while pid < numel:
i_d1_ = pid // n_dofs_tile_col
i_d2_ = pid % n_dofs_tile_col
i_d1 = i_d1_ + i_d1_start
i_d2 = i_d2_ + i_d2_start
coef = gs.qd_float(0.0)
if i_c_start == 0:
coef = rigid_global_info.mass_mat[i_d1, i_d2, i_b]
for j_c_ in range(n_conts_tile):
coef = coef + jac_row[j_c_, i_d1_] * jac_col[j_c_, i_d2_] * efc_D[j_c_]
if i_c_start == 0:
constraint_state.nt_H[i_b, i_d1, i_d2] = coef
else:
constraint_state.nt_H[i_b, i_d1, i_d2] = constraint_state.nt_H[i_b, i_d1, i_d2] + coef
pid = pid + BLOCK_DIM
pid = pid + BLOCK_DIM
qd.simt.block.sync()

i_d2_start = i_d2_start + MAX_DOFS_PER_BLOCK
Expand Down Expand Up @@ -2886,22 +2899,71 @@ def initialize_Jaref(
qacc: array_class.V_ANNOTATION,
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: qd.template(),
):
if qd.static(static_rigid_sim_config.parallel_init):
_initialize_Jaref_parallel(
qacc=qacc,
constraint_state=constraint_state,
static_rigid_sim_config=static_rigid_sim_config,
)
else:
_initialize_Jaref_per_env(
qacc=qacc,
constraint_state=constraint_state,
static_rigid_sim_config=static_rigid_sim_config,
)


@qd.func
def _initialize_Jaref_body(
i_c,
i_b,
n_dofs,
qacc: array_class.V_ANNOTATION,
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: qd.template(),
):
Jaref = -constraint_state.aref[i_c, i_b]
if qd.static(static_rigid_sim_config.sparse_solve):
for i_d_ in range(constraint_state.jac_n_relevant_dofs[i_c, i_b]):
i_d = constraint_state.jac_relevant_dofs[i_c, i_d_, i_b]
Jaref = Jaref + constraint_state.jac[i_c, i_d, i_b] * qacc[i_d, i_b]
else:
for i_d in range(n_dofs):
Jaref = Jaref + constraint_state.jac[i_c, i_d, i_b] * qacc[i_d, i_b]
constraint_state.Jaref[i_c, i_b] = Jaref


@qd.func
def _initialize_Jaref_per_env(
qacc: array_class.V_ANNOTATION,
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: qd.template(),
):
_B = constraint_state.jac.shape[2]
n_dofs = constraint_state.jac.shape[1]

qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_b in range(_B):
for i_c in range(constraint_state.n_constraints[i_b]):
Jaref = -constraint_state.aref[i_c, i_b]
if qd.static(static_rigid_sim_config.sparse_solve):
for i_d_ in range(constraint_state.jac_n_relevant_dofs[i_c, i_b]):
i_d = constraint_state.jac_relevant_dofs[i_c, i_d_, i_b]
Jaref = Jaref + constraint_state.jac[i_c, i_d, i_b] * qacc[i_d, i_b]
else:
for i_d in range(n_dofs):
Jaref = Jaref + constraint_state.jac[i_c, i_d, i_b] * qacc[i_d, i_b]
constraint_state.Jaref[i_c, i_b] = Jaref
_initialize_Jaref_body(i_c, i_b, n_dofs, qacc, constraint_state, static_rigid_sim_config)


@qd.func
def _initialize_Jaref_parallel(
qacc: array_class.V_ANNOTATION,
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: qd.template(),
):
"""Parallelizes over (constraints, envs) — better when GPU is not saturated by envs alone."""
_B = constraint_state.jac.shape[2]
n_dofs = constraint_state.jac.shape[1]
len_constraints = constraint_state.Jaref.shape[0]

qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_c, i_b in qd.ndrange(len_constraints, _B):
if i_c < constraint_state.n_constraints[i_b]:
_initialize_Jaref_body(i_c, i_b, n_dofs, qacc, constraint_state, static_rigid_sim_config)


@qd.func
Expand Down
24 changes: 24 additions & 0 deletions genesis/engine/solvers/rigid/rigid_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,29 @@ def build(self):
self._func_vel_at_point = func_vel_at_point
self._func_apply_coupling_force = func_apply_coupling_force

def _should_use_parallel_init(self):
"""Use parallel init (ndrange over constraints+envs) when envs alone don't saturate the GPU.

Uses hardware-derived GPU core count to determine saturation threshold, following the same
multi-backend pattern as collider.py (line 219).
"""
if gs.backend == gs.cpu or self.sim.options.requires_grad:
return False
import torch

if torch.cuda.is_available():
gpu_props = torch.cuda.get_device_properties(torch.cuda.current_device())
# NVIDIA: 128 CUDA cores per SM. AMD/ROCm: 64 stream processors per CU.
cores_per_unit = 64 if torch.version.hip else 128
gpu_cores = gpu_props.multi_processor_count * cores_per_unit
elif gs.backend == gs.metal:
# Upper-bound estimate for Apple Silicon: 40 GPU cores × 128 ALUs
gpu_cores = 5120
else:
# Fallback for other GPU backends (e.g. Vulkan)
gpu_cores = 16384
return self.n_envs <= gpu_cores

def _build_static_config(self):
static_rigid_sim_config = dict(
backend=gs.backend,
Expand All @@ -366,6 +389,7 @@ def _build_static_config(self):
sparse_solve=self._options.sparse_solve,
integrator=self._integrator,
solver_type=self._options.constraint_solver,
parallel_init=self._should_use_parallel_init(),
)

if self.is_active:
Expand Down
1 change: 1 addition & 0 deletions genesis/utils/array_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -2022,6 +2022,7 @@ class StructRigidSimStaticConfig(metaclass=AutoInitMeta):
integrator: int
solver_type: int
requires_grad: bool
parallel_init: bool = False # parallelize init over (constraints, envs) when GPU is not saturated by envs alone
enable_tiled_cholesky_mass_matrix: bool = False
enable_tiled_cholesky_hessian: bool = False
tiled_n_dofs_per_entity: int = -1
Expand Down
Loading