Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3529,9 +3529,8 @@ def _prepare_debugging_info(test_info, info):
"""Combine the information about the test and the call information to a patched function/method within it."""

info = f"{test_info}\n\n{info}"
p = os.path.join(os.environ.get("_PATCHED_TESTING_METHODS_OUTPUT_DIR", ""), "captured_info.txt")
# TODO (ydshieh): This is not safe when we use pytest-xdist with more than 1 worker.
with open(p, "a") as fp:
output_path = _get_patched_testing_methods_output_file()
with output_path.open("a") as fp:
fp.write(f"{info}\n\n{'=' * 120}\n\n")

return info
Expand Down Expand Up @@ -3754,15 +3753,35 @@ def _parse_call_info(func, args, kwargs, call_argument_expressions, target_args)
return info


def _get_patched_testing_methods_output_file() -> Path:
"""Return the output file used by patched assertion methods.

Under `pytest-xdist`, workers run in separate processes but can share the same output directory. Using a worker-
specific file avoids concurrent writes and resets clobbering each other's captured debugging information.
"""

output_dir = Path(os.environ.get("_PATCHED_TESTING_METHODS_OUTPUT_DIR", ""))
worker_id = os.environ.get("PYTEST_XDIST_WORKER")
filename = f"captured_info_{worker_id}.txt" if worker_id else "captured_info.txt"
return output_dir / filename


def _reset_patched_testing_methods_output_file() -> Path:
"""Clear the output file used by patched assertion methods and return its path."""

output_path = _get_patched_testing_methods_output_file()
output_path.unlink(missing_ok=True)
return output_path


def patch_testing_methods_to_collect_info():
"""
Patch some methods (`torch.testing.assert_close`, `unittest.case.TestCase.assertEqual`, etc).

This will allow us to collect the call information, e.g. the argument names and values, also the literal expressions
passed as the arguments.
"""
p = os.path.join(os.environ.get("_PATCHED_TESTING_METHODS_OUTPUT_DIR", ""), "captured_info.txt")
Path(p).unlink(missing_ok=True)
_reset_patched_testing_methods_output_file()

if is_torch_available():
import torch
Expand Down
86 changes: 86 additions & 0 deletions tests/utils/test_testing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest
from pathlib import Path
from unittest import mock

from transformers import testing_utils


class PatchedTestingMethodsOutputFileTest(unittest.TestCase):
def test_get_output_file_without_xdist_worker(self):
with (
tempfile.TemporaryDirectory() as tmpdir,
mock.patch.dict(os.environ, {"_PATCHED_TESTING_METHODS_OUTPUT_DIR": tmpdir}, clear=True),
):
output_path = testing_utils._get_patched_testing_methods_output_file()

self.assertEqual(output_path, Path(tmpdir) / "captured_info.txt")

def test_get_output_file_with_xdist_worker(self):
with (
tempfile.TemporaryDirectory() as tmpdir,
mock.patch.dict(
os.environ,
{
"_PATCHED_TESTING_METHODS_OUTPUT_DIR": tmpdir,
"PYTEST_XDIST_WORKER": "gw2",
},
clear=True,
),
):
output_path = testing_utils._get_patched_testing_methods_output_file()

self.assertEqual(output_path, Path(tmpdir) / "captured_info_gw2.txt")

def test_prepare_debugging_info_writes_worker_specific_file(self):
with (
tempfile.TemporaryDirectory() as tmpdir,
mock.patch.dict(
os.environ,
{
"_PATCHED_TESTING_METHODS_OUTPUT_DIR": tmpdir,
"PYTEST_XDIST_WORKER": "gw1",
},
clear=True,
),
):
output_path = Path(tmpdir) / "captured_info_gw1.txt"
rendered_info = testing_utils._prepare_debugging_info("test-info", "payload")
self.assertEqual(rendered_info, "test-info\n\npayload")
self.assertTrue(output_path.exists())
self.assertIn("test-info\n\npayload", output_path.read_text())

def test_reset_only_clears_current_worker_file(self):
with tempfile.TemporaryDirectory() as tmpdir:
current_worker_path = Path(tmpdir) / "captured_info_gw0.txt"
other_worker_path = Path(tmpdir) / "captured_info_gw1.txt"
current_worker_path.write_text("current worker")
other_worker_path.write_text("other worker")

with mock.patch.dict(
os.environ,
{
"_PATCHED_TESTING_METHODS_OUTPUT_DIR": tmpdir,
"PYTEST_XDIST_WORKER": "gw0",
},
clear=True,
):
output_path = testing_utils._reset_patched_testing_methods_output_file()
self.assertEqual(output_path, current_worker_path)
self.assertFalse(current_worker_path.exists())
self.assertTrue(other_worker_path.exists())
Loading