diff --git a/cppwg/generators.py b/cppwg/generators.py index 70d369e..488a562 100644 --- a/cppwg/generators.py +++ b/cppwg/generators.py @@ -1,200 +1,348 @@ import os -import logging +import re import fnmatch -import ntpath +import logging +import subprocess + +from pathlib import Path +from typing import Dict, List, Optional + +from pygccxml import __version__ as pygccxml_version +from pygccxml.declarations.namespace import namespace_t +from cppwg.input.class_info import CppClassInfo +from cppwg.input.free_function_info import CppFreeFunctionInfo from cppwg.input.info_helper import CppInfoHelper from cppwg.input.package_info import PackageInfo -from cppwg.input.free_function_info import CppFreeFunctionInfo -from cppwg.input.class_info import CppClassInfo + from cppwg.parsers.package_info import PackageInfoParser from cppwg.parsers.source_parser import CppSourceParser + from cppwg.writers.header_collection_writer import CppHeaderCollectionWriter from cppwg.writers.module_writer import CppModuleWrapperWriter -import cppwg.templates.pybind11_default as wrapper_templates +from cppwg.templates import pybind11_default as wrapper_templates +from cppwg.utils.constants import CPPWG_EXT, CPPWG_HEADER_COLLECTION_FILENAME -class CppWrapperGenerator(object): - def __init__(self, source_root, - source_includes=None, - wrapper_root=None, - castxml_binary='castxml', - package_info_path='package_info.yaml'): +class CppWrapperGenerator: + """ + Main class for generating C++ wrappers + Attributes + ---------- + source_root : str + The root directory of the C++ source code + source_includes : List[str] + The list of source include paths + wrapper_root : str + The output directory for the wrapper code + castxml_binary : str + The path to the CastXML binary + castxml_cflags : str + Optional cflags to be passed to CastXML e.g. "-std=c++17" + package_info_path : str + The path to the package info yaml config file; defaults to "package_info.yaml" + source_hpp_files : List[str] + The list of C++ source header files + source_ns : namespace_t + The namespace containing C++ declarations parsed from the source tree + package_info : PackageInfo + A data structure containing the information parsed from package_info_path + """ + + def __init__( + self, + source_root: str, + source_includes: Optional[List[str]] = None, + wrapper_root: Optional[str] = None, + castxml_binary: Optional[str] = "castxml", + package_info_path: Optional[str] = None, + castxml_cflags: Optional[str] = "-std=c++17", + ): + logging.basicConfig( + format="%(levelname)s %(message)s", + handlers=[logging.FileHandler("filename.log"), logging.StreamHandler()], + ) logger = logging.getLogger() logger.setLevel(logging.INFO) - self.source_root = os.path.realpath(source_root) - self.source_includes = source_includes - self.wrapper_root = wrapper_root - self.castxml_binary = castxml_binary - self.package_info_path = package_info_path - self.source_hpp_files = [] - self.global_ns = None - self.source_ns = None + # Sanitize source_root + self.source_root: str = os.path.abspath(source_root) + if not os.path.isdir(self.source_root): + logger.error(f"Could not find source root directory: {source_root}") + raise FileNotFoundError() + + # Sanitize wrapper_root + self.wrapper_root: str # type hinting + if wrapper_root: + # Create the specified wrapper root directory if it doesn't exist + self.wrapper_root = os.path.abspath(wrapper_root) - if self.wrapper_root is None: + if not os.path.isdir(self.wrapper_root): + logger.info( + f"Could not find wrapper root directory - creating it at {self.wrapper_root}" + ) + os.makedirs(self.wrapper_root) + else: self.wrapper_root = self.source_root + logger.info( + "Wrapper root not specified - using source_root: {self.source_root}" + ) + + # Sanitize source_includes + self.source_includes: List[str] # type hinting + if source_includes: + self.source_includes = [ + os.path.abspath(include_path) for include_path in source_includes + ] - if self.source_includes is None: + for include_path in self.source_includes: + if not os.path.isdir(include_path): + logger.error( + f"Could not find source include directory: {include_path}" + ) + raise FileNotFoundError() + else: self.source_includes = [self.source_root] - # If we suspect that a valid info file has not been supplied - # fall back to the default behaviour - path_is_default = (self.package_info_path == 'package_info.yaml') - file_exists = os.path.exists(self.package_info_path) - if path_is_default and (not file_exists): - logger.info('YAML package info file not found. Using default info.') - self.package_info_path = None + # Sanitize package_info_path + self.package_info_path: Optional[str] = None + if package_info_path: + # If a package info config file is specified, check that it exists + self.package_info_path = package_info_path + if not os.path.isfile(package_info_path): + logger.error(f"Could not find package info file: {package_info_path}") + raise FileNotFoundError() + else: + # If no package info config file has been supplied, check the default + default_package_info_file = os.path.abspath("./package_info.yaml") + if os.path.isfile(default_package_info_file): + self.package_info_path = default_package_info_file + logger.info( + f"Package info file not specified - using {default_package_info_file}" + ) + else: + logger.warning("No package info file found - using default settings.") + + # Check castxml and pygccxml versions + self.castxml_binary: str = castxml_binary + castxml_version: str = ( + subprocess.check_output([self.castxml_binary, "--version"]) + .decode("ascii") + .strip() + ) + castxml_version = re.search( + r"castxml version \d+\.\d+\.\d+", castxml_version + ).group(0) + logger.info(castxml_version) + logger.info(f"pygccxml version {pygccxml_version}") + + self.castxml_cflags: str = castxml_cflags - def collect_source_hpp_files(self): + # Initialize remaining attributes + self.source_hpp_files: List[str] = [] + self.source_ns: Optional[namespace_t] = None + + self.package_info: Optional[PackageInfo] = None + + self.header_collection_filepath: str = os.path.join( + self.wrapper_root, CPPWG_HEADER_COLLECTION_FILENAME + ) + + def collect_source_hpp_files(self) -> None: """ - Walk through the source root and add any files matching the provided patterns. - Keep the wrapper root out of the search path to avoid pollution. + Walk through the source root and add any files matching the provided + patterns e.g. "*.hpp". Skip the wrapper root and wrappers to + avoid pollution. """ + for root, _, filenames in os.walk(self.source_root, followlinks=True): for pattern in self.package_info.source_hpp_patterns: for filename in fnmatch.filter(filenames, pattern): - if "cppwg" not in filename: - self.package_info.source_hpp_files.append(os.path.join(root, filename)) - self.package_info.source_hpp_files = [path for path in self.package_info.source_hpp_files - if self.wrapper_root not in path] + filepath = os.path.abspath(os.path.join(root, filename)) - def generate_header_collection(self): + # Skip files in wrapper root dir + if Path(self.wrapper_root) in Path(filepath).parents: + continue + # Skip files with the extensions like .cppwg.hpp + suffix = os.path.splitext(os.path.splitext(filename)[0])[1] + if suffix == CPPWG_EXT: + continue + + self.package_info.source_hpp_files.append(filepath) + + def extract_templates_from_source(self) -> None: """ - Write the header collection to file + Extract template arguments for each class from the associated source file """ - - header_collection_writer = CppHeaderCollectionWriter(self.package_info, - self.wrapper_root) - header_collection_writer.write() - header_collection_path = self.wrapper_root + "/" - header_collection_path += header_collection_writer.header_file_name - return header_collection_path + for module_info in self.package_info.module_info_collection: + info_helper = CppInfoHelper(module_info) + for class_info in module_info.class_info_collection: + info_helper.extract_templates_from_source(class_info) + + def map_classes_to_hpp_files(self) -> None: + """ + Attempt to map source file paths to each class, assuming the containing + file name is the class name + """ + for module_info in self.package_info.module_info_collection: + for class_info in module_info.class_info_collection: + for hpp_file_path in self.package_info.source_hpp_files: + hpp_file_name = os.path.basename(hpp_file_path) + if class_info.name == os.path.splitext(hpp_file_name)[0]: + class_info.source_file_full_path = hpp_file_path + if class_info.source_file is None: + class_info.source_file = hpp_file_name + + def parse_header_collection(self) -> None: + """ + Parse the headers with pygccxml and CastXML to populate the source + namespace with C++ declarations collected from the source tree + """ - def parse_header_collection(self, header_collection_path): + source_parser = CppSourceParser( + self.source_root, + self.header_collection_filepath, + self.castxml_binary, + self.source_includes, + self.castxml_cflags, + ) + self.source_ns = source_parser.parse() + def parse_package_info(self): """ - Parse the header collection with pygccxml and Castxml - to population the global and source namespaces + Parse the package info file to create a PackageInfo object """ - - source_parser = CppSourceParser(self.source_root, - header_collection_path, - self.castxml_binary, - self.source_includes) - source_parser.parse() - self.global_ns = source_parser.global_ns - self.source_ns = source_parser.source_ns - def get_wrapper_template(self): - + if self.package_info_path: + # If a package info file exists, parse it to create a PackageInfo object + info_parser = PackageInfoParser(self.package_info_path, self.source_root) + self.package_info = info_parser.parse() + + else: + # If no package info file exists, create a PackageInfo object with default settings + self.package_info = PackageInfo("cppwg_package", self.source_root) + + def update_class_info(self) -> None: """ - Return the string templates for the wrappers + Update the class info with class declarations parsed by pygccxml from + the C++ source code. """ - - return wrapper_templates.template_collection - def update_free_function_info(self): - + for module_info in self.package_info.module_info_collection: + if module_info.use_all_classes: + # Create class info objects for all class declarations found + # from parsing the source code with pygccxml. + # Note: as module_info.use_all_classes == True, no class info + # objects were created while parsing the package info yaml file. + class_decls = self.source_ns.classes(allow_empty=True) + for class_decl in class_decls: + if module_info.is_decl_in_source_path(class_decl): + class_info = CppClassInfo(class_decl.name) + class_info.module_info = module_info + class_info.decl = class_decl + module_info.class_info_collection.append(class_info) + + else: + # As module_info.use_all_classes == False, class info objects + # have already been created while parsing the package info file. + # We only need to add the decl from pygccxml's output. + for class_info in module_info.class_info_collection: + class_decls = self.source_ns.classes( + class_info.name, allow_empty=True + ) + if len(class_decls) == 1: + class_info.decl = class_decls[0] + + def update_free_function_info(self) -> None: """ - Update the free function info pased on pygccxml output + Update the free function info with declarations parsed by pygccxml from + the C++ source code. """ - - for eachModule in self.package_info.module_info: - if eachModule.use_all_free_functions: + + for module_info in self.package_info.module_info_collection: + if module_info.use_all_free_functions: + # Create free function info objects for all free function + # declarations found from parsing the source code with pygccxml. + # Note: as module_info.use_all_free_functions == True, no class info + # objects were created while parsing the package info yaml file. free_functions = self.source_ns.free_functions(allow_empty=True) - for eachFunction in free_functions: - if eachModule.is_decl_in_source_path(eachFunction): - function_info = CppFreeFunctionInfo(eachFunction.name) - function_info.module_info = eachModule - function_info.decl = eachFunction - eachModule.free_function_info.append(function_info) + for free_function in free_functions: + if module_info.is_decl_in_source_path(free_function): + function_info = CppFreeFunctionInfo(free_function.name) + function_info.module_info = module_info + function_info.decl = free_function + module_info.free_function_info_collection.append(function_info) else: - for eachFunction in eachModule.free_function_info: - functions = self.source_ns.free_functions(eachFunction.name, - allow_empty=True) - if len(functions) == 1: - eachFunction.decl = functions[0] - - def update_class_info(self): - - """ - Update the class info pased on pygccxml output - """ - - for eachModule in self.package_info.module_info: - if eachModule.use_all_classes: - classes = self.source_ns.classes(allow_empty=True) - for eachClass in classes: - if eachModule.is_decl_in_source_path(eachClass): - class_info = CppClassInfo(eachClass.name) - class_info.module_info = eachModule - class_info.decl = eachClass - eachModule.class_info.append(class_info) - else: - for eachClass in eachModule.class_info: - classes = self.source_ns.classes(eachClass.name, - allow_empty=True) - if len(classes) == 1: - eachClass.decl = classes[0] + # As module_info.use_all_free_functions == False, free function + # info objects have already been created while parsing the + # package info file. We only need to add the decl from pygccxml's output. + for free_function_info in module_info.free_function_info_collection: + free_functions = self.source_ns.free_functions( + free_function_info.name, allow_empty=True + ) + if len(free_functions) == 1: + free_function_info.decl = free_functions[0] - def generate_wrapper(self): - + def write_header_collection(self) -> None: """ - Main method for wrapper generation + Write the header collection to file """ - # If there is an input file, parse it - if self.package_info_path is not None: - info_parser = PackageInfoParser(self.package_info_path, - self.source_root) - info_parser.parse() - self.package_info = info_parser.package_info - else: - self.package_info = PackageInfo("cppwg_package", self.source_root) + header_collection_writer = CppHeaderCollectionWriter( + self.package_info, + self.wrapper_root, + self.header_collection_filepath, + ) + header_collection_writer.write() + + def write_wrappers(self) -> None: + """ + Write all the wrappers required for the package + """ + for module_info in self.package_info.module_info_collection: + module_writer = CppModuleWrapperWriter( + self.source_ns, + module_info, + wrapper_templates.template_collection, + self.wrapper_root, + ) + module_writer.write() + + def generate_wrapper(self) -> None: + """ + Main method for generating all the wrappers + """ - # Generate a header collection + # Parse the input yaml for package, module, and class information + self.parse_package_info() + + # Search for header files in the source root self.collect_source_hpp_files() - # Attempt to assign source paths to each class, assuming the containing - # file name is the class name - for eachModule in self.package_info.module_info: - for eachClass in eachModule.class_info: - for eachPath in self.package_info.source_hpp_files: - base = ntpath.basename(eachPath) - if eachClass.name == base.split('.')[0]: - eachClass.source_file_full_path = eachPath - if eachClass.source_file is None: - eachClass.source_file = base - - # Attempt to automatically generate template args for each class - for eachModule in self.package_info.module_info: - info_genenerator = CppInfoHelper(eachModule) - for eachClass in eachModule.class_info: - info_genenerator.expand_templates(eachClass, "class") - - # Generate the header collection - header_collection_path = self.generate_header_collection() - - # Parse the header collection - self.parse_header_collection(header_collection_path) - - # Update the Class and Free Function Info from the parsed code + # Map each class to a header file + self.map_classes_to_hpp_files() + + # Attempt to extract templates for each class from the source files + self.extract_templates_from_source() + + # Write the header collection to file + self.write_header_collection() + + # Parse the headers with pygccxml and CastXML + self.parse_header_collection() + + # Update the Class Info from the parsed code self.update_class_info() + + # Update the Free Function Info from the parsed code self.update_free_function_info() - # Write the modules - for eachModule in self.package_info.module_info: - module_writer = CppModuleWrapperWriter(self.global_ns, - self.source_ns, - eachModule, - self.get_wrapper_template(), - self.wrapper_root) - module_writer.write() + # Write all the wrappers required + self.write_wrappers() diff --git a/cppwg/input/base_info.py b/cppwg/input/base_info.py index 09becfc..24cfa2a 100644 --- a/cppwg/input/base_info.py +++ b/cppwg/input/base_info.py @@ -1,89 +1,139 @@ -""" -Generic information structure for packages, modules and cpp types -""" +from typing import Any, Dict, List, Optional -class BaseInfo(object): - +class BaseInfo: """ - :param: name - the feature name, as it appears in its definition - :param: source_includes - a list of source files to be included with the feature - :param: calldef_excludes - do not include calldefs matching these patterns - :param: smart_ptr_type - handle classes with this smart pointer type - :param: template_substitutions - a list of template substitution sequences - :param: pointer_call_policy - the default pointer call policy - :param: reference_call_policy - the default reference call policy - :param: extra_code - any extra wrapper code for the feature - :param: prefix_code - any wrapper code that precedes the feature - :param: excluded_methods - do not include these methods - :param: excluded_variables - do not include these variables - :param: constructor_arg_type_excludes - list of exlude patterns for ctors - :param: return_type_exludes - list of exlude patterns for return types - :param: arg_type_excludes - list of exlude patterns for arg types + Generic information structure for features (i.e packages, modules, classes, + free functions, etc.) + + Attributes + ---------- + name : str + The feature name, as it appears in its definition. + source_includes : List[str] + A list of source files to be included with the feature. + calldef_excludes : List[str] + Do not include calldefs matching these patterns. + smart_ptr_type : str, optional + Handle classes with this smart pointer type. + template_substitutions : Dict[str, List[Any]] + A list of template substitution sequences. + pointer_call_policy : str, optional + The default pointer call policy. + reference_call_policy : str, optional + The default reference call policy. + extra_code : List[str] + Any extra wrapper code for the feature. + prefix_code : List[str] + Any wrapper code that precedes the feature. + custom_generator : str, optional + A custom generator for the feature. + excluded_methods : List[str] + Do not include these methods. + excluded_variables : List[str] + Do not include these variables. + constructor_arg_type_excludes : List[str] + List of exclude patterns for ctors. + return_type_excludes : List[str] + List of exclude patterns for return types. + arg_type_excludes : List[str] + List of exclude patterns for arg types. + name_replacements : Dict[str, str] + A dictionary of name replacements e.g. {"double":"Double", "unsigned int":"Unsigned"} """ def __init__(self, name): - - self.name = name - self.source_includes = [] - self.calldef_excludes = [] - self.smart_ptr_type = None - self.template_substitutions = [] - self.pointer_call_policy = None - self.reference_call_policy = None - self.extra_code = [] - self.prefix_code = [] - self.custom_generator = None - self.excluded_methods = None - self.excluded_variables = None - self.constructor_arg_type_excludes = [] - self.return_type_excludes = [] - self.arg_type_excludes = [] - self.name_replacements = {"double": "Double", - "unsigned int": "Unsigned", - "Unsigned int": "Unsigned", - "unsigned": "Unsigned", - "double": "Double", - "std::vector": "Vector", - "std::pair": "Pair", - "std::map": "Map", - "std::string": "String", - "boost::shared_ptr": "SharedPtr", - "*": "Ptr", - "c_vector": "CVector", - "std::set": "Set"} - + self.name: str = name + self.source_includes: List[str] = [] + self.calldef_excludes: List[str] = [] + self.smart_ptr_type: Optional[str] = None + self.template_substitutions: Dict[str, List[Any]] = [] + self.pointer_call_policy: Optional[str] = None + self.reference_call_policy: Optional[str] = None + self.extra_code: List[str] = [] + self.prefix_code: List[str] = [] + self.custom_generator: Optional[str] = None + self.excluded_methods: List[str] = [] + self.excluded_variables: List[str] = [] + self.constructor_arg_type_excludes: List[str] = [] + self.return_type_excludes: List[str] = [] + self.arg_type_excludes: List[str] = [] + self.name_replacements: Dict[str, str] = { + "double": "Double", + "unsigned int": "Unsigned", + "Unsigned int": "Unsigned", + "unsigned": "Unsigned", + "double": "Double", + "std::vector": "Vector", + "std::pair": "Pair", + "std::map": "Map", + "std::string": "String", + "boost::shared_ptr": "SharedPtr", + "*": "Ptr", + "c_vector": "CVector", + "std::set": "Set", + } + @property - def parent(self): + def parent(self) -> Optional["BaseInfo"]: + """ + Return the parent object of the feature in the hierarchy. This is + overriden in subclasses e.g. ModuleInfo returns a PackageInfo, ClassInfo + returns a ModuleInfo, etc. + + Returns + ------- + Optional[BaseInfo] + The parent object. + """ return None - - def hierarchy_attribute(self, attribute_name): - + + def hierarchy_attribute(self, attribute_name: str) -> Any: """ - For the supplied attribute iterate through parents until a non None - value is found. If the tope level parent attribute is None, return None. + For the supplied attribute, iterate through parent objects until a non-None + value is found. If the top level parent (i.e. package) attribute is + None, return None. + + Parameters + ---------- + attribute_name : str + The attribute name to search for. + + Returns + ------- + Any + The attribute value. """ - + if hasattr(self, attribute_name) and getattr(self, attribute_name) is not None: return getattr(self, attribute_name) - else: - if hasattr(self, 'parent') and self.parent is not None: - return self.parent.hierarchy_attribute(attribute_name) - else: - return None - - def hierarchy_attribute_gather(self, attribute_name): - + + if hasattr(self, "parent") and self.parent is not None: + return self.parent.hierarchy_attribute(attribute_name) + + return None + + def hierarchy_attribute_gather(self, attribute_name: str) -> List[Any]: """ - For the supplied attribute iterate through parents gathering list entries. + For the supplied attribute, iterate through parent objects gathering list entries. + + Parameters + ---------- + attribute_name : str + The attribute name to search for. + + Returns + ------- + List[Any] + The list of attribute values. """ - - att_list = [] + + att_list: List[Any] = [] + if hasattr(self, attribute_name) and getattr(self, attribute_name) is not None: att_list.extend(getattr(self, attribute_name)) - if hasattr(self, 'parent') and self.parent is not None: - att_list.extend(self.parent.hierarchy_attribute_gather(attribute_name)) - else: - if hasattr(self, 'parent') and self.parent is not None: - att_list.extend(self.parent.hierarchy_attribute_gather(attribute_name)) - return att_list \ No newline at end of file + + if hasattr(self, "parent") and self.parent is not None: + att_list.extend(self.parent.hierarchy_attribute_gather(attribute_name)) + + return att_list diff --git a/cppwg/input/class_info.py b/cppwg/input/class_info.py index 96c399c..865d36e 100644 --- a/cppwg/input/class_info.py +++ b/cppwg/input/class_info.py @@ -1,17 +1,20 @@ -""" -Information structure common to C++ classes -""" +from typing import Any, Dict, Optional -from cppwg.input import cpp_type_info +from cppwg.input.cpp_type_info import CppTypeInfo -class CppClassInfo(cpp_type_info.CppTypeInfo): +class CppClassInfo(CppTypeInfo): + """ + This class holds information for individual C++ classes to be wrapped + """ + def __init__(self, name: str, class_config: Optional[Dict[str, Any]] = None): + + super(CppClassInfo, self).__init__(name, class_config) - def __init__(self, name, type_info_dict = None): - - super(CppClassInfo, self).__init__(name, type_info_dict) - @property - def parent(self): + def parent(self) -> "ModuleInfo": + """ + Returns the parent module info object + """ return self.module_info diff --git a/cppwg/input/cpp_type_info.py b/cppwg/input/cpp_type_info.py index cea3eb6..2832b8a 100644 --- a/cppwg/input/cpp_type_info.py +++ b/cppwg/input/cpp_type_info.py @@ -1,136 +1,171 @@ -""" -Information structure common to C++ variables, functions, methods and classes -""" +from typing import Any, Dict, List, Optional -from cppwg.input import base_info +from cppwg.input.base_info import BaseInfo +from pygccxml.declarations import declaration_t -class CppTypeInfo(base_info.BaseInfo): +class CppTypeInfo(BaseInfo): """ - :param: module_info - info of the owning module - :param: decl - the pygccxml version of the declaration - :param: source_file - over-ridden feature source file - :param: name_override - feature name override + This class holds information for C++ types including classes, free functions etc. + + Attributes + ---------- + module_info : ModuleInfo + The module info parent object associated with this type + source_file : str + The source file containing the type + source_file_full_path : str + The full path to the source file containing the type + name_override : str + The name override specified in config e.g. "CustomFoo" -> "Foo" + template_arg_lists : List[List[Any]] + List of template replacement arguments for the type e.g. [[2, 2], [3, 3]] + decl : declaration_t + The pygccxml declaration associated with this type """ - def __init__(self, name, type_info_dict = None): - + def __init__(self, name: str, type_config: Optional[Dict[str, Any]] = None): + super(CppTypeInfo, self).__init__(name) - - self.module_info = None - self.source_file_full_path = None - self.source_file = None - self.name_override = None - self.template_args = None - self.decl = None - if type_info_dict is not None: - for key in type_info_dict: - setattr(self, key, type_info_dict[key]) - - - def get_short_names(self): + self.module_info: Optional["ModuleInfo"] = None + self.source_file_full_path: Optional[str] = None + self.source_file: Optional[str] = None + self.name_override: Optional[str] = None + self.template_arg_lists: Optional[list[List[Any]]] = None + self.decl: Optional[declaration_t] = None + + if type_config: + for key, value in type_config.items(): + setattr(self, key, value) + + # TODO: Consider setting short and full names on init as read-only properties + def get_short_names(self) -> List[str]: """ Return the name of the class as it will appear on the Python side. This - collapses template arguements, separating them by underscores and removes - special characters. The return type is a list, as a class can have multiple - names if it is templated. + collapses template arguments, separating them by underscores and removes + special characters. The return type is a list, as a class can have + multiple names if it is templated. For example, a class "Foo" with + template arguments [[2, 2], [3, 3]] will have a short name list + ["Foo2_2", "Foo3_3"]. + + Returns + ------- + List[str] + The list of short names """ - if self.template_args is None: + # Handles untemplated classes + if self.template_arg_lists is None: if self.name_override is None: return [self.name] - else: - return [self.name_override] + return [self.name_override] + + short_names = [] + + # Table of special characters for removal + rm_chars = {"<": None, ">": None, ",": None, " ": None} + rm_table = str.maketrans(rm_chars) + + # Clean the type name + type_name = self.name + if self.name_override is not None: + type_name = self.name_override + + # Do standard name replacements e.g. "unsigned int" -> "Unsigned" + for name, replacement in self.name_replacements.items(): + type_name = type_name.replace(name, replacement) + + # Remove special characters + type_name = type_name.translate(rm_table) + + # Capitalize the first letter e.g. "foo" -> "Foo" + if len(type_name) > 1: + type_name = type_name[0].capitalize() + type_name[1:] + + # Create a string of template args separated by "_" e.g. 2_2 + for template_arg_list in self.template_arg_lists: + # Example template_arg_list : [2, 2] - names = [] - for eachTemplateArg in self.template_args: template_string = "" - for idx, eachTemplateEntry in enumerate(eachTemplateArg): - - # Do standard translations - current_name = str(eachTemplateEntry) - for eachReplacementString in self.name_replacements.keys(): - replacement = self.name_replacements[eachReplacementString] - current_name = current_name.replace(eachReplacementString, - replacement) - - table = current_name.maketrans(dict.fromkeys('<>:,')) - cleaned_entry = current_name.translate(table) - cleaned_entry = cleaned_entry.replace(" ", "") - if len(cleaned_entry) > 1: - first_letter = cleaned_entry[0].capitalize() - cleaned_entry = first_letter + cleaned_entry[1:] - template_string += str(cleaned_entry) - if(idx != len(eachTemplateArg)-1): - template_string += "_" + for idx, arg in enumerate(template_arg_list): + + # Do standard name replacements + arg_str = str(arg) + for name, replacement in self.name_replacements.items(): + arg_str = arg_str.replace(name, replacement) - current_name = self.name - if self.name_override is not None: - current_name = self.name_override + # Remove special characters + arg_str = arg_str.translate(rm_table) - # Do standard translations - for eachReplacementString in self.name_replacements.keys(): - replacement = self.name_replacements[eachReplacementString] - current_name = current_name.replace(eachReplacementString, - replacement) + # Capitalize the first letter + if len(arg_str) > 1: + arg_str = arg_str[0].capitalize() + arg_str[1:] - # Strip templates and scopes - table = current_name.maketrans(dict.fromkeys('<>:,')) - cleaned_name = current_name.translate(table) - cleaned_name = cleaned_name.replace(" ", "") - if len(cleaned_name) > 1: - cleaned_name = cleaned_name[0].capitalize()+cleaned_name[1:] - names.append(cleaned_name+template_string) - return names + # Add "_" between template arguments + template_string += arg_str + if idx < len(template_arg_list) - 1: + template_string += "_" - def get_full_names(self): + short_names.append(type_name + template_string) + return short_names + + def get_full_names(self) -> List[str]: """ - Return the name (declaration) of the class as it appears on the C++ side. - The return type is a list, as a class can have multiple - names (declarations) if it is templated. + Return the name (declaration) of the class as it appears on the C++ + side. The return type is a list, as a class can have multiple names + (declarations) if it is templated. For example, a class "Foo" with + template arguments [[2, 2], [3, 3]] will have a full name list + ["Foo<2,2 >", "Foo<3,3 >"]. + + Returns + ------- + List[str] + The list of full names """ - if self.template_args is None: + # Handles untemplated classes + if self.template_arg_lists is None: return [self.name] - names = [] - for eachTemplateArg in self.template_args: - template_string = "<" - for idx, eachTemplateEntry in enumerate(eachTemplateArg): - template_string += str(eachTemplateEntry) - if(idx == len(eachTemplateArg)-1): - template_string += " >" - else: - template_string += "," - names.append(self.name + template_string) - return names + full_names = [] + for template_arg_list in self.template_arg_lists: + # Create template string from arg list e.g. [2, 2] -> "<2,2 >" + template_string = ",".join([str(arg) for arg in template_arg_list]) + template_string = "<" + template_string + " >" - def needs_header_file_instantiation(self): + # Join full name e.g. "Foo<2,2 >" + full_names.append(self.name + template_string) + return full_names + + # TODO: This method is not used, remove it? + def needs_header_file_instantiation(self): """ Does this class need to be instantiated in the header file """ - return ((self.template_args is not None) and - (not self.include_file_only) and - (self.needs_instantiation)) + return ( + (self.template_arg_lists is not None) + and (not self.include_file_only) + and (self.needs_instantiation) + ) + # TODO: This method is not used, remove it? def needs_header_file_typdef(self): - """ Does this type need to be typdef'd with a nicer name in the header file. All template classes need this. """ - return (self.template_args is not None) and (not self.include_file_only) + return (self.template_arg_lists is not None) and (not self.include_file_only) + # TODO: This method is not used, remove it? def needs_auto_wrapper_generation(self): - """ Does this class need a wrapper to be autogenerated. """ - return not self.include_file_only \ No newline at end of file + return not self.include_file_only diff --git a/cppwg/input/free_function_info.py b/cppwg/input/free_function_info.py index 1eaacb3..fe008bd 100644 --- a/cppwg/input/free_function_info.py +++ b/cppwg/input/free_function_info.py @@ -1,20 +1,22 @@ -""" -Information for free functions -""" +from typing import Any, Dict, Optional -from cppwg.input import cpp_type_info +from cppwg.input.cpp_type_info import CppTypeInfo -class CppFreeFunctionInfo(cpp_type_info.CppTypeInfo): - +class CppFreeFunctionInfo(CppTypeInfo): """ - A container for free function types to be wrapped + This class holds information for individual free functions to be wrapped """ - def __init__(self, name, type_info_dict = None): - - super(CppFreeFunctionInfo, self).__init__(name, type_info_dict) + def __init__( + self, name: str, free_function_config: Optional[Dict[str, Any]] = None + ): + + super(CppFreeFunctionInfo, self).__init__(name, free_function_config) @property - def parent(self): - return self.module_info \ No newline at end of file + def parent(self) -> "ModuleInfo": + """ + Returns the parent module info object + """ + return self.module_info diff --git a/cppwg/input/info_helper.py b/cppwg/input/info_helper.py index 333fd6a..0acb7d3 100644 --- a/cppwg/input/info_helper.py +++ b/cppwg/input/info_helper.py @@ -1,70 +1,135 @@ -""" -Helper class for generating extra feature information based -on simple analysis of the source tree. -""" - import os +import re +import logging + +from typing import Any +from cppwg.input.base_info import BaseInfo +from cppwg.input.class_info import CppClassInfo +from cppwg.input.module_info import ModuleInfo -class CppInfoHelper(object): +class CppInfoHelper: """ - This attempts to automatically fill in some class info based on - simple analysis of the source tree. + Helper class that attempts to automatically fill in extra feature + information based on simple analysis of the source tree. + + Attributes + __________ + module_info : ModuleInfo + The module info object that this helper is working with. + class_dict : dict + A dictionary of class info objects keyed by class name. """ - def __init__(self, module_info): + def __init__(self, module_info: ModuleInfo): - self.module_info = module_info - - self.class_dict = {} - self.setup_class_dict() - - def setup_class_dict(self): + self.module_info: ModuleInfo = module_info - # For convenience collect class info in a dict keyed by name - for eachClassInfo in self.module_info.class_info: - self.class_dict[eachClassInfo.name] = eachClassInfo + # For convenience, collect class info in a dict keyed by name + self.class_dict: Dict[str, CppClassInfo] = { + class_info.name: class_info + for class_info in module_info.class_info_collection + } - def expand_templates(self, feature_info, feature_type): + def extract_templates_from_source(self, feature_info: BaseInfo) -> None: + """ + Extract template arguments for a feature from the associated source + file. - template_substitutions = feature_info.hierarchy_attribute_gather('template_substitutions') - - if len(template_substitutions) == 0: + Parameters + __________ + feature_info : BaseInfo + The feature info object to expand. + """ + logger = logging.getLogger() + + if isinstance(feature_info, CppClassInfo): + feature_type = "class" + else: + logger.error(f"Unsupported feature type: {type(feature_info)}") + raise TypeError() + + # Skip if there are pre-defined template args + if feature_info.template_arg_lists: return - # Skip any features with pre-defined template args - no_template = feature_info.template_args is None + # Skip if there is no source file source_path = feature_info.source_file_full_path - if not (no_template and source_path is not None): + if not source_path: return - if not os.path.exists(source_path): + + if not os.path.isfile(source_path): + logger.error(f"Could not find source file: {source_path}") + raise FileNotFoundError() + + # Get list of template substitutions from this feature and its parents + # e.g. {"signature":"","replacement":[[2,2], [3,3]]} + template_substitutions: List[Dict[str, Any]] = ( + feature_info.hierarchy_attribute_gather("template_substitutions") + ) + + # Skip if there are no template substitutions + if len(template_substitutions) == 0: return - f = open(source_path) - lines = (line.rstrip() for line in f) # Remove blank lines - - lines = list(line for line in lines if line) - for idx, eachLine in enumerate(lines): - stripped_line = eachLine.replace(" ", "") - if idx+1 < len(lines): - stripped_next = lines[idx+1].replace(" ", "") - else: - continue - - for idx, eachSub in enumerate(template_substitutions): - template_args = eachSub['replacement'] - template_string = eachSub['signature'] - cleaned_string = template_string.replace(" ", "") - if cleaned_string in stripped_line: - feature_string = feature_type + feature_info.name - feature_decl_next = feature_string + ":" in stripped_next - feature_decl_whole = feature_string == stripped_next - if feature_decl_next or feature_decl_whole: - feature_info.template_args = template_args - break - f.close() + # Remove spaces from template substitution signatures + # e.g. -> + for tpl_sub in template_substitutions: + tpl_sub["signature"] = tpl_sub["signature"].replace(" ", "") + + # Remove whitespaces, blank lines, and directives from the source file + whitespace_regex = re.compile(r"\s+") + with open(source_path, "r") as in_file: + lines = [re.sub(whitespace_regex, "", line) for line in in_file] + lines = [line for line in lines if line and not line.startswith("#")] - def do_custom_template_substitution(self, feature_info): + # Search for template signatures in the source file lines + for idx in range(len(lines) - 1): + curr_line = lines[idx] + next_line = lines[idx + 1] - pass \ No newline at end of file + for template_substitution in template_substitutions: + # e.g. template + signature: str = "template" + template_substitution["signature"] + + # e.g. [[2,2], [3,3]] + replacement: List[List[Any]] = template_substitution["replacement"] + + if signature in curr_line: + feature_string = feature_type + feature_info.name # e.g. "classFoo" + + declaration_found = False + + if feature_string == next_line: + # template + # classFoo + declaration_found = True + + elif next_line.startswith(feature_string + "{"): + # template + # classFoo{ + declaration_found = True + + elif next_line.startswith(feature_string + ":"): + # template + # classFoo:publicBar + declaration_found = True + + elif curr_line == signature + feature_string: + # templateclassFoo + declaration_found = True + + elif curr_line.startswith(signature + feature_string + "{"): + # templateclassFoo{ + declaration_found = True + + elif curr_line.startswith(signature + feature_string + ":"): + # templateclassFoo:publicBar + declaration_found = True + + # TODO: Add support for more cases, or find a better way e.g. regex or castxml? + + if declaration_found: + feature_info.template_arg_lists = replacement + break diff --git a/cppwg/input/method_info.py b/cppwg/input/method_info.py index 1bd3def..18b7529 100644 --- a/cppwg/input/method_info.py +++ b/cppwg/input/method_info.py @@ -1,25 +1,27 @@ -""" -Information for methods -""" +from typing import Optional -from cppwg.input import cpp_type_info +from cppwg.input.cpp_type_info import CppTypeInfo -class CppMethodInfo(cpp_type_info.CppTypeInfo): - +class CppMethodInfo(CppTypeInfo): """ - A container for method types to be wrapped + This class holds information for individual methods to be wrapped + + Attributes + ---------- + class_info : CppClassInfo + The class info parent object associated with this method """ - def __init__(self, name, _): - + def __init__(self, name: str, _): + super(CppMethodInfo, self).__init__(name) - - self.class_info = None - + + self.class_info: Optional["CppClassInfo"] = None + @property - def parent(self): + def parent(self) -> "CppClassInfo": + """ + Returns the parent class info object + """ return self.class_info - - - diff --git a/cppwg/input/module_info.py b/cppwg/input/module_info.py index 3cefa08..75a2b97 100644 --- a/cppwg/input/module_info.py +++ b/cppwg/input/module_info.py @@ -1,46 +1,81 @@ -""" -Information for individual modules -""" +import os -from cppwg.input import base_info +from typing import Any, Dict, List, Optional -class ModuleInfo(base_info.BaseInfo): +from cppwg.input.base_info import BaseInfo +from pygccxml.declarations import declaration_t + + +class ModuleInfo(BaseInfo): """ - Information for individual modules + This class holds information for individual modules + + Attributes + ---------- + package_info : PackageInfo + The package info parent object associated with this module + source_locations : List[str] + A list of source locations for this module + class_info_collection : List[CppClassInfo] + A list of class info objects associated with this module + free_function_info_collection : List[CppFreeFunctionInfo] + A list of free function info objects associated with this module + variable_info_collection : List[CppFreeFunctionInfo] + A list of variable info objects associated with this module + use_all_classes : bool + Use all classes in the module + use_all_free_functions : bool + Use all free functions in the module + use_all_variables : bool + Use all variables in the module """ - def __init__(self, name, type_info_dict = None): - + def __init__(self, name: str, module_config: Optional[Dict[str, Any]] = None): + super(ModuleInfo, self).__init__(name) - self.package_info = None - self.source_locations = None - self.class_info = [] - self.free_function_info = [] - self.variable_info = [] - self.use_all_classes = False - self.use_all_free_functions = False - - if type_info_dict is not None: - for key in type_info_dict: - setattr(self, key, type_info_dict[key]) - + self.package_info: Optional["PackageInfo"] = None + self.source_locations: List[str] = None + self.class_info_collection: List["CppClassInfo"] = [] + self.free_function_info_collection: List["CppFreeFunctionInfo"] = [] + self.variable_info_collection: List["CppFreeFunctionInfo"] = [] + self.use_all_classes: bool = False + self.use_all_free_functions: bool = False + self.use_all_variables: bool = False + + if module_config: + for key, value in module_config.items(): + setattr(self, key, value) + @property - def parent(self): + def parent(self) -> "PackageInfo": + """ + Returns the parent package info object + """ return self.package_info - def is_decl_in_source_path(self, decl): - + def is_decl_in_source_path(self, decl: declaration_t) -> bool: """ - Return is the declaration associated with a file in the current source path + Check if the declaration is associated with a file in a specified source path + + Parameters + ---------- + decl : declaration_t + The declaration to check + + Returns + ------- + bool + True if the declaration is associated with a file in a specified source path """ if self.source_locations is None: return True - for eachSourceLocation in self.source_locations: - location = self.package_info.source_root + "/" + eachSourceLocation + "/" - if location in decl.location.file_name: + for source_location in self.source_locations: + full_path = os.path.join(self.package_info.source_root, source_location) + if full_path in decl.location.file_name: return True + return False diff --git a/cppwg/input/package_info.py b/cppwg/input/package_info.py index c418504..2c030cc 100644 --- a/cppwg/input/package_info.py +++ b/cppwg/input/package_info.py @@ -1,47 +1,62 @@ -""" -Information for the package -""" +from typing import Any, Dict, List, Optional -from cppwg.input import base_info +from cppwg.input.base_info import BaseInfo -class PackageInfo(base_info.BaseInfo): - +class PackageInfo(BaseInfo): """ - Information for individual modules + This class holds the package information + + Attributes + ---------- + name : str + The name of the package + source_locations : List[str] + A list of source locations for this package + module_info_collection : List[ModuleInfo] + A list of module info objects associated with this package + source_root : str + The root directory of the C++ source code + source_hpp_patterns : List[str] + A list of source file patterns to include + source_hpp_files : List[str] + A list of source file names to include + common_include_file : bool + Use a common include file for all source files """ - def __init__(self, name, source_root, type_info_dict = None): - + def __init__( + self, + name: str, + source_root: str, + package_config: Optional[Dict[str, Any]] = None, + ): + """ + Parameters: + ----------- + name : str + source_root : str + package_config : Dict[str, Any] + A dictionary of package configuration settings + """ + super(PackageInfo, self).__init__(name) - self.name = name - self.source_locations = None - self.module_info = [] - self.source_root = source_root - self.source_hpp_patterns = ["*.hpp"] - self.source_hpp_files = [] - self.common_include_file = False - - if type_info_dict is not None: - for key in type_info_dict: - setattr(self, key, type_info_dict[key]) - - @property - def parent(self): - return None - - def is_decl_in_source_path(self, decl): + self.name: str = name + self.source_locations: List[str] = None + self.module_info_collection: List["ModuleInfo"] = [] + self.source_root: str = source_root + self.source_hpp_patterns: List[str] = ["*.hpp"] + self.source_hpp_files: List[str] = [] + self.common_include_file: bool = False + if package_config: + for key, value in package_config.items(): + setattr(self, key, value) + + @property + def parent(self) -> None: """ - Return is the declaration associated with a file in the current source path + Returns None as this is the top level object in the hierarchy """ - - if self.source_locations is None: - return True - - for eachSourceLocation in self.source_locations: - location = self.package_info.source_root + "/" + eachSourceLocation + "/" - if location in decl.location.file_name: - return True - return False + return None diff --git a/cppwg/input/variable_info.py b/cppwg/input/variable_info.py index 77f2c5a..97b35a7 100644 --- a/cppwg/input/variable_info.py +++ b/cppwg/input/variable_info.py @@ -1,17 +1,13 @@ -""" -Information for variables -""" +from typing import Any, Optional -from cppwg.input import cpp_type_info +from cppwg.input.cpp_type_info import CppTypeInfo -class CppVariableInfo(cpp_type_info.CppTypeInfo): - +class CppVariableInfo(CppTypeInfo): """ - A container for variable types to be wrapped + This class holds information for individual variables to be wrapped """ - def __init__(self, name, type_info_dict = None): - - super(CppVariableInfo, self).__init__(name, type_info_dict) + def __init__(self, name: str, variable_config: Optional[Dict[str, Any]] = None): + super(CppVariableInfo, self).__init__(name, variable_config) diff --git a/cppwg/parsers/package_info.py b/cppwg/parsers/package_info.py index b75204b..252e2a9 100644 --- a/cppwg/parsers/package_info.py +++ b/cppwg/parsers/package_info.py @@ -1,146 +1,267 @@ import os -import imp -import ntpath +import importlib.util +import logging +import sys import yaml -from cppwg.input.package_info import PackageInfo -from cppwg.input.module_info import ModuleInfo +from typing import Any, Optional + +import cppwg.templates.custom + +from cppwg.input.base_info import BaseInfo from cppwg.input.class_info import CppClassInfo from cppwg.input.free_function_info import CppFreeFunctionInfo +from cppwg.input.module_info import ModuleInfo +from cppwg.input.package_info import PackageInfo + +from cppwg.utils import utils +from cppwg.utils.constants import CPPWG_SOURCEROOT_STRING + + +class PackageInfoParser: + """ + Parse for the package info yaml file + + Attributes + ---------- + input_filepath : str + The path to the package info yaml file + source_root : str + The root directory of the C++ source code + raw_package_info : Dict[str, Any] + Raw info from the yaml file + package_info : Optional[PackageInfo] + The parsed package info + """ + + def __init__(self, input_filepath: str, source_root: str): + """ + Parameters + ---------- + input_filepath : str + The path to the package info yaml file. + source_root : str + The root directory of the C++ source code. + """ + + self.input_filepath: str = input_filepath + self.source_root: str = source_root + + # For holding raw info from the yaml file + self.raw_package_info: Dict[str, Any] = {} + + # The parsed package info + self.package_info: Optional[PackageInfo] = None + + def check_for_custom_generators(self, info: BaseInfo) -> None: + """ + Check if a custom generator is specified and load it into a module. + + Parameters + ---------- + info : BaseInfo + The info object to check for a custom generator - might be info + about a package, module, class, or free function. + """ + logger = logging.getLogger() + + if not info.custom_generator: + return + + # Replace the `CPPWG_SOURCEROOT` placeholder in the custom generator + # string if needed. For example, a custom generator might be specified + # as `custom_generator: CPPWG_SOURCEROOT/path/to/CustomGenerator.py` + filepath: str = info.custom_generator.replace( + CPPWG_SOURCEROOT_STRING, self.source_root + ) + filepath = os.path.abspath(filepath) + + # Verify that the custom generator file exists + if not os.path.isfile(filepath): + logger.error( + f"Could not find specified custom generator for {info.name}: {filepath}" + ) + raise FileNotFoundError() + + logger.info(f"Custom generator for {info.name}: {filepath}") + + # Load the custom generator as a module + module_name: str = os.path.splitext(filepath)[0] # /path/to/CustomGenerator + class_name: str = os.path.basename(module_name) # CustomGenerator + + spec = importlib.util.spec_from_file_location(module_name, filepath) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + # Get the custom generator class from the loaded module. + # Note: The custom generator class name must match the filename. + CustomGeneratorClass: cppwg.templates.custom.Custom = getattr( + module, class_name + ) + + # Replace the `info.custom_generator` string with a new object created + # from the provided custom generator class + info.custom_generator = CustomGeneratorClass() + + def parse(self) -> PackageInfo: + """ + Parse the package info yaml file to extract information about the + package, modules, classes, and free functions. + + Returns + ------- + PackageInfo + The object holding data from the parsed package info yaml file. + """ + logger = logging.getLogger() + logger.info("Parsing package info file.") + + with open(self.input_filepath, "r") as input_filepath: + self.raw_package_info = yaml.safe_load(input_filepath) + # Default config options that apply to the package, modules, classes, and free functions + global_config: Dict[str, Any] = { + "source_includes": [], + "smart_ptr_type": None, + "calldef_excludes": None, + "return_type_excludes": None, + "template_substitutions": [], + "pointer_call_policy": None, + "reference_call_policy": None, + "constructor_arg_type_excludes": None, + "excluded_methods": [], + "excluded_variables": [], + "custom_generator": None, + "prefix_code": [], + } -class PackageInfoParser(object): - - def __init__(self, input_file, source_root): - - self.input_file = input_file - self.raw_info = {} - self.package_info = None - self.source_root = source_root - - def subsititute_bool_string(self, option, input_dict, on_string="ON", off_string="OFF"): - - is_string = isinstance(input_dict[option], str) - if is_string and input_dict[option].strip().upper() == off_string: - input_dict[option] = False - elif is_string and input_dict[option].strip().upper() == on_string: - input_dict[option] = True - - def is_option_ALL(self, option, input_dict, check_string = "CPPWG_ALL"): - - is_string = isinstance(input_dict[option], str) - return is_string and input_dict[option].upper() == check_string - - def check_for_custom_generators(self, feature_info): - - # Replace source root if needed - if feature_info.custom_generator is not None: - path = feature_info.custom_generator.replace("CPPWG_SOURCEROOT", self.source_root) - path = os.path.realpath(path) - print (feature_info.name, path) - if os.path.isfile(path): - module_name = ntpath.basename(path).split(".")[0] - custom_module = imp.load_source(os.path.splitext(path)[0], path) - feature_info.custom_generator = getattr(custom_module, module_name)() - - def parse(self): - - with open(self.input_file, 'r') as inpfile: - data = inpfile.read() - - self.raw_info = yaml.safe_load(data) - - global_defaults = {'source_includes': [], - 'smart_ptr_type': None, - 'calldef_excludes': None, - 'return_type_excludes': None, - 'template_substitutions': [], - 'pointer_call_policy': None, - 'reference_call_policy': None, - 'constructor_arg_type_excludes': None, - 'excluded_methods': [], - 'excluded_variables': [], - 'custom_generator' : None, - 'prefix_code': []} - - # Parse package data - package_defaults = {'name': 'cppwg_package', - 'common_include_file': True, - 'source_hpp_patterns': ["*.hpp"]} - package_defaults.update(global_defaults) - for eachEntry in package_defaults.keys(): - if eachEntry in self.raw_info: - package_defaults[eachEntry] = self.raw_info[eachEntry] - self.subsititute_bool_string('common_include_file', package_defaults) - - self.package_info = PackageInfo(package_defaults['name'], self.source_root, package_defaults) + # Get package config from the raw package info + package_config: Dict[str, Any] = { + "name": "cppwg_package", + "common_include_file": True, + "source_hpp_patterns": ["*.hpp"], + } + package_config.update(global_config) + + for key in package_config.keys(): + if key in self.raw_package_info: + package_config[key] = self.raw_package_info[key] + utils.substitute_bool_for_string(package_config, "common_include_file") + + # Create the PackageInfo object from the package config dict + self.package_info = PackageInfo( + package_config["name"], self.source_root, package_config + ) self.check_for_custom_generators(self.package_info) - # Parse module data - for eachModule in self.raw_info['modules']: - module_defaults = {'name':'cppwg_module', - 'source_locations': None, - 'classes': [], - 'free_functions': [], - 'variables': [], - 'use_all_classes': False, - 'use_all_free_functions': False} - module_defaults.update(global_defaults) - - for eachEntry in module_defaults.keys(): - if eachEntry in eachModule: - module_defaults[eachEntry] = eachModule[eachEntry] - - # Do classes - class_info_collection = [] - module_defaults['use_all_classes'] = self.is_option_ALL('classes', module_defaults) - if not module_defaults['use_all_classes']: - if module_defaults['classes'] is not None: - for eachClass in module_defaults['classes']: - class_defaults = { 'name_override': None, - 'source_file': None} - class_defaults.update(global_defaults) - - for eachEntry in class_defaults.keys(): - if eachEntry in eachClass: - class_defaults[eachEntry] = eachClass[eachEntry] - class_info = CppClassInfo(eachClass['name'], class_defaults) - self.check_for_custom_generators(class_info) - class_info_collection.append(class_info) - - # Do functions - function_info_collection = [] - module_defaults['use_all_free_functions'] = self.is_option_ALL('free_functions', - module_defaults) - if not module_defaults['use_all_free_functions']: - if module_defaults['free_functions'] is not None: - for _ in module_defaults['free_functions']: - ff_defaults = { 'name_override': None, - 'source_file': None} - ff_defaults.update(global_defaults) - function_info = CppFreeFunctionInfo(ff_defaults['name'], ff_defaults) - function_info_collection.append(function_info) - - variable_collection = [] - use_all_variables = self.is_option_ALL('variables', module_defaults) - if not use_all_variables: - for _ in module_defaults['variables']: - variable_defaults = { 'name_override': None, - 'source_file': None} - variable_defaults.update(global_defaults) - variable_info = CppFreeFunctionInfo(variable_defaults['name'], variable_defaults) - variable_collection.append(variable_info) - - module_info = ModuleInfo(module_defaults['name'], module_defaults) - module_info.class_info = class_info_collection - module_info.free_function_info = function_info_collection - module_info.variable_info = variable_collection - for class_info in module_info.class_info: - class_info.module_info = module_info - for free_function_info in module_info.free_function_info: - free_function_info.module_info = module_info - for variable_info in module_info.variable_info: - variable_info.module_info = module_info - self.package_info.module_info.append(module_info) - module_info.package_info = self.package_info + # Parse the module data + for raw_module_info in self.raw_package_info["modules"]: + # Get module config from the raw module info + module_config = { + "name": "cppwg_module", + "source_locations": None, + "classes": [], + "free_functions": [], + "variables": [], + "use_all_classes": False, + "use_all_free_functions": False, + } + module_config.update(global_config) + + for key in module_config.keys(): + if key in raw_module_info: + module_config[key] = raw_module_info[key] + + module_config["use_all_classes"] = utils.is_option_ALL( + module_config["classes"] + ) + + module_config["use_all_free_functions"] = utils.is_option_ALL( + module_config["free_functions"] + ) + + module_config["use_all_variables"] = utils.is_option_ALL( + module_config["variables"] + ) + + # Create the ModuleInfo object from the module config dict + module_info = ModuleInfo(module_config["name"], module_config) self.check_for_custom_generators(module_info) + + # Connect the module to the package + module_info.package_info = self.package_info + self.package_info.module_info_collection.append(module_info) + + # Parse the class data and create class info objects. + # Note: if module_config["use_all_classes"] == True, class info + # objects will be added later after parsing the C++ source code. + if not module_config["use_all_classes"]: + if module_config["classes"]: + for raw_class_info in module_config["classes"]: + # Get class config from the raw class info + class_config = {"name_override": None, "source_file": None} + class_config.update(global_config) + + for key in class_config.keys(): + if key in raw_class_info: + class_config[key] = raw_class_info[key] + + # Create the CppClassInfo object from the class config dict + class_info = CppClassInfo(raw_class_info["name"], class_config) + self.check_for_custom_generators(class_info) + + # Connect the class to the module + class_info.module_info = module_info + module_info.class_info_collection.append(class_info) + + + # Parse the free function data and create free function info objects. + # Note: if module_config["use_all_free_functions"] == True, free function + # info objects will be added later after parsing the C++ source code. + if not module_config["use_all_free_functions"]: + if module_config["free_functions"]: + for raw_free_function_info in module_config["free_functions"]: + # Get free function config from the raw free function info + free_function_config = { + "name_override": None, + "source_file": None, + } + free_function_config.update(global_config) + + for key in free_function_config.keys(): + if key in raw_free_function_info: + free_function_config[key] = raw_free_function_info[key] + + # Create the CppFreeFunctionInfo object from the free function config dict + free_function_info = CppFreeFunctionInfo( + free_function_config["name"], free_function_config + ) + + # Connect the free function to the module + free_function_info.module_info = module_info + module_info.free_function_info_collection.append( + free_function_info + ) + + # Parse the variable data + if not module_config["use_all_variables"]: + for raw_variable_info in module_config["variables"]: + # Get variable config from the raw variable info + variable_config = {"name_override": None, "source_file": None} + variable_config.update(global_config) + + for key in variable_config.keys(): + if key in raw_variable_info: + variable_config[key] = raw_variable_info[key] + + # Create the CppFreeFunctionInfo object from the variable config dict + variable_info = CppFreeFunctionInfo( + variable_config["name"], variable_config + ) + + # Connect the variable to the module + variable_info.module_info = module_info + module_info.variable_info_collection.append(variable_info) + + return self.package_info diff --git a/cppwg/parsers/source_parser.py b/cppwg/parsers/source_parser.py index 29066b5..4f263b9 100644 --- a/cppwg/parsers/source_parser.py +++ b/cppwg/parsers/source_parser.py @@ -1,49 +1,116 @@ -#!/usr/bin/env python +import logging -""" -Parse the single header file using CastXML and pygccxml -""" +from pathlib import Path +from typing import List, Optional from pygccxml import parser, declarations +from pygccxml.declarations import declaration_t -class CppSourceParser(): +# declaration_t is the base type for all declarations in pygccxml including: +# - class_declaration_t (pygccxml.declarations.class_declaration.class_declaration_t) +# - class_t (pygccxml.declarations.class_declaration.class_t) +# - constructor_t (pygccxml.declarations.calldef_members.constructor_t) +# - destructor_t (pygccxml.declarations.calldef_members.destructor_t) +# - free_function_t (pygccxml.declarations.free_calldef.free_function_t) +# - free_operator_t (pygccxml.declarations.free_calldef.free_operator_t) +# - member_function_t (pygccxml.declarations.calldef_members.member_function_t) +# - member_operator_t (pygccxml.declarations.calldef_members.member_operator_t) +# - typedef_t (pygccxml.declarations.typedef.typedef_t) +# - variable_t (pygccxml.declarations.variable.variable_t) - def __init__(self, source_root, wrapper_header_collection, - castxml_binary, source_includes): +from pygccxml.declarations.mdecl_wrapper import mdecl_wrapper_t +from pygccxml.declarations.namespace import namespace_t - self.source_root = source_root - self.wrapper_header_collection = wrapper_header_collection - self.castxml_binary = castxml_binary - self.source_includes = source_includes - self.global_ns = None - self.source_ns = None - def parse(self): +class CppSourceParser: + """ + Parser for C++ source code - xml_generator_config = parser.xml_generator_configuration_t(xml_generator_path=self.castxml_binary, - xml_generator="castxml", - cflags="-std=c++11", - include_paths=self.source_includes) + Attributes + ---------- + source_root : str + The root directory of the source code + wrapper_header_collection : str + The path to the header collection file + castxml_binary : str + The path to the CastXML binary + source_includes : List[str] + The list of source include paths + castxml_cflags : str + Optional cflags to be passed to CastXML e.g. "-std=c++17" + global_ns : namespace_t + The namespace containing all parsed C++ declarations + source_ns : namespace_t + The namespace containing C++ declarations from the source tree + """ - print ("INFO: Parsing Code") - decls = parser.parse([self.wrapper_header_collection], xml_generator_config, - compilation_mode=parser.COMPILATION_MODE.ALL_AT_ONCE) + def __init__( + self, + source_root: str, + wrapper_header_collection: str, + castxml_binary: str, + source_includes: List[str], + castxml_cflags: str = "", + ): + self.source_root: str = source_root + self.wrapper_header_collection: str = wrapper_header_collection + self.castxml_binary: str = castxml_binary + self.source_includes: List[str] = source_includes + self.castxml_cflags: str = castxml_cflags + + self.source_ns: Optional[namespace_t] = None + self.global_ns: Optional[namespace_t] = None + + def parse(self) -> namespace_t: + """ + Parses C++ source code from the header collection using CastXML and pygccxml. + + Returns + ------- + namespace_t + The namespace containing C++ declarations from the source tree + """ + logger = logging.getLogger() + + # Configure the XML generator (CastXML) + xml_generator_config = parser.xml_generator_configuration_t( + xml_generator_path=self.castxml_binary, + xml_generator="castxml", + cflags=self.castxml_cflags, + include_paths=self.source_includes, + ) + + # Parse all the C++ source code to extract declarations + logger.info("Parsing source code for declarations.") + decls: List[declaration_t] = parser.parse( + files=[self.wrapper_header_collection], + config=xml_generator_config, + compilation_mode=parser.COMPILATION_MODE.ALL_AT_ONCE, + ) # Get access to the global namespace - self.global_ns = declarations.get_global_namespace(decls) + self.global_ns: namespace_t = declarations.get_global_namespace(decls) - # Clean decls to only include those for which file locations exist - print ("INFO: Cleaning Decls") + # Filter declarations for which files exist + logger.info("Filtering source declarations.") query = declarations.custom_matcher_t(lambda decl: decl.location is not None) - decls_loc_not_none = self.global_ns.decls(function=query) + filtered_decls: mdecl_wrapper_t = self.global_ns.decls(function=query) + + # Filter declarations in our source tree; include declarations from the + # wrapper_header_collection file for explicit instantiations, typedefs etc. + source_decls: List[declaration_t] = [ + decl + for decl in filtered_decls + if Path(self.source_root) in Path(decl.location.file_name).parents + or decl.location.file_name == self.wrapper_header_collection + ] + + # Create a source namespace module for the filtered declarations + self.source_ns = namespace_t(name="source", declarations=source_decls) - # Identify decls in our source tree - def check_loc(loc): - return (self.source_root in loc) or ("wrapper_header_collection" in loc) - - source_decls = [decl for decl in decls_loc_not_none if check_loc(decl.location.file_name)] - self.source_ns = declarations.namespace_t("source", source_decls) + # Initialise the source namespace's internal type hash tables for faster queries + logger.info("Optimizing source declaration queries.") + self.source_ns.init_optimizer() - print ("INFO: Optimizing Decls") - self.source_ns.init_optimizer() \ No newline at end of file + return self.source_ns diff --git a/cppwg/templates/custom.py b/cppwg/templates/custom.py index 0e5a372..e350ad0 100644 --- a/cppwg/templates/custom.py +++ b/cppwg/templates/custom.py @@ -4,7 +4,7 @@ custom code generators. """ -class Custom(object): +class Custom: def __init__(self): diff --git a/cppwg/templates/pybind11_default.py b/cppwg/templates/pybind11_default.py index d59ba1d..56f6ffa 100644 --- a/cppwg/templates/pybind11_default.py +++ b/cppwg/templates/pybind11_default.py @@ -6,7 +6,8 @@ #include "{class_short_name}.cppwg.hpp" namespace py = pybind11; -typedef {class_full_name} {class_short_name};{smart_ptr_handle} +typedef {class_full_name} {class_short_name}; +{smart_ptr_handle}; """ class_cpp_header_chaste = """\ @@ -19,7 +20,8 @@ namespace py = pybind11; //PYBIND11_CVECTOR_TYPECASTER2(); //PYBIND11_CVECTOR_TYPECASTER3(); -typedef {class_full_name} {class_short_name};{smart_ptr_handle} +typedef {class_full_name} {class_short_name}; +{smart_ptr_handle}; """ class_hpp_header = """\ @@ -46,7 +48,7 @@ class {class_short_name}_Overloads : public {class_short_name}{{ method_virtual_override = """\ {return_type} {method_name}({arg_string}){const_adorn} override {{ - PYBIND11_OVERLOAD{overload_adorn}( + PYBIND11_OVERRIDE{overload_adorn}( {tidy_method_name}, {short_class_name}, {method_name}, diff --git a/cppwg/utils/__init__.py b/cppwg/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cppwg/utils/constants.py b/cppwg/utils/constants.py new file mode 100644 index 0000000..a81f002 --- /dev/null +++ b/cppwg/utils/constants.py @@ -0,0 +1,12 @@ +""" +Constants for the cppwg package +""" + +CPPWG_SOURCEROOT_STRING = "CPPWG_SOURCEROOT" +CPPWG_ALL_STRING = "CPPWG_ALL" + +CPPWG_EXT = "cppwg" +CPPWG_HEADER_COLLECTION_FILENAME = "wrapper_header_collection.hpp" + +CPPWG_TRUE_STRINGS = ["ON", "YES", "Y", "TRUE", "T"] +CPPWG_FALSE_STRINGS = ["OFF", "NO", "N", "FALSE", "F"] diff --git a/cppwg/utils/utils.py b/cppwg/utils/utils.py new file mode 100644 index 0000000..b568112 --- /dev/null +++ b/cppwg/utils/utils.py @@ -0,0 +1,52 @@ +""" +Utility functions for the cppwg package +""" + +from typing import Any, Dict + +from cppwg.utils.constants import CPPWG_ALL_STRING +from cppwg.utils.constants import CPPWG_TRUE_STRINGS, CPPWG_FALSE_STRINGS + + +def is_option_ALL(input_obj: Any, option_ALL_string: str = CPPWG_ALL_STRING) -> bool: + """ + Check if the input is a string that matches the "ALL" indicator e.g. "CPPWG_ALL" + + Parameters + ---------- + input_obj : Any + The object to check + option_ALL_string : str + The string to check against + + Returns + ------- + bool + True if the input is a string that matches the "ALL" indicator + """ + return isinstance(input_obj, str) and input_obj.upper() == option_ALL_string + + +def substitute_bool_for_string(input_dict: Dict[Any, Any], key: Any) -> None: + """ + Substitute a string in the input dictionary with a boolean value if the + string is a boolean indicator e.g. "ON", "OFF", "YES", "NO", "TRUE", "FALSE" + + Parameters + ---------- + input_dict : Dict[Any, Any] + The input dictionary + key : Any + The key to check + """ + + if not isinstance(input_dict[key], str): + return + + caps_string = input_dict[key].strip().upper() + + if caps_string in CPPWG_TRUE_STRINGS: + input_dict[key] = True + + elif caps_string in CPPWG_FALSE_STRINGS: + input_dict[key] = False diff --git a/cppwg/writers/base_writer.py b/cppwg/writers/base_writer.py index da8bb11..aba0190 100644 --- a/cppwg/writers/base_writer.py +++ b/cppwg/writers/base_writer.py @@ -1,36 +1,83 @@ -import collections +from collections import OrderedDict +from typing import Dict, List +from pygccxml.declarations import free_function_t -class CppBaseWrapperWriter(object): +class CppBaseWrapperWriter: """ Base class for wrapper writers + + Attributes + ---------- + wrapper_templates : Dict[str, str] + String templates with placeholders for generating wrapper code + tidy_replacements : OrderedDict[str, str] + A dictionary of replacements to use when tidying up C++ declarations """ - def __init__(self, wrapper_templates): + def __init__(self, wrapper_templates: Dict[str, str]): self.wrapper_templates = wrapper_templates - self.tidy_replacements = collections.OrderedDict([(", ", "_"), ("<", "_lt_"), - (">", "_gt_"), ("::", "_"), - ("*", "Ptr"), ("&", "Ref"), - ("-", "neg")]) + self.tidy_replacements = OrderedDict( + [ + (" ", ""), + (",", "_"), + ("<", "_lt_"), + (">", "_gt_"), + ("::", "_"), + ("*", "Ptr"), + ("&", "Ref"), + ("-", "neg"), + ] + ) - def tidy_name(self, name): - + def tidy_name(self, name: str) -> str: """ - This method replaces full c++ declarations with a simple version for use + This method replaces full C++ declarations with a simple version for use in typedefs + + Example: + "::foo::bar" -> "_foo_bar_lt_double_2_gt_" + + Parameters + ---------- + name : str + The C++ declaration to tidy up + + Returns + ------- + str + The tidied up C++ declaration """ for key, value in self.tidy_replacements.items(): name = name.replace(key, value) - return name.replace(" ", "") - def exclusion_critera(self, decl, exclusion_args): - + return name + + # TODO: Consider moving this implementation of exclusion_criteria to the + # free function writer it is only used there. exclusion_criteria is + # currently overriden in method writer and constructor writer. + def exclusion_criteria( + self, decl: free_function_t, exclusion_args: List[str] + ) -> bool: + """ + Checks if any of the types in the function declaration appear in the + exclusion args. + + Parameters + ---------- + decl : free_function_t + The declaration of the function or class + exclusion_args : List[str] + A list of arguments to exclude from the wrapper code + + Returns + ------- + bool + True if the function should be excluded from the wrapper code """ - Fails if any of the types in the declaration appear in the exclusion args - """ # Are any return types not wrappable return_type = decl.return_type.decl_string.replace(" ", "") @@ -38,12 +85,22 @@ def exclusion_critera(self, decl, exclusion_args): return True # Are any arguments not wrappable - for eachArg in decl.argument_types: - arg_type = eachArg.decl_string.split()[0].replace(" ", "") + for decl_arg_type in decl.argument_types: + arg_type = decl_arg_type.decl_string.split()[0].replace(" ", "") if arg_type in exclusion_args: return True + return False - def default_arg_exclusion_criteria(self): + # TODO: This method is currently a placeholder. Consider implementing or removing. + def default_arg_exclusion_criteria(self) -> bool: + """ + Check if default arguments should be excluded from the wrapper code + + Returns + ------- + bool + True if the default arguments should be excluded + """ return False diff --git a/cppwg/writers/class_writer.py b/cppwg/writers/class_writer.py index 1077171..e42570f 100644 --- a/cppwg/writers/class_writer.py +++ b/cppwg/writers/class_writer.py @@ -1,157 +1,255 @@ -import ntpath +import os +import logging + +from typing import Dict, List from pygccxml import declarations +from pygccxml.declarations.calldef_members import member_function_t +from pygccxml.declarations.class_declaration import class_t + +from cppwg.input.class_info import CppClassInfo -from cppwg.writers import base_writer -from cppwg.writers import method_writer -from cppwg.writers import constructor_writer +from cppwg.writers.base_writer import CppBaseWrapperWriter +from cppwg.writers.method_writer import CppMethodWrapperWriter +from cppwg.writers.constructor_writer import CppConstructorWrapperWriter +from cppwg.utils.constants import CPPWG_EXT, CPPWG_HEADER_COLLECTION_FILENAME -class CppClassWrapperWriter(base_writer.CppBaseWrapperWriter): +class CppClassWrapperWriter(CppBaseWrapperWriter): """ - This class generates wrapper code for Cpp classes + This class generates wrapper code for C++ classes + + Attributes + ---------- + class_info : CppClassInfo + The class information + wrapper_templates : Dict[str, str] + String templates with placeholders for generating wrapper code + exposed_class_full_names : List[str] + A list of full names for all classes in the module + class_full_names : List[str] + A list of full names for this class e.g. ["Foo<2,2>", "Foo<3,3>"] + class_short_names : List[str] + A list of short names for this class e.g. ["Foo2_2", "Foo3_3"] + class_decls : List[class_t] + A list of class declarations associated with the class + has_shared_ptr : bool + Whether the class uses shared pointers + is_abstract : bool + Whether the class is abstract + hpp_string : str + The hpp wrapper code + cpp_string : str + The cpp wrapper code """ - def __init__(self, class_info, wrapper_templates): - + def __init__( + self, + class_info: CppClassInfo, + wrapper_templates: Dict[str, str], + exposed_class_full_names: List[str], + ): + logger = logging.getLogger() + super(CppClassWrapperWriter, self).__init__(wrapper_templates) - self.hpp_string = "" - self.cpp_string = "" - self.class_info = class_info - self.class_decls = [] - self.exposed_class_full_names = [] - self.class_full_names = self.class_info.get_full_names() - self.class_short_names = self.class_info.get_short_names() - self.has_shared_ptr = True - self.is_abstract = False - - if(len(self.class_full_names) != len(self.class_short_names)): - message = 'Full and short name lists should be the same length' - raise ValueError(message) - - def write_files(self, work_dir, class_short_name): + self.class_info: CppClassInfo = class_info - """ - Write the hpp and cpp wrapper codes to file - """ + # Class full names eg. ["Foo<2,2>", "Foo<3,3>"] + self.class_full_names: List[str] = self.class_info.get_full_names() + + # Class short names eg. ["Foo2_2", "Foo3_3"] + self.class_short_names: List[str] = self.class_info.get_short_names() + + if len(self.class_full_names) != len(self.class_short_names): + logger.error("Full and short name lists should be the same length") + raise AssertionError() - path = work_dir + "/" + class_short_name - hpp_file = open(path + ".cppwg.hpp", "w") - hpp_file.write(self.hpp_string) - hpp_file.close() + self.exposed_class_full_names: List[str] = exposed_class_full_names - cpp_file = open(path + ".cppwg.cpp", "w") - cpp_file.write(self.cpp_string) - cpp_file.close() + self.class_decls: List[class_t] = [] + self.has_shared_ptr: bool = True + self.is_abstract: bool = False # TODO: Consider removing unused attribute - def add_hpp(self, class_short_name): - + self.hpp_string: str = "" + self.cpp_string: str = "" + + def add_hpp(self, class_short_name: str) -> None: """ - Add the class wrapper hpp file + Fill the class hpp string for a single class using the wrapper template + + Parameters + ---------- + class_short_name: str + The short name of the class e.g. Foo2_2 """ - wrapper_dict = {'class_short_name': class_short_name} - self.hpp_string += self.wrapper_templates['class_hpp_header'].format(**wrapper_dict) + class_hpp_dict = {"class_short_name": class_short_name} - def add_cpp_header(self, class_full_name, class_short_name): - + self.hpp_string += self.wrapper_templates["class_hpp_header"].format( + **class_hpp_dict + ) + + def add_cpp_header(self, class_full_name: str, class_short_name: str) -> None: """ - Add the 'top' of the class wrapper cpp file + Add the 'top' of the class wrapper cpp file for a single class + + Parameters + ---------- + class_full_name : str + The full name of the class e.g. Foo<2,2> + class_short_name : str + The short name of the class e.g. Foo2_2 """ - header = "wrapper_header_collection" - - # Check for custom smart pointers - smart_ptr_handle = "" - smart_pointer_handle = self.class_info.hierarchy_attribute('smart_ptr_type') - if smart_pointer_handle != None: - smart_ptr_template = self.wrapper_templates["smart_pointer_holder"] - smart_ptr_handle = "\n" + smart_ptr_template.format(smart_pointer_handle) + ";" - - header_dict = {'wrapper_header_collection': header, - 'class_short_name': class_short_name, - 'class_full_name': class_full_name, - 'smart_ptr_handle': smart_ptr_handle, - 'includes': '#include "' + header +'.hpp"\n'} - extra_include_string = "" - common_include_file = self.class_info.hierarchy_attribute('common_include_file') - - source_includes = self.class_info.hierarchy_attribute_gather('source_includes') - - if not common_include_file: - for eachInclude in source_includes: - if eachInclude[0] != "<": - extra_include_string += '#include "' + eachInclude +'"\n' + # Add the includes for this class + includes = "" + + if self.class_info.hierarchy_attribute("common_include_file"): + includes += f'#include "{CPPWG_HEADER_COLLECTION_FILENAME}"\n' + + else: + source_includes = self.class_info.hierarchy_attribute_gather( + "source_includes" + ) + + for source_include in source_includes: + if source_include[0] == "<": + # e.g. #include + includes += f"#include {source_include}\n" else: - extra_include_string += '#include ' + eachInclude +'\n' - if self.class_info.source_file is not None: - extra_include_string += '#include "' + self.class_info.source_file +'"\n' - else: - include_name = ntpath.basename(self.class_info.decl.location.file_name) - extra_include_string += '#include "' + include_name +'"\n' - header_dict['includes'] = extra_include_string - - header_string = self.wrapper_templates["class_cpp_header"].format(**header_dict) - self.cpp_string += header_string - - for eachLine in self.class_info.prefix_code: - self.cpp_string += eachLine + "\n" - - # Any custom generators - if self.class_info.custom_generator is not None: - self.cpp_string += self.class_info.custom_generator.get_class_cpp_pre_code(class_short_name) - - def add_virtual_overides(self, class_decl, short_class_name): - + # e.g. #include "Foo.hpp" + includes += f'#include "{source_include}"\n' + + source_file = self.class_info.source_file + if not source_file: + source_file = os.path.basename(self.class_info.decl.location.file_name) + includes += f'#include "{source_file}"\n' + + # Check for custom smart pointers e.g. "boost::shared_ptr" + smart_ptr_type: str = self.class_info.hierarchy_attribute("smart_ptr_type") + + smart_ptr_handle = "" + if smart_ptr_type: + # Adds e.g. "PYBIND11_DECLARE_HOLDER_TYPE(T, boost::shared_ptr)" + smart_ptr_handle = self.wrapper_templates["smart_pointer_holder"].format( + smart_ptr_type + ) + + # Fill in the cpp header template + header_dict = { + "includes": includes, + "class_short_name": class_short_name, + "class_full_name": class_full_name, + "smart_ptr_handle": smart_ptr_handle, + } + + self.cpp_string += self.wrapper_templates["class_cpp_header"].format( + **header_dict + ) + + # Add any specified custom prefix code + for code_line in self.class_info.prefix_code: + self.cpp_string += code_line + "\n" + + # Run any custom generators to add additional prefix code + if self.class_info.custom_generator: + self.cpp_string += self.class_info.custom_generator.get_class_cpp_pre_code( + class_short_name + ) + + def add_virtual_overrides( + self, class_decl: class_t, short_class_name: str + ) -> List[member_function_t]: """ - Virtual over-rides if neeeded + Identify any methods needing overrides (i.e. any that are virtual in the + current class or in a parent), and add the overrides to the cpp string. + + Parameters + ---------- + class_decl : class_t + The class declaration + short_class_name : str + The short name of the class e.g. Foo2_2 + + Returns + ------- + list[member_function_t]: A list of member functions needing override """ - # Identify any methods needing over-rides, i.e. any that are virtual - # here or in a parent. - methods_needing_override = [] - return_types = [] - for eachMemberFunction in class_decl.member_functions(allow_empty=True): - is_pure_virtual = eachMemberFunction.virtuality == "pure virtual" - is_virtual = eachMemberFunction.virtuality == "virtual" + methods_needing_override: List[member_function_t] = [] + return_types: List[str] = [] # e.g. ["void", "unsigned int", "::Bar<2> *"] + + # Collect all virtual methods and their return types + for member_function in class_decl.member_functions(allow_empty=True): + is_pure_virtual = member_function.virtuality == "pure virtual" + is_virtual = member_function.virtuality == "virtual" if is_pure_virtual or is_virtual: - methods_needing_override.append(eachMemberFunction) - return_types.append(eachMemberFunction.return_type.decl_string) + methods_needing_override.append(member_function) + return_types.append(member_function.return_type.decl_string) if is_pure_virtual: self.is_abstract = True - for eachReturnString in return_types: - if eachReturnString != self.tidy_name(eachReturnString): - typdef_string = "typedef {full_name} {tidy_name};\n" - typdef_dict = {'full_name': eachReturnString, - 'tidy_name': self.tidy_name(eachReturnString)} - self.cpp_string += typdef_string.format(**typdef_dict) + # Add typedefs for return types with special characters + # e.g. typedef ::Bar<2> * _Bar_lt_2_gt_Ptr; + for return_type in return_types: + if return_type != self.tidy_name(return_type): + typedef_template = "typedef {full_name} {tidy_name};\n" + typedef_dict = { + "full_name": return_type, + "tidy_name": self.tidy_name(return_type), + } + self.cpp_string += typedef_template.format(**typedef_dict) self.cpp_string += "\n" - needs_override = len(methods_needing_override) > 0 - if needs_override: - over_ride_dict = {'class_short_name': short_class_name, - 'class_base_name': self.class_info.name} - override_template = self.wrapper_templates['class_virtual_override_header'] - self.cpp_string += override_template.format(**over_ride_dict) - - for eachMethod in methods_needing_override: - writer = method_writer.CppMethodWrapperWriter(self.class_info, - eachMethod, - class_decl, - self.wrapper_templates, - short_class_name) - self.cpp_string = writer.add_override(self.cpp_string) + # Override virtual methods + if methods_needing_override: + # Add virtual override class, e.g.: + # class Foo_Overloads : public Foo{{ + # public: + # using Foo::Foo; + override_header_dict = { + "class_short_name": short_class_name, + "class_base_name": self.class_info.name, + } + + self.cpp_string += self.wrapper_templates[ + "class_virtual_override_header" + ].format(**override_header_dict) + + # Override each method, e.g.: + # void bar(int a, bool b) override {{ + # PYBIND11_OVERRIDE(void, Foo, bar, a, b); + for method in methods_needing_override: + method_writer = CppMethodWrapperWriter( + self.class_info, + method, + class_decl, + self.wrapper_templates, + short_class_name, + ) + # TODO: Consider returning the override string instead + self.cpp_string = method_writer.add_override(self.cpp_string) + self.cpp_string += "\n};\n" + return methods_needing_override - def write(self, work_dir): + def write(self, work_dir: str) -> None: + """ + Write the hpp and cpp wrapper codes to file + + Parameters + ---------- + work_dir : str + The directory to write the files to + """ + logger = logging.getLogger() - if(len(self.class_decls) != len(self.class_full_names)): - message = 'Not enough class decls added to do write.' - raise ValueError(message) + if len(self.class_decls) != len(self.class_full_names): + logger.error("Not enough class decls added to do write.") + raise AssertionError() for idx, full_name in enumerate(self.class_full_names): short_name = self.class_short_names[idx] @@ -161,103 +259,148 @@ def write(self, work_dir): # Add the cpp file header self.add_cpp_header(full_name, short_name) - - # Check for struct-enum pattern + + # Check for struct-enum pattern. For example: + # struct Foo{ + # enum Value{A, B, C}; + # }; + # TODO: Consider moving some parts into templates if declarations.is_struct(class_decl): enums = class_decl.enumerations(allow_empty=True) - if len(enums)==1: - replacements = {'class': class_decl.name, 'enum': enums[0].name} - self.cpp_string += 'void register_{class}_class(py::module &m){{\n'.format(**replacements) - self.cpp_string += ' py::class_<{class}> myclass(m, "{class}");\n'.format(**replacements) - self.cpp_string += ' py::enum_<{class}::{enum}>(myclass, "{enum}")\n'.format(**replacements) - for eachval in enums[0].values: - replacements = {'class': class_decl.name, - 'enum': enums[0].name, - 'val': eachval[0]} - self.cpp_string += ' .value("{val}", {class}::{enum}::{val})\n'.format(**replacements) + + if len(enums) == 1: + enum_tpl = "void register_{class}_class(py::module &m){{\n" + enum_tpl += ' py::class_<{class}> myclass(m, "{class}");\n' + enum_tpl += ' py::enum_<{class}::{enum}>(myclass, "{enum}")\n' + + replacements = {"class": class_decl.name, "enum": enums[0].name} + self.cpp_string += enum_tpl.format(**replacements) + + value_tpl = ' .value("{val}", {class}::{enum}::{val})\n' + for value in enums[0].values: + replacements["val"] = value[0] + self.cpp_string += value_tpl.format(**replacements) + self.cpp_string += " .export_values();\n}\n" - + # Set up the hpp self.add_hpp(short_name) - - # Do the write + + # Write the struct cpp and hpp files self.write_files(work_dir, short_name) continue - # Define any virtual function overloads - methods_needing_override = self.add_virtual_overides(class_decl, short_name) + # Find and define virtual function "trampoline" overrides + methods_needing_override: List[member_function_t] = ( + self.add_virtual_overrides(class_decl, short_name) + ) - # Add overrides if needed + # Add the virtual "trampoline" overrides from "Foo_Overloads" to + # the "Foo" wrapper class definition if needed + # e.g. py::class_(m, "Foo") overrides_string = "" - if len(methods_needing_override)>0: - overrides_string = ', ' + short_name + '_Overloads' + if methods_needing_override: + # TODO: Assign the "_Overloads" literal to a constant + overrides_string = f", {short_name}_Overloads" - # Add smart ptr support if needed - smart_pointer_handle = self.class_info.hierarchy_attribute('smart_ptr_type') + # Add smart pointer support to the wrapper class definition if needed + # e.g. py::class_ >(m, "Foo") + smart_ptr_type: str = self.class_info.hierarchy_attribute("smart_ptr_type") ptr_support = "" - if self.has_shared_ptr and smart_pointer_handle is not None: - ptr_support = ', ' + smart_pointer_handle + '<' + short_name + ' > ' + if self.has_shared_ptr and smart_ptr_type: + ptr_support = f", {smart_ptr_type}<{short_name} > " - # Add base classes if needed + # Add base classes to the wrapper class definition if needed + # e.g. py::class_(m, "Foo") bases = "" - for eachBase in class_decl.bases: - cleaned_base = eachBase.related_class.name.replace(" ","") - exposed = any(cleaned_base in t.replace(" ","") for t in self.exposed_class_full_names) - public = not eachBase.access_type == "private" - if exposed and public: - bases += ', ' + eachBase.related_class.name + " " - - # Add the class refistration - class_definition_dict = {'short_name': short_name, - 'overrides_string': overrides_string, - 'ptr_support': ptr_support, - 'bases': bases} + + for base in class_decl.bases: # type(base) -> hierarchy_info_t + # Check that the base class is not private + if base.access_type == "private": + continue + + # Check if the base class is exposed (i.e. to be wrapped in the module) + base_class_name: str = base.related_class.name.replace(" ", "") + if base_class_name in self.exposed_class_full_names: + bases += f", {base.related_class.name} " + + # Add the class registration + class_definition_dict = { + "short_name": short_name, + "overrides_string": overrides_string, + "ptr_support": ptr_support, + "bases": bases, + } class_definition_template = self.wrapper_templates["class_definition"] self.cpp_string += class_definition_template.format(**class_definition_dict) - # Add constructors - #if not self.is_abstract and not class_decl.is_abstract: - # No constructors for classes with private pure virtual methods! - ppv_class = False - for eachMemberFunction in class_decl.member_functions(allow_empty=True): - if eachMemberFunction.virtuality == "pure virtual" and eachMemberFunction.access_type == "private": - ppv_class = True - break - - if not ppv_class: - query = declarations.access_type_matcher_t('public') - for eachConstructor in class_decl.constructors(function=query, - allow_empty=True): - writer = constructor_writer.CppConsturctorWrapperWriter(self.class_info, - eachConstructor, - class_decl, - self.wrapper_templates, - short_name) - self.cpp_string = writer.add_self(self.cpp_string) + # Add public constructors + query = declarations.access_type_matcher_t("public") + for constructor in class_decl.constructors( + function=query, allow_empty=True + ): + constructor_writer = CppConstructorWrapperWriter( + self.class_info, + constructor, + class_decl, + self.wrapper_templates, + short_name, + ) + # TODO: Consider returning the constructor string instead + self.cpp_string = constructor_writer.add_self(self.cpp_string) # Add public member functions - query = declarations.access_type_matcher_t('public') - for eachMemberFunction in class_decl.member_functions(function=query, allow_empty=True): - exlcuded = False - if self.class_info.excluded_methods is not None: - exlcuded = (eachMemberFunction.name in self.class_info.excluded_methods) - if not exlcuded: - writer = method_writer.CppMethodWrapperWriter(self.class_info, - eachMemberFunction, - class_decl, - self.wrapper_templates, - short_name) - self.cpp_string = writer.add_self(self.cpp_string) - - # Any custom generators - if self.class_info.custom_generator is not None: - self.cpp_string += self.class_info.custom_generator.get_class_cpp_def_code(short_name) + query = declarations.access_type_matcher_t("public") + for member_function in class_decl.member_functions( + function=query, allow_empty=True + ): + if self.class_info.excluded_methods: + # Skip excluded methods + if member_function.name in self.class_info.excluded_methods: + continue + + method_writer = CppMethodWrapperWriter( + self.class_info, + member_function, + class_decl, + self.wrapper_templates, + short_name, + ) + # TODO: Consider returning the member string instead + self.cpp_string = method_writer.add_self(self.cpp_string) + + # Run any custom generators to add additional class code + if self.class_info.custom_generator: + self.cpp_string += ( + self.class_info.custom_generator.get_class_cpp_def_code(short_name) + ) # Close the class definition - self.cpp_string += ' ;\n}\n' + self.cpp_string += " ;\n}\n" # Set up the hpp self.add_hpp(short_name) - # Do the write + # Write the class cpp and hpp files self.write_files(work_dir, short_name) + + def write_files(self, work_dir: str, class_short_name: str) -> None: + """ + Write the hpp and cpp wrapper code to file + + Parameters + ---------- + work_dir : str + The directory to write the files to + class_short_name : str + The short name of the class e.g. Foo2_2 + """ + + hpp_filepath = os.path.join(work_dir, f"{class_short_name}.{CPPWG_EXT}.hpp") + cpp_filepath = os.path.join(work_dir, f"{class_short_name}.{CPPWG_EXT}.cpp") + + with open(hpp_filepath, "w") as hpp_file: + hpp_file.write(self.hpp_string) + + with open(cpp_filepath, "w") as cpp_file: + cpp_file.write(self.cpp_string) diff --git a/cppwg/writers/constructor_writer.py b/cppwg/writers/constructor_writer.py index 1b5db81..f981565 100644 --- a/cppwg/writers/constructor_writer.py +++ b/cppwg/writers/constructor_writer.py @@ -1,77 +1,159 @@ +from typing import Dict, Optional + from pygccxml import declarations +from pygccxml.declarations.class_declaration import class_t +from pygccxml.declarations.calldef_members import constructor_t -from cppwg.writers import base_writer +from cppwg.input.class_info import CppClassInfo +from cppwg.writers.base_writer import CppBaseWrapperWriter -class CppConsturctorWrapperWriter(base_writer.CppBaseWrapperWriter): +class CppConstructorWrapperWriter(CppBaseWrapperWriter): """ Manage addition of constructor wrapper code + + Attributes + ---------- + class_info : ClassInfo + The class information for the class containing the constructor + ctor_decl : constructor_t + The pygccxml declaration object for the constructor + class_decl : class_t + The class declaration for the class containing the constructor + wrapper_templates : Dict[str, str] + String templates with placeholders for generating wrapper code + class_short_name : Optional[str] + The short name of the class e.g. 'Foo2_2' """ - def __init__(self, class_info, - ctor_decl, - class_decl, - wrapper_templates, - class_short_name=None): - - super(CppConsturctorWrapperWriter, self).__init__(wrapper_templates) + def __init__( + self, + class_info: CppClassInfo, + ctor_decl: constructor_t, + class_decl: class_t, + wrapper_templates: Dict[str, str], + class_short_name: Optional[str] = None, + ): - self.class_info = class_info - self.ctor_decl = ctor_decl - self.class_decl = class_decl + super(CppConstructorWrapperWriter, self).__init__(wrapper_templates) + + self.class_info: CppClassInfo = class_info + self.ctor_decl: constructor_t = ctor_decl + self.class_decl: class_t = class_decl self.class_short_name = class_short_name if self.class_short_name is None: self.class_short_name = self.class_decl.name - def exclusion_critera(self): - - # Check for exclusions - exclusion_args = self.class_info.hierarchy_attribute_gather('calldef_excludes') - ctor_arg_exludes = self.class_info.hierarchy_attribute_gather('constructor_arg_type_excludes') - - for eachArg in self.ctor_decl.argument_types: - if eachArg.decl_string.replace(" ", "") in exclusion_args: - return True - - for eachExclude in ctor_arg_exludes: - if eachExclude in eachArg.decl_string: - return True + def exclusion_criteria(self) -> bool: + """ + Check if the constructor should be excluded from the wrapper code + + Returns + ------- + bool + True if the constructor should be excluded, False otherwise + """ + + # Exclude constructors for classes with private pure virtual methods + if any( + mf.virtuality == "pure virtual" and mf.access_type == "private" + for mf in self.class_decl.member_functions(allow_empty=True) + ): + return True - for eachArg in self.ctor_decl.argument_types: - if "iterator" in eachArg.decl_string.lower(): + # Exclude constructors for abstract classes inheriting from abstract bases + if self.class_decl.is_abstract and len(self.class_decl.recursive_bases) > 0: + if any( + base.related_class.is_abstract + for base in self.class_decl.recursive_bases + ): return True + # Exclude sub class (e.g. iterator) constructors such as: + # class Foo { + # public: + # class FooIterator { if self.ctor_decl.parent != self.class_decl: return True - if self.ctor_decl.is_artificial and declarations.is_copy_constructor(self.ctor_decl): + # Exclude default copy constructors e.g. Foo::Foo(Foo const & foo) + if ( + declarations.is_copy_constructor(self.ctor_decl) + and self.ctor_decl.is_artificial + ): return True - - if self.class_decl.is_abstract and len(self.class_decl.recursive_bases)>0: - if any(t.related_class.is_abstract for t in self.class_decl.recursive_bases): - return True + + # Check for excluded argument patterns + calldef_excludes = [ + x.replace(" ", "") + for x in self.class_info.hierarchy_attribute_gather("calldef_excludes") + ] + + ctor_arg_type_excludes = [ + x.replace(" ", "") + for x in self.class_info.hierarchy_attribute_gather( + "constructor_arg_type_excludes" + ) + ] + + for arg_type in self.ctor_decl.argument_types: + # e.g. ::std::vector const & -> ::std::vectorconst& + arg_type_str = arg_type.decl_string.replace(" ", "") + + # Exclude constructors with "iterator" in args + if "iterator" in arg_type_str.lower(): + return True + + # Exclude constructors with args matching calldef_excludes + if arg_type_str in calldef_excludes: + return True + + # Exclude constructurs with args matching constructor_arg_type_excludes + for excluded_type in ctor_arg_type_excludes: + if excluded_type in arg_type_str: + return True return False - def add_self(self, output_string): + def add_self(self, cpp_string: str) -> str: + """ + Add the constructor wrapper code to the input string for example: + .def(py::init(), py::arg("i") = 1, py::arg("b") = false) + + Parameters + ---------- + cpp_string : str + The input string containing current wrapper code + + Returns + ------- + str + The input string with the constructor wrapper code added + """ - if self.exclusion_critera(): - return output_string + # Skip excluded constructors + if self.exclusion_criteria(): + return cpp_string - output_string += " "*8 + '.def(py::init<' - num_arg_types = len(self.ctor_decl.argument_types) - for idx, eachArg in enumerate(self.ctor_decl.argument_types): - output_string += eachArg.decl_string - if idx < num_arg_types-1: - output_string += ", " - output_string += ' >()' + # Get the arg signature e.g. "int, bool" + cpp_string += " .def(py::init<" + arg_types = [t.decl_string for t in self.ctor_decl.argument_types] + cpp_string += ", ".join(arg_types) + + cpp_string += " >()" + + # Default args e.g. py::arg("i") = 1 default_args = "" if not self.default_arg_exclusion_criteria(): - for eachArg in self.ctor_decl.arguments: - default_args += ', py::arg("{}")'.format(eachArg.name) - if eachArg.default_value is not None: - default_args += ' = ' + eachArg.default_value - output_string += default_args + ')\n' - return output_string + for arg in self.ctor_decl.arguments: + default_args += f', py::arg("{arg.name}")' + + if arg.default_value is not None: + # TODO: Fix in default args (see method_writer) + default_args += f" = {arg.default_value}" + + cpp_string += default_args + ")\n" + + return cpp_string diff --git a/cppwg/writers/free_function_writer.py b/cppwg/writers/free_function_writer.py index fe37447..2a417b2 100644 --- a/cppwg/writers/free_function_writer.py +++ b/cppwg/writers/free_function_writer.py @@ -1,49 +1,79 @@ -from cppwg.writers import base_writer +from cppwg.input.free_function_info import CppFreeFunctionInfo +from cppwg.writers.base_writer import CppBaseWrapperWriter -class CppFreeFunctionWrapperWriter(base_writer.CppBaseWrapperWriter): +class CppFreeFunctionWrapperWriter(CppBaseWrapperWriter): """ Manage addition of free function wrapper code + + Attributes + ---------- + free_function_info : CppFreeFunctionInfo + The free function information to generate Python bindings for + wrapper_templates : Dict[str, str] + String templates with placeholders for generating wrapper code + exclusion_args : List[str] + A list of argument types to exclude from the wrapper code """ def __init__(self, free_function_info, wrapper_templates): - + super(CppFreeFunctionWrapperWriter, self).__init__(wrapper_templates) - self.free_function_info = free_function_info - self.wrapper_templates = wrapper_templates - self.exclusion_args = [] + self.free_function_info: CppFreeFunctionInfo = free_function_info + self.wrapper_templates: Dict[str, str] = wrapper_templates + self.exclusion_args: List[str] = [] + + def add_self(self, wrapper_string) -> str: + """ + Add the free function wrapper code to the wrapper code string - def add_self(self, output_string): + Parameters + ---------- + wrapper_string : str + String containing the current C++ wrapper code - # Check for exclusions - if self.exclusion_critera(self.free_function_info.decl, self.exclusion_args): - return output_string + Returns + ------- + str + The updated C++ wrapper code string + """ - # Which definition type + # Skip this free function if it uses any excluded arg types or return types + if self.exclusion_criteria(self.free_function_info.decl, self.exclusion_args): + return wrapper_string + + # Pybind11 def type e.g. "_static" for def_static() def_adorn = "" + # TODO: arg_signature isn't used. Remove? # Get the arg signature arg_signature = "" arg_types = self.free_function_info.decl.argument_types num_arg_types = len(arg_types) for idx, eachArg in enumerate(arg_types): arg_signature += eachArg.decl_string - if idx < num_arg_types-1: + if idx < num_arg_types - 1: arg_signature += ", " - # Default args + # Pybind11 arg string with or without default values. + # e.g. without default values: ', py::arg("foo"), py::arg("bar")' + # e.g. with default values: ', py::arg("foo") = 1, py::arg("bar") = 2' default_args = "" if not self.default_arg_exclusion_criteria(): - for eachArg in self.free_function_info.decl.arguments: - default_args += ', py::arg("{}")'.format(eachArg.name) - if eachArg.default_value is not None: - default_args += ' = ' + eachArg.default_value - - method_dict = {'def_adorn': def_adorn, - 'function_name': self.free_function_info.decl.name, - 'function_docs': '" "', - 'default_args': default_args} - output_string += self.wrapper_templates["free_function"].format(**method_dict) - return output_string + for argument in self.free_function_info.decl.arguments: + default_args += f', py::arg("{argument.name}")' + if argument.default_value is not None: + default_args += f" = {argument.default_value}" + + # Add the free function wrapper code to the wrapper string + func_dict = { + "def_adorn": def_adorn, + "function_name": self.free_function_info.decl.name, + "function_docs": '" "', + "default_args": default_args, + } + wrapper_string += self.wrapper_templates["free_function"].format(**func_dict) + + return wrapper_string diff --git a/cppwg/writers/header_collection_writer.py b/cppwg/writers/header_collection_writer.py index 3ffbaac..ed1234f 100644 --- a/cppwg/writers/header_collection_writer.py +++ b/cppwg/writers/header_collection_writer.py @@ -1,131 +1,171 @@ -#!/usr/bin/env python - -""" -Generate the file classes_to_be_wrapped.hpp, which contains includes, -instantiation and naming typedefs for all classes that are to be -automatically wrapped. -""" - import os -import ntpath - -class CppHeaderCollectionWriter(): +from cppwg.input.class_info import CppClassInfo +from cppwg.input.free_function_info import CppFreeFunctionInfo +from cppwg.input.package_info import PackageInfo +class CppHeaderCollectionWriter: """ - This class manages generation of the header collection file for - parsing by CastXML + This class manages the generation of the header collection file, which + includes all the headers to be parsed by CastXML. The header collection file + also contains explicit template instantiations and their corresponding + typedefs (e.g. typedef Foo<2,2> Foo2_2) for all + classes that are to be automatically wrapped. + + Attributes + ---------- + package_info : PackageInfo + The package information + wrapper_root : str + The output directory for the generated wrapper code + hpp_collection_filepath : str + The path to save the header collection file to + hpp_collection_string : str + The output string that gets written to the header collection file + class_dict : Dict[str, CppClassInfo] + A dictionary of all class info objects + free_func_dict : Dict[str, CppFreeFunctionInfo] + A dictionary of all free function info objects """ - def __init__(self, package_info, wrapper_root): + def __init__( + self, + package_info: PackageInfo, + wrapper_root: str, + hpp_collection_filepath: str, + ): - self.wrapper_root = wrapper_root - self.package_info = package_info - self.header_file_name = "wrapper_header_collection.hpp" - self.hpp_string = "" - self.class_dict = {} - self.free_func_dict = {} + self.package_info: PackageInfo = package_info + self.wrapper_root: str = wrapper_root + self.hpp_collection_filepath: str = hpp_collection_filepath + self.hpp_collection_string: str = "" - for eachModule in self.package_info.module_info: - for eachClassInfo in eachModule.class_info: - self.class_dict[eachClassInfo.name] = eachClassInfo - - for eachFuncInfo in eachModule.free_function_info: - self.free_func_dict[eachFuncInfo.name] = eachFuncInfo - - def add_custom_header_code(self): - - """ - Any custom header code goes here - """ + # For convenience, collect all class and free function info into dicts keyed by name + self.class_dict: Dict[str, CppClassInfo] = {} + self.free_func_dict: Dict[str, CppFreeFunctionInfo] = {} - pass + for module_info in self.package_info.module_info_collection: + for class_info in module_info.class_info_collection: + self.class_dict[class_info.name] = class_info - def write_file(self): + for free_function_info in module_info.free_function_info_collection: + self.free_func_dict[free_function_info.name] = free_function_info + def should_include_all(self) -> bool: """ - The actual write - """ - - if not os.path.exists(self.wrapper_root + "/"): - os.makedirs(self.wrapper_root + "/") - file_path = self.wrapper_root + "/" + self.header_file_name - hpp_file = open(file_path, 'w') - hpp_file.write(self.hpp_string) - hpp_file.close() - - def should_include_all(self): - + Return whether all source files in the module source locations should be included + + Returns + ------- + bool """ - Return whether all source files in the module source locs should be included - """ - for eachModule in self.package_info.module_info: - if eachModule.use_all_classes or eachModule.use_all_free_functions: - return True + # True if any module uses all classes or all free functions + for module_info in self.package_info.module_info_collection: + if module_info.use_all_classes or module_info.use_all_free_functions: + return True return False - def write(self): - + def write(self) -> None: """ - Main method for generating the header file output string + Generate the header file output string and write it to file """ - hpp_header_dict = {'package_name': self.package_info.name} - hpp_header_template = """\ -#ifndef {package_name}_HEADERS_HPP_ -#define {package_name}_HEADERS_HPP_ + # Add opening header guard + self.hpp_collection_string = f"#ifndef {self.package_info.name}_HEADERS_HPP_\n" + self.hpp_collection_string += f"#define {self.package_info.name}_HEADERS_HPP_\n" -// Includes -""" - self.hpp_string = hpp_header_template.format(**hpp_header_dict) + self.hpp_collection_string += "\n// Includes\n" + + included_files = set() # Keep track of included files to avoid duplicates - # Now our own includes if self.should_include_all(): - for eachFile in self.package_info.source_hpp_files: - include_name = ntpath.basename(eachFile) - self.hpp_string += '#include "' + include_name + '"\n' + # Include all the headers + for hpp_filepath in self.package_info.source_hpp_files: + hpp_filename = os.path.basename(hpp_filepath) + + if hpp_filename not in included_files: + self.hpp_collection_string += f'#include "{hpp_filename}"\n' + included_files.add(hpp_filename) + else: - for eachModule in self.package_info.module_info: - for eachClassInfo in eachModule.class_info: - if eachClassInfo.source_file is not None: - self.hpp_string += '#include "' + eachClassInfo.source_file + '"\n' - elif eachClassInfo.source_file_full_path is not None: - include_name = ntpath.basename(eachClassInfo.source_file_full_path) - self.hpp_string += '#include "' + include_name + '"\n' - for eachFuncInfo in eachModule.free_function_info: - if eachFuncInfo.source_file_full_path is not None: - include_name = ntpath.basename(eachFuncInfo.source_file_full_path) - self.hpp_string += '#include "' + include_name + '"\n' - - # Add the template instantiations - self.hpp_string += "\n// Instantiate Template Classes \n" - for eachModule in self.package_info.module_info: - for eachClassInfo in eachModule.class_info: - full_names = eachClassInfo.get_full_names() - if len(full_names) == 1: - continue - prefix = "template class " - for eachTemplateName in full_names: - self.hpp_string += prefix + eachTemplateName.replace(" ","") + ";\n" - - # Add typdefs for nice naming - self.hpp_string += "\n// Typedef for nicer naming\n" - self.hpp_string += "namespace cppwg{ \n" - for eachModule in self.package_info.module_info: - for eachClassInfo in eachModule.class_info: - full_names = eachClassInfo.get_full_names() - if len(full_names) == 1: - continue - - short_names = eachClassInfo.get_short_names() - for idx, eachTemplateName in enumerate(full_names): - short_name = short_names[idx] - typdef_prefix = "typedef " + eachTemplateName.replace(" ","") + " " - self.hpp_string += typdef_prefix + short_name + ";\n" - self.hpp_string += "}\n" - - self.add_custom_header_code() - self.hpp_string += "\n#endif // {}_HEADERS_HPP_\n".format(self.package_info.name) - - self.write_file() + # Include specific headers needed by classes + for module_info in self.package_info.module_info_collection: + for class_info in module_info.class_info_collection: + hpp_filename = None + + if class_info.source_file: + hpp_filename = class_info.source_file + + elif class_info.source_file_full_path: + hpp_filename = os.path.basename( + class_info.source_file_full_path + ) + + if hpp_filename and hpp_filename not in included_files: + self.hpp_collection_string += f'#include "{hpp_filename}"\n' + included_files.add(hpp_filename) + + # Include specific headers needed by free functions + for free_function_info in module_info.free_function_info_collection: + if free_function_info.source_file_full_path: + hpp_filename = os.path.basename( + free_function_info.source_file_full_path + ) + + if hpp_filename not in included_files: + self.hpp_collection_string += f'#include "{hpp_filename}"\n' + included_files.add(hpp_filename) + + # Add the template instantiations e.g. `template class Foo<2,2>;` + self.hpp_collection_string += "\n// Instantiate Template Classes \n" + + for module_info in self.package_info.module_info_collection: + for class_info in module_info.class_info_collection: + # Class full names eg. ["Foo<2,2>", "Foo<3,3>"] + full_names = class_info.get_full_names() + + # TODO: What if the class is templated but has only one template instantiation? + # See https://github.com/Chaste/cppwg/issues/2 + if len(full_names) < 2: + continue # Skip if the class is untemplated + + for template_name in full_names: + clean_template_name = template_name.replace(" ", "") + self.hpp_collection_string += ( + f"template class {clean_template_name};\n" + ) + + # Add typdefs for nice naming e.g. `typedef Foo<2,2> Foo2_2` + self.hpp_collection_string += "\n// Typedefs for nicer naming\n" + self.hpp_collection_string += "namespace cppwg{ \n" + + for module_info in self.package_info.module_info_collection: + for class_info in module_info.class_info_collection: + # Class full names eg. ["Foo<2,2>", "Foo<3,3>"] + full_names = class_info.get_full_names() + + # TODO: What if the class is templated but has only one template instantiation? + # See https://github.com/Chaste/cppwg/issues/2 + if len(full_names) < 2: + continue # Skip if the class is untemplated + + # Class short names eg. ["Foo2_2", "Foo3_3"] + short_names = class_info.get_short_names() + + for short_name, template_name in zip(short_names, full_names): + clean_template_name = template_name.replace(" ", "") + self.hpp_collection_string += ( + f"typedef {clean_template_name} {short_name};\n" + ) + + self.hpp_collection_string += "} // namespace cppwg\n" + + # Add closing header guard + self.hpp_collection_string += ( + f"\n#endif // {self.package_info.name}_HEADERS_HPP_\n" + ) + + # Write the header collection string to file + with open(self.hpp_collection_filepath, "w") as hpp_file: + hpp_file.write(self.hpp_collection_string) diff --git a/cppwg/writers/method_writer.py b/cppwg/writers/method_writer.py index 53aad63..9cb3671 100644 --- a/cppwg/writers/method_writer.py +++ b/cppwg/writers/method_writer.py @@ -1,167 +1,251 @@ +from typing import Dict, Optional + from pygccxml import declarations +from pygccxml.declarations.class_declaration import class_t +from pygccxml.declarations.calldef_members import member_function_t -from cppwg.writers import base_writer +from cppwg.input.class_info import CppClassInfo +from cppwg.writers.base_writer import CppBaseWrapperWriter -class CppMethodWrapperWriter(base_writer.CppBaseWrapperWriter): +class CppMethodWrapperWriter(CppBaseWrapperWriter): """ Manage addition of method wrapper code + + Attributes + ---------- + class_info : ClassInfo + The class information for the class containing the method + method_decl : member_function_t + The pygccxml declaration object for the method + class_decl : class_t + The class declaration for the class containing the method + wrapper_templates : Dict[str, str] + String templates with placeholders for generating wrapper code + class_short_name : Optional[str] + The short name of the class e.g. 'Foo2_2' """ - def __init__(self, class_info, - method_decl, - class_decl, - wrapper_templates, - class_short_name=None): + def __init__( + self, + class_info: CppClassInfo, + method_decl: member_function_t, + class_decl: class_t, + wrapper_templates: Dict[str, str], + class_short_name: Optional[str] = None, + ): super(CppMethodWrapperWriter, self).__init__(wrapper_templates) - self.class_info = class_info - self.method_decl = method_decl - self.class_decl = class_decl - self.class_short_name = class_short_name + self.class_info: CppClassInfo = class_info + self.method_decl: member_function_t = method_decl + self.class_decl: class_t = class_decl + + self.class_short_name: str = class_short_name if self.class_short_name is None: - self.class_short_name = self.class_decl + self.class_short_name = self.class_decl.name - def exclusion_critera(self): + def exclusion_criteria(self) -> bool: + """ + Check if the method should be excluded from the wrapper code - # Are any return types not wrappable - exclusion_args = self.class_info.hierarchy_attribute_gather('calldef_excludes') - return_excludes = self.class_info.hierarchy_attribute_gather('return_type_excludes') - - return_type = self.method_decl.return_type.decl_string.replace(" ", "") - if return_type in exclusion_args or return_type in return_excludes: + Returns + ------- + bool + True if the method should be excluded, False otherwise + """ + + # Exclude private methods without over-rides + if self.method_decl.access_type == "private": return True - # Don't include sub class (e.g. iterator) methods + # Exclude sub class (e.g. iterator) methods such as: + # class Foo { + # public: + # class FooIterator { if self.method_decl.parent != self.class_decl: return True - # No private methods without over-rides - if self.method_decl.access_type == "private": + # Check for excluded return types + calldef_excludes = [ + x.replace(" ", "") + for x in self.class_info.hierarchy_attribute_gather("calldef_excludes") + ] + + return_type_excludes = [ + x.replace(" ", "") + for x in self.class_info.hierarchy_attribute_gather("return_type_excludes") + ] + + return_type = self.method_decl.return_type.decl_string.replace(" ", "") + if return_type in calldef_excludes or return_type in return_type_excludes: return True - # Are any arguments not wrappable - tidied_excl = [x.replace(" ", "") for x in exclusion_args] - for eachArg in self.method_decl.argument_types: - arg_type = eachArg.decl_string.split()[0].replace(" ", "") - if arg_type in tidied_excl: + # Check for excluded argument patterns + for argument_type in self.method_decl.argument_types: + # e.g. ::std::vector const & -> ::std::vector const & -> ::std::vectorconst& + arg_type_full = argument_type.decl_string.replace(" ", "") + if arg_type_full in calldef_excludes: return True - arg_type_full = eachArg.decl_string.replace(" ", "") - if arg_type_full in tidied_excl: - print (arg_type_full) - return True + return False - def add_self(self, output_string): + def add_self(self, cpp_string) -> str: + """ + Add the method wrapper code to the input string. For example: + .def("bar", (void(Foo::*)(double)) &Foo::bar, " ", py::arg("d") = 1.0) + + Parameters + ---------- + cpp_string : str + The input string containing current wrapper code - # Check for exclusions - if self.exclusion_critera(): - return output_string + Returns + ------- + str + The input string with the method wrapper code added + """ - # Which definition type + # Skip excluded methods + if self.exclusion_criteria(): + return cpp_string + + # Pybind11 def type e.g. "_static" for def_static() def_adorn = "" if self.method_decl.has_static: def_adorn += "_static" # How to point to class - if not self.method_decl.has_static: - self_ptr = self.class_short_name + "::*" - else: + if self.method_decl.has_static: self_ptr = "*" - - # Get the arg signature - arg_signature = "" - num_arg_types = len(self.method_decl.argument_types) - for idx, eachArg in enumerate(self.method_decl.argument_types): - arg_signature += eachArg.decl_string - if idx < num_arg_types-1: - arg_signature += ", " + else: + # e.g. Foo2_2::* + self_ptr = self.class_short_name + "::*" # Const-ness const_adorn = "" if self.method_decl.has_const: - const_adorn = ' const ' + const_adorn = " const " - # Default args + # Get the arg signature e.g. "int, bool" + arg_types = [t.decl_string for t in self.method_decl.argument_types] + arg_signature = ", ".join(arg_types) + + # Default args e.g. py::arg("d") = 1.0 default_args = "" if not self.default_arg_exclusion_criteria(): - arg_types = self.method_decl.argument_types - for idx, eachArg in enumerate(self.method_decl.arguments): - default_args += ', py::arg("{}")'.format(eachArg.name) - if eachArg.default_value is not None: - + for arg, arg_type in zip( + self.method_decl.arguments, self.method_decl.argument_types + ): + default_args += f', py::arg("{arg.name}")' + + if arg.default_value is not None: + default_value = str(arg.default_value) + # Hack for missing template in default args - repl_value = str(eachArg.default_value) - if "" in repl_value: - if "<2>" in str(arg_types[idx]).replace(" ", ""): - repl_value = repl_value.replace("","<2>") - elif "<3>" in str(arg_types[idx]).replace(" ", ""): - repl_value= repl_value.replace("","<3>") - default_args += ' = ' + repl_value - - # Call policy - pointer_call_policy = self.class_info.hierarchy_attribute('pointer_call_policy') - reference_call_policy = self.class_info.hierarchy_attribute('reference_call_policy') - + # e.g. Foo<2>::bar(Bar<2> const & b = Bar()) + # TODO: Make more robust + arg_type_str = str(arg_type).replace(" ", "") + if "" in default_value: + if "<2>" in arg_type_str: + default_value = default_value.replace("", "<2>") + elif "<3>" in arg_type_str: + default_value = default_value.replace("", "<3>") + + default_args += f" = {default_value}" + + # Call policy, e.g. "py::return_value_policy::reference" call_policy = "" - is_ptr = declarations.is_pointer(self.method_decl.return_type) - if pointer_call_policy is not None and is_ptr: - call_policy = ", py::return_value_policy::" + pointer_call_policy - is_ref = declarations.is_reference(self.method_decl.return_type) - if reference_call_policy is not None and is_ref: - call_policy = ", py::return_value_policy::" + reference_call_policy - - method_dict = {'def_adorn': def_adorn, - 'method_name': self.method_decl.name, - 'return_type': self.method_decl.return_type.decl_string, - 'self_ptr': self_ptr, - 'arg_signature': arg_signature, - 'const_adorn': const_adorn, - 'class_short_name': self.class_short_name, - 'method_docs': '" "', - 'default_args': default_args, - 'call_policy': call_policy} - template = self.wrapper_templates["class_method"] - output_string += template.format(**method_dict) - return output_string - - def add_override(self, output_string): - + if declarations.is_pointer(self.method_decl.return_type): + ptr_policy = self.class_info.hierarchy_attribute("pointer_call_policy") + if ptr_policy: + call_policy = f", py::return_value_policy::{ptr_policy}" + + elif declarations.is_reference(self.method_decl.return_type): + ref_policy = self.class_info.hierarchy_attribute("reference_call_policy") + if ref_policy: + call_policy = f", py::return_value_policy::{ref_policy}" + + method_dict = { + "def_adorn": def_adorn, + "method_name": self.method_decl.name, + "return_type": self.method_decl.return_type.decl_string, + "self_ptr": self_ptr, + "arg_signature": arg_signature, + "const_adorn": const_adorn, + "class_short_name": self.class_short_name, + "method_docs": '" "', + "default_args": default_args, + "call_policy": call_policy, + } + class_method_template = self.wrapper_templates["class_method"] + cpp_string += class_method_template.format(**method_dict) + + return cpp_string + + def add_override(self, cpp_string) -> str: + """ + Add overrides for virtual methods to the input string. + + Parameters + ---------- + cpp_string : str + The input string containing current wrapper code + + Returns + ------- + str + The input string with the virtual override wrapper code added + """ + + # Skip private methods if self.method_decl.access_type == "private": - return output_string + return cpp_string + + # Get list of arguments and types + arg_list = [] + arg_name_list = [] + + for arg, arg_type in zip( + self.method_decl.arguments, self.method_decl.argument_types + ): + arg_list.append(f"{arg_type.decl_string} {arg.name}") + arg_name_list.append(f" {arg.name}") - arg_string = "" - num_arg_types = len(self.method_decl.argument_types) - args = self.method_decl.arguments - for idx, eachArg in enumerate(self.method_decl.argument_types): - arg_string += eachArg.decl_string + " " + args[idx].name - if idx < num_arg_types-1: - arg_string += ", " + arg_string = ", ".join(arg_list) # e.g. "int a, bool b, double c" + arg_name_string = ",\n".join(arg_name_list) # e.g. "a,\n b,\n c" + # Const-ness const_adorn = "" if self.method_decl.has_const: const_adorn = " const " + # For pure virtual methods, use PYBIND11_OVERRIDE_PURE overload_adorn = "" if self.method_decl.virtuality == "pure virtual": - overload_adorn = "_PURE" - - all_args_string = "" - for idx, eachArg in enumerate(self.method_decl.argument_types): - all_args_string += ""*8 + args[idx].name - if idx < num_arg_types-1: - all_args_string += ", \n" + overload_adorn = "_PURE" + # Get the return type e.g. "void" return_string = self.method_decl.return_type.decl_string - override_dict = {'return_type': return_string, - 'method_name': self.method_decl.name, - 'arg_string': arg_string, - 'const_adorn': const_adorn, - 'overload_adorn': overload_adorn, - 'tidy_method_name': self.tidy_name(return_string), - 'short_class_name': self.class_short_name, - 'args_string': all_args_string, - } - output_string += self.wrapper_templates["method_virtual_override"].format(**override_dict) - return output_string + + # Add the override code from the template + override_dict = { + "return_type": return_string, + "method_name": self.method_decl.name, + "arg_string": arg_string, + "const_adorn": const_adorn, + "overload_adorn": overload_adorn, + "tidy_method_name": self.tidy_name(return_string), + "short_class_name": self.class_short_name, + "args_string": arg_name_string, + } + cpp_string += self.wrapper_templates["method_virtual_override"].format( + **override_dict + ) + + return cpp_string diff --git a/cppwg/writers/module_writer.py b/cppwg/writers/module_writer.py index d7cdd2c..aa4ef34 100644 --- a/cppwg/writers/module_writer.py +++ b/cppwg/writers/module_writer.py @@ -1,114 +1,172 @@ -#!/usr/bin/env python - -""" -This scipt automatically generates Python bindings using a rule based approach -""" - import os - -from cppwg.writers import free_function_writer -from cppwg.writers import class_writer - - -class CppModuleWrapperWriter(object): - - def __init__(self, global_ns, - source_ns, - module_info, - wrapper_templates, - wrapper_root, - package_license=None): - - self.global_ns = global_ns - self.source_ns = source_ns - self.module_info = module_info - self.wrapper_root = wrapper_root - self.exposed_class_full_names = [] - self.wrapper_templates = wrapper_templates - self.license = package_license - - self.exposed_class_full_names = [] - - def generate_main_cpp(self): - +import logging + +from typing import Dict + +from pygccxml.declarations.class_declaration import class_t +from pygccxml.declarations.namespace import namespace_t + +from cppwg.input.module_info import ModuleInfo + +from cppwg.writers.free_function_writer import CppFreeFunctionWrapperWriter +from cppwg.writers.class_writer import CppClassWrapperWriter + +from cppwg.utils.constants import CPPWG_EXT +from cppwg.utils.constants import CPPWG_HEADER_COLLECTION_FILENAME + + +class CppModuleWrapperWriter: + """ + This class automatically generates Python bindings using a rule based approach + + Attributes + ---------- + source_ns : namespace_t + The pygccxml namespace containing declarations from the source code + module_info : ModuleInfo + The module information to generate Python bindings for + wrapper_templates : Dict[str, str] + String templates with placeholders for generating wrapper code + wrapper_root : str + The output directory for the generated wrapper code + package_license : str + The license to include in the generated wrapper code + exposed_class_full_names : List[str] + A list of full names of all classes to be wrapped in the module + """ + + def __init__( + self, + source_ns: namespace_t, + module_info: ModuleInfo, + wrapper_templates: Dict[str, str], + wrapper_root: str, + package_license: str = "", + ): + self.source_ns: namespace_t = source_ns + self.module_info: ModuleInfo = module_info + self.wrapper_templates: Dict[str, str] = wrapper_templates + self.wrapper_root: str = wrapper_root + self.package_license: str = ( + package_license # TODO: use this in the generated wrappers + ) + + # For convenience, create a list of all classes to be wrapped in the module + # e.g. ['Foo', 'Bar<2>', 'Bar<3>'] + self.exposed_class_full_names: List[str] = [] + + for class_info in self.module_info.class_info_collection: + for full_name in class_info.get_full_names(): + self.exposed_class_full_names.append(full_name.replace(" ", "")) + + def write_module_wrapper(self) -> None: """ - Generate the main cpp for the module + Generate the contents of the main cpp file for the module and write it + to modulename.main.cpp. This file contains the pybind11 module + definition. Within the module definition, the module's free functions + and classes are registered. + + For example, the generated file might look like this: + + ``` + #include + #include "Foo.cppwg.hpp" + + PYBIND11_MODULE(_packagename_modulename, m) + { + register_Foo_class(m); + } + ``` """ - # Generate the main cpp file - module_name = self.module_info.name - full_module_name = "_" + self.module_info.package_info.name + "_" + module_name + # Add top level includes + cpp_string = "#include \n" - cpp_string = "" - cpp_string += '#include \n' - - if self.module_info.package_info.common_include_file: - cpp_string += '#include "wrapper_header_collection.hpp"\n' - - # Custom code - if self.module_info.custom_generator is not None: + cpp_string += f'#include "{CPPWG_HEADER_COLLECTION_FILENAME}"\n' + + # Add outputs from running custom generator code + if self.module_info.custom_generator: cpp_string += self.module_info.custom_generator.get_module_pre_code() - # Add includes - for eachClass in self.module_info.class_info: - for short_name in eachClass.get_short_names(): - cpp_string += '#include "' + short_name + '.cppwg.hpp"\n' - cpp_string += '\nnamespace py = pybind11;\n\n' - cpp_string += 'PYBIND11_MODULE(' + full_module_name + ', m)\n{\n' + # Add includes for class wrappers in the module + for class_info in self.module_info.class_info_collection: + for short_name in class_info.get_short_names(): + # Example: #include "Foo2_2.cppwg.hpp" + cpp_string += f'#include "{short_name}.{CPPWG_EXT}.hpp"\n' + + # Format module name as _packagename_modulename + full_module_name = ( + "_" + self.module_info.package_info.name + "_" + self.module_info.name + ) + + # Create the pybind11 module + cpp_string += "\nnamespace py = pybind11;\n" + cpp_string += f"\nPYBIND11_MODULE({full_module_name}, m)\n" + cpp_string += "{\n" # Add free functions - for eachFunction in self.module_info.free_function_info: - writer = free_function_writer.CppFreeFunctionWrapperWriter(eachFunction, - self.wrapper_templates) - cpp_string = writer.add_self(cpp_string) - - # Add viable classes - for eachClass in self.module_info.class_info: - for short_name in eachClass.get_short_names(): - cpp_string += ' register_' + short_name + '_class(m);\n' - - # Add any custom code - if self.module_info.custom_generator is not None: + for free_function_info in self.module_info.free_function_info_collection: + function_writer = CppFreeFunctionWrapperWriter( + free_function_info, self.wrapper_templates + ) + # TODO: Consider returning the function string instead + cpp_string = function_writer.add_self(cpp_string) + + # Add classes + for class_info in self.module_info.class_info_collection: + for short_name in class_info.get_short_names(): + # Example: register_Foo2_2_class(m);" + cpp_string += f" register_{short_name}_class(m);\n" + + # Add code from the module's custom generator + if self.module_info.custom_generator: cpp_string += self.module_info.custom_generator.get_module_code() - output_dir = self.wrapper_root + "/" + self.module_info.name + "/" - if not os.path.exists(output_dir): - os.makedirs(output_dir) - main_cpp_file = open(output_dir + self.module_info.name + ".main.cpp", "w") - main_cpp_file.write(cpp_string + '}\n') - main_cpp_file.close() + cpp_string += "}\n" # End of the pybind11 module - def get_class_writer(self, class_info): - - """ - Return the class writer, override for custom writers - """ + # Write to /path/to/wrapper_root/modulename/modulename.main.cpp + module_dir = os.path.join(self.wrapper_root, self.module_info.name) + if not os.path.isdir(module_dir): + os.makedirs(module_dir) - this_class_writer = class_writer.CppClassWrapperWriter(class_info, self.wrapper_templates) - return this_class_writer + module_cpp_file = os.path.join(module_dir, self.module_info.name + ".main.cpp") - def write(self): - + with open(module_cpp_file, "w") as out_file: + out_file.write(cpp_string) + + def write_class_wrappers(self) -> None: """ - Main method for writing the module + Write wrappers for classes in the module """ + logger = logging.getLogger() + + for class_info in self.module_info.class_info_collection: + logger.info(f"Generating wrapper for class {class_info.name}") - print ('Generating Wrapper Code for: ' + self.module_info.name + ' Module.') + class_writer = CppClassWrapperWriter( + class_info, self.wrapper_templates, self.exposed_class_full_names + ) - self.generate_main_cpp() + # Get the declaration for each class and add it to the class writer + # TODO: Consider using class_info.decl instead + for full_name in class_info.get_full_names(): + name = full_name.replace(" ", "") # e.g. Foo<2,2> - # Generate class files - for eachClassInfo in self.module_info.class_info: - self.exposed_class_full_names.extend(eachClassInfo.get_full_names()) + class_decl: class_t = self.source_ns.class_(name) + class_writer.class_decls.append(class_decl) + + # Write the class wrappers into /path/to/wrapper_root/modulename/ + module_dir = os.path.join(self.wrapper_root, self.module_info.name) + class_writer.write(module_dir) - for eachClassInfo in self.module_info.class_info: + def write(self) -> None: + """ + Main method for writing the module + """ + logger = logging.getLogger() - print ('Generating Wrapper Code for: ' + eachClassInfo.name + ' Class.') + logger.info(f"Generating wrappers for module {self.module_info.name}") - class_writer = self.get_class_writer(eachClassInfo) - class_writer.exposed_class_full_names = self.exposed_class_full_names - for fullName in eachClassInfo.get_full_names(): - class_decl = self.source_ns.class_(fullName.replace(" ","")) - class_writer.class_decls.append(class_decl) - class_writer.write(self.wrapper_root + "/" + self.module_info.name + "/") + self.write_module_wrapper() + self.write_class_wrappers() diff --git a/tests/test_shapes.py b/tests/test_shapes.py index e070953..e142389 100644 --- a/tests/test_shapes.py +++ b/tests/test_shapes.py @@ -2,8 +2,10 @@ import subprocess import unittest +from typing import List -def get_file_lines(file_path: str) -> "list[str]": + +def get_file_lines(file_path: str) -> List[str]: """ Load a file into a list of lines @@ -14,7 +16,7 @@ def get_file_lines(file_path: str) -> "list[str]": Returns ------- - list[str] + List[str] A list of lines read from the file, with excess whitespace and empty lines removed """ @@ -52,7 +54,7 @@ def compare_files(file_path_a: str, file_path_b: str) -> bool: class TestShapes(unittest.TestCase): - def test_wrapper_generation(self): + def test_wrapper_generation(self) -> None: """ Generate wrappers and compare with the reference wrappers. """