Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions tutorials/llm/mamba/mamba.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ In order to proceed, ensure that you have met the following requirements:

* A Docker-enabled environment, with `NVIDIA Container Runtime <https://developer.nvidia.com/container-runtime>`_ installed, which will make the container GPU-aware.


* `Authenticate with NVIDIA NGC <https://docs.nvidia.com/nim/large-language-models/latest/getting-started.html#ngc-authentication>`_, generate API KEY from `NGC <https://org.ngc.nvidia.com/setup >`__, add the key to your credentials following instructions in `this guide <https://docs.nvidia.com/launchpad/ai/base-command-coe/latest/bc-coe-docker-basics-step-02.html>`__, and get into NVIDIA NeMo dev container ``nvcr.io/nvidia/nemo:dev``.

Step-by-step Guide for Fine-Tuning
Expand All @@ -51,13 +52,13 @@ Convert the Pytorch Checkpoint to a NeMo Checkpoint

.. code:: bash

CUDA_VISIBLE_DEVICES="0" python /NeMo/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py \
CUDA_VISIBLE_DEVICES="0" python /opt/NeMo/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py \
--input_name_or_path <path to the source pytorch model> \
--output_path <path to target .nemo model> \
--ngroups_mamba 8 \
--mamba_ssm_ngroups 8 \
--precision bf16

* Note: the ``ngroups_mamba`` parameter should be 1 for the Mamba2 models from the `Transformers are SSMs paper <https://arxiv.org/pdf/2405.21060>`__ (130m, 370m, 780m, 1.3b, and 2.7b) and 8 for the Mamba2 and Mamba2-Hybrid models by `NVIDIA <https://arxiv.org/pdf/2406.07887>`__ (both 8b).
* Note: the ``mamba_ssm_ngroups`` parameter should be 1 for the Mamba2 models from the `Transformers are SSMs paper <https://arxiv.org/pdf/2405.21060>`__ (130m, 370m, 780m, 1.3b, and 2.7b) and 8 for the Mamba2 and Mamba2-Hybrid models by `NVIDIA <https://arxiv.org/pdf/2406.07887>`__ (both 8b).

Model (Tensor) Parallelism for the 8b Models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -106,8 +107,8 @@ Run Fine-Tuning
export NVTE_FUSED_ATTN=1
export NVTE_FLASH_ATTN=0

MASTER_PORT=15008 torchrun --nproc_per_node=${NUM_DEVICES}
/home/ataghibakhsh/NeMo/examples/nlp/language_modeling/tuning/megatron_mamba_finetuning.py \
torchrun --nproc_per_node=${NUM_DEVICES}
/opt/NeMo/examples/nlp/language_modeling/tuning/megatron_mamba_finetuning.py \
--config-path=${CONFIG_PATH} \
--config-name=${CONFIG_NAME} \
trainer.devices=${NUM_DEVICES} \
Expand Down