diff --git a/scripts/connection_methods.json b/scripts/connection_methods.json index 27705d6a..a87b992f 100644 --- a/scripts/connection_methods.json +++ b/scripts/connection_methods.json @@ -412,7 +412,7 @@ "fetch_record_batch", "arrow" ], - + "function": "FetchRecordBatchReader", "docs": "Fetch an Arrow RecordBatchReader following execute()", "args": [ @@ -992,7 +992,7 @@ "args": [ { "name": "file_globs", - "type": "str" + "type": "List[str]" }, { "name": "binary_as_string", diff --git a/scripts/generate_connection_stubs.py b/scripts/generate_connection_stubs.py index 563ade3d..fbb66c21 100644 --- a/scripts/generate_connection_stubs.py +++ b/scripts/generate_connection_stubs.py @@ -51,8 +51,12 @@ def create_arguments(arguments) -> list: result.append(argument) return result - def create_definition(name, method) -> str: - definition = f"def {name}(" + def create_definition(name, method, overloaded: bool) -> str: + if overloaded: + definition: str = "@overload\n" + else: + definition: str = "" + definition += f"def {name}(" arguments = ['self'] if 'args' in method: arguments.extend(create_arguments(method['args'])) @@ -65,9 +69,9 @@ def create_definition(name, method) -> str: definition += f" -> {method['return']}: ..." return definition - # We have "duplicate" methods, which are overloaded - # maybe we should add @overload to these instead, but this is easier - written_methods = set() + # We have "duplicate" methods, which are overloaded. + # We keep note of them to add the @overload decorator. + overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m['name'], list)} for method in connection_methods: if isinstance(method['name'], list): @@ -75,10 +79,7 @@ def create_definition(name, method) -> str: else: names = [method['name']] for name in names: - if name in written_methods: - continue - body.append(create_definition(name, method)) - written_methods.add(name) + body.append(create_definition(name, method, name in overloaded_methods)) # ---- End of generation code ---- diff --git a/scripts/generate_connection_wrapper_stubs.py b/scripts/generate_connection_wrapper_stubs.py index 94b0e0ee..62c60a84 100644 --- a/scripts/generate_connection_wrapper_stubs.py +++ b/scripts/generate_connection_wrapper_stubs.py @@ -66,8 +66,12 @@ def create_arguments(arguments) -> list: result.append(argument) return result - def create_definition(name, method) -> str: - definition = f"def {name}(" + def create_definition(name, method, overloaded: bool) -> str: + if overloaded: + definition: str = "@overload\n" + else: + definition: str = "" + definition += f"def {name}(" arguments = [] if name in SPECIAL_METHOD_NAMES: arguments.append('df: pandas.DataFrame') @@ -82,9 +86,9 @@ def create_definition(name, method) -> str: definition += f" -> {method['return']}: ..." return definition - # We have "duplicate" methods, which are overloaded - # maybe we should add @overload to these instead, but this is easier - written_methods = set() + # We have "duplicate" methods, which are overloaded. + # We keep note of them to add the @overload decorator. + overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m['name'], list)} body = [] for method in methods: @@ -99,10 +103,7 @@ def create_definition(name, method) -> str: method['kwargs'].append({'name': 'connection', 'type': 'DuckDBPyConnection', 'default': '...'}) for name in names: - if name in written_methods: - continue - body.append(create_definition(name, method)) - written_methods.add(name) + body.append(create_definition(name, method, name in overloaded_methods)) # ---- End of generation code ----