diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index 9c61097251a..56a033df410 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -355,20 +355,23 @@ def slow_stream(): class ErrorFlightServer(FlightServerBase): """A Flight server that uses all the Flight-specific errors.""" - errors = { - "internal": flight.FlightInternalError, - "timedout": flight.FlightTimedOutError, - "cancel": flight.FlightCancelledError, - "unauthenticated": flight.FlightUnauthenticatedError, - "unauthorized": flight.FlightUnauthorizedError, - "notimplemented": NotImplementedError, - "invalid": pa.ArrowInvalid, - "key": KeyError, - } + @staticmethod + def error_cases(): + return { + "internal": flight.FlightInternalError, + "timedout": flight.FlightTimedOutError, + "cancel": flight.FlightCancelledError, + "unauthenticated": flight.FlightUnauthenticatedError, + "unauthorized": flight.FlightUnauthorizedError, + "notimplemented": NotImplementedError, + "invalid": pa.ArrowInvalid, + "key": KeyError, + } def do_action(self, context, action): - if action.type in self.errors: - raise self.errors[action.type]("foo") + error_cases = ErrorFlightServer.error_cases() + if action.type in error_cases: + raise error_cases[action.type]("foo") elif action.type == "protobuf": err_msg = b'this is an error message' raise flight.FlightUnauthorizedError("foo", err_msg) @@ -1564,7 +1567,7 @@ def test_roundtrip_errors(): with ErrorFlightServer() as server, \ FlightClient(('localhost', server.port)) as client: - for arg, exc_type in ErrorFlightServer.errors.items(): + for arg, exc_type in ErrorFlightServer.error_cases().items(): with pytest.raises(exc_type, match=".*foo.*"): list(client.do_action(flight.Action(arg, b""))) with pytest.raises(flight.FlightInternalError, match=".*foo.*"):