diff --git a/toolshed/check_cython_abi.py b/toolshed/check_cython_abi.py index 2199d526a3..32b5e6be11 100644 --- a/toolshed/check_cython_abi.py +++ b/toolshed/check_cython_abi.py @@ -6,12 +6,17 @@ """ Tool to check for Cython ABI changes in a given package. -There are different types of ABI changes, only one of which is covered by this tool: +Cython must be installed in your venv to run this script. + +There are different types of ABI changes, some of which are covered by this tool: - cdef function signatures (capsule strings) — covered here -- cdef class struct size (tp_basicsize) — not covered -- cdef class vtable layout / method reordering — not covered, and this one fails as silent UB rather than an import-time error -- Fused specialization ordering — partially covered (reorders manifest as capsule-name deltas, but the mapping is non-obvious) +- cdef class struct size (tp_basicsize) — covered here +- cdef struct / ctypedef struct field layout — covered here (via .pxd parsing) +- cdef class vtable layout / method reordering — not covered, and this one fails + as silent UB rather than an import-time error +- Fused specialization ordering — partially covered (reorders manifest as + capsule-name deltas, but the mapping is non-obvious) The workflow is basically: @@ -21,13 +26,13 @@ package is installed), where `package_name` is the import path to the package, e.g. `cuda.bindings`: - python check_cython_abi.py generate + python check_cython_abi.py generate 3) Checkout a version with the changes to be tested, and build and install. 4) Check the ABI against the previously generated files by running: - python check_cython_abi.py check + python check_cython_abi.py check """ import ctypes @@ -35,8 +40,14 @@ import json import sys import sysconfig +from io import StringIO from pathlib import Path +from Cython.Compiler import Parsing +from Cython.Compiler.Scanning import FileSourceDescriptor, PyrexScanner +from Cython.Compiler.Symtab import ModuleScope +from Cython.Compiler.TreeFragment import StringParseContext + EXT_SUFFIX = sysconfig.get_config_var("EXT_SUFFIX") ABI_SUFFIX = ".abi.json" @@ -66,12 +77,12 @@ def import_from_path(root_package: str, root_dir: Path, path: Path) -> object: def so_path_to_abi_path(so_path: Path, build_dir: Path, abi_dir: Path) -> Path: - abi_name = short_stem(so_path.name) + ABI_SUFFIX + abi_name = f"{short_stem(so_path.name)}{ABI_SUFFIX}" return abi_dir / so_path.parent.relative_to(build_dir) / abi_name def abi_path_to_so_path(abi_path: Path, build_dir: Path, abi_dir: Path) -> Path: - so_name = short_stem(abi_path.name) + EXT_SUFFIX + so_name = f"{short_stem(abi_path.name)}{EXT_SUFFIX}" return build_dir / abi_path.parent.relative_to(abi_dir) / so_name @@ -80,16 +91,244 @@ def is_cython_module(module: object) -> bool: return hasattr(module, "__pyx_capi__") -def module_to_json(module: object) -> dict: - """ - Converts extracts information about a Cython-compiled .so into JSON-serializable information. +###################################################################################### +# STRUCTS + + +def get_cdef_classes(module: object) -> dict: + """Extract cdef class (extension type) basicsize from a compiled Cython module.""" + result = {} + module_name = module.__name__ + for name in sorted(dir(module)): + obj = getattr(module, name, None) + if isinstance(obj, type) and getattr(obj, "__module__", None) == module_name and hasattr(obj, "__basicsize__"): + result[name] = {"basicsize": obj.__basicsize__} + return result + + +def _format_base_type_name(bt: object) -> str: + """Format a Cython base type AST node into a type name string.""" + cls = type(bt).__name__ + if cls == "CSimpleBaseTypeNode": + return bt.name + if cls == "CComplexBaseTypeNode": + inner = _format_base_type_name(bt.base_type) + return _unwrap_declarator(inner, bt.declarator)[0] + return cls + + +def _unwrap_declarator(type_str: str, decl: object) -> tuple[str, str]: + """Unwrap nested Cython declarator nodes to get (type_string, field_name).""" + cls = type(decl).__name__ + if cls == "CNameDeclaratorNode": + return type_str, decl.name + if cls == "CPtrDeclaratorNode": + return _unwrap_declarator(f"{type_str}*", decl.base) + if cls == "CReferenceDeclaratorNode": + return _unwrap_declarator(f"{type_str}&", decl.base) + if cls == "CArrayDeclaratorNode": + dim = getattr(decl, "dimension", None) + size = getattr(dim, "value", "") if dim is not None else "" + return _unwrap_declarator(f"{type_str}[{size}]", decl.base) + return type_str, "" + + +def _extract_fields_from_cvardef(node: object) -> list: + """Extract [type, name] pairs from a CVarDefNode.""" + results = [] + for d in node.declarators: + type_str, name = _unwrap_declarator(_format_base_type_name(node.base_type), d) + if name: + results.append([type_str, name]) + return results + + +def _collect_cvardef_fields(node: object) -> list: + """Recursively collect CVarDefNode fields, skipping nested struct/class/func defs.""" + fields = [] + if type(node).__name__ == "CVarDefNode": + fields.extend(_extract_fields_from_cvardef(node)) + skip = ("CStructOrUnionDefNode", "CClassDefNode", "CFuncDefNode") + for attr_name in getattr(node, "child_attrs", []): + child = getattr(node, attr_name, None) + if child is None: + continue + if isinstance(child, list): + for item in child: + if hasattr(item, "child_attrs") and type(item).__name__ not in skip: + fields.extend(_collect_cvardef_fields(item)) + elif hasattr(child, "child_attrs") and type(child).__name__ not in skip: + fields.extend(_collect_cvardef_fields(child)) + return fields + + +def _collect_structs_from_tree(node: object) -> dict: + """Walk a Cython AST and collect struct/class field definitions.""" + result = {} + cls = type(node).__name__ + + if cls == "CStructOrUnionDefNode": + fields = [] + for attr in node.attributes: + if type(attr).__name__ == "CVarDefNode": + fields.extend(_extract_fields_from_cvardef(attr)) + if fields: + result[node.name] = {"fields": fields} + + elif cls == "CClassDefNode": + fields = _collect_cvardef_fields(node.body) + if fields: + result[node.class_name] = {"fields": fields} + + for attr_name in getattr(node, "child_attrs", []): + child = getattr(node, attr_name, None) + if child is None: + continue + if isinstance(child, list): + for item in child: + if hasattr(item, "child_attrs"): + result.update(_collect_structs_from_tree(item)) + elif hasattr(child, "child_attrs"): + result.update(_collect_structs_from_tree(child)) + + return result + + +class _PxdParseContext(StringParseContext): + """Parse context that resolves includes via real paths and ignores unknown cimports.""" + + def find_module( + self, + module_name, + from_module=None, # noqa: ARG002 + pos=None, # noqa: ARG002 + need_pxd=1, # noqa: ARG002 + absolute_fallback=True, # noqa: ARG002 + relative_import=False, # noqa: ARG002 + ): + return ModuleScope(module_name, parent_module=None, context=self) + + +def parse_pxd_structs(pxd_path: Path) -> dict: + """Parse struct and cdef class field definitions from a .pxd file. + + Uses Cython's own parser (in .pxd mode) for reliable extraction. + cimport lines in the top-level file are stripped since they are + unresolvable without the full compilation context; included files + are handled via a lenient context that returns dummy scopes. + + Returns a dict mapping struct/class name to {"fields": [[type, name], ...]}. """ - # Sort the dictionary by keys to make diffs in the JSON files smaller - pyx_capi = module.__pyx_capi__ + text = pxd_path.read_text(encoding="utf-8") + + # Strip cimport lines (unresolvable without full compilation context) + lines = text.splitlines() + cleaned = "\n".join("" if (" cimport " in ln or ln.lstrip().startswith("cimport ")) else ln for ln in lines) + + name = pxd_path.stem + context = _PxdParseContext(name, include_directories=[str(pxd_path.parent)]) + code_source = FileSourceDescriptor(str(pxd_path)) + scope = context.find_module(name, pos=(code_source, 1, 0), need_pxd=False) + + scanner = PyrexScanner( + StringIO(cleaned), + code_source, + source_encoding="UTF-8", + scope=scope, + context=context, + initial_pos=(code_source, 1, 0), + ) + tree = Parsing.p_module(scanner, pxd=1, full_module_name=name) + tree.scope = scope + + return _collect_structs_from_tree(tree) + + +def get_structs(module: object) -> dict: + # Extract cdef class basicsize from compiled module (primary) + structs = get_cdef_classes(module) + so_path = Path(module.__file__) + + # Parse neighboring .pxd file for struct/class field layout (fallback complement) + if so_path is not None: + pxd_path = so_path.parent / f"{short_stem(so_path.name)}.pxd" + if pxd_path.is_file(): + pxd_structs = parse_pxd_structs(pxd_path) + for name, info in pxd_structs.items(): + if name in structs: + structs[name].update(info) + else: + structs[name] = info + + return dict(sorted(structs.items())) + + +def _report_field_changes(name: str, expected_fields: list, found_fields: list) -> None: + """Print detailed field-level differences for a struct.""" + expected_dict = {f[1]: f[0] for f in expected_fields} + found_dict = {f[1]: f[0] for f in found_fields} + + for field_name, field_type in expected_dict.items(): + if field_name not in found_dict: + print(f" Struct {name}: removed field '{field_name}'") + elif found_dict[field_name] != field_type: + print( + f" Struct {name}: field '{field_name}' type changed from '{field_type}' to '{found_dict[field_name]}'" + ) + for field_name in found_dict: + if field_name not in expected_dict: + print(f" Struct {name}: added field '{field_name}'") + + expected_common = [f[1] for f in expected_fields if f[1] in found_dict] + found_common = [f[1] for f in found_fields if f[1] in expected_dict] + if expected_common != found_common: + print(f" Struct {name}: fields were reordered") + + +def check_structs(expected: dict, found: dict) -> tuple[bool, bool]: + has_errors = False + has_allowed_changes = False + + for name, expected_info in expected.items(): + if name not in found: + print(f" Missing struct/class: {name}") + has_errors = True + continue + found_info = found[name] - return { - "functions": {k: get_capsule_name(pyx_capi[k]) for k in sorted(pyx_capi.keys())}, - } + if "basicsize" in expected_info: + if "basicsize" not in found_info: + print(f" Struct {name}: basicsize no longer available") + has_errors = True + elif found_info["basicsize"] != expected_info["basicsize"]: + print( + f" Struct {name}: basicsize changed from {expected_info['basicsize']} to {found_info['basicsize']}" + ) + has_errors = True + + if "fields" in expected_info: + if "fields" not in found_info: + print(f" Struct {name}: field information no longer available") + has_errors = True + elif found_info["fields"] != expected_info["fields"]: + _report_field_changes(name, expected_info["fields"], found_info["fields"]) + has_errors = True + + for name in found: + if name not in expected: + print(f" Added struct/class: {name}") + has_allowed_changes = True + + return has_errors, has_allowed_changes + + +###################################################################################### +# FUNCTIONS + + +def get_functions(module: object) -> dict: + pyx_capi = module.__pyx_capi__ + return {k: get_capsule_name(pyx_capi[k]) for k in sorted(pyx_capi.keys())} def check_functions(expected: dict[str, str], found: dict[str, str]) -> tuple[bool, bool]: @@ -109,17 +348,29 @@ def check_functions(expected: dict[str, str], found: dict[str, str]) -> tuple[bo return has_errors, has_allowed_changes +###################################################################################### +# MAIN + + def compare(expected: dict, found: dict) -> tuple[bool, bool]: has_errors = False has_allowed_changes = False - errors, allowed_changes = check_functions(expected["functions"], found["functions"]) - has_errors |= errors - has_allowed_changes |= allowed_changes + for func, name in [(check_functions, "functions"), (check_structs, "structs")]: + errors, allowed_changes = func(expected[name], found[name]) + has_errors |= errors + has_allowed_changes |= allowed_changes return has_errors, has_allowed_changes +def module_to_json(module: object) -> dict: + """ + Extracts information about a Cython-compiled .so into JSON-serializable information. + """ + return {"functions": get_functions(module), "structs": get_structs(module)} + + def check(package: str, abi_dir: Path) -> bool: build_dir = get_package_path(package) @@ -168,7 +419,7 @@ def check(package: str, abi_dir: Path) -> bool: return False -def regenerate(package: str, abi_dir: Path) -> bool: +def generate(package: str, abi_dir: Path) -> bool: if abi_dir.is_dir(): print(f"ABI directory {abi_dir} already exists. Please remove it before regenerating.") return True @@ -199,10 +450,10 @@ def regenerate(package: str, abi_dir: Path) -> bool: subparsers = parser.add_subparsers() - regen_parser = subparsers.add_parser("generate", help="Regenerate the ABI files") - regen_parser.set_defaults(func=regenerate) - regen_parser.add_argument("package", help="Python package to collect data from") - regen_parser.add_argument("dir", help="Output directory to save data to") + gen_parser = subparsers.add_parser("generate", help="Regenerate the ABI files") + gen_parser.set_defaults(func=generate) + gen_parser.add_argument("package", help="Python package to collect data from") + gen_parser.add_argument("dir", help="Output directory to save data to") check_parser = subparsers.add_parser("check", help="Check the API against existing ABI files") check_parser.set_defaults(func=check)