Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4283d56
Rename quadrants alias from ti to qd in solver_breakdown.py
hughperkins Mar 29, 2026
fc94ee6
Add cuda_graph_counter ndarray to StructConstraintState
hughperkins Mar 29, 2026
1aeec31
Fix cuda_graph_counter annotation to always be ndarray
hughperkins Mar 29, 2026
df7169a
Convert decomposed solver to CUDA graph with graph_do_while
hughperkins Mar 29, 2026
62edd2a
Rename cuda_graph_counter to graph_continue_loop
hughperkins Mar 29, 2026
310e083
Rename graph_continue_loop to graph_counter
hughperkins Mar 29, 2026
1b4977f
Disable CUDA graph solver when requires_grad is active
hughperkins Mar 29, 2026
7ed9707
update quadrants version
hughperkins Mar 29, 2026
a3d5798
b2
hughperkins Mar 29, 2026
a460681
upgrade Genesis, and rename from cuda_graph to gpu_graph
hughperkins Mar 29, 2026
6635211
always run if not autodiff
hughperkins Mar 29, 2026
cb0c533
precommit
hughperkins Mar 29, 2026
b182d2f
Merge branch 'main' into hp/cuda-graph
hughperkins Mar 29, 2026
bd76fb2
only run decmposed on cuda
hughperkins Mar 30, 2026
d5abc21
fix: use decomposed solver only on CUDA backend, not non-CUDA
hughperkins Mar 30, 2026
8692ac8
Parallelize _func_check_early_exit for 60% dex_hand speedup
hughperkins Mar 30, 2026
329fb75
Merge remote-tracking branch 'myself/hp/cuda-graph-par-exit' into hp/…
hughperkins Mar 30, 2026
c3dceb8
Allow pre-release quadrants in production_build.sh
hughperkins Mar 30, 2026
dad9c23
Rewrap docstring comments to 120-char line width
hughperkins Mar 30, 2026
2328cbe
Remove --prerelease=allow to fix CI EGL crashes
hughperkins Mar 30, 2026
a47e95f
Use atomic_max for early_exit_flag in parallel early exit check
hughperkins Mar 30, 2026
430140d
Revert early_exit_flag parallel kernel back to serial loop
hughperkins Mar 30, 2026
9e5f9ae
b2
hughperkins Mar 30, 2026
d0ed2c3
Revert "Revert early_exit_flag parallel kernel back to serial loop"
hughperkins Mar 30, 2026
7370643
Merge origin/main into hp/cuda-graph
hughperkins Mar 30, 2026
6b3c15f
Broaden cuda-graph solver compatibility to all GPU backends
hughperkins Mar 30, 2026
7bab0c7
quadrants 0.5.0
hughperkins Mar 31, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 104 additions & 93 deletions genesis/engine/solvers/rigid/constraint/solver_breakdown.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import quadrants as ti
import numpy as np
import quadrants as qd

import genesis as gs
import genesis.utils.array_class as array_class
from genesis.engine.solvers.rigid.constraint import solver


@ti.kernel(fastcache=gs.use_fastcache)
def _kernel_linesearch(
@qd.func
def _func_linesearch(
entities_info: array_class.EntitiesInfo,
dofs_state: array_class.DofsState,
constraint_state: array_class.ConstraintState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: ti.template(),
static_rigid_sim_config: qd.template(),
):
_B = constraint_state.grad.shape[1]
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32)
qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32)
for i_b in range(_B):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]:
solver.func_linesearch_and_apply_alpha(
Expand All @@ -29,34 +30,32 @@ def _kernel_linesearch(
constraint_state.improved[i_b] = False


@ti.kernel(fastcache=gs.use_fastcache)
def _kernel_cg_only_save_prev_grad(
@qd.func
def _func_cg_only_save_prev_grad(
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: ti.template(),
static_rigid_sim_config: qd.template(),
):
"""Save prev_grad and prev_Mgrad (CG only)"""
_B = constraint_state.grad.shape[1]
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32)
qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32)
for i_b in range(_B):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]:
solver.func_save_prev_grad(i_b, constraint_state=constraint_state)


@ti.kernel(fastcache=gs.use_fastcache)
def _kernel_update_constraint_forces(
@qd.func
def _func_update_constraint_forces(
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: ti.template(),
static_rigid_sim_config: qd.template(),
):
"""Compute active flags and efc_force, parallelized over (constraint, env)."""
len_constraints = constraint_state.active.shape[0]
_B = constraint_state.grad.shape[1]

for i_c, i_b in ti.ndrange(len_constraints, _B):
for i_c, i_b in qd.ndrange(len_constraints, _B):
if i_c < constraint_state.n_constraints[i_b] and constraint_state.improved[i_b]:
ne = constraint_state.n_constraints_equality[i_b]
nef = ne + constraint_state.n_constraints_frictionloss[i_b]

if ti.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton):
if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton):
constraint_state.prev_active[i_c, i_b] = constraint_state.active[i_c, i_b]

constraint_state.active[i_c, i_b] = True
Expand All @@ -78,16 +77,15 @@ def _kernel_update_constraint_forces(
)


@ti.kernel(fastcache=gs.use_fastcache)
def _kernel_update_constraint_qfrc(
@qd.func
def _func_update_constraint_qfrc(
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: ti.template(),
static_rigid_sim_config: qd.template(),
):
"""Compute qfrc_constraint = J^T @ efc_force, parallelized over (dof, env)."""
n_dofs = constraint_state.qfrc_constraint.shape[0]
_B = constraint_state.grad.shape[1]

for i_d, i_b in ti.ndrange(n_dofs, _B):
for i_d, i_b in qd.ndrange(n_dofs, _B):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]:
n_con = constraint_state.n_constraints[i_b]
qfrc = gs.qd_float(0.0)
Expand All @@ -96,16 +94,15 @@ def _kernel_update_constraint_qfrc(
constraint_state.qfrc_constraint[i_d, i_b] = qfrc


@ti.kernel(fastcache=gs.use_fastcache)
def _kernel_update_constraint_cost(
@qd.func
def _func_update_constraint_cost(
dofs_state: array_class.DofsState,
constraint_state: array_class.ConstraintState,
static_rigid_sim_config: ti.template(),
static_rigid_sim_config: qd.template(),
):
"""Compute gauss and cost (reductions over dofs and constraints). One thread per env."""
_B = constraint_state.grad.shape[1]

ti.loop_config(block_dim=32)
qd.loop_config(block_dim=32)
for i_b in range(_B):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]:
n_dofs = constraint_state.qfrc_constraint.shape[0]
Expand All @@ -118,7 +115,6 @@ def _kernel_update_constraint_cost(
cost_i = gs.qd_float(0.0)
gauss_i = gs.qd_float(0.0)

# Gauss cost from dofs
for i_d in range(n_dofs):
v = (
0.5
Expand All @@ -128,7 +124,6 @@ def _kernel_update_constraint_cost(
gauss_i += v
cost_i += v

# Constraint cost: quadratic + friction linear
for i_c in range(n_con):
cost_i += 0.5 * (
constraint_state.Jaref[i_c, i_b] ** 2
Expand All @@ -149,41 +144,39 @@ def _kernel_update_constraint_cost(
constraint_state.cost[i_b] = cost_i


@ti.kernel(fastcache=gs.use_fastcache)
def _kernel_newton_only_nt_hessian(
@qd.func
def _func_newton_only_nt_hessian(
constraint_state: array_class.ConstraintState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: ti.template(),
static_rigid_sim_config: qd.template(),
):
"""Step 4: Newton Hessian update (Newton only)"""
solver.func_hessian_direct_tiled(constraint_state=constraint_state, rigid_global_info=rigid_global_info)
if ti.static(static_rigid_sim_config.enable_tiled_cholesky_hessian):
if qd.static(static_rigid_sim_config.enable_tiled_cholesky_hessian):
solver.func_cholesky_factor_direct_tiled(
constraint_state=constraint_state,
rigid_global_info=rigid_global_info,
static_rigid_sim_config=static_rigid_sim_config,
)
else:
_B = constraint_state.jac.shape[2]
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32)
qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32)
for i_b in range(_B):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]:
solver.func_cholesky_factor_direct_batch(
i_b=i_b, constraint_state=constraint_state, rigid_global_info=rigid_global_info
)


@ti.kernel(fastcache=gs.use_fastcache)
def _kernel_update_gradient(
@qd.func
def _func_update_gradient(
entities_info: array_class.EntitiesInfo,
dofs_state: array_class.DofsState,
constraint_state: array_class.ConstraintState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: ti.template(),
static_rigid_sim_config: qd.template(),
):
"""Step 5: Update gradient"""
_B = constraint_state.grad.shape[1]
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32)
qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32)
for i_b in range(_B):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]:
solver.func_update_gradient_batch(
Expand All @@ -196,15 +189,14 @@ def _kernel_update_gradient(
)


@ti.kernel(fastcache=gs.use_fastcache)
def _kernel_update_search_direction(
@qd.func
def _func_update_search_direction(
constraint_state: array_class.ConstraintState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: ti.template(),
static_rigid_sim_config: qd.template(),
):
"""Step 6: Check convergence and update search direction"""
_B = constraint_state.grad.shape[1]
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32)
qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32)
for i_b in range(_B):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]:
solver.func_terminate_or_update_descent_batch(
Expand All @@ -215,7 +207,57 @@ def _kernel_update_search_direction(
)


@solver.func_solve_body.register(is_compatible=lambda *args, **kwargs: gs.backend is not gs.cpu)
@qd.func
def _func_check_early_exit(
constraint_state: array_class.ConstraintState,
graph_counter: qd.types.ndarray(qd.i32, ndim=0),
):
"""Decrement iteration counter and exit early if no batch element improved."""
for _ in range(1):
graph_counter[()] = graph_counter[()] - 1
constraint_state.early_exit_flag[()] = 0

_B = constraint_state.grad.shape[1]
for i_b in range(_B):
if constraint_state.improved[i_b]:
qd.atomic_max(constraint_state.early_exit_flag[()], 1)

for _ in range(1):
if constraint_state.early_exit_flag[()] == 0:
graph_counter[()] = 0


@qd.kernel(gpu_graph=True, fastcache=gs.use_fastcache)
def _kernel_solve_gpu_graph(
entities_info: array_class.EntitiesInfo,
dofs_state: array_class.DofsState,
constraint_state: array_class.ConstraintState,
rigid_global_info: array_class.RigidGlobalInfo,
static_rigid_sim_config: qd.template(),
graph_counter: qd.types.ndarray(qd.i32, ndim=0),
):
while qd.graph_do_while(graph_counter):
_func_linesearch(entities_info, dofs_state, constraint_state, rigid_global_info, static_rigid_sim_config)
if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.CG):
_func_cg_only_save_prev_grad(constraint_state, static_rigid_sim_config)
_func_update_constraint_forces(constraint_state, static_rigid_sim_config)
_func_update_constraint_qfrc(constraint_state, static_rigid_sim_config)
_func_update_constraint_cost(dofs_state, constraint_state, static_rigid_sim_config)
if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton):
_func_newton_only_nt_hessian(constraint_state, rigid_global_info, static_rigid_sim_config)
_func_update_gradient(entities_info, dofs_state, constraint_state, rigid_global_info, static_rigid_sim_config)
_func_update_search_direction(constraint_state, rigid_global_info, static_rigid_sim_config)
_func_check_early_exit(constraint_state, graph_counter)


@solver.func_solve_body.register(
is_compatible=lambda entities_info,
dofs_state,
constraint_state,
rigid_global_info,
static_rigid_sim_config,
_n_iterations: not static_rigid_sim_config.requires_grad and static_rigid_sim_config.backend != gs.cpu,
)
def func_solve_decomposed(
entities_info,
dofs_state,
Expand All @@ -225,53 +267,22 @@ def func_solve_decomposed(
_n_iterations,
):
"""
Uses separate kernels for each solver step per iteration.
GPU graph accelerated solver loop with GPU-side iteration via graph_do_while.

This maximizes kernel granularity, potentially allowing better GPU scheduling
and more flexibility in execution, at the cost of more Python→C++ boundary crossings.
On CUDA SM 9.0+ (Hopper), the entire iteration loop runs on the GPU with no host involvement. On older CUDA GPUs,
falls back to a host-side do-while loop that still benefits from CUDA graph kernel launch batching. On other GPUs,
falls back to a host-side C++-side loop, that still reduces python launch overhead.

Early exits when all batch elements have converged (no improved[i_b] is True).
"""
# _n_iterations is a Python-native int to avoid CPU-GPU sync (vs rigid_global_info.iterations[None])
for _it in range(_n_iterations):
_kernel_linesearch(
entities_info,
dofs_state,
constraint_state,
rigid_global_info,
static_rigid_sim_config,
)
if static_rigid_sim_config.solver_type == gs.constraint_solver.CG:
_kernel_cg_only_save_prev_grad(
constraint_state,
static_rigid_sim_config,
)
_kernel_update_constraint_forces(
constraint_state,
static_rigid_sim_config,
)
_kernel_update_constraint_qfrc(
constraint_state,
static_rigid_sim_config,
)
_kernel_update_constraint_cost(
dofs_state,
constraint_state,
static_rigid_sim_config,
)
if static_rigid_sim_config.solver_type == gs.constraint_solver.Newton:
_kernel_newton_only_nt_hessian(
constraint_state,
rigid_global_info,
static_rigid_sim_config,
)
_kernel_update_gradient(
entities_info,
dofs_state,
constraint_state,
rigid_global_info,
static_rigid_sim_config,
)
_kernel_update_search_direction(
constraint_state,
rigid_global_info,
static_rigid_sim_config,
)
if _n_iterations <= 0:
return
constraint_state.graph_counter.from_numpy(np.array(_n_iterations, dtype=np.int32))
_kernel_solve_gpu_graph(
entities_info,
dofs_state,
constraint_state,
rigid_global_info,
static_rigid_sim_config,
constraint_state.graph_counter,
)
5 changes: 5 additions & 0 deletions genesis/utils/array_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ class StructConstraintState(metaclass=BASE_METACLASS):
bw_w: V_ANNOTATION
# Timers for profiling
timers: V_ANNOTATION
# Always ndarray (not field): graph_do_while requires the same physical ndarray on every call.
graph_counter: qd.types.ndarray()
early_exit_flag: V_ANNOTATION


def get_constraint_state(constraint_solver, solver):
Expand Down Expand Up @@ -396,6 +399,8 @@ def get_constraint_state(constraint_solver, solver):
bw_w=V(dtype=gs.qd_float, shape=maybe_shape((len_constraints_, _B), solver._requires_grad)),
# Timers
timers=V(dtype=qd.i64 if gs.backend != gs.metal else qd.i32, shape=(10, _B)),
graph_counter=qd.ndarray(qd.i32, shape=()),
early_exit_flag=V(dtype=qd.i32, shape=()),
)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ readme = "README.md"
requires-python = ">=3.10,<3.14"
dependencies = [
"psutil",
"quadrants==0.4.5",
"quadrants==0.5.0",
"pydantic>=2.11.0",
"numpy>=1.26.4",
"frozendict",
Expand Down
Loading