diff --git a/applications/Chat/coati/ray/detached_trainer_base.py b/applications/Chat/coati/ray/detached_trainer_base.py index 1137d8f7b491..a4f666dc5714 100644 --- a/applications/Chat/coati/ray/detached_trainer_base.py +++ b/applications/Chat/coati/ray/detached_trainer_base.py @@ -43,14 +43,15 @@ def __init__(self, self.callbacks = callbacks self.target_holder_name_list = experience_maker_holder_name_list self.target_holder_list = [] - + self._is_target_holder_initialized = False self._debug = debug - def update_target_holder_list(self, experience_maker_holder_name_list): - self.target_holder_name_list = experience_maker_holder_name_list - self.target_holder_list = [] - for name in self.target_holder_name_list: - self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) + def update_target_holder_list(self): + # as the length of target_holder_list may be zero, we need to check it by a bool flag + if not self._is_target_holder_initialized: + for name in self.target_holder_name_list: + self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) + self._is_target_holder_initialized = True @abstractmethod def _update_remote_makers(self, fully_update: bool = False, **kwargs): diff --git a/applications/Chat/coati/ray/detached_trainer_ppo.py b/applications/Chat/coati/ray/detached_trainer_ppo.py index b0630cd0b5ae..c5459c4d96d1 100644 --- a/applications/Chat/coati/ray/detached_trainer_ppo.py +++ b/applications/Chat/coati/ray/detached_trainer_ppo.py @@ -17,6 +17,7 @@ get_actor_from_args, get_critic_from_args, get_model_numel, + get_rank, get_strategy_from_args, is_rank_0, set_dist_env, @@ -102,38 +103,36 @@ def __init__( dataloader_pin_memory=dataloader_pin_memory, callbacks=callbacks, debug=debug) + if self._debug: + print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}') @ray.method(concurrency_group="model_io") @torch.no_grad() def _update_remote_makers(self, fully_update: bool = False, **config): # TODO: balance duties - if is_rank_0(): - self.update_target_holder_list(self.target_holder_name_list) - # mark start, ensure order - tasks = [] - for target_holder in self.target_holder_list: - tasks.append(target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update)) - ray.get(tasks) + self.update_target_holder_list() + # mark start, ensure order + tasks = [] + for target_holder in self.target_holder_list: + tasks.append(target_holder.update_experience_maker.remote(chunk_start=True, fully_update=fully_update)) + ray.get(tasks) # sending loop tasks = [] for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_model(self.actor), **config): - if is_rank_0(): - for target_holder in self.target_holder_list: - tasks.append( - target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard, - fully_update=fully_update)) + for target_holder in self.target_holder_list: + tasks.append( + target_holder.update_experience_maker.remote(new_actor_state_dict=state_dict_shard, + fully_update=fully_update)) # sending loop for state_dict_shard in self._get_model_state_dict_shard(self.strategy._unwrap_critic(self.critic), **config): - if is_rank_0(): - for target_holder in self.target_holder_list: - tasks.append( - target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard, - fully_update=fully_update)) - ray.get(tasks) - if is_rank_0(): - # mark end for target_holder in self.target_holder_list: - target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update) + tasks.append( + target_holder.update_experience_maker.remote(new_critic_state_dict=state_dict_shard, + fully_update=fully_update)) + ray.get(tasks) + # mark end + for target_holder in self.target_holder_list: + target_holder.update_experience_maker.remote(chunk_end=True, fully_update=fully_update) @ray.method(concurrency_group="compute") def training_step(self, experience: Experience) -> Dict[str, float]: diff --git a/applications/Chat/coati/ray/experience_maker_holder.py b/applications/Chat/coati/ray/experience_maker_holder.py index ebeb58137370..4616c01bdf0f 100644 --- a/applications/Chat/coati/ray/experience_maker_holder.py +++ b/applications/Chat/coati/ray/experience_maker_holder.py @@ -19,7 +19,7 @@ from torch import Tensor from tqdm import tqdm -from .utils import get_model_numel, is_rank_0, set_dist_env +from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env @ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1}) @@ -50,8 +50,8 @@ def __init__( if env_info: set_dist_env(env_info=env_info) self.target_trainer_list = [] - for name in detached_trainer_name_list: - self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) + assert len(detached_trainer_name_list) > 0 + self._detached_trainer_name_list = detached_trainer_name_list self.strategy = strategy_fn() self.buffer_cpu_offload = buffer_cpu_offload self.kl_coef = kl_coef @@ -81,8 +81,10 @@ def __init__( self._target_idx = 0 - if self._debug and not self._is_fully_initialized: - print('[maker] Waiting for INIT') + if self._debug: + print(f'[maker{get_rank()}] will send items to {self._detached_trainer_name_list}') + if not self._is_fully_initialized: + print(f'[maker{get_rank()}] Waiting for INIT') def _get_ready(self): while not self._fully_initialized(): @@ -91,10 +93,11 @@ def _get_ready(self): def _fully_initialized(self): return self._is_fully_initialized - def update_target_trainer_list(self, detached_trainer_name_list): - self.target_trainer_list = [] - for name in detached_trainer_name_list: - self.target_trainer_list.append(ray.get_actor(name)) + def _init_target_trainer_list(self): + if len(self.target_trainer_list) > 0: + return + for name in self._detached_trainer_name_list: + self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"])) # copy from ../trainer/base.py @ray.method(concurrency_group="compute") @@ -106,43 +109,9 @@ def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experien else: raise ValueError(f'Unsupported input type "{type(inputs)}"') - # TODO(ver217): remove this method - @ray.method(concurrency_group="experience_io") - def _send_experience(self, experience): - if not self.target_auto_balance: - # choose the trainer in polling mannar - if not hasattr(self, "_target_idx"): - self._target_idx = 0 - chosen_trainer = self.target_trainer_list[self._target_idx] - if self._debug: - print(f"[maker] sending exp to {chosen_trainer}") - chosen_trainer.buffer_append.remote(experience) - self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list) - else: - # choose a trainer that has the least experience batch in its detached_replay_buffer - chosen_trainer = None - min_length = None - if self._debug: - print("[maker] choosing tartget trainer") - while chosen_trainer is None: - for target_trainer in self.target_trainer_list: - try: - temp_length = ray.get(target_trainer.buffer_get_length.remote(), timeout=0.1) - if min_length is None: - min_length = temp_length - chosen_trainer = target_trainer - else: - if temp_length < min_length: - min_length = temp_length - chosen_trainer = target_trainer - except GetTimeoutError: - pass - if self._debug: - print(f"[maker] sending exp to {chosen_trainer}") - chosen_trainer.buffer_append.remote(experience) - @ray.method(concurrency_group="experience_io") def _send_items(self, experience: Experience) -> None: + self._init_target_trainer_list() items = split_experience_batch(experience) items_per_trainer = [[] for _ in range(len(self.target_trainer_list))] for item in items: diff --git a/applications/Chat/coati/ray/utils.py b/applications/Chat/coati/ray/utils.py index 6e62ba0b4841..bc38bd012d61 100644 --- a/applications/Chat/coati/ray/utils.py +++ b/applications/Chat/coati/ray/utils.py @@ -18,6 +18,14 @@ def is_rank_0() -> bool: return not dist.is_initialized() or dist.get_rank() == 0 +def get_rank() -> int: + return dist.get_rank() if dist.is_initialized() else 0 + + +def get_world_size() -> int: + return dist.get_world_size() if dist.is_initialized() else 1 + + def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0): if model == 'gpt2': actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank) @@ -76,9 +84,9 @@ def get_strategy_from_args(strategy: str): elif strategy == 'colossalai_zero2': strategy_ = ColossalAIStrategy(stage=2, placement_policy='cuda') elif strategy == 'colossalai_gemini_cpu': - strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5) + strategy_ = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5) elif strategy == 'colossalai_zero2_cpu': - strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') + strategy_ = ColossalAIStrategy(stage=2, placement_policy='cpu') else: raise ValueError(f'Unsupported strategy "{strategy}"') return strategy_ @@ -126,3 +134,18 @@ def state_dict_to(state_dict: Dict[str, Any], def get_model_numel(model: nn.Module) -> int: numel = sum(p.numel() for p in model.parameters()) return numel + + +def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: int, allow_idle_sender: bool) -> list: + target_receivers = [] + if num_senders <= num_receivers or allow_idle_sender: + # a sender will send data to one or more than one receivers + # a receiver only has one sender + for i in range(num_receivers): + if i % num_senders == sender_idx: + target_receivers.append(i) + else: + # a sender will send data to one receiver + # a receiver may have more than one sender + target_receivers.append(sender_idx % num_receivers) + return target_receivers diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index b809c010247b..d39092d5e7ad 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -219,4 +219,4 @@ def get_model_state_dict_shard(self, model: nn.Module, **config): if isinstance(module, LoraLinear): module.merge_weights = True module.eval() - yield from model.state_dict_shard(max_shard_size=1024) + yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False) diff --git a/applications/Chat/coati/trainer/strategies/ddp.py b/applications/Chat/coati/trainer/strategies/ddp.py index a0fd3fa27a58..4600c63907e8 100644 --- a/applications/Chat/coati/trainer/strategies/ddp.py +++ b/applications/Chat/coati/trainer/strategies/ddp.py @@ -1,13 +1,12 @@ -from typing import Optional - import os import random +from typing import Optional import numpy as np import torch import torch.distributed as dist import torch.nn as nn -from coati.models.base import LM, Actor, RewardModel, Critic +from coati.models.base import LM, Actor, Critic, RewardModel from coati.models.lora import LoraLinear from coati.replay_buffer import ReplayBuffer from torch.nn.parallel import DistributedDataParallel as DDP @@ -30,19 +29,8 @@ def __init__(self, seed: int = 42) -> None: super().__init__() def setup_distributed(self) -> None: - try: - rank = int(os.environ['RANK']) - local_rank = int(os.environ['LOCAL_RANK']) - world_size = int(os.environ['WORLD_SIZE']) - host = os.environ['MASTER_ADDR'] - port = int(os.environ['MASTER_PORT']) - except KeyError as e: - raise RuntimeError( - f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" - ) - dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank) + self._try_init_dist(force=True) self.set_seed(self.seed) - torch.cuda.set_device(local_rank) def set_seed(self, seed: int) -> None: random.seed(seed) @@ -74,21 +62,25 @@ def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False def _unwrap_actor(actor: Actor) -> nn.Module: model: DDP = Strategy._unwrap_actor(actor) return model.module - + @staticmethod def _unwrap_critic(critic: Critic) -> nn.Module: model: DDP = Strategy._unwrap_critic(critic) return model.module - def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + def save_model(self, + model: nn.Module, + path: str, + only_rank0: bool = False, + tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: if only_rank0 and dist.get_rank() != 0: return None - + for module in model.modules(): if isinstance(module, LoraLinear): module.merge_weights = True module.eval() - + if isinstance(model, RewardModel): state_dict = model.state_dict() if only_rank0 and dist.get_rank() != 0: @@ -114,4 +106,3 @@ def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = Fal def setup_sampler(self, dataset) -> DistributedSampler: return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank()) - diff --git a/applications/Chat/coati/trainer/strategies/naive.py b/applications/Chat/coati/trainer/strategies/naive.py index a22be1181fb8..a761786e0f21 100644 --- a/applications/Chat/coati/trainer/strategies/naive.py +++ b/applications/Chat/coati/trainer/strategies/naive.py @@ -1,11 +1,13 @@ +import os from typing import Any, Optional import torch +import torch.distributed as dist import torch.nn as nn import torch.optim as optim -from coati.replay_buffer import ReplayBuffer from coati.models.base import LM, RewardModel from coati.models.lora import LoraLinear +from coati.replay_buffer import ReplayBuffer from torch.optim import Optimizer from torch.utils.data import DataLoader from transformers.tokenization_utils_base import PreTrainedTokenizerBase @@ -25,7 +27,7 @@ def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None: optimizer.step() def setup_distributed(self) -> None: - pass + self._try_init_dist(force=False) def setup_model(self, model: nn.Module) -> nn.Module: return model @@ -41,12 +43,16 @@ def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False pin_memory=pin_memory, collate_fn=replay_buffer.collate_fn) - def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + def save_model(self, + model: nn.Module, + path: str, + only_rank0: bool = False, + tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: for module in model.modules(): if isinstance(module, LoraLinear): module.merge_weights = True module.eval() - + if isinstance(model, RewardModel): state_dict = model.state_dict() torch.save(state_dict, path) @@ -77,10 +83,28 @@ def get_model_state_dict_shard(self, model: nn.Module, **config): # TODO: implement sharding on naive strategy state_dict = model.state_dict() yield state_dict - + def merge_lora_weight(self, model: nn.Module): unwrapped_model = self._unwrap_model(model) for module in unwrapped_model.modules(): if isinstance(module, LoraLinear): module.merge_weights = True - module.eval() \ No newline at end of file + module.eval() + + def _try_init_dist(self, force: bool = False) -> None: + try: + rank = int(os.environ['RANK']) + local_rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) + host = os.environ['MASTER_ADDR'] + port = int(os.environ['MASTER_PORT']) + dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank) + torch.cuda.set_device(local_rank) + except KeyError as e: + if force: + raise RuntimeError( + f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" + ) + except Exception as e: + if force: + raise e diff --git a/applications/Chat/examples/ray/1mmt_dummy.py b/applications/Chat/examples/ray/1mmt_dummy.py index d293e6940fbe..d2e820680114 100644 --- a/applications/Chat/examples/ray/1mmt_dummy.py +++ b/applications/Chat/examples/ray/1mmt_dummy.py @@ -11,6 +11,7 @@ from coati.ray.utils import ( get_actor_from_args, get_critic_from_args, + get_receivers_per_sender, get_reward_model_from_args, get_strategy_from_args, ) @@ -74,7 +75,7 @@ def model_fn(): return actor, critic, reward_model, initial_model # configure Experience Maker - experience_holder_ref = ExperienceMakerHolder.options(name="maker1", num_gpus=1, max_concurrency=2).remote( + experience_holder_ref = ExperienceMakerHolder.options(name="maker0", num_gpus=1, max_concurrency=2).remote( detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)], strategy_fn=partial(get_strategy_from_args, args.maker_strategy), model_fn=model_fn, @@ -102,7 +103,9 @@ def trainer_model_fn(): # configure Trainer trainer_refs = [ DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote( - experience_maker_holder_name_list=["maker1"], + experience_maker_holder_name_list=[ + f'maker{x}' for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True) + ], strategy_fn=partial(get_strategy_from_args, args.trainer_strategy), model_fn=trainer_model_fn, env_info=env_info_trainer, diff --git a/applications/Chat/examples/ray/mmmt_dummy.py b/applications/Chat/examples/ray/mmmt_dummy.py new file mode 100644 index 000000000000..767fe37030f6 --- /dev/null +++ b/applications/Chat/examples/ray/mmmt_dummy.py @@ -0,0 +1,188 @@ +import argparse +import os +import socket +from functools import partial + +import ray +import torch +from coati.quant import llama_load_quant, low_resource_init +from coati.ray.detached_trainer_ppo import DetachedPPOTrainer +from coati.ray.experience_maker_holder import ExperienceMakerHolder +from coati.ray.utils import ( + get_actor_from_args, + get_critic_from_args, + get_receivers_per_sender, + get_reward_model_from_args, + get_strategy_from_args, +) +from torch.utils.data import DataLoader +from transformers import AutoConfig, AutoTokenizer +from transformers.modeling_utils import no_init_weights + + +def get_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + return s.getsockname()[1] + + +def get_local_ip(): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(('8.8.8.8', 80)) + return s.getsockname()[0] + + +def main(args): + master_addr = str(get_local_ip()) + # trainer_env_info + trainer_port = str(get_free_port()) + env_info_trainers = [{ + 'local_rank': '0', + 'rank': str(rank), + 'world_size': str(args.num_trainers), + 'master_port': trainer_port, + 'master_addr': master_addr + } for rank in range(args.num_trainers)] + + # maker_env_info + maker_port = str(get_free_port()) + env_info_makers = [{ + 'local_rank': '0', + 'rank': str(rank), + 'world_size': str(args.num_makers), + 'master_port': maker_port, + 'master_addr': master_addr + } for rank in range(args.num_makers)] + + # configure tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + + def model_fn(): + actor_cfg = AutoConfig.from_pretrained(args.pretrain) + critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain) + actor = get_actor_from_args(args.model, config=actor_cfg).half().cuda() + critic = get_critic_from_args(args.critic_model, config=critic_cfg).half().cuda() + reward_model = get_reward_model_from_args(args.critic_model, config=critic_cfg).half().cuda() + if args.initial_model_quant_ckpt is not None and args.model == 'llama': + # quantize initial model + with low_resource_init(), no_init_weights(): + initial_model = get_actor_from_args(args.model, config=actor_cfg) + initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, + args.quant_group_size).cuda() + else: + initial_model = get_actor_from_args(args.model, config=actor_cfg).half().cuda() + return actor, critic, reward_model, initial_model + + # configure Experience Maker + experience_holder_refs = [ + ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote( + detached_trainer_name_list=[ + f'trainer{x}' + for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False) + ], + strategy_fn=partial(get_strategy_from_args, args.maker_strategy), + model_fn=model_fn, + env_info=env_info_maker, + kl_coef=0.1, + debug=args.debug, + # sync_models_from_trainers=True, + # generation kwargs: + max_length=512, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + eval_performance=True, + use_cache=True, + ) + for i, env_info_maker in enumerate(env_info_makers) + ] + + def trainer_model_fn(): + actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda() + critic = get_critic_from_args(args.critic_model, + config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda() + return actor, critic + + # configure Trainer + trainer_refs = [ + DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote( + experience_maker_holder_name_list=[ + f"maker{x}" + for x in get_receivers_per_sender(i, args.num_trainers, args.num_makers, allow_idle_sender=True) + ], + strategy_fn=partial(get_strategy_from_args, args.trainer_strategy), + model_fn=trainer_model_fn, + env_info=env_info_trainer, + train_batch_size=args.train_batch_size, + buffer_limit=16, + eval_performance=True, + debug=args.debug, + ) + for i, env_info_trainer in enumerate(env_info_trainers) + ] + + dataset_size = args.experience_batch_size * 4 + + def data_gen_fn(): + input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device()) + attn_mask = torch.ones_like(input_ids) + return {'input_ids': input_ids, 'attention_mask': attn_mask} + + def build_dataloader(size): + dataset = [data_gen_fn() for _ in range(size)] + dataloader = DataLoader(dataset, batch_size=args.experience_batch_size) + return dataloader + + # uncomment this function if sync_models_from_trainers is True + # ray.get([ + # trainer_ref.sync_models_to_remote_makers.remote() + # for trainer_ref in trainer_refs + # ]) + + wait_tasks = [] + + for experience_holder_ref in experience_holder_refs: + wait_tasks.append( + experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size), + num_steps=args.experience_steps)) + + total_steps = args.experience_batch_size * args.experience_steps * \ + args.num_makers // (args.num_trainers * args.train_batch_size) + for trainer_ref in trainer_refs: + wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs)) + + ray.get(wait_tasks) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--num_makers', type=int, default=1) + parser.add_argument('--num_trainers', type=int, default=1) + parser.add_argument('--trainer_strategy', + choices=[ + 'naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', + 'colossalai_zero2_cpu' + ], + default='naive') + parser.add_argument('--maker_strategy', choices=['naive'], default='naive') + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) + parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--critic_pretrain', type=str, default=None) + parser.add_argument('--experience_steps', type=int, default=4) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--train_epochs', type=int, default=1) + parser.add_argument('--update_steps', type=int, default=2) + parser.add_argument('--train_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + + parser.add_argument('--initial_model_quant_ckpt', type=str, default=None) + parser.add_argument('--quant_bits', type=int, default=4) + parser.add_argument('--quant_group_size', type=int, default=128) + parser.add_argument('--debug', action='store_true') + args = parser.parse_args() + ray.init(namespace=os.environ["RAY_NAMESPACE"]) + main(args)