Conversation
btaba
added a commit
that referenced
this pull request
Mar 4, 2026
* Update Python versions in CI to 3.11, 3.12, and 3.13 * Update ci.yml
AIRJASON50
added a commit
to wuji-technology/wujimjx-brax
that referenced
this pull request
Mar 9, 2026
* Replace unicode escaped characters in ipynb files PiperOrigin-RevId: 856196218 Change-Id: I42b3faac6a8f923078c55fc431b526656c19cbfd * Add soft-sign clipping (mean_clip_scale) and configurable mean_kernel_init to PolicyModuleWithStd Add two new optional parameters to PolicyModuleWithStd: - mean_clip_scale: Applies softsign clipping to mean output: scale * (mean / (1 + |mean|)) - mean_kernel_init: Configurable kernel initializer for the final mean Dense layer Also adds policy_network_kwargs to make_ppo_networks for generic pass-through of additional options. PiperOrigin-RevId: 862772852 Change-Id: I063c99d08ce7fc6ce1c5c6b94cb636c7d4cf4bbf * Flatten params. PiperOrigin-RevId: 862801302 Change-Id: If14899543ab36dc983235ffdd6f5b76493324538 * Add mean_kernel_init_fn to checkpoint serialization keywords PiperOrigin-RevId: 862855364 Change-Id: I0fad9dac158a7be2215b1db0c1ce38931be018f0 * Check more general exception type in assert. PiperOrigin-RevId: 866629105 Change-Id: I6a7e2fb51af09589c106268eba29ba48dce86bb9 * [pmap] Prepare brax ES agent for jax_pmap_shmap_merge=True. PiperOrigin-RevId: 867693155 Change-Id: Ibf11ef57bcd9588d1cd55f1124d2c90b809ed1f3 * Fix checkpoint.save PiperOrigin-RevId: 868221106 Change-Id: I81a24e66f1fac351553bbc79c955c0c06a5b76ae * Fix for brax es train, which was broken by cl/869007637. PiperOrigin-RevId: 869379212 Change-Id: Idacecde1a97a0612fd475e79a861d0086f252652 * Bump brax to 0.14.1 PiperOrigin-RevId: 869389467 Change-Id: Icac5463f0f2c6c4499a724cd70c16f6bec112bd1 * [pmap] Remove `jax.config.pmap_shmap_merge`. `jax.config.pmap_shmap_merge` was deprecated as of JAX v0.9.0 in January 2025 and will be removed in JAX v0.10.0 in April 2025. PiperOrigin-RevId: 875621716 Change-Id: I50df76fe83f1f5ee0cc69ced2ddb79b6a27ade97 * Vision update (google#662) * Add configurable CNN * Fix spatial softmax * Fix spatial softmax shape * Fix leading batch dimensions * Remove spatial softmax Change-Id: I00d018ca08226e60ba4ddff03a1a3df48e69dced * Fix Brax external tests by updating jax.sharding.PartitionSpec PiperOrigin-RevId: 878160775 Change-Id: I822c76523cff1d72a7c2135ac9332d39f158434e * Import Brax PR google#662: Vision update PiperOrigin-RevId: 878174910 Change-Id: I8088a5aee048b07566ce65e9da6f580814630e60 * Import Brax PR google#662: Vision update PiperOrigin-RevId: 878202976 Change-Id: I8c81f958868a9cc4f3d218b8df1a0e8df8b39d51 * Update Python CI versions (google#663) * Update Python versions in CI to 3.11, 3.12, and 3.13 * Update ci.yml * Add CNN kernel init and spatial softmax (h/t zakka) PiperOrigin-RevId: 879785561 Change-Id: I0e9dc87128ecad6f63953f06b533713f4e6994c8 * Add modifications for WujiHand training - training.py: Add curriculum support - acting.py: Add full_reset support - wrappers/training.py: Inherit info across episodes - base.py: Minor adjustments * Fix curriculum param naming and remove unused metric - Rename global_step to total_training_steps for clarity - Remove eval/epoch_eval_time metric (redundant with eval/sps) * Add CLAUDE.md documenting wuji-custom branch modifications Documents the curriculum learning kwargs passthrough architecture, modified files, data flow, and maintenance guide for future rebases. * Add merge rules and known limitations to CLAUDE.md * Fix CI: make curriculum_params injection opt-in Add inject_curriculum_params parameter (default False) to train(). Previously step_kwargs were always injected, breaking envs that don't accept **kwargs (e.g. InvertedPendulum in tests). Now only injected when explicitly enabled by the caller. * Update CLAUDE.md: document inject_curriculum_params opt-in behavior * Fix CI: avoid closure capture in vmap when kwargs is empty When no kwargs are passed, use the original jax.vmap path without closure to preserve identical numerical behavior with upstream. The closure-based path is only used when curriculum_params is injected. Fixes: training_test.py::test_domain_randomization_wrapper (255 != 256) * fix: separate code paths in DomainRandomizationVmapWrapper for empty kwargs Previous fix used conditional inside closure (kwargs still captured in scope). Now define completely separate step_fn closures - the no-kwargs path never references kwargs, producing identical JAX trace to upstream. * fix: eliminate kwargs closure capture in all wrapper step methods The previous fix only addressed DomainRandomizationVmapWrapper, but the test call chain is AutoResetWrapper → EpisodeWrapper → DRVmapWrapper. EpisodeWrapper's jax.lax.scan closure and AutoResetWrapper's env.step call also captured the empty kwargs dict, changing JAX trace behavior on Python 3.11. Now ALL wrappers use `if not kwargs` guard to ensure the no-kwargs path never references kwargs at all, producing identical traces to upstream. * feat(brax): add bounds_loss to PPO loss function Soft quadratic penalty on |mu| > 1.1 (rl_games industry standard). Controlled by bounds_loss_coef parameter (default 0.0 = off). Uses dist.loc to extract mean, compatible with both NormalDistribution and NormalTanhDistribution. Metrics: bounds_loss, bounds_loss_scaled, bounds_violation_rate. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat(brax): thread bounds_loss_coef through PPO train() Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * test(brax): add bounds_loss unit tests for PPO Tests cover pure math verification, distribution compatibility (normal and tanh_normal), gradient finiteness, and smoke training integration across distribution types and coefficient values. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat(brax): add period-level episode metrics for curriculum decisions - logger.py: accumulate episode length/reward per eval period, expose pop_curriculum_summary() for one-shot consumption - train.py: log training-side episode metrics when log_training_metrics enabled Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: CLAUDE.md wrapper ordering + checkpoint.py None guard for load_config - Fix wrapper call chain diagram: AutoResetWrapper is outermost, not EpisodeWrapper - Add None guard in load_config() for kernel init fn fields to prevent KeyError when loading checkpoints saved with None kernel init fn Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * style: CLAUDE.md nitpicks — remove hardcoded line number, add code block lang Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * docs: CLAUDE.md add eval path curriculum_params note Document that curriculum_params injection only occurs in training rollouts, not in eval path (brax/training/acting.py eval does not pass step_kwargs). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * docs: document multi-host curriculum sync limitation Curriculum update runs only on process_id==0. Single-host multi-GPU (our current setup) is unaffected. Multi-host would need broadcast. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Matej Aleksandrov <maleksandrov@google.com> Co-authored-by: Baruch Tabanpour <btaba@google.com> Co-authored-by: Erik Frey <erikfrey@google.com> Co-authored-by: Daniel Suo <dsuo@google.com> Co-authored-by: Taylor Howell <taylorhowell@google.com> Co-authored-by: Brax Team <no-reply@google.com> Co-authored-by: Mustafa H <34825877+StafaH@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Drop 3.10 and add 3.12, 3.13 to matrix.