Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,4 @@ coverage.xml
applications/ColossalChat/logs
applications/ColossalChat/tests/logs
applications/ColossalChat/wandb
applications/ColossalChat/model
2 changes: 0 additions & 2 deletions applications/ColossalChat/coati/distributed/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__(

self.model_config = model_config
self.plugin_config = plugin_config
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"

self.device = get_current_device()
self.lr_scheduler = None
Expand Down Expand Up @@ -95,7 +94,6 @@ def loop(self) -> None:
i = 0
for _ in range(self.num_recv_per_update):
# receive data from producers

for r in range(self.num_producers):
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
self.buffer.extend(
Expand Down
306 changes: 207 additions & 99 deletions applications/ColossalChat/coati/distributed/grpo_consumer.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions applications/ColossalChat/coati/distributed/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def launch_distributed(
master_addr: str = "localhost",
master_port: int = 29500,
core_algo: str = "GRPO",
project_name: Optional[str] = None,
):

if core_algo not in ALGO_MAP:
Expand Down Expand Up @@ -108,6 +109,7 @@ def launch_distributed(
"train_microbatch_size": train_microbatch_size,
},
num_generations=num_generations,
project_name=project_name,
)
procs.append(consumer)
ray.get([p.setup.remote() for p in procs])
Expand Down
63 changes: 49 additions & 14 deletions applications/ColossalChat/rl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,44 @@
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument("-t", "--num-trainers", type=int, default=2)
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
parser.add_argument("-g", "--num-generations", type=int, default=8)
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64)
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8)
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1)
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
parser.add_argument("-b", "--backend", type=str, default="transformers")
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
parser.add_argument(
"-ibs",
"--inference-batch-size",
type=int,
default=64,
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
)
parser.add_argument(
"-imbs",
"--inference-microbatch-size",
type=int,
default=8,
help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.",
)
parser.add_argument(
"-tbs",
"--train-batch-size",
type=int,
default=32,
help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples",
)
parser.add_argument(
"-tMbs",
"--train-minibatch-size",
type=int,
default=1,
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
)
parser.add_argument(
"-tmbs",
"--train-microbatch-size",
type=int,
default=2,
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
)
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
args = parser.parse_args()

Expand All @@ -29,11 +60,7 @@
ray.init(address="local", namespace="ray-example")

inference_model_config = dict(path=args.model)
train_model_config = dict(
path=args.model,
# use_flash_attention_2=True,
# use_cache=False
)
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)

if args.backend == "transformers":
Expand Down Expand Up @@ -91,9 +118,17 @@
generate_config=generate_config,
num_generations=args.num_generations,
train_model_config=train_model_config,
plugin_config={},
# plugin_config={}, # for zero
plugin_config={
"pp_size": 2,
"tp_size": 1,
"microbatch_size": args.train_microbatch_size // 2,
"zero_stage": 0,
"max_norm": 1.0,
}, # for pp
inference_backend=args.backend,
master_addr="localhost",
master_port=29505,
master_port=29506,
core_algo=args.algo,
project_name=args.project,
)
6 changes: 4 additions & 2 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,8 +1411,10 @@ def execute_pipeline(
)

# run with gradients accumulation
if model.require_grad_sync == False or (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
if (
not torch.is_grad_enabled()
or model.require_grad_sync == False
or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
):
return outputs

Expand Down
1 change: 1 addition & 0 deletions colossalai/shardformer/modeling/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def qwen2_for_causal_lm_forward(
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**kwargs,
):
r"""
Args:
Expand Down