diff --git a/genesis/engine/solvers/rigid/abd/forward_dynamics.py b/genesis/engine/solvers/rigid/abd/forward_dynamics.py index ecb764b0e..a7f82a736 100644 --- a/genesis/engine/solvers/rigid/abd/forward_dynamics.py +++ b/genesis/engine/solvers/rigid/abd/forward_dynamics.py @@ -1004,6 +1004,156 @@ def func_solve_mass( ) +@qd.func +def func_solve_mass_tiled( + vec: array_class.V_ANNOTATION, + out: array_class.V_ANNOTATION, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), +): + """Wave-cooperative solve of `M @ out = vec` using the pre-factored LDLT of M. + + One wavefront (BLOCK_DIM=64) per (entity, env). Mirrors the structure of + `func_cholesky_solve_tiled` in constraint/solver.py, adapted for LDLT + (unit L, diagonal D stored inverted in `mass_mat_D_inv`) and indexed per + entity's DoF range. + + Three phases per entity (all serialized across i_d, parallelized across + lanes within each serial step): + Phase 1: solve L^T @ w = vec # unit diagonal + Phase 2: z = D_inv * w # elementwise + Phase 3: solve L @ out = z # unit diagonal + """ + BLOCK_DIM = qd.static(64) + MAX_DOFS = qd.static(static_rigid_sim_config.tiled_n_dofs_per_entity) + ENABLE_WARP_REDUCTION = qd.static( + static_rigid_sim_config.backend == gs.cuda and gs.qd_float == qd.f32 + ) + WARP_SIZE = qd.static(32) + NUM_WARPS = qd.static(BLOCK_DIM // WARP_SIZE) + + _B = out.shape[1] + n_entities = entities_info.n_links.shape[0] + + qd.loop_config(block_dim=BLOCK_DIM) + for i in range(n_entities * _B * BLOCK_DIM): + tid = i % BLOCK_DIM + i_e = (i // BLOCK_DIM) % n_entities + i_b = i // (BLOCK_DIM * n_entities) + warp_id = tid // WARP_SIZE + lane_id = tid % WARP_SIZE + + if i_b >= _B: + continue + if not rigid_global_info.mass_mat_mask[i_e, i_b]: + continue + + entity_dof_start = entities_info.dof_start[i_e] + entity_dof_end = entities_info.dof_end[i_e] + n_dofs = entities_info.n_dofs[i_e] + + # LDS: entity-local lower-triangular L, D_inv, working vector v, and reduction buffer. + L = qd.simt.block.SharedArray((MAX_DOFS, MAX_DOFS + 1), gs.qd_float) + D_inv = qd.simt.block.SharedArray((MAX_DOFS,), gs.qd_float) + v = qd.simt.block.SharedArray((MAX_DOFS,), gs.qd_float) + partial = qd.simt.block.SharedArray( + (NUM_WARPS if qd.static(ENABLE_WARP_REDUCTION) else BLOCK_DIM,), gs.qd_float + ) + + # Copy entity-local lower triangle of L from global mass_mat_L into LDS. + # L[j_local, i_local] is the (j_local, i_local) entry below the diagonal + # for j_local > i_local. Source global index is (entity_dof_start + local). + n_tri = n_dofs * n_dofs + i_flat = tid + while i_flat < n_tri: + r = i_flat // n_dofs + c = i_flat % n_dofs + if c < r: # strict lower triangle + L[r, c] = rigid_global_info.mass_mat_L[entity_dof_start + r, entity_dof_start + c, i_b] + i_flat = i_flat + BLOCK_DIM + + # Copy D_inv and the input vector into LDS (strided). + k_d = tid + while k_d < n_dofs: + D_inv[k_d] = rigid_global_info.mass_mat_D_inv[entity_dof_start + k_d, i_b] + v[k_d] = vec[entity_dof_start + k_d, i_b] + k_d = k_d + BLOCK_DIM + qd.simt.block.sync() + + # ------------------------------------------------------------------ + # Phase 1: solve L^T @ w = y in-place in v. + # For each i_d from end-1 down to 0: + # w[i_d] = y[i_d] - sum_{j > i_d} L[j, i_d] * w[j] + # Note: serial scan must go DECREASING i_d because w[j] for j > i_d + # must already be finalized. L is unit-diagonal, no divide. + # ------------------------------------------------------------------ + for i_d_ in range(n_dofs): + i_d = n_dofs - 1 - i_d_ + dot = gs.qd_float(0.0) + j_d = i_d + 1 + tid + while j_d < n_dofs: + dot = dot + L[j_d, i_d] * v[j_d] + j_d = j_d + BLOCK_DIM + if qd.static(ENABLE_WARP_REDUCTION): + for offset in qd.static([16, 8, 4, 2, 1]): + dot = dot + qd.simt.warp.shfl_down_f32(qd.u32(0xFFFFFFFF), dot, offset) + if lane_id == 0: + partial[warp_id] = dot + else: + partial[tid] = dot + qd.simt.block.sync() + + if tid == 0: + total = gs.qd_float(0.0) + for k in qd.static(range(NUM_WARPS)) if qd.static(ENABLE_WARP_REDUCTION) else range(BLOCK_DIM): + total = total + partial[k] + v[i_d] = v[i_d] - total + qd.simt.block.sync() + + # ------------------------------------------------------------------ + # Phase 2: z = D^{-1} @ w (elementwise, fully parallel). + # ------------------------------------------------------------------ + k_d = tid + while k_d < n_dofs: + v[k_d] = v[k_d] * D_inv[k_d] + k_d = k_d + BLOCK_DIM + qd.simt.block.sync() + + # ------------------------------------------------------------------ + # Phase 3: solve L @ x = z in-place in v. + # For each i_d from 0 upward: + # x[i_d] = z[i_d] - sum_{j < i_d} L[i_d, j] * x[j] + # ------------------------------------------------------------------ + for i_d in range(n_dofs): + dot = gs.qd_float(0.0) + j_d = tid + while j_d < i_d: + dot = dot + L[i_d, j_d] * v[j_d] + j_d = j_d + BLOCK_DIM + if qd.static(ENABLE_WARP_REDUCTION): + for offset in qd.static([16, 8, 4, 2, 1]): + dot = dot + qd.simt.warp.shfl_down_f32(qd.u32(0xFFFFFFFF), dot, offset) + if lane_id == 0: + partial[warp_id] = dot + else: + partial[tid] = dot + qd.simt.block.sync() + + if tid == 0: + total = gs.qd_float(0.0) + for k in qd.static(range(NUM_WARPS)) if qd.static(ENABLE_WARP_REDUCTION) else range(BLOCK_DIM): + total = total + partial[k] + v[i_d] = v[i_d] - total + qd.simt.block.sync() + + # Copy the final solution from LDS back to global out[entity_dof_start..end, i_b]. + k_d = tid + while k_d < n_dofs: + out[entity_dof_start + k_d, i_b] = v[k_d] + k_d = k_d + BLOCK_DIM + + @qd.func def func_torque_and_passive_force( entities_state: array_class.EntitiesState, diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py index ed08afc47..bca1d1abc 100644 --- a/genesis/engine/solvers/rigid/constraint/solver.py +++ b/genesis/engine/solvers/rigid/constraint/solver.py @@ -9,7 +9,7 @@ import genesis.utils.array_class as array_class import genesis.utils.geom as gu -from genesis.engine.solvers.rigid.abd import func_solve_mass_batch +from genesis.engine.solvers.rigid.abd import func_solve_mass_batch, func_solve_mass_tiled from genesis.utils.misc import qd_to_torch, indices_to_mask, assign_indexed_tensor from ..collider.contact_island import ContactIsland @@ -2771,18 +2771,16 @@ def func_update_gradient_tiled( ) if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.CG): - qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) - for i_b in range(_B): - func_solve_mass_batch( - i_b, - constraint_state.grad, - constraint_state.Mgrad, - array_class.PLACEHOLDER, - entities_info=entities_info, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - is_backward=False, - ) + # Wave-cooperative LDLT back-solve: 1 wave (BLOCK_DIM=64) per (entity, env), + # mirrors func_cholesky_solve_tiled for the Newton path. Replaces the serial + # per-env func_solve_mass_batch loop. + func_solve_mass_tiled( + constraint_state.grad, + constraint_state.Mgrad, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton): func_cholesky_solve_tiled(constraint_state, static_rigid_sim_config) diff --git a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py index e5cb7f655..9ebbe1052 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py +++ b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py @@ -181,19 +181,16 @@ def _kernel_update_gradient( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): - """Step 5: Update gradient""" - _B = constraint_state.grad.shape[1] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) - for i_b in range(_B): - if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: - solver.func_update_gradient_batch( - i_b, - dofs_state=dofs_state, - entities_info=entities_info, - rigid_global_info=rigid_global_info, - constraint_state=constraint_state, - static_rigid_sim_config=static_rigid_sim_config, - ) + """Step 5: Update gradient — delegates to the dispatcher so AMDGPU gets + the tiled Cholesky/mass-solve path (func_update_gradient_tiled) instead + of per-env serial func_update_gradient_batch.""" + solver.func_update_gradient( + dofs_state=dofs_state, + entities_info=entities_info, + rigid_global_info=rigid_global_info, + constraint_state=constraint_state, + static_rigid_sim_config=static_rigid_sim_config, + ) @ti.kernel(fastcache=gs.use_fastcache) @@ -215,7 +212,7 @@ def _kernel_update_search_direction( ) -@solver.func_solve_body.register(is_compatible=lambda *args, **kwargs: gs.backend in {gs.cuda}) +@solver.func_solve_body.register(is_compatible=lambda *args, **kwargs: gs.backend is not gs.cpu) def func_solve_decomposed( entities_info, dofs_state,