From 88eddf86674302ad07d8daf906e4a1874eec3d22 Mon Sep 17 00:00:00 2001 From: Daniel Bogdoll Date: Fri, 22 Nov 2024 11:22:03 -0500 Subject: [PATCH 1/4] Option to set 'non_blocking' for to(device) operation for performance improvements. Defaults to 'false', thus no behavioral changes. --- src/transformers/tokenization_utils_base.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 03df02d21ff3..094bfef71b2e 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -798,12 +798,13 @@ def as_tensor(value, dtype=None): return self - def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": + def to(self, device: Union[str, "torch.device"], non_blocking: bool = False) -> "BatchEncoding": """ - Send all values to device by calling `v.to(device)` (PyTorch only). + Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only). Args: device (`str` or `torch.device`): The device to put the tensors on. + non_blocking (`bool`): Whether to perform the copy asynchronously. Returns: [`BatchEncoding`]: The same instance after modification. @@ -815,7 +816,10 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": # Otherwise it passes the casts down and casts the LongTensor containing the token idxs # into a HalfTensor if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): - self.data = {k: v.to(device=device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()} + self.data = { + k: v.to(device=device, non_blocking=non_blocking) if isinstance(v, torch.Tensor) else v + for k, v in self.data.items() + } else: logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") return self From 08f5d4b30119c5bc2a06da3ee5604ad492990ba6 Mon Sep 17 00:00:00 2001 From: Daniel Bogdoll Date: Fri, 22 Nov 2024 13:50:00 -0500 Subject: [PATCH 2/4] Enabling non_blocking in to() operation of BatchFeature. --- src/transformers/feature_extraction_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index f3cde8180c1b..6c9e4bc7998d 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -222,6 +222,7 @@ def to(self, *args, **kwargs) -> "BatchFeature": new_data = {} device = kwargs.get("device") + non_blocking = kwargs.get("non_blocking", False) # Check if the args are a device or a dtype if device is None and len(args) > 0: # device should be always the first argument @@ -241,7 +242,7 @@ def to(self, *args, **kwargs) -> "BatchFeature": # cast and send to device new_data[k] = v.to(*args, **kwargs) elif isinstance(v, torch.Tensor) and device is not None: - new_data[k] = v.to(device=device) + new_data[k] = v.to(device=device, non_blocking=non_blocking) else: new_data[k] = v self.data = new_data From 9f465d7ca836c3836b02cb1642033427d56bf6d5 Mon Sep 17 00:00:00 2001 From: Daniel Bogdoll Date: Fri, 22 Nov 2024 14:00:29 -0500 Subject: [PATCH 3/4] Improved docstring on utilization of non_blocking --- src/transformers/feature_extraction_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 6c9e4bc7998d..6e8007edbc0b 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -213,6 +213,7 @@ def to(self, *args, **kwargs) -> "BatchFeature": Will be passed to the `to(...)` function of the tensors. kwargs (`Dict`, *optional*): Will be passed to the `to(...)` function of the tensors. + To enable asynchronous data transfer, set the `non_blocking` flag in `kwargs` (defaults to `False`). Returns: [`BatchFeature`]: The same instance after modification. From 4e8e786d120286f88006c852c35611cfde5bcc1f Mon Sep 17 00:00:00 2001 From: Daniel Bogdoll Date: Mon, 25 Nov 2024 09:37:56 -0500 Subject: [PATCH 4/4] Force non_blocking as keyword argument Co-authored-by: Pavel Iakubovskii --- src/transformers/tokenization_utils_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 094bfef71b2e..709845ac84e8 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -798,7 +798,7 @@ def as_tensor(value, dtype=None): return self - def to(self, device: Union[str, "torch.device"], non_blocking: bool = False) -> "BatchEncoding": + def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False) -> "BatchEncoding": """ Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only).