Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/run_chatgpt_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,6 @@ jobs:
PRETRAINED_MODEL_PATH: ./models
SFT_DATASET: ./sft_data
PROMPT_DATASET: ./prompt_data
PROMPT_RLVR_DATASET: ./prompt_data
PREFERENCE_DATASET: ./preference_data
KTO_DATASET: ./kto_data
1 change: 1 addition & 0 deletions applications/ColossalChat/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ temp/
applications/ColossalChat/logs
applications/ColossalChat/models
applications/ColossalChat/sft_data
applications/ColossalChat/kto_data
applications/ColossalChat/prompt_data
applications/ColossalChat/preference_data
applications/ColossalChat/temp
Expand Down
2 changes: 1 addition & 1 deletion applications/ColossalChat/coati/dataset/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def setup_conversation_template(
pass
except ValueError as e:
raise ValueError(e)
if not dist.is_initialized() or dist.get_rank() == 0:
if save_path is not None and (not dist.is_initialized() or dist.get_rank() == 0):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, "w", encoding="utf8") as f:
logger.info(f"Successfully generated a conversation tempalte config, save to {save_path}.")
Expand Down
3 changes: 2 additions & 1 deletion applications/ColossalChat/coati/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,14 @@ def __call__(self, instances: Sequence[Dict[str, List[int]]]) -> Dict[str, torch
`input_ids`: `torch.Tensor` of shape (bsz, max_len);
`attention_mask`: `torch.BoolTensor` of shape (bsz, max_len);
"""
gt_answer = [ins.get("gt_answer", None) for ins in instances]
instances = [{"input_ids": ins["input_ids"], "labels": ins["input_ids"]} for ins in instances]
ret = super().__call__(instances=instances)
input_ids = F.pad(
ret["input_ids"], (self.max_length - ret["input_ids"].size(1), 0), value=self.tokenizer.pad_token_id
)
attention_mask = F.pad(ret["attention_mask"], (self.max_length - ret["attention_mask"].size(1), 0), value=False)
return {"input_ids": input_ids, "attention_mask": attention_mask}
return {"input_ids": input_ids, "attention_mask": attention_mask, "gt_answer": gt_answer}


@dataclass
Expand Down
23 changes: 15 additions & 8 deletions applications/ColossalChat/coati/dataset/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def tokenize_prompt(
ignore_index: the ignore index when calculate loss during training
max_length: the maximum context length
"""

messages = data_point["messages"]
template = deepcopy(conversation_template)
template.messages = []
Expand All @@ -167,7 +166,6 @@ def tokenize_prompt(
if len(template.messages) % 2 != 1:
# exclude the answer if provided. keep only the prompt
template.messages = template.messages[:-1]

# Prepare data
prompt = template.get_prompt(length=len(template.messages), add_generation_prompt=True)
tokenized = tokenizer([prompt], add_special_tokens=False)["input_ids"][0]
Expand All @@ -185,12 +183,21 @@ def tokenize_prompt(
)

# `inputs_decode` can be used to check whether the tokenization method is true.
return dict(
input_ids=tokenized,
inputs_decode=prompt,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
)
if "gt_answer" in data_point:
return dict(
input_ids=tokenized,
inputs_decode=prompt,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
gt_answer=data_point["gt_answer"],
)
else:
return dict(
input_ids=tokenized,
inputs_decode=prompt,
seq_length=len(tokenized),
seq_category=data_point["category"] if "category" in data_point else "None",
)


def apply_rlhf_data_format(template: Conversation, tokenizer: Any):
Expand Down
10 changes: 9 additions & 1 deletion applications/ColossalChat/coati/experience_buffer/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = T
self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}")
# TODO(ver217): add prefetch
self.items: List[BufferItem] = []
self.rng_sequence = []
self.ptr = 0

@torch.no_grad()
def append(self, experience: Experience) -> None:
Expand All @@ -40,6 +42,9 @@ def append(self, experience: Experience) -> None:
if samples_to_remove > 0:
logger.warning(f"Experience buffer is full. Removing {samples_to_remove} samples.")
self.items = self.items[samples_to_remove:]
self.rng_sequence = [i for i in range(len(self.items))]
random.shuffle(self.rng_sequence)
self.ptr = 0
Comment thread
YeAnbang marked this conversation as resolved.

def clear(self) -> None:
self.items.clear()
Expand All @@ -52,7 +57,10 @@ def sample(self) -> Experience:
Returns:
A batch of sampled experiences.
"""
items = random.sample(self.items, self.sample_batch_size)
items = []
for _ in range(self.sample_batch_size):
self.ptr = (self.ptr + 1) % len(self.items)
items.append(self.items[self.rng_sequence[self.ptr]])
experience = make_experience_batch(items)
if self.cpu_offload:
experience.to_device(self.target_device)
Expand Down
Loading