Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ class Booster:
Examples:
```python
colossalai.launch(...)
plugin = GeminiPlugin(stage=3, ...)
plugin = GeminiPlugin(...)
booster = Booster(precision='fp16', plugin=plugin)

model = GPT2()
optimizer = Adam(model.parameters())
optimizer = HybridAdam(model.parameters())
dataloader = Dataloader(Dataset)
lr_scheduler = LinearWarmupScheduler()
criterion = GPTLMLoss()
Expand Down
4 changes: 2 additions & 2 deletions docs/source/en/features/zero_with_chunk.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def get_data(batch_size, seq_len, vocab_size):
Finally, we define a model which uses Gemini + ZeRO DDP and define our training loop, As we pre-train GPT in this example, we just use a simple language model loss:

```python
from torch.optim import Adam
from colossalai.nn.optimizer import HybridAdam

from colossalai.booster import Booster
from colossalai.zero import ColoInitContext
Expand All @@ -211,7 +211,7 @@ def main():

# build criterion
criterion = GPTLMLoss()
optimizer = Adam(model.parameters(), lr=0.001)
optimizer = HybridAdam(model.parameters(), lr=0.001)

torch.manual_seed(123)
default_pg = ProcessGroup(tp_degree=args.tp_degree)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/zh-Hans/features/zero_with_chunk.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def get_data(batch_size, seq_len, vocab_size):
最后,使用booster注入 Gemini + ZeRO DDP 特性, 并定义训练循环。由于我们在这个例子中对GPT进行预训练,因此只使用了一个简单的语言模型损失函数:

```python
from torch.optim import Adam
from colossalai.nn.optimizer import HybridAdam

from colossalai.booster import Booster
from colossalai.zero import ColoInitContext
Expand All @@ -213,7 +213,7 @@ def main():

# build criterion
criterion = GPTLMLoss()
optimizer = Adam(model.parameters(), lr=0.001)
optimizer = HybridAdam(model.parameters(), lr=0.001)

torch.manual_seed(123)
default_pg = ProcessGroup(tp_degree=args.tp_degree)
Expand Down