Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion pytd/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,25 @@ def test_perform_wait_callback_parameter(self):

# Check that perform was called with the wait_callback parameter
mock_bulk_import.perform.assert_called_with(
wait=True, wait_callback=callback_func
wait=True, timeout=None, wait_callback=callback_func
)

def test_perform_timeout_parameter(self):
"""Test that perform_timeout parameter is passed correctly"""
df = pd.DataFrame([[1, 2], [3, 4]])
timeout_value = 300 # 5 minutes

# Mock the bulk_import.perform method to check if timeout is passed
mock_bulk_import = self.table.client.api_client.create_bulk_import.return_value
mock_bulk_import.perform = MagicMock()

self.writer.write_dataframe(
df, self.table, "overwrite", perform_timeout=timeout_value
)

# Check that perform was called with the timeout parameter
mock_bulk_import.perform.assert_called_with(
wait=True, timeout=timeout_value, wait_callback=None
)


Expand Down
15 changes: 14 additions & 1 deletion pytd/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def write_dataframe(
show_progress=False,
bulk_import_name=None,
commit_timeout=None,
perform_timeout=None,
perform_wait_callback=None,
):
"""Write a given DataFrame to a Treasure Data table.
Expand Down Expand Up @@ -453,6 +454,10 @@ def write_dataframe(
Timeout in seconds for the bulk import commit operation. If None,
no timeout is applied.

perform_timeout : int, optional, default: None
Timeout in seconds for the bulk import perform operation. If None,
no timeout is applied.

perform_wait_callback : callable, optional, default: None
A callable to be called on every tick of wait interval during
bulk import job execution.
Expand Down Expand Up @@ -539,6 +544,7 @@ def write_dataframe(
show_progress=show_progress,
bulk_import_name=bulk_import_name,
commit_timeout=commit_timeout,
perform_timeout=perform_timeout,
perform_wait_callback=perform_wait_callback,
)
stack.close()
Expand All @@ -553,6 +559,7 @@ def _bulk_import(
show_progress=False,
bulk_import_name=None,
commit_timeout=None,
perform_timeout=None,
perform_wait_callback=None,
):
"""Write a specified CSV file to a Treasure Data table.
Expand Down Expand Up @@ -594,6 +601,10 @@ def _bulk_import(
Timeout in seconds for the bulk import commit operation. If None,
no timeout is applied.

perform_timeout : int, optional, default: None
Timeout in seconds for the bulk import perform operation. If None,
no timeout is applied.

perform_wait_callback : callable, optional, default: None
A callable to be called on every tick of wait interval during
bulk import job execution.
Expand Down Expand Up @@ -659,7 +670,9 @@ def _bulk_import(
logger.debug(f"uploaded data in {time.time() - s_time:.2f} sec")

logger.info("performing a bulk import job")
job = bulk_import.perform(wait=True, wait_callback=perform_wait_callback)
job = bulk_import.perform(
wait=True, timeout=perform_timeout, wait_callback=perform_wait_callback
)

if 0 < bulk_import.error_records:
logger.warning(
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ install_requires =
trino>=0.334.0
pandas>=2.1.0
numpy>=1.25.2
td-client>=1.1.0
td-client>=1.5.0
pytz>=2018.5
tqdm>=4.60.0

Expand Down
Loading