diff --git a/flame/data.py b/flame/data.py index c9a5bc6..8feb866 100644 --- a/flame/data.py +++ b/flame/data.py @@ -557,8 +557,9 @@ def build_dataset( color = utils.Color min_num_shards = dp_degree * num_workers if dp_degree else None if len(dataset.split(',')) == 1: + dataset_path = dataset dataset = load_dataset( - path=dataset, + path=dataset_path, name=dataset_name, split=dataset_split, data_dir=data_dir, @@ -585,7 +586,7 @@ def build_dataset( f"{color.reset}" ) dataset = load_dataset( - path=dataset, + path=dataset_path, name=dataset_name, split=dataset_split, data_dir=data_dir,