diff --git a/pytd/tests/test_writer.py b/pytd/tests/test_writer.py index 6e2cda9..2ef675d 100644 --- a/pytd/tests/test_writer.py +++ b/pytd/tests/test_writer.py @@ -446,7 +446,25 @@ def test_perform_wait_callback_parameter(self): # Check that perform was called with the wait_callback parameter mock_bulk_import.perform.assert_called_with( - wait=True, wait_callback=callback_func + wait=True, timeout=None, wait_callback=callback_func + ) + + def test_perform_timeout_parameter(self): + """Test that perform_timeout parameter is passed correctly""" + df = pd.DataFrame([[1, 2], [3, 4]]) + timeout_value = 300 # 5 minutes + + # Mock the bulk_import.perform method to check if timeout is passed + mock_bulk_import = self.table.client.api_client.create_bulk_import.return_value + mock_bulk_import.perform = MagicMock() + + self.writer.write_dataframe( + df, self.table, "overwrite", perform_timeout=timeout_value + ) + + # Check that perform was called with the timeout parameter + mock_bulk_import.perform.assert_called_with( + wait=True, timeout=timeout_value, wait_callback=None ) diff --git a/pytd/writer.py b/pytd/writer.py index e1fe28f..f5278dc 100644 --- a/pytd/writer.py +++ b/pytd/writer.py @@ -335,6 +335,7 @@ def write_dataframe( show_progress=False, bulk_import_name=None, commit_timeout=None, + perform_timeout=None, perform_wait_callback=None, ): """Write a given DataFrame to a Treasure Data table. @@ -453,6 +454,10 @@ def write_dataframe( Timeout in seconds for the bulk import commit operation. If None, no timeout is applied. + perform_timeout : int, optional, default: None + Timeout in seconds for the bulk import perform operation. If None, + no timeout is applied. + perform_wait_callback : callable, optional, default: None A callable to be called on every tick of wait interval during bulk import job execution. @@ -539,6 +544,7 @@ def write_dataframe( show_progress=show_progress, bulk_import_name=bulk_import_name, commit_timeout=commit_timeout, + perform_timeout=perform_timeout, perform_wait_callback=perform_wait_callback, ) stack.close() @@ -553,6 +559,7 @@ def _bulk_import( show_progress=False, bulk_import_name=None, commit_timeout=None, + perform_timeout=None, perform_wait_callback=None, ): """Write a specified CSV file to a Treasure Data table. @@ -594,6 +601,10 @@ def _bulk_import( Timeout in seconds for the bulk import commit operation. If None, no timeout is applied. + perform_timeout : int, optional, default: None + Timeout in seconds for the bulk import perform operation. If None, + no timeout is applied. + perform_wait_callback : callable, optional, default: None A callable to be called on every tick of wait interval during bulk import job execution. @@ -659,7 +670,9 @@ def _bulk_import( logger.debug(f"uploaded data in {time.time() - s_time:.2f} sec") logger.info("performing a bulk import job") - job = bulk_import.perform(wait=True, wait_callback=perform_wait_callback) + job = bulk_import.perform( + wait=True, timeout=perform_timeout, wait_callback=perform_wait_callback + ) if 0 < bulk_import.error_records: logger.warning( diff --git a/setup.cfg b/setup.cfg index 1f7793b..572e152 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,7 +33,7 @@ install_requires = trino>=0.334.0 pandas>=2.1.0 numpy>=1.25.2 - td-client>=1.1.0 + td-client>=1.5.0 pytz>=2018.5 tqdm>=4.60.0