11import logging
22from dataclasses import dataclass
3-
3+ from datetime import datetime
44import requests
55import lz4 .frame
66import threading
77import time
8-
8+ import os
9+ from threading import get_ident
910from databricks .sql .thrift_api .TCLIService .ttypes import TSparkArrowResultLink
1011
1112logger = logging .getLogger (__name__ )
1213
14+ DEFAULT_CLOUD_FILE_TIMEOUT = int (os .getenv ("DATABRICKS_CLOUD_FILE_TIMEOUT" , 60 ))
15+
1316
1417@dataclass
1518class DownloadableResultSettings :
@@ -20,13 +23,17 @@ class DownloadableResultSettings:
2023 is_lz4_compressed (bool): Whether file is expected to be lz4 compressed.
2124 link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
2225 download_timeout (int): Timeout for download requests. Default 60 secs.
23- max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
26+ download_max_retries (int): Number of consecutive download retries before shutting down.
27+ max_retries (int): Number of consecutive download retries before shutting down.
28+ backoff_factor (int): Factor to increase wait time between retries.
29+
2430 """
2531
2632 is_lz4_compressed : bool
2733 link_expiry_buffer_secs : int = 0
28- download_timeout : int = 60
29- max_consecutive_file_download_retries : int = 0
34+ download_timeout : int = DEFAULT_CLOUD_FILE_TIMEOUT
35+ max_retries : int = 5
36+ backoff_factor : int = 2
3037
3138
3239class ResultSetDownloadHandler (threading .Thread ):
@@ -57,16 +64,21 @@ def is_file_download_successful(self) -> bool:
5764 else None
5865 )
5966 try :
67+ logger .debug (
68+ f"waiting for at most { timeout } seconds for download file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
69+ )
70+
6071 if not self .is_download_finished .wait (timeout = timeout ):
6172 self .is_download_timedout = True
6273 logger .debug (
63- "Cloud fetch download timed out after {} seconds for link representing rows {} to {}" .format (
64- self .settings .download_timeout ,
65- self .result_link .startRowOffset ,
66- self .result_link .startRowOffset + self .result_link .rowCount ,
67- )
74+ f"cloud fetch download timed out after { self .settings .download_timeout } seconds for link representing rows { self .result_link .startRowOffset } to { self .result_link .startRowOffset + self .result_link .rowCount } "
6875 )
69- return False
76+ # there are some weird cases when the is_download_finished is not set, but the file is downloaded successfully
77+ return self .is_file_downloaded_successfully
78+
79+ logger .debug (
80+ f"finish waiting for download file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
81+ )
7082 except Exception as e :
7183 logger .error (e )
7284 return False
@@ -81,24 +93,36 @@ def run(self):
8193 """
8294 self ._reset ()
8395
84- # Check if link is already expired or is expiring
85- if ResultSetDownloadHandler .check_link_expired (
86- self .result_link , self .settings .link_expiry_buffer_secs
87- ):
88- self .is_link_expired = True
89- return
96+ try :
97+ # Check if link is already expired or is expiring
98+ if ResultSetDownloadHandler .check_link_expired (
99+ self .result_link , self .settings .link_expiry_buffer_secs
100+ ):
101+ self .is_link_expired = True
102+ return
90103
91- session = requests .Session ()
92- session .timeout = self .settings .download_timeout
104+ logger .debug (
105+ f"started to download file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
106+ )
93107
94- try :
95108 # Get the file via HTTP request
96- response = session .get (self .result_link .fileLink )
109+ response = http_get_with_retry (
110+ url = self .result_link .fileLink ,
111+ max_retries = self .settings .max_retries ,
112+ backoff_factor = self .settings .backoff_factor ,
113+ download_timeout = self .settings .download_timeout ,
114+ )
97115
98- if not response .ok :
99- self .is_file_downloaded_successfully = False
116+ if not response :
117+ logger .error (
118+ f"failed downloading file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
119+ )
100120 return
101121
122+ logger .debug (
123+ f"success downloading file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
124+ )
125+
102126 # Save (and decompress if needed) the downloaded file
103127 compressed_data = response .content
104128 decompressed_data = (
@@ -109,15 +133,22 @@ def run(self):
109133 self .result_file = decompressed_data
110134
111135 # The size of the downloaded file should match the size specified from TSparkArrowResultLink
112- self .is_file_downloaded_successfully = (
113- len (self .result_file ) == self .result_link .bytesNum
136+ success = len (self .result_file ) == self .result_link .bytesNum
137+ logger .debug (
138+ f"download successful file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
114139 )
140+ self .is_file_downloaded_successfully = success
115141 except Exception as e :
142+ logger .debug (
143+ f"exception downloading file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
144+ )
116145 logger .error (e )
117146 self .is_file_downloaded_successfully = False
118147
119148 finally :
120- session and session .close ()
149+ logger .debug (
150+ f"signal finished file: startRow { self .result_link .startRowOffset } , rowCount { self .result_link .rowCount } , endRow { self .result_link .startRowOffset + self .result_link .rowCount } "
151+ )
121152 # Awaken threads waiting for this to be true which signals the run is complete
122153 self .is_download_finished .set ()
123154
@@ -145,6 +176,7 @@ def check_link_expired(
145176 link .expiryTime < current_time
146177 or link .expiryTime - current_time < expiry_buffer_secs
147178 ):
179+ logger .debug ("link expired" )
148180 return True
149181 return False
150182
@@ -171,3 +203,34 @@ def decompress_data(compressed_data: bytes) -> bytes:
171203 uncompressed_data += data
172204 start += num_bytes
173205 return uncompressed_data
206+
207+
208+ def http_get_with_retry (url , max_retries = 5 , backoff_factor = 2 , download_timeout = 60 ):
209+ attempts = 0
210+
211+ while attempts < max_retries :
212+ try :
213+ session = requests .Session ()
214+ session .timeout = download_timeout
215+ response = session .get (url )
216+
217+ # Check if the response status code is in the 2xx range for success
218+ if response .status_code == 200 :
219+ return response
220+ else :
221+ logger .error (response )
222+ except requests .RequestException as e :
223+ print (f"request failed with exception: { e } " )
224+ finally :
225+ session .close ()
226+ # Exponential backoff before the next attempt
227+ wait_time = backoff_factor ** attempts
228+ logger .info (f"retrying in { wait_time } seconds..." )
229+ time .sleep (wait_time )
230+
231+ attempts += 1
232+
233+ logger .error (
234+ f"exceeded maximum number of retries ({ max_retries } ) while downloading result."
235+ )
236+ return None
0 commit comments