diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py index 974b1dde0..21cf8f643 100644 --- a/genesis/engine/solvers/rigid/constraint/solver.py +++ b/genesis/engine/solvers/rigid/constraint/solver.py @@ -2185,6 +2185,7 @@ def func_ls_init_and_eval_p0_opt( constraint_state.jv[i_c, i_b] = jv # -- quad_gauss (same as original func_ls_init) -- + quad_gauss_0 = constraint_state.gauss[i_b] quad_gauss_1 = gs.qd_float(0.0) quad_gauss_2 = gs.qd_float(0.0) for i_d in range(n_dofs): @@ -2193,12 +2194,9 @@ def func_ls_init_and_eval_p0_opt( - constraint_state.search[i_d, i_b] * dofs_state.force[i_d, i_b] ) quad_gauss_2 = quad_gauss_2 + 0.5 * constraint_state.search[i_d, i_b] * constraint_state.mv[i_d, i_b] - constraint_state.quad_gauss[0, i_b] = constraint_state.gauss[i_b] - constraint_state.quad_gauss[1, i_b] = quad_gauss_1 - constraint_state.quad_gauss[2, i_b] = quad_gauss_2 # -- Compute quad per constraint and accumulate by type -- - quad_total_0 = constraint_state.gauss[i_b] + quad_total_0 = quad_gauss_0 quad_total_1 = quad_gauss_1 quad_total_2 = quad_gauss_2 eq_sum_0 = gs.qd_float(0.0) @@ -2244,10 +2242,11 @@ def func_ls_init_and_eval_p0_opt( quad_total_1 = quad_total_1 + qf_1 * active quad_total_2 = quad_total_2 + qf_2 * active - # Write eq_sum to global for subsequent calls - constraint_state.eq_sum[0, i_b] = eq_sum_0 - constraint_state.eq_sum[1, i_b] = eq_sum_1 - constraint_state.eq_sum[2, i_b] = eq_sum_2 + # Thread the equality-included quadratic base through subsequent evaluations + # as locals instead of storing to global fields and reloading every time. + base_0 = quad_gauss_0 + eq_sum_0 + base_1 = quad_gauss_1 + eq_sum_1 + base_2 = quad_gauss_2 + eq_sum_2 # Return p0 result (alpha=0) cost = quad_total_0 @@ -2256,22 +2255,159 @@ def func_ls_init_and_eval_p0_opt( if hess <= 0.0: hess = rigid_global_info.EPS[None] - constraint_state.ls_it[i_b] = 1 + ls_it = gs.qd_int(1) - return gs.qd_float(0.0), cost, grad, hess + return gs.qd_float(0.0), cost, grad, hess, base_0, base_1, base_2, ls_it + + +@qd.func +def func_ls_init_and_eval_p0_search_lds( + i_b, + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: qd.template(), +): + """Fused linesearch initialization and first evaluation point (alpha=0) for a single environment. + + Merges init (computing mv, jv, quad_gauss) and alpha=0 evaluation into a single pass, and pre-computes eq_sum + (the summed quadratic coefficients for always-active equality constraints) for reuse by subsequent evaluation calls. + + Bandwidth optimization: quad coefficients (D*Ja*Ja, D*jv*Ja, D*jv*jv) are recomputed on the fly from Jaref, jv, + and efc_D (~8 FLOPs per constraint) instead of being precomputed and stored to a separate quad array. At 0.2% + compute utilization (0.40 FLOPs/byte, 147x below roofline), this trades negligible compute for eliminating 3 global + memory writes per constraint during init and 3 reads per constraint in every subsequent evaluation call — a 40% + bandwidth reduction for contacts (5→3 loads) and 29% for friction (7→5 loads) in the hottest loop.""" + n_dofs = constraint_state.search.shape[0] + n_entities = static_rigid_sim_config.n_entities_ + ne = constraint_state.n_constraints_equality[i_b] + nef = ne + constraint_state.n_constraints_frictionloss[i_b] + n_con = constraint_state.n_constraints[i_b] + + # BLOCK_DIM tracks the kernel's block_dim. On AMDGPU (wave64) we run the + # constraint solver with block_dim=64 (see func_solve_body_monolith and + # the B3/B4 kernels in solver_amdgpu.py); the LDS tile is sized accordingly. + BLOCK_DIM = qd.static(64) + N_DOFS = qd.static(static_rigid_sim_config.tiled_n_dofs) + tid = i_b & (BLOCK_DIM - 1) + search_lds = qd.simt.block.SharedArray((N_DOFS, BLOCK_DIM), gs.qd_float) + for i_d in range(n_dofs): + search_lds[i_d, tid] = constraint_state.search[i_d, i_b] + + # -- mv and jv (same as original func_ls_init) -- + for i_e in range(n_entities): + for i_d1 in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): + mv = gs.qd_float(0.0) + for i_d2 in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): + mv = mv + rigid_global_info.mass_mat[i_d1, i_d2, i_b] * search_lds[i_d2, tid] + constraint_state.mv[i_d1, i_b] = mv + + for i_c in range(n_con): + jv = gs.qd_float(0.0) + if qd.static(static_rigid_sim_config.sparse_solve): + # C4: read from compact (N_c, max_nz, B) storage (parity with + # func_ls_init_and_eval_p0_opt) — bounds operands by max_nz per + # row and improves L2 hit-rate vs the dense (N_c, N_d, B) jac. + for i_d_ in range(constraint_state.jac_n_relevant_dofs[i_c, i_b]): + i_d = constraint_state.jac_relevant_dofs[i_c, i_d_, i_b] + jv = jv + constraint_state.jac_compact_values[i_c, i_d_, i_b] * search_lds[i_d, tid] + else: + for i_d in range(n_dofs): + jv = jv + constraint_state.jac[i_c, i_d, i_b] * search_lds[i_d, tid] + constraint_state.jv[i_c, i_b] = jv + + # -- quad_gauss (same as original func_ls_init) -- + quad_gauss_0 = constraint_state.gauss[i_b] + quad_gauss_1 = gs.qd_float(0.0) + quad_gauss_2 = gs.qd_float(0.0) + for i_d in range(n_dofs): + quad_gauss_1 = quad_gauss_1 + ( + search_lds[i_d, tid] * constraint_state.Ma[i_d, i_b] + - search_lds[i_d, tid] * dofs_state.force[i_d, i_b] + ) + quad_gauss_2 = quad_gauss_2 + 0.5 * search_lds[i_d, tid] * constraint_state.mv[i_d, i_b] + + # -- Compute quad per constraint and accumulate by type -- + quad_total_0 = quad_gauss_0 + quad_total_1 = quad_gauss_1 + quad_total_2 = quad_gauss_2 + eq_sum_0 = gs.qd_float(0.0) + eq_sum_1 = gs.qd_float(0.0) + eq_sum_2 = gs.qd_float(0.0) + + # Recompute quad on the fly from Jaref, jv, efc_D — avoids writing/reading the quad array entirely. + # 3 loads per constraint (Jaref, jv, D) + ~8 FLOPs, vs 3 writes + 3 reads through global memory. + for i_c in range(n_con): + Jaref_c = constraint_state.Jaref[i_c, i_b] + jv_c = constraint_state.jv[i_c, i_b] + D = constraint_state.efc_D[i_c, i_b] + qf_0 = D * (0.5 * Jaref_c * Jaref_c) + qf_1 = D * (jv_c * Jaref_c) + qf_2 = D * (0.5 * jv_c * jv_c) + + if i_c < ne: + # Equality: always active + eq_sum_0 = eq_sum_0 + qf_0 + eq_sum_1 = eq_sum_1 + qf_1 + eq_sum_2 = eq_sum_2 + qf_2 + quad_total_0 = quad_total_0 + qf_0 + quad_total_1 = quad_total_1 + qf_1 + quad_total_2 = quad_total_2 + qf_2 + elif i_c < nef: + # Friction: check linear regime at x=Jaref (alpha=0) + f = constraint_state.efc_frictionloss[i_c, i_b] + r = constraint_state.diag[i_c, i_b] + rf = r * f + linear_neg = Jaref_c <= -rf + linear_pos = Jaref_c >= rf + if linear_neg or linear_pos: + qf_0 = linear_neg * f * (-0.5 * rf - Jaref_c) + linear_pos * f * (-0.5 * rf + Jaref_c) + qf_1 = linear_neg * (-f * jv_c) + linear_pos * (f * jv_c) + qf_2 = 0.0 + quad_total_0 = quad_total_0 + qf_0 + quad_total_1 = quad_total_1 + qf_1 + quad_total_2 = quad_total_2 + qf_2 + else: + # Contact: check Jaref < 0 + active = Jaref_c < 0 + quad_total_0 = quad_total_0 + qf_0 * active + quad_total_1 = quad_total_1 + qf_1 * active + quad_total_2 = quad_total_2 + qf_2 * active + + # Thread the equality-included quadratic base through subsequent evaluations + # as locals instead of storing to global fields and reloading every time. + base_0 = quad_gauss_0 + eq_sum_0 + base_1 = quad_gauss_1 + eq_sum_1 + base_2 = quad_gauss_2 + eq_sum_2 + + # Return p0 result (alpha=0) + cost = quad_total_0 + grad = quad_total_1 + hess = 2 * quad_total_2 + if hess <= 0.0: + hess = rigid_global_info.EPS[None] + + ls_it = gs.qd_int(1) + + return gs.qd_float(0.0), cost, grad, hess, base_0, base_1, base_2, ls_it @qd.func def func_ls_point_fn_opt( i_b, alpha, + base_0, + base_1, + base_2, + ls_it, constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, ): """Evaluate linesearch cost, gradient, and curvature at a single candidate alpha. Iterates over only friction and contact constraints — equality constraints are skipped by initializing accumulators - from quad_gauss + eq_sum (pre-computed during init). + from base_{0,1,2} = quad_gauss + eq_sum (pre-computed during init). Quad coefficients are recomputed on the fly from Jaref, jv, efc_D rather than read from a precomputed quad array. This reduces per-constraint loads from 5 to 3 (contacts) and 7 to 5 (friction), a 40%/29% bandwidth reduction. @@ -2280,10 +2416,9 @@ def func_ls_point_fn_opt( nef = ne + constraint_state.n_constraints_frictionloss[i_b] n_con = constraint_state.n_constraints[i_b] - # Start from quad_gauss + eq_sum (skips ne equality constraints) - quad_total_0 = constraint_state.quad_gauss[0, i_b] + constraint_state.eq_sum[0, i_b] - quad_total_1 = constraint_state.quad_gauss[1, i_b] + constraint_state.eq_sum[1, i_b] - quad_total_2 = constraint_state.quad_gauss[2, i_b] + constraint_state.eq_sum[2, i_b] + quad_total_0 = base_0 + quad_total_1 = base_1 + quad_total_2 = base_2 # Friction constraints [ne, nef): 5 loads (Jaref, jv, D, f, diag) + recompute quad for i_c in range(ne, nef): @@ -2327,9 +2462,9 @@ def func_ls_point_fn_opt( if hess <= 0.0: hess = rigid_global_info.EPS[None] - constraint_state.ls_it[i_b] = constraint_state.ls_it[i_b] + 1 + ls_it = ls_it + 1 - return alpha, cost, grad, hess + return alpha, cost, grad, hess, ls_it @qd.func @@ -2338,13 +2473,17 @@ def func_ls_point_fn_3alphas_opt( alpha_0, alpha_1, alpha_2, + base_0, + base_1, + base_2, + ls_it, constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, ): """Evaluate linesearch cost, gradient, and curvature at three candidate alphas in a single constraint loop pass. Batches three candidate step sizes into one loop, amortizing per-constraint loads (Jaref, jv, efc_D, etc.) across - all three evaluations. Equality constraints are skipped via quad_gauss + eq_sum. + all three evaluations. Equality constraints are skipped via base_{0,1,2} = quad_gauss + eq_sum. Quad coefficients are recomputed on the fly from Jaref, jv, efc_D — same bandwidth optimization as func_ls_point_fn_opt (3 loads per contact instead of 5, 5 per friction instead of 7). Combined with 3-alpha @@ -2353,11 +2492,6 @@ def func_ls_point_fn_3alphas_opt( nef = ne + constraint_state.n_constraints_frictionloss[i_b] n_con = constraint_state.n_constraints[i_b] - # Start from quad_gauss + eq_sum for all 3 - base_0 = constraint_state.quad_gauss[0, i_b] + constraint_state.eq_sum[0, i_b] - base_1 = constraint_state.quad_gauss[1, i_b] + constraint_state.eq_sum[1, i_b] - base_2 = constraint_state.quad_gauss[2, i_b] + constraint_state.eq_sum[2, i_b] - t0_0, t0_1, t0_2 = base_0, base_1, base_2 t1_0, t1_1, t1_2 = base_0, base_1, base_2 t2_0, t2_1, t2_2 = base_0, base_1, base_2 @@ -2456,12 +2590,12 @@ def func_ls_point_fn_3alphas_opt( if hess_2 <= 0.0: hess_2 = EPS - constraint_state.ls_it[i_b] = constraint_state.ls_it[i_b] + 3 + ls_it = ls_it + 3 costs = qd.Vector([cost_0, cost_1, cost_2]) grads = qd.Vector([grad_0, grad_1, grad_2]) hess = qd.Vector([hess_0, hess_1, hess_2]) - return costs, grads, hess + return costs, grads, hess, ls_it @qd.func @@ -2546,14 +2680,23 @@ def _func_linesearch_phase3_batch( p2_cost, p2_deriv_0, p2_deriv_1, + base_0, + base_1, + base_2, + ls_it, + ls_result, constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, ): # B5: Phase 3 (refinement with batched 3-alpha evaluation) extracted from - # func_linesearch_batch. The narrow input list scopes Phase 1/2 locals + # the linesearch outer loop. The narrow input list scopes Phase 1/2 locals # (p0_alpha/deriv_*, snorm, scale, direction, p2update, done, ...) to # before this call, narrowing the live-range graph the AMDGPU register # allocator has to work with through the bracketing loop. + # + # Linesearch state (base_0/1/2 = quad_gauss + eq_sum, ls_it, ls_result) + # is threaded as locals/return values rather than via constraint_state + # field round-trips, eliminating per-call HBM stores+loads on these fields. alpha_0 = p1_alpha - p1_deriv_0 / p1_deriv_1 # Newton from p1 alpha_1 = p1_alpha # p2_next (= current p1) alpha_2 = (p1_alpha + p2_alpha) * 0.5 # midpoint @@ -2561,9 +2704,18 @@ def _func_linesearch_phase3_batch( res_alpha = gs.qd_float(0.0) done = False - while constraint_state.ls_it[i_b] < rigid_global_info.ls_iterations[None]: - costs, grads, hess = func_ls_point_fn_3alphas_opt( - i_b, alpha_0, alpha_1, alpha_2, constraint_state, rigid_global_info + while ls_it < rigid_global_info.ls_iterations[None]: + costs, grads, hess, ls_it = func_ls_point_fn_3alphas_opt( + i_b, + alpha_0, + alpha_1, + alpha_2, + base_0, + base_1, + base_2, + ls_it, + constraint_state, + rigid_global_info, ) alphas = qd.Vector([alpha_0, alpha_1, alpha_2]) @@ -2620,9 +2772,9 @@ def _func_linesearch_phase3_batch( if b1 == 0 and b2 == 0: if costs[2] < p0_cost: - constraint_state.ls_result[i_b] = 0 + ls_result = 0 else: - constraint_state.ls_result[i_b] = 7 + ls_result = 7 res_alpha = alpha_2 done = True @@ -2636,20 +2788,20 @@ def _func_linesearch_phase3_batch( if not done: if p1_cost <= p2_cost and p1_cost < p0_cost: - constraint_state.ls_result[i_b] = 4 + ls_result = 4 res_alpha = p1_alpha elif p2_cost <= p1_cost and p2_cost < p0_cost: - constraint_state.ls_result[i_b] = 4 + ls_result = 4 res_alpha = p2_alpha else: - constraint_state.ls_result[i_b] = 5 + ls_result = 5 res_alpha = 0.0 - return res_alpha + return res_alpha, ls_result @qd.func -def func_linesearch_batch( +def func_linesearch_batch_global( i_b, entities_info: array_class.EntitiesInfo, dofs_state: array_class.DofsState, @@ -2665,20 +2817,139 @@ def func_linesearch_batch( snorm = qd.sqrt(snorm) scale = rigid_global_info.meaninertia[i_b] * qd.max(1, n_dofs) gtol = rigid_global_info.tolerance[None] * rigid_global_info.ls_tolerance[None] * snorm * scale - constraint_state.gtol[i_b] = gtol + ls_it = gs.qd_int(0) + ls_result = gs.qd_int(0) - constraint_state.ls_it[i_b] = 0 - constraint_state.ls_result[i_b] = 0 + res_alpha = gs.qd_float(0.0) + done = False + + if snorm < rigid_global_info.EPS[None]: + ls_result = 1 + res_alpha = 0.0 + else: + # Phase 1: Init + p0 + p1 + p0_alpha, p0_cost, p0_deriv_0, p0_deriv_1, base_0, base_1, base_2, ls_it = func_ls_init_and_eval_p0_opt( + i_b, + entities_info=entities_info, + dofs_state=dofs_state, + constraint_state=constraint_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + p1_alpha, p1_cost, p1_deriv_0, p1_deriv_1, ls_it = func_ls_point_fn_opt( + i_b, + p0_alpha - p0_deriv_0 / p0_deriv_1, + base_0, + base_1, + base_2, + ls_it, + constraint_state, + rigid_global_info, + ) + + if p0_cost < p1_cost: + p1_alpha, p1_cost, p1_deriv_0, p1_deriv_1 = p0_alpha, p0_cost, p0_deriv_0, p0_deriv_1 + + if qd.abs(p1_deriv_0) < gtol: + if qd.abs(p1_alpha) < rigid_global_info.EPS[None]: + ls_result = 2 + else: + ls_result = 0 + res_alpha = p1_alpha + else: + # Phase 2: Bracketing + direction = (p1_deriv_0 < 0) * 2 - 1 + p2update = 0 + p2_alpha, p2_cost, p2_deriv_0, p2_deriv_1 = p1_alpha, p1_cost, p1_deriv_0, p1_deriv_1 + while ( + p1_deriv_0 * direction <= -gtol and ls_it < rigid_global_info.ls_iterations[None] + ): + p2_alpha, p2_cost, p2_deriv_0, p2_deriv_1 = p1_alpha, p1_cost, p1_deriv_0, p1_deriv_1 + p2update = 1 + + p1_alpha, p1_cost, p1_deriv_0, p1_deriv_1, ls_it = func_ls_point_fn_opt( + i_b, + p1_alpha - p1_deriv_0 / p1_deriv_1, + base_0, + base_1, + base_2, + ls_it, + constraint_state, + rigid_global_info, + ) + if qd.abs(p1_deriv_0) < gtol: + res_alpha = p1_alpha + done = True + break + if not done: + if ls_it >= rigid_global_info.ls_iterations[None]: + ls_result = 3 + res_alpha = p1_alpha + done = True + + if not p2update and not done: + ls_result = 6 + res_alpha = p1_alpha + done = True + + if not done: + # B5: Phase 3 split into _func_linesearch_phase3_batch so + # Phase 1/2 locals (p0_alpha/deriv, snorm, scale, direction, + # p2update, done, ...) end their live ranges before the + # call site, shrinking the AMDGPU register-pressure footprint + # of the bracketing loop body. + res_alpha, ls_result = _func_linesearch_phase3_batch( + i_b, + gtol, + p0_cost, + p1_alpha, + p1_cost, + p1_deriv_0, + p1_deriv_1, + p2_alpha, + p2_cost, + p2_deriv_0, + p2_deriv_1, + base_0, + base_1, + base_2, + ls_it, + ls_result, + constraint_state, + rigid_global_info, + ) + return res_alpha + + +@qd.func +def func_linesearch_batch_search_lds( + i_b, + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + rigid_global_info: array_class.RigidGlobalInfo, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + n_dofs = constraint_state.search.shape[0] + ## use adaptive linesearch tolerance + snorm = gs.qd_float(0.0) + for jd in range(n_dofs): + snorm = snorm + constraint_state.search[jd, i_b] ** 2 + snorm = qd.sqrt(snorm) + scale = rigid_global_info.meaninertia[i_b] * qd.max(1, n_dofs) + gtol = rigid_global_info.tolerance[None] * rigid_global_info.ls_tolerance[None] * snorm * scale + ls_it = gs.qd_int(0) + ls_result = gs.qd_int(0) res_alpha = gs.qd_float(0.0) done = False if snorm < rigid_global_info.EPS[None]: - constraint_state.ls_result[i_b] = 1 + ls_result = 1 res_alpha = 0.0 else: # Phase 1: Init + p0 + p1 - p0_alpha, p0_cost, p0_deriv_0, p0_deriv_1 = func_ls_init_and_eval_p0_opt( + p0_alpha, p0_cost, p0_deriv_0, p0_deriv_1, base_0, base_1, base_2, ls_it = func_ls_init_and_eval_p0_search_lds( i_b, entities_info=entities_info, dofs_state=dofs_state, @@ -2686,8 +2957,15 @@ def func_linesearch_batch( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, ) - p1_alpha, p1_cost, p1_deriv_0, p1_deriv_1 = func_ls_point_fn_opt( - i_b, p0_alpha - p0_deriv_0 / p0_deriv_1, constraint_state, rigid_global_info + p1_alpha, p1_cost, p1_deriv_0, p1_deriv_1, ls_it = func_ls_point_fn_opt( + i_b, + p0_alpha - p0_deriv_0 / p0_deriv_1, + base_0, + base_1, + base_2, + ls_it, + constraint_state, + rigid_global_info, ) if p0_cost < p1_cost: @@ -2695,9 +2973,9 @@ def func_linesearch_batch( if qd.abs(p1_deriv_0) < gtol: if qd.abs(p1_alpha) < rigid_global_info.EPS[None]: - constraint_state.ls_result[i_b] = 2 + ls_result = 2 else: - constraint_state.ls_result[i_b] = 0 + ls_result = 0 res_alpha = p1_alpha else: # Phase 2: Bracketing @@ -2705,26 +2983,33 @@ def func_linesearch_batch( p2update = 0 p2_alpha, p2_cost, p2_deriv_0, p2_deriv_1 = p1_alpha, p1_cost, p1_deriv_0, p1_deriv_1 while ( - p1_deriv_0 * direction <= -gtol and constraint_state.ls_it[i_b] < rigid_global_info.ls_iterations[None] + p1_deriv_0 * direction <= -gtol and ls_it < rigid_global_info.ls_iterations[None] ): p2_alpha, p2_cost, p2_deriv_0, p2_deriv_1 = p1_alpha, p1_cost, p1_deriv_0, p1_deriv_1 p2update = 1 - p1_alpha, p1_cost, p1_deriv_0, p1_deriv_1 = func_ls_point_fn_opt( - i_b, p1_alpha - p1_deriv_0 / p1_deriv_1, constraint_state, rigid_global_info + p1_alpha, p1_cost, p1_deriv_0, p1_deriv_1, ls_it = func_ls_point_fn_opt( + i_b, + p1_alpha - p1_deriv_0 / p1_deriv_1, + base_0, + base_1, + base_2, + ls_it, + constraint_state, + rigid_global_info, ) if qd.abs(p1_deriv_0) < gtol: res_alpha = p1_alpha done = True break if not done: - if constraint_state.ls_it[i_b] >= rigid_global_info.ls_iterations[None]: - constraint_state.ls_result[i_b] = 3 + if ls_it >= rigid_global_info.ls_iterations[None]: + ls_result = 3 res_alpha = p1_alpha done = True if not p2update and not done: - constraint_state.ls_result[i_b] = 6 + ls_result = 6 res_alpha = p1_alpha done = True @@ -2734,7 +3019,7 @@ def func_linesearch_batch( # p2update, done, ...) end their live ranges before the # call site, shrinking the AMDGPU register-pressure footprint # of the bracketing loop body. - res_alpha = _func_linesearch_phase3_batch( + res_alpha, ls_result = _func_linesearch_phase3_batch( i_b, gtol, p0_cost, @@ -2746,12 +3031,46 @@ def func_linesearch_batch( p2_cost, p2_deriv_0, p2_deriv_1, + base_0, + base_1, + base_2, + ls_it, + ls_result, constraint_state, rigid_global_info, ) return res_alpha +@qd.func +def func_linesearch_batch( + i_b, + entities_info: array_class.EntitiesInfo, + dofs_state: array_class.DofsState, + rigid_global_info: array_class.RigidGlobalInfo, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: qd.template(), +): + if qd.static(static_rigid_sim_config.backend == gs.amdgpu): + return func_linesearch_batch_search_lds( + i_b, + entities_info=entities_info, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + constraint_state=constraint_state, + static_rigid_sim_config=static_rigid_sim_config, + ) + else: + return func_linesearch_batch_global( + i_b, + entities_info=entities_info, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + constraint_state=constraint_state, + static_rigid_sim_config=static_rigid_sim_config, + ) + + # ===================================================================================================================== # ================================================= Solving Algorithm ================================================= # ===================================================================================================================== @@ -2994,6 +3313,7 @@ def func_update_gradient( @qd.func def func_terminate_or_update_descent_batch( i_b, + prev_cost, constraint_state: array_class.ConstraintState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: qd.template(), @@ -3002,7 +3322,7 @@ def func_terminate_or_update_descent_batch( # Check convergence, i.e. whether the cost function is not longer decreasing or the gradient is flat tol_scaled = (rigid_global_info.meaninertia[i_b] * qd.max(1, n_dofs)) * rigid_global_info.tolerance[None] - improvement = constraint_state.prev_cost[i_b] - constraint_state.cost[i_b] + improvement = prev_cost - constraint_state.cost[i_b] grad_norm = gs.qd_float(0.0) for i_d in range(n_dofs): grad_norm = grad_norm + constraint_state.grad[i_d, i_b] * constraint_state.grad[i_d, i_b] @@ -3250,6 +3570,7 @@ def func_solve_iter( constraint_state.cg_prev_grad[i_d, i_b] = constraint_state.grad[i_d, i_b] constraint_state.cg_prev_Mgrad[i_d, i_b] = constraint_state.Mgrad[i_d, i_b] + prev_cost = constraint_state.cost[i_b] func_update_constraint_batch( i_b, qacc=constraint_state.qacc, @@ -3288,6 +3609,7 @@ def func_solve_iter( func_terminate_or_update_descent_batch( i_b, + prev_cost, constraint_state=constraint_state, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, @@ -3318,6 +3640,12 @@ def func_solve_iter_post_linesearch( constraint_state.cg_prev_grad[i_d, i_b] = constraint_state.grad[i_d, i_b] constraint_state.cg_prev_Mgrad[i_d, i_b] = constraint_state.Mgrad[i_d, i_b] + # Capture prev_cost as a local before func_update_constraint_batch + # mutates constraint_state.cost[i_b]. Threading it into + # func_terminate_or_update_descent_batch as an arg avoids the + # constraint_state.prev_cost field round-trip (mirrors the same + # optimization in the inlined func_solve_iter path). + prev_cost = constraint_state.cost[i_b] func_update_constraint_batch( i_b, qacc=constraint_state.qacc, @@ -3356,6 +3684,7 @@ def func_solve_iter_post_linesearch( func_terminate_or_update_descent_batch( i_b, + prev_cost, constraint_state=constraint_state, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, diff --git a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py index f6de5be01..2568ca6d9 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py +++ b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py @@ -213,8 +213,17 @@ def _kernel_update_search_direction( ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL, block_dim=32) for i_b in range(_B): if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: + # The decomposed CUDA path splits cost-update (Step 4) and the + # convergence check (Step 6) into separate kernels, so prev_cost + # has to be carried via constraint_state.prev_cost (written in + # _kernel_update_constraint_cost). Read it once and pass as a + # local arg to match func_terminate_or_update_descent_batch's + # new signature, where the inlined paths thread it as a local + # to avoid the field round-trip entirely. + prev_cost = constraint_state.prev_cost[i_b] solver.func_terminate_or_update_descent_batch( i_b, + prev_cost, rigid_global_info=rigid_global_info, constraint_state=constraint_state, static_rigid_sim_config=static_rigid_sim_config,