Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions applications/Chat/coati/dataset/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,29 +53,25 @@ class SFTDataset(Dataset):

def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None:
super().__init__()
# self.prompts = []
self.input_ids = []

for data in tqdm(dataset, disable=not is_rank_0()):
prompt = data['prompt'] + data['completion'] + "<|endoftext|>"
prompt = data['prompt'] + data['completion'] + tokenizer.eos_token
prompt_token = tokenizer(prompt,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")

# self.prompts.append(prompt_token)s
self.input_ids.append(prompt_token)
self.labels = copy.deepcopy(self.input_ids)
self.input_ids.append(prompt_token['input_ids'][0])
self.labels = copy.deepcopy(self.input_ids)

def __len__(self):
length = len(self.prompts)
length = len(self.input_ids)
return length

def __getitem__(self, idx):
# dict(input_ids=self.input_ids[i], labels=self.labels[i])
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
# return dict(self.prompts[idx], self.prompts[idx])


def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict:
Expand Down
15 changes: 8 additions & 7 deletions applications/Chat/coati/trainer/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def fit(self, logger, log_interval=10):
loss = outputs.loss
prompt_logits = outputs.logits

if loss >= 2.5:
if loss >= 2.5 and is_rank_0():
logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")

loss = loss / self.accimulation_steps
Expand All @@ -110,12 +110,13 @@ def fit(self, logger, log_interval=10):
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
self.scheduler.step()
wandb.log({
"loss": total_loss / self.accimulation_steps,
"lr": self.scheduler.get_last_lr()[0],
"epoch": epoch,
"batch_id": batch_id
})
if is_rank_0():
wandb.log({
"loss": total_loss / self.accimulation_steps,
"lr": self.scheduler.get_last_lr()[0],
"epoch": epoch,
"batch_id": batch_id
})
total_loss = 0
step_bar.update()

Expand Down
2 changes: 1 addition & 1 deletion applications/Chat/examples/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def train(args):
max_datasets_size=args.max_datasets_size,
max_length=max_len)
eval_dataset = None
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

if dist.is_initialized() and dist.get_world_size() > 1:
train_sampler = DistributedSampler(train_dataset,
Expand Down