Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
. ./ci-support-v0

build_py_project_in_venv
python -m mypy codepy test
python -m mypy codepy

docs:
name: Documentation
Expand Down
180 changes: 113 additions & 67 deletions codepy/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,34 @@
"""

import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import NamedTuple
from types import ModuleType
from typing import Any, NamedTuple

from codepy import CompileError
from codepy.toolchain import GCCLikeToolchain, Toolchain


logger = logging.getLogger(__name__)


def _erase_dir(dir):
def _erase_dir(dir: str) -> None:
from os import listdir, rmdir, unlink
from os.path import join

for name in listdir(dir):
unlink(join(dir, name))

rmdir(dir)


def extension_file_from_string(toolchain, ext_file, source_string,
source_name="module.cpp", debug=False):
def extension_file_from_string(
toolchain: Toolchain,
ext_file: str,
source_string: str,
source_name: str = "module.cpp",
debug: bool = False) -> None:
"""Using *toolchain*, build the extension file named *ext_file*
from the source code in *source_string*, which is saved to a
temporary file named *source_name*. Raise :exc:`CompileError` in
Expand All @@ -57,56 +66,65 @@ def extension_file_from_string(toolchain, ext_file, source_string,
src_dir = mkdtemp()

from os.path import join

source_file = join(src_dir, source_name)
outf = open(source_file, "w")
outf.write(str(source_string))
outf.close()
with open(source_file, "w") as outf:
outf.write(str(source_string))

try:
toolchain.build_extension(ext_file, [source_file], debug=debug)
finally:
_erase_dir(src_dir)


class CleanupBase:
pass
class CleanupBase(ABC):
@abstractmethod
def clean_up(self) -> None:
pass

@abstractmethod
def error_clean_up(self) -> None:
pass


class CleanupManager(CleanupBase):
def __init__(self):
self.cleanups = []
def __init__(self) -> None:
self.cleanups: list[CleanupBase] = []

def register(self, c):
def register(self, c: CleanupBase) -> None:
self.cleanups.insert(0, c)

def clean_up(self):
def clean_up(self) -> None:
for c in self.cleanups:
c.clean_up()

def error_clean_up(self):
def error_clean_up(self) -> None:
for c in self.cleanups:
c.error_clean_up()


class TempDirManager(CleanupBase):
def __init__(self, cleanup_m):
def __init__(self, cleanup_m: CleanupManager) -> None:
from tempfile import mkdtemp
self.path = mkdtemp()
cleanup_m.register(self)

def sub(self, n):
def sub(self, n: str) -> str:
from os.path import join
return join(self.path, n)

def clean_up(self):
def clean_up(self) -> None:
_erase_dir(self.path)

def error_clean_up(self):
def error_clean_up(self) -> None:
pass


class CacheLockManager(CleanupBase):
def __init__(self, cleanup_m, cache_dir, sleep_delay=1):
def __init__(self,
cleanup_m: CleanupManager,
cache_dir: str,
sleep_delay: int = 1) -> None:
import os

if cache_dir is not None:
Expand All @@ -133,17 +151,17 @@ def __init__(self, cleanup_m, cache_dir, sleep_delay=1):

cleanup_m.register(self)

def clean_up(self):
def clean_up(self) -> None:
import os
os.close(self.fd)
os.unlink(self.lock_file)

def error_clean_up(self):
def error_clean_up(self) -> None:
pass


class ModuleCacheDirManager(CleanupBase):
def __init__(self, cleanup_m, path):
def __init__(self, cleanup_m: CleanupManager, path: str) -> None:
from os import mkdir

self.path = path
Expand All @@ -154,26 +172,32 @@ def __init__(self, cleanup_m, path):
except OSError:
self.existed = True

def sub(self, n):
def sub(self, n: str) -> str:
from os.path import join
return join(self.path, n)

def reset(self):
def reset(self) -> None:
import os
_erase_dir(self.path)
os.mkdir(self.path)

def clean_up(self):
def clean_up(self) -> None:
pass

def error_clean_up(self):
def error_clean_up(self) -> None:
_erase_dir(self.path)


def extension_from_string(toolchain, name, source_string,
source_name="module.cpp", cache_dir=None,
debug=False, wait_on_error=None,
debug_recompile=True, sleep_delay=1):
def extension_from_string(
toolchain: Toolchain,
name: str,
source_string: str | list[str],
source_name: str | list[str] = "module.cpp",
cache_dir: str | None = None,
debug: bool = False,
wait_on_error: bool | None = None,
debug_recompile: bool = True,
sleep_delay: int = 1) -> ModuleType:
"""Return a reference to the extension module *name*, which can be built
from the source code in *source_string* if necessary. Raise
:exc:`CompileError` in case of error.
Expand Down Expand Up @@ -204,7 +228,7 @@ def extension_from_string(toolchain, name, source_string,
_checksum, mod_name, ext_file, _recompiled = (
compile_from_string(toolchain, name, source_string, source_name,
cache_dir, debug, wait_on_error, debug_recompile,
False, sleep_delay=sleep_delay))
object=False, sleep_delay=sleep_delay))

# try loading it
from codepy.tools import load_dynamic
Expand All @@ -217,28 +241,37 @@ class _InvalidInfoFileError(RuntimeError):

class _Dependency(NamedTuple):
name: str
mtime: int
mtime: float
md5: str


@dataclass(frozen=True)
class _SourceInfo:
dependencies: list[_Dependency]
source_name: str


def compile_from_string(toolchain, name, source_string,
source_name=None, cache_dir=None,
debug=False, wait_on_error=None, debug_recompile=True,
object=False, source_is_binary=False, sleep_delay=1):
"""Returns a tuple: mod_name, file_name, recompiled.
mod_name is the name of the module represented by a compiled object,
file_name is the name of the compiled object, which can be built from the
source_name: list[str]


def compile_from_string(
toolchain: Toolchain,
name: str,
source_string: str | bytes | list[str] | list[bytes],
source_name: str | list[str] | None = None,
cache_dir: str | None = None,
debug: bool = False,
wait_on_error: bool | None = None,
debug_recompile: bool = True,
object: bool = False,
source_is_binary: bool = False,
sleep_delay: int = 1) -> tuple[str, str, str, bool]:
"""Returns a tuple: ``(checksum, mod_name, file_name, recompiled)``.
*mod_name* is the name of the module represented by a compiled object,
*file_name* is the name of the compiled object, which can be built from the
source code(s) in *source_strings* if necessary,
recompiled is True if the object had to be recompiled, False if the cache
*recompiled* is *True* if the object had to be recompiled, *False* if the cache
is hit.
Raise :exc:`CompileError` in case of error. The mod_name and file_name
are designed to be used with load_dynamic to load a python module from

Raises :exc:`CompileError` in case of error. The *mod_name* and *file_name*
are designed to be used with ``load_dynamic`` to load a Python module from
this object, if desired.

Compiled code is cached in *cache_dir* and available immediately if it has
Expand Down Expand Up @@ -267,12 +300,17 @@ def compile_from_string(toolchain, name, source_string,
If *source_is_binary*, the source string is a compile object file and
should be treated as binary for read/write purposes
"""
if not isinstance(toolchain, GCCLikeToolchain):
raise TypeError(f"Unsupported toolchain type: {type(toolchain)}")

if source_name is None:
source_name = ["module.cpp"]

# first ensure that source strings and names are lists
if isinstance(source_string, str) \
or (source_is_binary and isinstance(source_string, bytes)):
if isinstance(source_string, str):
source_string = [source_string]

if source_is_binary and isinstance(source_string, bytes):
source_string = [source_string]

if isinstance(source_name, str):
Expand Down Expand Up @@ -302,40 +340,40 @@ def compile_from_string(toolchain, name, source_string,
if e.errno != EEXIST:
raise

def get_file_md5sum(fname):
def get_file_md5sum(fname: str) -> str:
import hashlib
checksum = hashlib.md5()

inf = open(fname, "rb")
checksum.update(inf.read())
with open(fname, "rb") as inf:
checksum.update(inf.read())

inf.close()
return checksum.hexdigest()

def get_dep_structure(source_paths):
def get_dep_structure(source_paths: list[str]) -> list[_Dependency]:
deps = toolchain.get_dependencies(source_paths)
return [_Dependency(dep, os.stat(dep).st_mtime, get_file_md5sum(dep))
for dep in sorted(deps) if dep not in source_paths]

def write_source(name):
def write_source(name: list[str]) -> None:
for i, source in enumerate(source_string):
outf = open(name[i], "w" if not source_is_binary else "wb")
outf.write(source)
outf.close()
with open(name[i], "w" if not source_is_binary else "wb") as outf:
outf.write(source)

def calculate_hex_checksum():
def calculate_hex_checksum() -> str:
import hashlib
checksum = hashlib.md5()

for source in source_string:
if source_is_binary:
assert isinstance(source, bytes)
checksum.update(source)
else:
assert isinstance(source, str)
checksum.update(source.encode("utf-8"))
checksum.update(str(toolchain.abi_id()).encode("utf-8"))
return checksum.hexdigest()

def load_info(info_path):
def load_info(info_path: str) -> Any:
import pickle

try:
Expand All @@ -350,7 +388,7 @@ def load_info(info_path):
finally:
info_file.close()

def check_deps(deps):
def check_deps(deps: list[_Dependency]) -> bool:
for name, date, md5sum in deps:
try:
possibly_updated = os.stat(name).st_mtime != date
Expand All @@ -370,7 +408,7 @@ def check_deps(deps):

return True

def check_source(source_path):
def check_source(source_path: list[str]) -> bool:
valid = True
for i, path in enumerate(source_path):
source = source_string[i]
Expand Down Expand Up @@ -440,11 +478,10 @@ def check_source(source_path):
if info_path is not None:
import pickle

info_file = open(info_path, "wb")
pickle.dump(_SourceInfo(
dependencies=get_dep_structure(source_paths),
source_name=source_name), info_file)
info_file.close()
with open(info_path, "wb") as info_file:
pickle.dump(_SourceInfo(
dependencies=get_dep_structure(source_paths),
source_name=source_name), info_file)

return hex_checksum, mod_name, ext_file, True
except Exception:
Expand All @@ -454,8 +491,16 @@ def check_source(source_path):
cleanup_m.clean_up()


def link_extension(toolchain, objects, mod_name, cache_dir=None,
debug=False, wait_on_error=True):
def link_extension(
toolchain: Toolchain,
objects: list[str],
mod_name: str,
cache_dir: str | None = None,
debug: bool = False,
wait_on_error: bool = True) -> ModuleType:
if not isinstance(toolchain, GCCLikeToolchain):
raise TypeError(f"Unsupported toolchain type: {type(toolchain)}")

import os.path
if cache_dir is not None:
destination = os.path.join(cache_dir, mod_name + toolchain.so_ext)
Expand All @@ -465,6 +510,7 @@ def link_extension(toolchain, objects, mod_name, cache_dir=None,
destination = os.path.join(
destination_base,
mod_name + toolchain.so_ext)

try:
toolchain.link_extension(destination, objects, debug=debug)
except CompileError:
Expand Down
Loading
Loading