-
-
Notifications
You must be signed in to change notification settings - Fork 1k
ROCm Support for AI-Toolkit #563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Add device_map=None to prevent automatic CPU placement - Use torch.device context to hint GPU loading - Immediately move transformers to GPU after loading - Keep quantized blocks on GPU for ROCm (only move to CPU in low_vram mode) - Fix device consistency in quantization process
- Split 36-channel latents into 16-channel latents and 20-channel conditioning - Add defensive checks to ensure consistent channel counts throughout pipeline - Move conditioning to device once before loop to avoid repeated transfers - Fix mask application with proper shape matching - Remove zero conditioning fallback that produced garbage output - Add clear error messages for missing i2v conditioning
- Add first frame caching to LatentCachingMixin when do_i2v is enabled - Store first frames separately from latents for i2v conditioning - Add get_first_frame_path() and get_first_frame() methods - Add first_frame_tensor to DataLoaderBatchDTO for cached first frames - Support both disk and memory caching for first frames
- Ensure VAE is on correct device before encoding first frames - Temporarily move VAE to GPU when needed, even if latents are cached - Support cached first frames when batch.tensor is None - Add clear error messages for missing first frames - Fix device management for both 14B and 5B models
- Add ROCm/AMD GPU detection via rocm-smi alongside NVIDIA support - Update GPU API to detect and return both NVIDIA and ROCm GPUs - Fix jobs route to use gpu_ids from request body instead of hardcoding '0' - Update GPUMonitor component to display ROCm GPUs - Add Windows support for ROCm detection - Parse rocm-smi CSV output for GPU stats (temperature, utilization, memory, power, clocks)
- Use Card SKU as primary name source (more descriptive than hex IDs) - Fallback to descriptive 'AMD GPU N' format if SKU is hex or missing - Prevents unstable/incorrect GPU names in the UI
- Add dedicated section for AMD GPU installation using ROCm - Document uv virtual environment setup (.venv) - Include specific ROCm PyTorch installation command for gfx1151 - Add note about different GPU architectures and how to check them - Clarify CUDA vs ROCm installation paths
- Generate dummy solid gray first frame when ctrl_img is None - Ensures model always receives expected 36 channels (16 latent + 20 conditioning) - Allows baseline sampling to work without requiring control image in config - Fixes device mismatch by creating tensor on CPU first, then moving to target device
- Encode only first frame instead of full video sequence with zeros - Replicate encoded latent to match required number of frames - Significantly reduces memory usage and computation - Add explicit device management and synchronization points - Ensure VAE is in eval mode before encoding - Move VAE back to CPU after encoding if it was originally there - Add better error handling and synchronization for ROCm/HIP compatibility
- Implement WanAttnProcessor2_0Flash for optimized attention computation - Support both Flash Attention 2 and 3 APIs - Fallback to PyTorch SDP when Flash Attention not available - Fix ROCm compatibility by checking for 'hip' device type - Enable flash attention on all transformer blocks (including dual-transformer WAN 2.2) - Add use_flash_attention option to model config - Add sdp: true option to training config for PyTorch flash SDP backend
- Set AMD_SERIALIZE_KERNEL=3 for better error reporting - Set TORCH_USE_HIP_DSA=1 for device-side assertions - Set HSA_ENABLE_SDMA=0 for APU compatibility - Set PYTORCH_ROCM_ALLOC_CONF for better VRAM fragmentation - Apply ROCm variables in run.py before torch imports - Pass ROCm variables when launching jobs from UI - Improve HIP error detection and debugging capabilities
- Prioritize Card Model over Card SKU when SKU is hex ID - Provide more descriptive GPU names for AMD GPUs - Fallback to generic names when model/SKU not available
- Create start_toolkit.sh for Linux with modes: setup, train, gradio, ui - Create start_toolkit.ps1 for Windows with same functionality - Create setup.sh for automated Linux installation - Create setup.ps1 for automated Windows installation - Auto-detect virtual environment (uv .venv or standard venv) - Auto-detect GPU backend (ROCm or CUDA) - Set ROCm environment variables automatically - Verify dependencies before running - Support training options: --recover, --name, --log - Support UI options: --port, --dev (development mode)
- Add Quick Start section using setup scripts - Keep manual installation instructions for advanced users - Add detailed running instructions for both Linux and Windows - Document startup script modes and options - Include examples for all common use cases - Improve organization and readability
- Update wan22_14b_model.py with minor improvements - Update UI package-lock.json with dependency changes
- Disable ROCBLAS_USE_HIPBLASLT by default to prevent HIPBLAS_STATUS_INTERNAL_ERROR with quantized models - Fix UI mode to not set ROCm env vars (let run.py handle them when jobs spawn) - Only pass through ROCm env vars in startJob.ts if already set in parent process - Add ROCBLAS_LOG_LEVEL=0 to reduce logging overhead - Unset problematic ROCm vars in UI mode to prevent freezing during model loading This fixes job freezing when launched from UI and prevents crashes during quantized model training.
…clamping - Implement validation and clamping for GPU stats in `getNvidiaGpuStats` and `getRocmGpuStats` functions to prevent invalid values. - Update GPU utilization and memory percentage calculations in `GPUWidget` to handle edge cases and ensure proper display. - Ensure non-negative values for power draw, clock speeds, and fan speeds, improving overall robustness of GPU data handling.
…r experience - Enhance parsing logic in `getRocmGpuStats` to handle edge cases, ensuring valid temperature, power draw, and clock speed values. - Update GPUWidget to display 'N/A' for invalid or zero values, improving clarity in the UI. - Implement additional safety checks to prevent numeric values from being displayed as GPU names, ensuring more descriptive outputs.
- Introduce optional `hasAmdSmi` field in `GPUApiResponse` to indicate AMD GPU support. - Implement `checkAmdSmi` function to verify the presence of AMD SMI tool for GPU detection. - Update GPU stats retrieval logic to include AMD GPU metrics, ensuring comprehensive support alongside NVIDIA and ROCm GPUs. - Enhance error handling and logging for improved debugging and user experience.
|
FYI the https://rocm.nightlies.amd.com/v2/ website does not have a directory for each GPU arch by their exact name. You can try accessing it to view the directory. changing line 103 in setup.sh to Did not get around to testing everything else yet. [Later Edit]: After manually installing npm and self-compiling bitsandbytes, the training functionality works fine. |
|
Thank you for the excellent feedback! I'm not sure why but the PR isn't picking up my later changes to the branch. I've addressed the GPU rocm-smi issues but will want to look at a better way to address your hardware specific issues. Thank you again for the kind words and the time you've spent testing the effort 🙏 |
- Update README.md to clarify GPU architecture detection and mapping for ROCm installations. - Implement GPU architecture mapping functions in setup.ps1, setup.sh, and start_toolkit.sh to streamline the installation process. - Improve error handling and user prompts for GPU architecture detection, ensuring users can easily specify or auto-detect their GPU architecture. - Add detailed notes on common GPU architectures and their corresponding ROCm directory names for better user guidance.
|
Hi, I'm trying this with an AMD RYZEN AI MAX+ 395 w/ Radeon 8060S (gfx1151) and getting this error: `(venv) alessio@fedora ~/dati/ai-toolkit (pr-563) $ ./start_toolkit.sh ui [INFO] Virtual environment already active: /home/alessio/dati/ai-toolkit/venv [INFO] Starting in mode: ui [INFO] Launching web UI on port 8675... added 490 packages, and audited 491 packages in 4s 67 packages are looking for funding 3 vulnerabilities (2 high, 1 critical) To address issues that do not require attention, run: To address all issues, run: Run
up to date, audited 491 packages in 779ms 67 packages are looking for funding 3 vulnerabilities (2 high, 1 critical) To address issues that do not require attention, run: To address all issues, run: Run
Prisma schema loaded from prisma/schema.prisma ✔ Generated Prisma Client (v6.3.1) to ./node_modules/@prisma/client in 26ms Start by importing your Prisma Client (See: https://pris.ly/d/importing-client) Tip: Want real-time updates to your database without manual polling? Discover how with Pulse: https://pris.ly/tip-0-pulse Prisma schema loaded from prisma/schema.prisma SQLite database aitk_db.db created at file:../../aitk_db.db 🚀 Your database is now in sync with your Prisma schema. Done in 26ms ✔ Generated Prisma Client (v6.3.1) to ./node_modules/@prisma/client in 27ms
▲ Next.js 15.1.7 Creating an optimized production build ... ./node_modules/systeminformation/lib/cpu.js Import trace for requested module: ✓ Compiled successfully Route (app) Size First Load JS
ƒ Middleware 32.1 kB ○ (Static) prerendered as static content
[WORKER] TOOLKIT_ROOT: /home/alessio/dati/ai-toolkit Any idea what to try? Thanks |
Addresses feedback from PR ostris#563. Removed standalone 'self.accelerator' line at line 2248 that served no purpose.
Addresses feedback from PR ostris#563. Replaced data reassignment with explicit String() conversion for better null/undefined handling when processing caption data from API responses.
Addresses feedback from PR ostris#563. Added try-catch block to properly handle database errors and return appropriate error responses instead of letting errors propagate unhandled.
Addresses feedback from PR ostris#563. Fixed several issues with ROCm GPU stats parsing: - Fixed hardcoded field indices in error messages to use dynamic field indices - Made field parsing errors conditional on development mode to reduce log noise - Improved power parsing to detect and skip clock-like values (e.g., "(1472Mhz)") - Added automatic Hz to MHz conversion for clock values that are way too high - Better handling of edge cases in power and clock value parsing This fixes the recurring errors: - "Could not parse power from field[9]=\"(1472Mhz)\"" - "Invalid graphics clock 51834MHz from field[5]=\"51834\""
- Make amd-smi the primary tool for AMD GPU monitoring, with rocm-smi as fallback - Improve field detection to handle multiple possible field names (temperature, power, etc.) - Enhance GPU name detection to try multiple paths in static data structure - Ensure all metrics are properly mapped: temperature, fan speed, GPU load, memory, clock speeds, and names - Fallback to rocm-smi only if amd-smi is unavailable or returns insufficient data
bab8bec and 2f0eb5f hopefully fixed your issue. Thank you for pasting the log! |
|
cool, thanks @iGavroche . I'll be able to test this in a few days and will let you know here for sure. |
asoldano
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@iGavroche , I've tested your changes and I confirm the warnings are gone. The AMD Strix Halo is also properly detected and shown in the UI.
Having said that, before you fixed this properly, I worked around the problem by asking Cline+GLM4.7 to come up with a solution; I'm mentioning this because in both cases, the training is still actually failing to start. In stderr.log I can see terminate called after throwing an instance of 'std::bad_alloc' what(): std::bad_alloc
I've digged the problem a bit (again with Cline, so might not be super accurate) and eventually discovered that torchaudio is crashing with the std::bad_alloc when imported. Getting rid of the only torchaudio import in config_modules.py (or going with something like asoldano@58514f8 ) allows me starting and completing a lora training, though I'm not really sure if this is a problem with my environment only or what (btw, I'm using Python 3.13, I also had to use a more recent scipy version). Hope this feedback is of some use.
|
@asoldano thanks for the feedback. I believe the issue to be with the --pre TheRock builds. Probably |
Add ROCm/AMD GPU Support and Enhancements
This PR adds comprehensive ROCm/AMD GPU support to the AI Toolkit, along with significant improvements to WAN model handling, UI enhancements, and developer experience improvements.
Major Features
ROCm/AMD GPU Support
rocm-smi, alongside existing NVIDIA supportgpu_idsfrom request body instead of hardcoded valuesSetup and Startup Scripts
setup.sh(Linux) andsetup.ps1(Windows) for automated installationstart_toolkit.sh(Linux) andstart_toolkit.ps1(Windows) with multiple modes:setup: Install dependenciestrain: Run training jobsgradio: Launch Gradio interfaceui: Launch web UI.venvor standard venv) and GPU backend (ROCm or CUDA)--recover,--name,--logflags--portand--dev(development mode) flagsWAN Model Improvements
Image-to-Video (i2v) Enhancements
Flash Attention Support
WanAttnProcessor2_0Flashfor optimized attention computationuse_flash_attentionoption to model config andsdp: truefor training configDevice Management
UI Enhancements
GPU Monitoring
GPUMonitorcomponent to display ROCm GPUs alongside NVIDIAJob Management
ROCBLAS_USE_HIPBLASLTby default to prevent crashes with quantized modelsEnvironment Variables and Configuration
ROCm Environment Variables
AMD_SERIALIZE_KERNEL=3for better error reportingTORCH_USE_HIP_DSA=1for device-side assertionsHSA_ENABLE_SDMA=0for APU compatibilityPYTORCH_ROCM_ALLOC_CONFfor better VRAM fragmentationROCBLAS_LOG_LEVEL=0to reduce logging overheadrun.pybefore torch imports and passed when launching jobs from UIrun.pyhandle them when jobs spawn)Documentation
uvTechnical Details
Key Files Modified
run.py: ROCm environment variable setupui/src/app/api/gpu/route.ts: ROCm GPU detection and statsui/src/components/GPUMonitor.tsx&GPUWidget.tsx: ROCm GPU displaytoolkit/models/wan21/wan_attn_flash.py: Flash Attention implementationextensions_built_in/diffusion_models/wan22/*: WAN model improvementstoolkit/dataloader_mixins.py: First frame cachingstart_toolkit.sh&start_toolkit.ps1: Startup scriptssetup.sh&setup.ps1: Setup scriptsTesting Considerations
Bug Fix Commits
Migration Notes
start_toolkit.sh/start_toolkit.ps1) are recommended but not requiredrun.py