diff --git a/python/python/lance/torch/data.py b/python/python/lance/torch/data.py index fd2be0da161..d5adcbbfe19 100644 --- a/python/python/lance/torch/data.py +++ b/python/python/lance/torch/data.py @@ -11,7 +11,16 @@ import math import warnings from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Union +from typing import ( + Any, + Dict, + Iterable, + List, + Literal, + Optional, + Protocol, + Union, +) import pyarrow as pa @@ -32,6 +41,17 @@ __all__ = ["LanceDataset", "SafeLanceDataset", "get_safe_loader"] +class ToTensorFn(Protocol): + def __call__( + self, + batch: Union[pa.RecordBatch, Dict[str, Any]], + *, + hf_converter: Optional[dict] = None, + use_blob_api: bool = False, + **kwargs: Any, + ) -> Union[dict[str, torch.Tensor], torch.Tensor]: ... + + # Convert an Arrow FSL array into a 2D torch tensor def _fsl_to_tensor(arr: pa.FixedSizeListArray, dimension: int) -> torch.Tensor: # Note: FixedSizeListArray.values does not take offset/len into account and @@ -192,9 +212,7 @@ def __init__( world_size: Optional[int] = None, shard_granularity: Optional[Literal["fragment", "batch"]] = None, batch_readahead: int = 16, - to_tensor_fn: Optional[ - Callable[[pa.RecordBatch], Union[dict[str, torch.Tensor], torch.Tensor]] - ] = _to_tensor, + to_tensor_fn: Optional[ToTensorFn] = _to_tensor, sampler: Optional[Sampler] = None, auto_detect_rank: bool = True, **kwargs, @@ -236,6 +254,9 @@ def __init__( A function that samples the dataset. to_tensor_fn : callable, optional A function that converts a pyarrow RecordBatch to torch.Tensor. + Should accept a batch (RecordBatch or Dict[str, pa.Array]) as the first + argument, plus optional keyword arguments ``hf_converter`` and + ``use_blob_api``. auto_detect_rank: bool = True, optional If set true, the rank and world_size will be detected automatically. """