Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/connection_methods.json
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@
"fetch_record_batch",
"arrow"
],

"function": "FetchRecordBatchReader",
"docs": "Fetch an Arrow RecordBatchReader following execute()",
"args": [
Expand Down Expand Up @@ -992,7 +992,7 @@
"args": [
{
"name": "file_globs",
"type": "str"
"type": "List[str]"
},
{
"name": "binary_as_string",
Expand Down
19 changes: 10 additions & 9 deletions scripts/generate_connection_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,12 @@ def create_arguments(arguments) -> list:
result.append(argument)
return result

def create_definition(name, method) -> str:
definition = f"def {name}("
def create_definition(name, method, overloaded: bool) -> str:
if overloaded:
definition: str = "@overload\n"
else:
definition: str = ""
definition += f"def {name}("
arguments = ['self']
if 'args' in method:
arguments.extend(create_arguments(method['args']))
Expand All @@ -65,20 +69,17 @@ def create_definition(name, method) -> str:
definition += f" -> {method['return']}: ..."
return definition

# We have "duplicate" methods, which are overloaded
# maybe we should add @overload to these instead, but this is easier
written_methods = set()
# We have "duplicate" methods, which are overloaded.
# We keep note of them to add the @overload decorator.
overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m['name'], list)}

for method in connection_methods:
if isinstance(method['name'], list):
names = method['name']
else:
names = [method['name']]
for name in names:
if name in written_methods:
continue
body.append(create_definition(name, method))
written_methods.add(name)
body.append(create_definition(name, method, name in overloaded_methods))

# ---- End of generation code ----

Expand Down
19 changes: 10 additions & 9 deletions scripts/generate_connection_wrapper_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,12 @@ def create_arguments(arguments) -> list:
result.append(argument)
return result

def create_definition(name, method) -> str:
definition = f"def {name}("
def create_definition(name, method, overloaded: bool) -> str:
if overloaded:
definition: str = "@overload\n"
else:
definition: str = ""
definition += f"def {name}("
arguments = []
if name in SPECIAL_METHOD_NAMES:
arguments.append('df: pandas.DataFrame')
Expand All @@ -82,9 +86,9 @@ def create_definition(name, method) -> str:
definition += f" -> {method['return']}: ..."
return definition

# We have "duplicate" methods, which are overloaded
# maybe we should add @overload to these instead, but this is easier
written_methods = set()
# We have "duplicate" methods, which are overloaded.
# We keep note of them to add the @overload decorator.
overloaded_methods: set[str] = {m for m in connection_methods if isinstance(m['name'], list)}

body = []
for method in methods:
Expand All @@ -99,10 +103,7 @@ def create_definition(name, method) -> str:
method['kwargs'].append({'name': 'connection', 'type': 'DuckDBPyConnection', 'default': '...'})

for name in names:
if name in written_methods:
continue
body.append(create_definition(name, method))
written_methods.add(name)
body.append(create_definition(name, method, name in overloaded_methods))

# ---- End of generation code ----

Expand Down