diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 863242a695c6..920dc884f3e7 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -3529,9 +3529,8 @@ def _prepare_debugging_info(test_info, info): """Combine the information about the test and the call information to a patched function/method within it.""" info = f"{test_info}\n\n{info}" - p = os.path.join(os.environ.get("_PATCHED_TESTING_METHODS_OUTPUT_DIR", ""), "captured_info.txt") - # TODO (ydshieh): This is not safe when we use pytest-xdist with more than 1 worker. - with open(p, "a") as fp: + output_path = _get_patched_testing_methods_output_file() + with output_path.open("a") as fp: fp.write(f"{info}\n\n{'=' * 120}\n\n") return info @@ -3754,6 +3753,27 @@ def _parse_call_info(func, args, kwargs, call_argument_expressions, target_args) return info +def _get_patched_testing_methods_output_file() -> Path: + """Return the output file used by patched assertion methods. + + Under `pytest-xdist`, workers run in separate processes but can share the same output directory. Using a worker- + specific file avoids concurrent writes and resets clobbering each other's captured debugging information. + """ + + output_dir = Path(os.environ.get("_PATCHED_TESTING_METHODS_OUTPUT_DIR", "")) + worker_id = os.environ.get("PYTEST_XDIST_WORKER") + filename = f"captured_info_{worker_id}.txt" if worker_id else "captured_info.txt" + return output_dir / filename + + +def _reset_patched_testing_methods_output_file() -> Path: + """Clear the output file used by patched assertion methods and return its path.""" + + output_path = _get_patched_testing_methods_output_file() + output_path.unlink(missing_ok=True) + return output_path + + def patch_testing_methods_to_collect_info(): """ Patch some methods (`torch.testing.assert_close`, `unittest.case.TestCase.assertEqual`, etc). @@ -3761,8 +3781,7 @@ def patch_testing_methods_to_collect_info(): This will allow us to collect the call information, e.g. the argument names and values, also the literal expressions passed as the arguments. """ - p = os.path.join(os.environ.get("_PATCHED_TESTING_METHODS_OUTPUT_DIR", ""), "captured_info.txt") - Path(p).unlink(missing_ok=True) + _reset_patched_testing_methods_output_file() if is_torch_available(): import torch diff --git a/tests/utils/test_testing_utils.py b/tests/utils/test_testing_utils.py new file mode 100644 index 000000000000..80b06f37159e --- /dev/null +++ b/tests/utils/test_testing_utils.py @@ -0,0 +1,86 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest +from pathlib import Path +from unittest import mock + +from transformers import testing_utils + + +class PatchedTestingMethodsOutputFileTest(unittest.TestCase): + def test_get_output_file_without_xdist_worker(self): + with ( + tempfile.TemporaryDirectory() as tmpdir, + mock.patch.dict(os.environ, {"_PATCHED_TESTING_METHODS_OUTPUT_DIR": tmpdir}, clear=True), + ): + output_path = testing_utils._get_patched_testing_methods_output_file() + + self.assertEqual(output_path, Path(tmpdir) / "captured_info.txt") + + def test_get_output_file_with_xdist_worker(self): + with ( + tempfile.TemporaryDirectory() as tmpdir, + mock.patch.dict( + os.environ, + { + "_PATCHED_TESTING_METHODS_OUTPUT_DIR": tmpdir, + "PYTEST_XDIST_WORKER": "gw2", + }, + clear=True, + ), + ): + output_path = testing_utils._get_patched_testing_methods_output_file() + + self.assertEqual(output_path, Path(tmpdir) / "captured_info_gw2.txt") + + def test_prepare_debugging_info_writes_worker_specific_file(self): + with ( + tempfile.TemporaryDirectory() as tmpdir, + mock.patch.dict( + os.environ, + { + "_PATCHED_TESTING_METHODS_OUTPUT_DIR": tmpdir, + "PYTEST_XDIST_WORKER": "gw1", + }, + clear=True, + ), + ): + output_path = Path(tmpdir) / "captured_info_gw1.txt" + rendered_info = testing_utils._prepare_debugging_info("test-info", "payload") + self.assertEqual(rendered_info, "test-info\n\npayload") + self.assertTrue(output_path.exists()) + self.assertIn("test-info\n\npayload", output_path.read_text()) + + def test_reset_only_clears_current_worker_file(self): + with tempfile.TemporaryDirectory() as tmpdir: + current_worker_path = Path(tmpdir) / "captured_info_gw0.txt" + other_worker_path = Path(tmpdir) / "captured_info_gw1.txt" + current_worker_path.write_text("current worker") + other_worker_path.write_text("other worker") + + with mock.patch.dict( + os.environ, + { + "_PATCHED_TESTING_METHODS_OUTPUT_DIR": tmpdir, + "PYTEST_XDIST_WORKER": "gw0", + }, + clear=True, + ): + output_path = testing_utils._reset_patched_testing_methods_output_file() + self.assertEqual(output_path, current_worker_path) + self.assertFalse(current_worker_path.exists()) + self.assertTrue(other_worker_path.exists())