From 9ef1606dedb0525d86c1c544fb793294c7cb0bf1 Mon Sep 17 00:00:00 2001 From: AndreaBozzo Date: Wed, 24 Dec 2025 20:33:32 +0100 Subject: [PATCH 1/3] fix(python): correct type hint for to_tensor_fn parameter Fixes #3129 The type hint for in was declaring a single-argument callable, but the function is actually called with additional keyword arguments ( and ). Changes: - Update type hint to use to allow arbitrary args - Enhance docstring to document the expected function signature --- python/python/lance/torch/data.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/python/lance/torch/data.py b/python/python/lance/torch/data.py index fd2be0da161..4b314c47c68 100644 --- a/python/python/lance/torch/data.py +++ b/python/python/lance/torch/data.py @@ -193,7 +193,9 @@ def __init__( shard_granularity: Optional[Literal["fragment", "batch"]] = None, batch_readahead: int = 16, to_tensor_fn: Optional[ - Callable[[pa.RecordBatch], Union[dict[str, torch.Tensor], torch.Tensor]] + Callable[ + ..., Union[dict[str, torch.Tensor], torch.Tensor] + ] ] = _to_tensor, sampler: Optional[Sampler] = None, auto_detect_rank: bool = True, @@ -236,6 +238,9 @@ def __init__( A function that samples the dataset. to_tensor_fn : callable, optional A function that converts a pyarrow RecordBatch to torch.Tensor. + Should accept a batch (RecordBatch or Dict[str, pa.Array]) as the first + argument, plus optional keyword arguments ``hf_converter`` and + ``use_blob_api``. auto_detect_rank: bool = True, optional If set true, the rank and world_size will be detected automatically. """ From 20887fd05c424dc6cf6b0103d5f77c51bca1edf1 Mon Sep 17 00:00:00 2001 From: AndreaBozzo Date: Fri, 2 Jan 2026 09:14:04 +0100 Subject: [PATCH 2/3] fix(python): update type hint for to_tensor_fn parameter to be less generic, ruff --- python/python/lance/torch/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/python/lance/torch/data.py b/python/python/lance/torch/data.py index 4b314c47c68..626b9b14f9a 100644 --- a/python/python/lance/torch/data.py +++ b/python/python/lance/torch/data.py @@ -194,7 +194,7 @@ def __init__( batch_readahead: int = 16, to_tensor_fn: Optional[ Callable[ - ..., Union[dict[str, torch.Tensor], torch.Tensor] + [pa.RecordBatch, ...], Union[dict[str, torch.Tensor], torch.Tensor] ] ] = _to_tensor, sampler: Optional[Sampler] = None, From f21e7785382fa050a06950f6e951c0badf64e446 Mon Sep 17 00:00:00 2001 From: AndreaBozzo Date: Mon, 12 Jan 2026 21:06:01 +0100 Subject: [PATCH 3/3] update to_tensor_fn type hint to use Protocol for better type safety --- python/python/lance/torch/data.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/python/python/lance/torch/data.py b/python/python/lance/torch/data.py index 626b9b14f9a..d5adcbbfe19 100644 --- a/python/python/lance/torch/data.py +++ b/python/python/lance/torch/data.py @@ -11,7 +11,16 @@ import math import warnings from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Union +from typing import ( + Any, + Dict, + Iterable, + List, + Literal, + Optional, + Protocol, + Union, +) import pyarrow as pa @@ -32,6 +41,17 @@ __all__ = ["LanceDataset", "SafeLanceDataset", "get_safe_loader"] +class ToTensorFn(Protocol): + def __call__( + self, + batch: Union[pa.RecordBatch, Dict[str, Any]], + *, + hf_converter: Optional[dict] = None, + use_blob_api: bool = False, + **kwargs: Any, + ) -> Union[dict[str, torch.Tensor], torch.Tensor]: ... + + # Convert an Arrow FSL array into a 2D torch tensor def _fsl_to_tensor(arr: pa.FixedSizeListArray, dimension: int) -> torch.Tensor: # Note: FixedSizeListArray.values does not take offset/len into account and @@ -192,11 +212,7 @@ def __init__( world_size: Optional[int] = None, shard_granularity: Optional[Literal["fragment", "batch"]] = None, batch_readahead: int = 16, - to_tensor_fn: Optional[ - Callable[ - [pa.RecordBatch, ...], Union[dict[str, torch.Tensor], torch.Tensor] - ] - ] = _to_tensor, + to_tensor_fn: Optional[ToTensorFn] = _to_tensor, sampler: Optional[Sampler] = None, auto_detect_rank: bool = True, **kwargs,