-
Notifications
You must be signed in to change notification settings - Fork 18
feat: add precision checker with hook system and command-line control #102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
This PR introduces a comprehensive precision checking system for debugging numerical accuracy issues in distributed training: **Core Features:** - Two-level precision checking (module-level and function-level) - Command-line flags: --precision_check, --precision_check_all_ranks - Extensible hook system for Functions, Modules, and Tensors - Automatic FP32 reference computation for validation **Hook System:** - Forward/backward pre/post hooks for Functions and Modules - Tensor gradient hooks for inspection - Unified hook type definitions to reduce code duplication **Implementation:** - PrecisionChecker utility with configurable check levels - Integration with autograd Function and nn::Module - Support for distributed training (per-rank checking) - Detailed logging to precision_check_rank_[N].log files **Documentation:** - docs/hook_mechanism.md - Hook system architecture - docs/precision_checker_guide.md - Usage guide **Testing:** - test/hook/test_hook.cc - Hook functionality tests - test/hook/test_precision_check.cc - Precision checker tests Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
…omprehensive docs - Add PrecisionCheckConfig and PrecisionCheckContext for better state management - Refactor precision checker to use context-based architecture - Add comprehensive documentation (hook_mechanism.md, precision_checker_guide.md) - Add test cases for hook system and precision checking - Update CMakeLists.txt to include new test targets - Improve command-line flag handling in examples Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
- Unify Function and Module hook infrastructure into common/hook.h - Remove duplicated HookHandle and HookHandleImpl classes - Update precision_checker_guide.md and hook_mechanism.md
This commit fixes the issue where only rank 0 generated precision check log files when running with tensor parallelism. The root cause was that GetLogStream() used process-global static variables, causing all threads in a single process to share the same log file handle. Changes: - Add thread_global_rank thread-local variable to track per-thread rank - Convert GetLogStream() and TableHeaderPrinted() to use thread_local storage - Set thread_global_rank in Train() function for each thread - Move baseline output (key|md5 format) into table format branch to avoid duplicate output in simple format - Add directory creation and error handling for log file opening With these changes, each thread now creates its own log file based on its global rank (process_rank * nthread_per_process + thread_rank). Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
d35e92a to
a7806d9
Compare
Add tools/compare_loss.py to automate end-to-end loss comparison between two log directories, eliminating manual verification overhead as test cases scale up. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
| namespace { | ||
|
|
||
| // Simple MD5 implementation | ||
| class MD5 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接比 tensor 的 md5,而不是看 abs/rel diff 且留一个阈值范围,是不是很难完全一致,而且无法看出差距多大?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
与下个PR一起提交此功能
| utils::PrecisionChecker::RegisterForModule(this); | ||
| precision_check_registered_ = true; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我们应当有一个能够注册全局 module hook 的机制,目前 precision_checker 本质上是注册了一个全局的 module hook,应当是 precision_checker 直接调用注册全局 module hook 的接口(例如在 InitAllEnv 里根据传入的 precision 参数决定是否注册全局 precision_check hook)(可以等下次 pr 再改)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单独提PR修改
- Move PrecisionCheckEnv from nn::parallel::global to utils namespace - Create separate .cc files for precision_check_config and precision_check_context - Move struct/class implementations from headers to source files - Add const qualifiers to local variables in precision checker - Add UNLIKELY macro for branch prediction optimization in Module::operator() Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This PR introduces a comprehensive precision checking system for debugging numerical accuracy issues in distributed training:
Core Features:
Hook System:
Implementation:
Documentation:
Testing: