Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions genesis/engine/solvers/rigid/abd/forward_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,156 @@ def func_solve_mass(
)


@qd.func
def func_solve_mass_tiled(
vec: array_class.V_ANNOTATION,
out: array_class.V_ANNOTATION,
entities_info: array_class.EntitiesInfo,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: qd.template(),
):
"""Wave-cooperative solve of `M @ out = vec` using the pre-factored LDLT of M.

One wavefront (BLOCK_DIM=64) per (entity, env). Mirrors the structure of
`func_cholesky_solve_tiled` in constraint/solver.py, adapted for LDLT
(unit L, diagonal D stored inverted in `mass_mat_D_inv`) and indexed per
entity's DoF range.

Three phases per entity (all serialized across i_d, parallelized across
lanes within each serial step):
Phase 1: solve L^T @ w = vec # unit diagonal
Phase 2: z = D_inv * w # elementwise
Phase 3: solve L @ out = z # unit diagonal
"""
BLOCK_DIM = qd.static(64)
MAX_DOFS = qd.static(static_rigid_sim_config.tiled_n_dofs_per_entity)
ENABLE_WARP_REDUCTION = qd.static(
static_rigid_sim_config.backend == gs.cuda and gs.qd_float == qd.f32
)
WARP_SIZE = qd.static(32)
NUM_WARPS = qd.static(BLOCK_DIM // WARP_SIZE)

_B = out.shape[1]
n_entities = entities_info.n_links.shape[0]

qd.loop_config(block_dim=BLOCK_DIM)
for i in range(n_entities * _B * BLOCK_DIM):
tid = i % BLOCK_DIM
i_e = (i // BLOCK_DIM) % n_entities
i_b = i // (BLOCK_DIM * n_entities)
warp_id = tid // WARP_SIZE
lane_id = tid % WARP_SIZE

if i_b >= _B:
continue
if not rigid_global_info.mass_mat_mask[i_e, i_b]:
continue

entity_dof_start = entities_info.dof_start[i_e]
entity_dof_end = entities_info.dof_end[i_e]
n_dofs = entities_info.n_dofs[i_e]

# LDS: entity-local lower-triangular L, D_inv, working vector v, and reduction buffer.
L = qd.simt.block.SharedArray((MAX_DOFS, MAX_DOFS + 1), gs.qd_float)
D_inv = qd.simt.block.SharedArray((MAX_DOFS,), gs.qd_float)
v = qd.simt.block.SharedArray((MAX_DOFS,), gs.qd_float)
partial = qd.simt.block.SharedArray(
(NUM_WARPS if qd.static(ENABLE_WARP_REDUCTION) else BLOCK_DIM,), gs.qd_float
)

# Copy entity-local lower triangle of L from global mass_mat_L into LDS.
# L[j_local, i_local] is the (j_local, i_local) entry below the diagonal
# for j_local > i_local. Source global index is (entity_dof_start + local).
n_tri = n_dofs * n_dofs
i_flat = tid
while i_flat < n_tri:
r = i_flat // n_dofs
c = i_flat % n_dofs
if c < r: # strict lower triangle
L[r, c] = rigid_global_info.mass_mat_L[entity_dof_start + r, entity_dof_start + c, i_b]
i_flat = i_flat + BLOCK_DIM

# Copy D_inv and the input vector into LDS (strided).
k_d = tid
while k_d < n_dofs:
D_inv[k_d] = rigid_global_info.mass_mat_D_inv[entity_dof_start + k_d, i_b]
v[k_d] = vec[entity_dof_start + k_d, i_b]
k_d = k_d + BLOCK_DIM
qd.simt.block.sync()

# ------------------------------------------------------------------
# Phase 1: solve L^T @ w = y in-place in v.
# For each i_d from end-1 down to 0:
# w[i_d] = y[i_d] - sum_{j > i_d} L[j, i_d] * w[j]
# Note: serial scan must go DECREASING i_d because w[j] for j > i_d
# must already be finalized. L is unit-diagonal, no divide.
# ------------------------------------------------------------------
for i_d_ in range(n_dofs):
i_d = n_dofs - 1 - i_d_
dot = gs.qd_float(0.0)
j_d = i_d + 1 + tid
while j_d < n_dofs:
dot = dot + L[j_d, i_d] * v[j_d]
j_d = j_d + BLOCK_DIM
if qd.static(ENABLE_WARP_REDUCTION):
for offset in qd.static([16, 8, 4, 2, 1]):
dot = dot + qd.simt.warp.shfl_down_f32(qd.u32(0xFFFFFFFF), dot, offset)
if lane_id == 0:
partial[warp_id] = dot
else:
partial[tid] = dot
qd.simt.block.sync()

if tid == 0:
total = gs.qd_float(0.0)
for k in qd.static(range(NUM_WARPS)) if qd.static(ENABLE_WARP_REDUCTION) else range(BLOCK_DIM):
total = total + partial[k]
v[i_d] = v[i_d] - total
qd.simt.block.sync()

# ------------------------------------------------------------------
# Phase 2: z = D^{-1} @ w (elementwise, fully parallel).
# ------------------------------------------------------------------
k_d = tid
while k_d < n_dofs:
v[k_d] = v[k_d] * D_inv[k_d]
k_d = k_d + BLOCK_DIM
qd.simt.block.sync()

# ------------------------------------------------------------------
# Phase 3: solve L @ x = z in-place in v.
# For each i_d from 0 upward:
# x[i_d] = z[i_d] - sum_{j < i_d} L[i_d, j] * x[j]
# ------------------------------------------------------------------
for i_d in range(n_dofs):
dot = gs.qd_float(0.0)
j_d = tid
while j_d < i_d:
dot = dot + L[i_d, j_d] * v[j_d]
j_d = j_d + BLOCK_DIM
if qd.static(ENABLE_WARP_REDUCTION):
for offset in qd.static([16, 8, 4, 2, 1]):
dot = dot + qd.simt.warp.shfl_down_f32(qd.u32(0xFFFFFFFF), dot, offset)
if lane_id == 0:
partial[warp_id] = dot
else:
partial[tid] = dot
qd.simt.block.sync()

if tid == 0:
total = gs.qd_float(0.0)
for k in qd.static(range(NUM_WARPS)) if qd.static(ENABLE_WARP_REDUCTION) else range(BLOCK_DIM):
total = total + partial[k]
v[i_d] = v[i_d] - total
qd.simt.block.sync()

# Copy the final solution from LDS back to global out[entity_dof_start..end, i_b].
k_d = tid
while k_d < n_dofs:
out[entity_dof_start + k_d, i_b] = v[k_d]
k_d = k_d + BLOCK_DIM


@qd.func
def func_torque_and_passive_force(
entities_state: array_class.EntitiesState,
Expand Down
24 changes: 11 additions & 13 deletions genesis/engine/solvers/rigid/constraint/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import genesis.utils.array_class as array_class
import genesis.utils.geom as gu
from genesis.engine.solvers.rigid.abd import func_solve_mass_batch
from genesis.engine.solvers.rigid.abd import func_solve_mass_batch, func_solve_mass_tiled
from genesis.utils.misc import qd_to_torch, indices_to_mask, assign_indexed_tensor

from ..collider.contact_island import ContactIsland
Expand Down Expand Up @@ -2771,18 +2771,16 @@ def func_update_gradient_tiled(
)

if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.CG):
qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32)
for i_b in range(_B):
func_solve_mass_batch(
i_b,
constraint_state.grad,
constraint_state.Mgrad,
array_class.PLACEHOLDER,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
is_backward=False,
)
# Wave-cooperative LDLT back-solve: 1 wave (BLOCK_DIM=64) per (entity, env),
# mirrors func_cholesky_solve_tiled for the Newton path. Replaces the serial
# per-env func_solve_mass_batch loop.
func_solve_mass_tiled(
constraint_state.grad,
constraint_state.Mgrad,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
)

if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton):
func_cholesky_solve_tiled(constraint_state, static_rigid_sim_config)
Expand Down
25 changes: 11 additions & 14 deletions genesis/engine/solvers/rigid/constraint/solver_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,16 @@ def _kernel_update_gradient(
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: ti.template(),
):
"""Step 5: Update gradient"""
_B = constraint_state.grad.shape[1]
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32)
for i_b in range(_B):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]:
solver.func_update_gradient_batch(
i_b,
dofs_state=dofs_state,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
constraint_state=constraint_state,
static_rigid_sim_config=static_rigid_sim_config,
)
"""Step 5: Update gradient — delegates to the dispatcher so AMDGPU gets
the tiled Cholesky/mass-solve path (func_update_gradient_tiled) instead
of per-env serial func_update_gradient_batch."""
solver.func_update_gradient(
dofs_state=dofs_state,
entities_info=entities_info,
rigid_global_info=rigid_global_info,
constraint_state=constraint_state,
static_rigid_sim_config=static_rigid_sim_config,
)


@ti.kernel(fastcache=gs.use_fastcache)
Expand All @@ -215,7 +212,7 @@ def _kernel_update_search_direction(
)


@solver.func_solve_body.register(is_compatible=lambda *args, **kwargs: gs.backend in {gs.cuda})
@solver.func_solve_body.register(is_compatible=lambda *args, **kwargs: gs.backend is not gs.cpu)
def func_solve_decomposed(
entities_info,
dofs_state,
Expand Down