Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions flame/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,8 +557,9 @@ def build_dataset(
color = utils.Color
min_num_shards = dp_degree * num_workers if dp_degree else None
if len(dataset.split(',')) == 1:
dataset_path = dataset

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Introducing dataset_path correctly fixes the bug. On a related note for maintainability, the load_dataset call on the next line is duplicated in the resharding logic (lines 588-597). Consider refactoring these calls to reduce duplication, for example by using a shared dictionary of arguments.

dataset = load_dataset(
path=dataset,
path=dataset_path,
name=dataset_name,
split=dataset_split,
data_dir=data_dir,
Expand All @@ -585,7 +586,7 @@ def build_dataset(
f"{color.reset}"
)
dataset = load_dataset(
path=dataset,
path=dataset_path,
name=dataset_name,
split=dataset_split,
data_dir=data_dir,
Expand Down