Skip to content

[BUG] Accelerate backend with dp fails to load cache #1102

@Aceticia

Description

@Aceticia

Describe the bug

In SampleCache, get_cache_path doesn't consider the rank or apply a lock when saving/loading parquet files. This causes parquet caches to be potentially corrupted when multiple ranks are launched with accelerate backend in calls to cache_samples.

This causes the get_samples_from_cache function calls to fail, because the parquet load call in this block fails (cache_management.py: line 283):

        for task_id in task_ids:
            if task_id.sampling_method != sampling_method:
                continue
            cache_file = self.get_cache_path(task_id)
            try:
                dataset = load_dataset("parquet", data_files=str(cache_file), split="train")
                dataset_df = dataset.to_pandas().set_index("sample_id")
                task_datasets[task_id] = dataset_df
            except Exception as e:
                logger.warning(f"Error loading prediction cache for {str(task_id)}: {e}")

To Reproduce

I have a bit of a complicated setup so I can't provide a simple repro directly here. I can try to produce one if it's absolute necessary.

Expected behavior

Only rank 0 should cache the samples and cache with multiple ranks should work.

Version

I'm using this branch: #1083, but the cache portion of code is identical.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions