diff --git a/.gitignore b/.gitignore index 0dd11a8..07377ae 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ # Ignore example dataset storage generated in tutorial -pod_data/ +**/pod_data/ # Autogenerated version file _version.py diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..ec6206f --- /dev/null +++ b/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = -v --cov=src --cov-report=term-missing \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c6aff8e..ad1779c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,5 @@ numpy -xxhash \ No newline at end of file +xxhash +# TODO: separate dev and prod requirements +pytest>=7.4.0 +pytest-cov>=4.1.0 \ No newline at end of file diff --git a/src/orcabridge/__init__.py b/src/orcabridge/__init__.py index 6d2c89d..8533ffa 100644 --- a/src/orcabridge/__init__.py +++ b/src/orcabridge/__init__.py @@ -30,9 +30,11 @@ "DirDataStore", "SafeDirDataStore", "DEFAULT_TRACKER", + "SyncStreamFromLists", ] from .mapper import MapTags, MapPackets, Join, tag, packet from .pod import FunctionPod, function_pod from .source import GlobSource -from .store import DirDataStore, SafeDirDataStore \ No newline at end of file +from .store import DirDataStore, SafeDirDataStore +from .stream import SyncStreamFromLists as SimpleStream diff --git a/src/orcabridge/base.py b/src/orcabridge/base.py index b770d07..c855c1a 100644 --- a/src/orcabridge/base.py +++ b/src/orcabridge/base.py @@ -1,5 +1,5 @@ from orcabridge.hashing import HashableMixin -from .types import Tag, Packet +from orcabridge.types import Tag, Packet from typing import ( Optional, Tuple, @@ -10,7 +10,7 @@ Iterator, ) from collections.abc import Collection -from typing import Any, List, Tuple +import threading class Operation(HashableMixin): @@ -40,9 +40,12 @@ def keys(self, *streams: "SyncStream") -> Tuple[List[str], List[str]]: @property def label(self) -> str: """ - Overwrite this method to attain a custom label logic for the operation. + Returns a human-readable label for this operation. + Default implementation returns the provided label or class name if no label was provided. """ - return self._label + if self._label: + return self._label + return self.__class__.__name__ @label.setter def label(self, value: str) -> None: @@ -58,17 +61,17 @@ def identity_structure(self, *streams: "SyncStream") -> Any: def __call__(self, *streams: "SyncStream", **kwargs) -> "SyncStream": # trigger call on source if passed as stream - streams = [stream() if isinstance(stream, Source) else stream for stream in streams] + streams = [ + stream() if isinstance(stream, Source) else stream + for stream in streams + ] output_stream = self.forward(*streams, **kwargs) # create an invocation instance invocation = Invocation(self, streams) # label the output_stream with the invocation information output_stream.invocation = invocation - # delay import to avoid circular import - from .tracker import Tracker - - # reg + # register the invocation with active trackers active_trackers = Tracker.get_active_trackers() for tracker in active_trackers: tracker.record(invocation) @@ -78,16 +81,64 @@ def __call__(self, *streams: "SyncStream", **kwargs) -> "SyncStream": def __repr__(self): return self.__class__.__name__ + def __str__(self): + if self._label is not None: + return f"{self.__class__.__name__}({self._label})" + return self.__class__.__name__ + def forward(self, *streams: "SyncStream") -> "SyncStream": ... +class Tracker: + """ + A tracker is a class that can track the invocations of operations. Only "active" trackers + participate in tracking and its `record` method gets called on each invocation of an operation. + Multiple trackers can be active at any time. + """ + + _local = threading.local() + + @classmethod + def get_active_trackers(cls) -> List["Tracker"]: + if hasattr(cls._local, "active_trackers"): + return cls._local.active_trackers + return [] + + def __init__(self): + self.active = False + + def activate(self) -> None: + """ + Activate the tracker. This is a no-op if the tracker is already active. + """ + if not self.active: + if not hasattr(self._local, "active_trackers"): + self._local.active_trackers = [] + self._local.active_trackers.append(self) + self.active = True + + def deactivate(self) -> None: + # Remove this tracker from active trackers + if hasattr(self._local, "active_trackers") and self.active: + self._local.active_trackers.remove(self) + self.active = False + + def __enter__(self): + self.activate() + return self + + def __exit__(self, exc_type, exc_val, ext_tb): + self.deactivate() + + def record(self, invocation: "Invocation") -> None: ... + + class Invocation(HashableMixin): """ This class represents an invocation of an operation on a collection of streams. - It contains the operation, the invocation ID, and the streams that were used - in the invocation. - The invocation ID is a unique identifier for the invocation and is used to - track the invocation in the tracker. + It contains the operation and the streams that were used in the invocation. + Note that the collection of streams may be empty, in which case the invocation + likely corresponds to a source operation. """ def __init__( @@ -108,8 +159,8 @@ def keys(self) -> Tuple[Collection[str], Collection[str]]: return self.operation.keys(*self.streams) def identity_structure(self) -> int: - # default implementation is streams order sensitive. If an operation does - # not depend on the order of the streams, it should override this method + # Identity of an invocation is entirely dependend on + # the operation's identity structure upon invocation return self.operation.identity_structure(*self.streams) def __eq__(self, other: Any) -> bool: @@ -136,15 +187,37 @@ class Stream(HashableMixin): This may be None if the stream is not generated by an operation. """ - def __init__(self, **kwargs) -> None: + def __init__(self, label: Optional[str] = None, **kwargs) -> None: super().__init__(**kwargs) self._invocation: Optional[Invocation] = None + self._label = label def identity_structure(self) -> Any: + """ + Identity structure of a stream is deferred to the identity structure + of the associated invocation, if present. + A bare stream without invocation has no well-defined identity structure. + """ if self.invocation is not None: return self.invocation.identity_structure() return super().identity_structure() + @property + def label(self) -> str: + """ + Returns a human-readable label for this stream. + If no label is provided and the stream is generated by an operation, + the label of the operation is used. + Otherwise, the class name is used as the label. + """ + if self._label is None: + if self.invocation is not None: + # use the invocation operation label + return self.invocation.operation.label + else: + return self.__class__.__name__ + return self._label + @property def invocation(self) -> Optional[Invocation]: return self._invocation @@ -152,32 +225,11 @@ def invocation(self) -> Optional[Invocation]: @invocation.setter def invocation(self, value: Invocation) -> None: if not isinstance(value, Invocation): - raise TypeError("invocation field must be an instance of Invocation") + raise TypeError( + "invocation field must be an instance of Invocation" + ) self._invocation = value - def __iter__(self) -> Iterator[Tuple[Tag, Packet]]: - raise NotImplementedError("Subclasses must implement __iter__ method") - - def flow(self) -> Collection[Tuple[Tag, Packet]]: - """ - Flow everything through the stream, returning the entire collection of - (Tag, Packet) as a collection. This will tigger any upstream computation of the stream. - """ - return list(self) - - -class SyncStream(Stream): - """ - A stream that will complete in a fixed amount of time. It is suitable for synchronous operations that - will have to wait for the stream to finish before proceeding. - """ - - def content_hash(self) -> str: - if self.invocation is not None: # and hasattr(self.invocation, "invocation_id"): - # use the invocation ID as the hash - return self.invocation.content_hash() - return super().content_hash() - def keys(self) -> Tuple[Collection[str], Collection[str]]: """ Returns the keys of the stream. @@ -194,9 +246,28 @@ def keys(self) -> Tuple[Collection[str], Collection[str]]: if tag_keys is not None and packet_keys is not None: return tag_keys, packet_keys # otherwise, use the keys from the first packet in the stream + # note that this may be computationally expensive tag, packet = next(iter(self)) return list(tag.keys()), list(packet.keys()) + def __iter__(self) -> Iterator[Tuple[Tag, Packet]]: + raise NotImplementedError("Subclasses must implement __iter__ method") + + def flow(self) -> Collection[Tuple[Tag, Packet]]: + """ + Flow everything through the stream, returning the entire collection of + (Tag, Packet) as a collection. This will tigger any upstream computation of the stream. + """ + return list(self) + + +class SyncStream(Stream): + """ + A stream that will complete in a fixed amount of time. + It is suitable for synchronous operations that + will have to wait for the stream to finish before proceeding. + """ + def head(self, n: int = 5) -> None: """ Print the first n elements of the stream. @@ -223,9 +294,9 @@ def __rshift__(self, transformer: Any) -> "SyncStream": The mapping is applied to each packet in the stream and the resulting packets are returned in a new stream. """ + # TODO: remove just in time import from .mapper import MapPackets - # TODO: extend to generic mapping if isinstance(transformer, dict): return MapPackets(transformer)(self) elif isinstance(transformer, Callable): @@ -235,6 +306,7 @@ def __mul__(self, other: "SyncStream") -> "SyncStream": """ Returns a new stream that is the result joining with the other stream """ + # TODO: remove just in time import from .mapper import Join if not isinstance(other, SyncStream): @@ -242,6 +314,13 @@ def __mul__(self, other: "SyncStream") -> "SyncStream": return Join()(self, other) +class Mapper(Operation): + """ + A Mapper is an operation that does NOT generate new file content. + It is used to control the flow of data in the pipeline without modifying or creating new data (file). + """ + + class Source(Operation, SyncStream): """ A base class for all sources in the system. A source can be seen as a special diff --git a/src/orcabridge/dj/source.py b/src/orcabridge/dj/source.py index ab5f598..aad4229 100644 --- a/src/orcabridge/dj/source.py +++ b/src/orcabridge/dj/source.py @@ -1,5 +1,5 @@ from ..source import Source -from .stream import QueryStream, TableCachedStream +from .stream import QueryStream, TableCachedStream, TableStream from .operation import QueryOperation from ..stream import SyncStream from datajoint import Table @@ -7,6 +7,11 @@ from datajoint import Schema import datajoint as dj from ..utils.name import pascal_to_snake, snake_to_pascal +from ..utils.stream_utils import common_elements +import logging +from ..hashing import hash_to_uuid + +logger = logging.getLogger(__name__) class QuerySource(Source, QueryOperation): @@ -200,10 +205,9 @@ def label(self) -> str: return self.source.label return self._label - def compile( - self, tag_keys: Collection[str], packet_keys: Collection[str] - ) -> None: + def compile(self) -> None: # create a table to store the cached packets + tag_keys, packet_keys = self.source().keys() key_fields = "\n".join([f"{k}: varchar(255)" for k in tag_keys]) output_fields = "\n".join([f"{k}: varchar(255)" for k in packet_keys]) @@ -241,7 +245,7 @@ def forward( raise ValueError("No streams should be passed to TableCachedSource") if self.table is None: - self.compile(*self.source().keys()) + self.compile() return TableCachedStream( self.table, self.source(), @@ -256,50 +260,174 @@ class MergedQuerySource(QuerySource): """ def __init__( - self, *sources: QuerySource, label: Optional[str] = None + self, + *streams: QueryStream, + schema: Schema, + table_name: str = None, + table_postfix: str = "", + label: Optional[str] = None, + lazy_build: bool = True, ) -> None: super().__init__(label=label) - self.sources = sources + self.streams = streams + self.schema = schema + self.table = None + if table_name is None: + table_name = self.label if self.label is not None else "MergedData" + + self.table_name = pascal_to_snake(table_name) + ( + f"_{table_postfix}" if table_postfix else "" + ) + if not lazy_build: + self.compile() + + @property + def label(self) -> str: + if self._label is None: + return "_".join([stream.label for stream in self.streams]) + return self._label def identity_structure(self, *streams): return ( self.__class__.__name__, - str(self.sources), + self.streams, ) + tuple(streams) def forward(self, *streams: SyncStream) -> QueryStream: if len(streams) > 0: - raise NotImplementedError( - "Passing streams through MergedQuerySource is not implemented yet" + logger.warning( + "Handling multiple streams in forward is not implemented yet in " + "MergedQuerySource and this will be silently ignored" ) + if self.table is None: + self.compile() + + return TableStream(self.table) + + def compile(self) -> None: + + part_tag_keys = [] + part_packet_keys = [] + for stream in self.streams: + tag_key, packet_key = stream.keys() + part_tag_keys.append(tag_key) + part_packet_keys.append(packet_key) + + # find common keys among all tags and use that as primary key + common_tag_keys = common_elements(*part_tag_keys) + common_packet_keys = common_elements(*part_packet_keys) + + use_uuid = True + if all([len(k) == len(common_tag_keys) for k in part_tag_keys]): + # if all tags have the same number of keys, it is not necessary + # to include an additional UUID + use_uuid = False - def compile( - self, tag_keys: Collection[str], packet_keys: Collection[str] - ) -> None: # create a table to store the cached packets - key_fields = "\n".join([f"{k}: varchar(255)" for k in tag_keys]) - output_fields = "\n".join([f"{k}: varchar(255)" for k in packet_keys]) + key_fields = "\n".join([f"{k}: varchar(255)" for k in common_tag_keys]) + output_fields = "\n".join( + [f"{k}: varchar(255)" for k in common_packet_keys] + ) + table_field = f"{self.table_name}_part" + uuid_field = f"{self.table_name}_uuid" if use_uuid else "" + table_entry = f"{table_field}: varchar(255)" + uuid_entry = f"{uuid_field}: uuid" if use_uuid else "" - class CachedTable(dj.Manual): - source = self # this refers to the outer class instance + class MergedTable(dj.Manual): + source = self definition = f""" - # {self.table_name} outputs + # {self.table_name} inputs {key_fields} + {table_entry} + {uuid_entry} --- {output_fields} """ - def populate( - self, batch_size: int = 10, use_skip_duplicates: bool = False - ) -> int: - return sum( - 1 - for _ in self.operation( - batch_size=batch_size, - use_skip_duplicates=use_skip_duplicates, - ) + for stream in self.streams: + if not isinstance(stream, QueryStream): + raise ValueError( + f"Stream {stream} is not a QueryStream. " + "Please use a QueryStream as input." ) + part_table = make_part_table( + stream, + common_tag_keys, + common_packet_keys, + table_field, + uuid_field, + ) + setattr(MergedTable, snake_to_pascal(stream.label), part_table) + + MergedTable.__name__ = snake_to_pascal(self.table_name) + MergedTable = self.schema(MergedTable) + self.table = MergedTable + + # class CachedTable(dj.Manual): + # source = self # this refers to the outer class instance + # definition = f""" + # # {self.table_name} outputs + # {key_fields} + # --- + # {output_fields} + # """ + + # def populate( + # self, batch_size: int = 10, use_skip_duplicates: bool = False + # ) -> int: + # return sum( + # 1 + # for _ in self.operation( + # batch_size=batch_size, + # use_skip_duplicates=use_skip_duplicates, + # ) + # ) + + # CachedTable.__name__ = snake_to_pascal(self.table_name) + # CachedTable = self.schema(CachedTable) + # self.table = CachedTable + + +def make_part_table( + stream: QueryStream, + common_tag_keys, + common_packet_keys, + table_field, + uuid_field, +) -> type[dj.Part]: + upstreams = "\n".join( + f"-> self.upstream_tables[{i}]" + for i in range(len(stream.upstream_tables)) + ) + + tag_keys, packet_keys = stream.keys() + + extra_packet_keys = [k for k in packet_keys if k not in common_packet_keys] + + extra_output_fields = "\n".join( + [f"{k}: varchar(255)" for k in extra_packet_keys] + ) + + class PartTable(dj.Part, dj.Computed): + upstream_tables = stream.upstream_tables + definition = f""" + -> master + --- + {upstreams} + {extra_output_fields} + """ - CachedTable.__name__ = snake_to_pascal(self.table_name) - CachedTable = self.schema(CachedTable) - self.table = CachedTable + @property + def key_source(self): + return stream.query + + def make(self, key): + content = (stream.query & key).fetch1() + content[table_field] = self.__class__.__name__ + if uuid_field: + content[uuid_field] = hash_to_uuid(key) + self.master.insert1(content, ignore_extra_fields=True) + self.insert1(content, ignore_extra_fields=True) + + PartTable.__name__ = snake_to_pascal(stream.label) + return PartTable diff --git a/src/orcabridge/dj/tracker.py b/src/orcabridge/dj/tracker.py index ee29a13..1b89d24 100644 --- a/src/orcabridge/dj/tracker.py +++ b/src/orcabridge/dj/tracker.py @@ -1,15 +1,15 @@ -from ..tracker import Tracker +from orcabridge.tracker import GraphTracker from datajoint import Schema from typing import List, Collection, Tuple, Optional, Any from types import ModuleType import networkx as nx -from ..base import Operation, Source -from ..mapper import Mapper -from ..pod import FunctionPod +from orcabridge.base import Operation, Source +from orcabridge.mapper import Mapper, Merge +from orcabridge.pod import FunctionPod from .stream import QueryStream -from .source import TableCachedSource +from .source import TableCachedSource, MergedQuerySource from .operation import QueryOperation from .pod import TableCachedPod from .mapper import convert_to_query_mapper @@ -57,6 +57,17 @@ def convert_to_query_operation( True, ) + if isinstance(operation, Merge): + return ( + MergedQuerySource( + *upstreams, + schema=schema, + table_name=table_name, + table_postfix=table_postfix, + ), + True, + ) + if isinstance(operation, Mapper): return convert_to_query_mapper(operation), True @@ -64,7 +75,7 @@ def convert_to_query_operation( raise ValueError(f"Unsupported operation for DJ conversion: {operation}") -class QueryTracker(Tracker): +class QueryTracker(GraphTracker): """ Query-specific tracker that tracks the invocations of operations and their associated streams. diff --git a/src/orcabridge/hashing.py b/src/orcabridge/hashing.py index 9fa77af..2a97e14 100644 --- a/src/orcabridge/hashing.py +++ b/src/orcabridge/hashing.py @@ -6,7 +6,6 @@ suitable for arbitrarily nested data structures and custom objects via HashableMixin. """ -from ast import Str import hashlib import inspect import json @@ -26,10 +25,10 @@ ) from pathlib import Path from os import PathLike -import os import xxhash import zlib -from .types import PathSet, Packet +from orcabridge.types import PathSet, Packet +from orcabridge.utils.name import find_noncolliding_name # Configure logging with __name__ for proper hierarchy logger = logging.getLogger(__name__) @@ -827,6 +826,70 @@ def stable_hash(s: Any) -> int: return hash_to_int(s) +# Hashing of packets and PathSet + + +class PathSetHasher: + def __init__(self, char_count=32): + self.char_count = char_count + + def hash_pathset(self, pathset: PathSet) -> str: + if isinstance(pathset, str) or isinstance(pathset, PathLike): + pathset = Path(pathset) + if not pathset.exists(): + raise FileNotFoundError(f"Path {pathset} does not exist") + if pathset.is_dir(): + # iterate over all entries in the directory include subdirectory (single step) + hash_dict = {} + for entry in pathset.iterdir(): + file_name = find_noncolliding_name(entry.name, hash_dict) + hash_dict[file_name] = self.hash_pathset(entry) + return hash_to_hex(hash_dict, char_count=self.char_count) + else: + # it's a file, hash it directly + return hash_file(pathset) + + if isinstance(pathset, Collection): + hash_dict = {} + for path in pathset: + file_name = find_noncolliding_name(Path(path).name, hash_dict) + hash_dict[file_name] = self.hash_pathset(path) + return hash_to_hex(hash_dict, char_count=self.char_count) + + raise ValueError(f"PathSet of type {type(pathset)} is not supported") + + def hash_file(self, filepath) -> str: ... + + def id(self) -> str: ... + + +def hash_packet_with_psh( + packet: Packet, algo: PathSetHasher, prefix_algorithm: bool = True +) -> str: + """ + Generate a hash for a packet based on its content. + + Args: + packet: The packet to hash + algorithm: The algorithm to use for hashing + prefix_algorithm: Whether to prefix the hash with the algorithm name + + Returns: + A hexadecimal digest of the packet's content + """ + hash_results = {} + for key, pathset in packet.items(): + hash_results[key] = algo.hash_pathset(pathset) + + packet_hash = hash_to_hex(hash_results) + + if prefix_algorithm: + # Prefix the hash with the algorithm name + packet_hash = f"{algo.id()}-{packet_hash}" + + return packet_hash + + def hash_packet( packet: Packet, algorithm: str = "sha256", @@ -879,7 +942,7 @@ def hash_pathset( # iterate over all entries in the directory include subdirectory (single step) hash_dict = {} for entry in pathset.iterdir(): - file_name = entry.name + file_name = find_noncolliding_name(entry.name, hash_dict) hash_dict[file_name] = hash_pathset( entry, algorithm=algorithm, @@ -896,7 +959,7 @@ def hash_pathset( if isinstance(pathset, Collection): hash_dict = {} for path in pathset: - file_name = Path(path).name + file_name = find_noncolliding_name(Path(path).name, hash_dict) hash_dict[file_name] = hash_pathset( path, algorithm=algorithm, diff --git a/src/orcabridge/mapper.py b/src/orcabridge/mapper.py index 777a4b3..d17dc9a 100644 --- a/src/orcabridge/mapper.py +++ b/src/orcabridge/mapper.py @@ -1,8 +1,17 @@ -from typing import Callable, Dict, Optional, List, Sequence - -from .stream import SyncStream, SyncStreamFromGenerator -from .base import Operation -from .utils.stream_utils import ( +from typing import ( + Callable, + Dict, + Optional, + List, + Sequence, + Tuple, + Iterator, + Collection, + Any, +) +from orcabridge.base import Operation, SyncStream, Mapper +from orcabridge.stream import SyncStreamFromGenerator +from orcabridge.utils.stream_utils import ( join_tags, check_packet_compatibility, batch_tag, @@ -10,17 +19,9 @@ ) from .hashing import hash_function from .types import Tag, Packet -from typing import Iterator, Tuple, Any, Collection from itertools import chain -class Mapper(Operation): - """ - A Mapper is an operation that does NOT generate new file content. - It is used to control the flow of data in the pipeline without modifying or creating new data (file). - """ - - class Repeat(Mapper): """ A Mapper that repeats the packets in the stream a specified number of times. @@ -599,16 +600,14 @@ def __init__(self) -> None: self.is_cached = False def forward(self, *streams: SyncStream) -> SyncStream: - if len(streams) != 1: + if not self.is_cached and len(streams) != 1: raise ValueError( "CacheStream operation requires exactly one stream" ) - stream = streams[0] - def generator() -> Iterator[Tuple[Tag, Packet]]: if not self.is_cached: - for tag, packet in stream: + for tag, packet in streams[0]: self.cache.append((tag, packet)) yield tag, packet self.is_cached = True diff --git a/src/orcabridge/pod.py b/src/orcabridge/pod.py index 6a021fa..ba15918 100644 --- a/src/orcabridge/pod.py +++ b/src/orcabridge/pod.py @@ -2,9 +2,7 @@ logger = logging.getLogger(__name__) -from pathlib import Path from typing import ( - List, Optional, Tuple, Iterator, @@ -13,14 +11,12 @@ Literal, Any, ) -from .hashing import hash_function, get_function_signature -from .base import Operation -from .mapper import Join -from .stream import SyncStream, SyncStreamFromGenerator -from .types import Tag, Packet, PodFunction -from .store import DataStore, NoOpDataStore -import json -import shutil +from orcabridge.types import Tag, Packet, PodFunction +from orcabridge.hashing import hash_function, get_function_signature +from orcabridge.base import Operation +from orcabridge.stream import SyncStream, SyncStreamFromGenerator +from orcabridge.mapper import Join +from orcabridge.store import DataStore, NoOpDataStore import functools import warnings @@ -120,6 +116,7 @@ def __init__( custom_hash: Optional[int] = None, label: Optional[str] = None, force_computation: bool = False, + skip_cache_lookup: bool = False, skip_memoization: bool = False, error_handling: Literal["raise", "ignore", "warn"] = "raise", _hash_function_kwargs: Optional[dict] = None, @@ -139,6 +136,7 @@ def __init__( self.function_hash_mode = function_hash_mode self.custom_hash = custom_hash self.force_computation = force_computation + self.skip_cache_lookup = skip_cache_lookup self.skip_memoization = skip_memoization self.error_handling = error_handling self._hash_function_kwargs = _hash_function_kwargs @@ -168,15 +166,21 @@ def generator() -> Iterator[Tuple[Tag, Packet]]: n_computed = 0 for tag, packet in stream: try: - memoized_packet = self.data_store.retrieve_memoized( - self.store_name, - self.content_hash(char_count=16), - packet, - ) + if not self.skip_cache_lookup: + memoized_packet = self.data_store.retrieve_memoized( + self.store_name, + self.content_hash(char_count=16), + packet, + ) + else: + memoized_packet = None if ( not self.force_computation and memoized_packet is not None ): + logger.info( + "Memoized packet found, skipping computation" + ) yield tag, memoized_packet continue values = self.function(**packet) @@ -214,7 +218,7 @@ def generator() -> Iterator[Tuple[Tag, Packet]]: # e.g. if the output is a file, the path may be changed output_packet = self.data_store.memoize( self.store_name, - self.content_hash(), + self.content_hash(), # identity of this function pod packet, output_packet, ) diff --git a/src/orcabridge/source.py b/src/orcabridge/source.py index 65646bb..3071b89 100644 --- a/src/orcabridge/source.py +++ b/src/orcabridge/source.py @@ -1,10 +1,19 @@ -from .base import Source -from .stream import SyncStream, SyncStreamFromGenerator -from .types import Tag, Packet -from typing import Iterator, Tuple, Optional, Callable, Any, Collection, Literal +from orcabridge.types import Tag, Packet +from orcabridge.hashing import hash_function +from orcabridge.base import Source +from orcabridge.stream import SyncStream, SyncStreamFromGenerator +from typing import ( + Iterator, + Tuple, + Optional, + Callable, + Any, + Collection, + Literal, + Union, +) from os import PathLike from pathlib import Path -from .hashing import hash_function class LoadFromSource(Source): @@ -27,9 +36,11 @@ class GlobSource(Source): The directory path to search for files pattern : str, default='*' The glob pattern to match files against - tag_function : Optional[Callable[[PathLike], Tag]], default=None + tag_key : Optional[Union[str, Callable[[PathLike], Tag]]], default=None Optional function to generate a tag from a file path. If None, uses the file's - stem name (without extension) in a dict with key 'file_name' + stem name (without extension) in a dict with key 'file_name'. If only string is + provided, it will be used as the key for the tag. If a callable is provided, it + should accept a file path and return a dictionary of tags. Examples -------- @@ -48,7 +59,7 @@ def __init__( file_path: PathLike, pattern: str = "*", label: Optional[str] = None, - tag_function: Optional[Callable[[PathLike], Tag]] = None, + tag_function: Optional[Union[str, Callable[[PathLike], Tag]]] = None, tag_function_hash_mode: Literal[ "content", "signature", "name" ] = "name", @@ -60,9 +71,13 @@ def __init__( self.file_path = file_path self.pattern = pattern self.expected_tag_keys = expected_tag_keys + if self.expected_tag_keys is None and isinstance(tag_function, str): + self.expected_tag_keys = [tag_function] if tag_function is None: - # extract the file name without extension tag_function = self.__class__.default_tag_function + elif isinstance(tag_function, str): + tag_key = tag_function + tag_function = lambda f: {tag_key: Path(f).stem} self.tag_function = tag_function self.tag_function_hash_mode = tag_function_hash_mode diff --git a/src/orcabridge/stream.py b/src/orcabridge/stream.py index a4c697a..3dc8f7a 100644 --- a/src/orcabridge/stream.py +++ b/src/orcabridge/stream.py @@ -1,6 +1,52 @@ -from typing import Generator, Tuple, Dict, Any, Callable, Iterator, Optional, List -from .types import Tag, Packet -from .base import SyncStream +from typing import ( + Generator, + Tuple, + Dict, + Any, + Callable, + Iterator, + Optional, + List, + Collection, +) +from orcabridge.types import Tag, Packet +from orcabridge.base import SyncStream + + +class SyncStreamFromLists(SyncStream): + def __init__( + self, + tags: Optional[Collection[Tag]] = None, + packets: Optional[Collection[Packet]] = None, + paired: Optional[Collection[Tuple[Tag, Packet]]] = None, + tag_keys: Optional[List[str]] = None, + packet_keys: Optional[List[str]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.tag_keys = tag_keys + self.packet_keys = packet_keys + if tags is not None and packets is not None: + if len(tags) != len(packets): + raise ValueError( + "tags and packets must have the same length if both are provided" + ) + self.paired = list(zip(tags, packets)) + elif paired is not None: + self.paired = list(paired) + else: + raise ValueError( + "Either tags and packets or paired must be provided to SyncStreamFromLists" + ) + + def keys(self) -> Tuple[List[str], List[str]]: + if self.tag_keys is None or self.packet_keys is None: + return super().keys() + # If the keys are already set, return them + return self.tag_keys.copy(), self.packet_keys.copy() + + def __iter__(self) -> Iterator[Tuple[Tag, Packet]]: + yield from self.paired class SyncStreamFromGenerator(SyncStream): diff --git a/src/orcabridge/tracker.py b/src/orcabridge/tracker.py index b612ba9..2d2b72f 100644 --- a/src/orcabridge/tracker.py +++ b/src/orcabridge/tracker.py @@ -1,21 +1,26 @@ import threading from typing import Dict, Collection, List import networkx as nx -from .base import Operation, Invocation +from orcabridge.base import Operation, Invocation, Tracker import matplotlib.pyplot as plt -class Tracker: +class GraphTracker(Tracker): + """ + A tracker that records the invocations of operations and generates a graph + of the invocations and their dependencies. + """ # Thread-local storage to track active trackers - _local = threading.local() def __init__(self) -> None: - self.active = False + super().__init__() self.invocation_lut: Dict[Operation, Collection[Invocation]] = {} def record(self, invocation: Invocation) -> None: - invocation_list = self.invocation_lut.setdefault(invocation.operation, []) + invocation_list = self.invocation_lut.setdefault( + invocation.operation, [] + ) if invocation not in invocation_list: invocation_list.append(invocation) @@ -43,22 +48,6 @@ def generate_namemap(self) -> Dict[Invocation, str]: namemap[invocation] = f"{node_label}_{idx}" return namemap - def activate(self) -> None: - """ - Activate the tracker. This is a no-op if the tracker is already active. - """ - if not self.active: - if not hasattr(self._local, "active_trackers"): - self._local.active_trackers = [] - self._local.active_trackers.append(self) - self.active = True - - def deactivate(self) -> None: - # Remove this tracker from active trackers - if hasattr(self._local, "active_trackers") and self.active: - self._local.active_trackers.remove(self) - self.active = False - def generate_graph(self): G = nx.DiGraph() @@ -90,16 +79,3 @@ def draw_graph(self): arrowsize=20, ) plt.tight_layout() - - def __enter__(self): - self.activate() - return self - - def __exit__(self, exc_type, exc_val, ext_tb): - self.deactivate() - - @classmethod - def get_active_trackers(cls) -> List["Tracker"]: - if hasattr(cls._local, "active_trackers"): - return cls._local.active_trackers - return [] diff --git a/src/orcabridge/types.py b/src/orcabridge/types.py index b7ae366..55af3d3 100644 --- a/src/orcabridge/types.py +++ b/src/orcabridge/types.py @@ -1,19 +1,22 @@ -from typing import Union, List, Tuple, Protocol, Mapping, Collection -from anyio import Path +from typing import Union, Tuple, Protocol, Mapping, Collection, Optional +from pathlib import Path from typing_extensions import TypeAlias import os +# Convenience alias for anything pathlike PathLike = Union[str, bytes, os.PathLike] # arbitrary depth of nested list of strings or None -L: TypeAlias = Collection[Union[str, None, "L"]] +L: TypeAlias = Union[str, None, Collection[Optional[str]]] + # the top level tag is a mapping from string keys to values that can be a string or # an arbitrary depth of nested list of strings or None Tag: TypeAlias = Mapping[str, Union[str, L]] + # a pathset is a path or an arbitrary depth of nested list of paths -PathSet: TypeAlias = Union[PathLike, Collection[PathLike]] +PathSet: TypeAlias = Union[PathLike, Collection[Optional[PathLike]]] # a packet is a mapping from string keys to pathsets Packet: TypeAlias = Mapping[str, PathSet] diff --git a/src/orcabridge/utils/name.py b/src/orcabridge/utils/name.py index f3e7b21..c5b952f 100644 --- a/src/orcabridge/utils/name.py +++ b/src/orcabridge/utils/name.py @@ -9,6 +9,17 @@ import ast +def find_noncolliding_name(name: str, lut: dict) -> str: + if name not in lut: + return name + + suffix = 1 + while f"{name}_{suffix}" in lut: + suffix += 1 + + return f"{name}_{suffix}" + + def pascal_to_snake(name: str) -> str: # Convert PascalCase to snake_case # if already in snake_case, return as is diff --git a/src/orcabridge/utils/stream_utils.py b/src/orcabridge/utils/stream_utils.py index 87c882d..ada32e1 100644 --- a/src/orcabridge/utils/stream_utils.py +++ b/src/orcabridge/utils/stream_utils.py @@ -3,14 +3,42 @@ """ from _collections_abc import dict_keys -from typing import List, Dict, Optional, Any, TypeVar, Set, Union, Sequence, Mapping +from typing import ( + List, + Dict, + Optional, + Any, + TypeVar, + Set, + Union, + Sequence, + Mapping, + Collection, +) from ..types import Tag, Packet K = TypeVar("K") V = TypeVar("V") -def join_tags(tag1: Mapping[K, V], tag2: Mapping[K, V]) -> Optional[Mapping[K, V]]: +def common_elements(*values) -> Collection[str]: + """ + Returns the common keys between all lists of values. The identified common elements are + order preserved with respect to the first list of values + """ + if len(values) == 0: + return [] + common_keys = set(values[0]) + for tag in values[1:]: + common_keys.intersection_update(tag) + # Preserve the order of the first list of values + common_keys = [k for k in values[0] if k in common_keys] + return common_keys + + +def join_tags( + tag1: Mapping[K, V], tag2: Mapping[K, V] +) -> Optional[Mapping[K, V]]: """ Joins two tags together. If the tags have the same key, the value must be the same or None will be returned. """ @@ -42,14 +70,20 @@ def batch_tag(all_tags: Sequence[Tag]) -> Tag: all_keys: Set[str] = set() for tag in all_tags: all_keys.update(tag.keys()) - batch_tag = {key: [] for key in all_keys} # Initialize batch_tag with all keys + batch_tag = { + key: [] for key in all_keys + } # Initialize batch_tag with all keys for tag in all_tags: for k in all_keys: - batch_tag[k].append(tag.get(k, None)) # Append the value or None if the key is not present + batch_tag[k].append( + tag.get(k, None) + ) # Append the value or None if the key is not present return batch_tag -def batch_packet(all_packets: Sequence[Packet], drop_missing_keys: bool = True) -> Packet: +def batch_packet( + all_packets: Sequence[Packet], drop_missing_keys: bool = True +) -> Packet: """ Batches the packets together. Grouping values under the same key into a list. If all packets do not have the same key, raise an error unless drop_missing_keys is True diff --git a/tests/test_basic_hashing.py b/tests/test_basic_hashing.py new file mode 100644 index 0000000..10bf379 --- /dev/null +++ b/tests/test_basic_hashing.py @@ -0,0 +1,132 @@ +import pytest +from orcabridge.hashing import ( + hash_to_hex, + hash_to_int, + hash_to_uuid, + HashableMixin, + hash_dict, + stable_hash, +) + + +def test_hash_to_hex(): + # Test with string + # Should be equivalent to hashing b'"test"' + assert ( + hash_to_hex("test", None) + == "4d967a30111bf29f0eba01c448b375c1629b2fed01cdfcc3aed91f1b57d5dd5e" + ) + + # Test with integer + # Should be equivalent to hashing b'42' + assert ( + hash_to_hex(42, None) + == "73475cb40a568e8da8a045ced110137e159f890ac4da883b6b17dc651b3a8049" + ) + + assert ( + hash_to_hex(True, None) + == "b5bea41b6c623f7c09f1bf24dcae58ebab3c0cdd90ad966bc43a45b44867e12b" + ) + + assert ( + hash_to_hex(0.256, None) + == "79308bed382bc45abbb1297149dda93e29d676aff0b366bc5f2bb932a4ff55ca" + ) + + # equivalent to hashing b'null' + assert ( + hash_to_hex(None, None) + == "74234e98afe7498fb5daf1f36ac2d78acc339464f950703b8c019892f982b90b" + ) + + # Hash structure + assert ( + hash_to_hex(["a", "b", "c"], None) + == "fa1844c2988ad15ab7b49e0ece09684500fad94df916859fb9a43ff85f5bb477" + ) + + # hash set + assert ( + hash_to_hex(set([1, 2, 3]), None) + == "a615eeaee21de5179de080de8c3052c8da901138406ba71c38c032845f7d54f4" + ) + + # Test with custom char_count + assert len(hash_to_hex("test", char_count=16)) == 16 + + assert len(hash_to_hex("test", char_count=0)) == 0 + + +def test_structure_equivalence(): + # identical content should yield the same hash + assert hash_to_hex(["a", "b", "c"], None) == hash_to_hex( + ["a", "b", "c"], None + ) + # list should be order dependent + assert hash_to_hex(["a", "b", "c"], None) != hash_to_hex( + ["a", "c", "b"], None + ) + + # dict should be order independent + assert hash_to_hex({"a": 1, "b": 2, "c": 3}, None) == hash_to_hex( + {"c": 3, "b": 2, "a": 1}, None + ) + + # set should be order independent + assert hash_to_hex(set([1, 2, 3]), None) == hash_to_hex( + set([3, 2, 1]), None + ) + + # equivalence under nested structure + assert hash_to_hex( + set([("a", "b", "c"), ("d", "e", "f")]), None + ) == hash_to_hex(set([("d", "e", "f"), ("a", "b", "c")]), None) + + +def test_hash_to_int(): + # Test with string + assert isinstance(hash_to_int("test"), int) + + # Test with custom hexdigits + result = hash_to_int("test", hexdigits=8) + assert result < 16**8 # Should be less than max value for 8 hex digits + + +def test_hash_to_uuid(): + # Test with string + uuid = hash_to_uuid("test") + assert str(uuid).count("-") == 4 # Valid UUID format + + # Test with integer + uuid = hash_to_uuid(42) + assert str(uuid).count("-") == 4 # Valid UUID format + + +class TestHashableMixin(HashableMixin): + def __init__(self, value): + self.value = value + + def identity_structure(self): + return {"value": self.value} + + +def test_hash_dict(): + test_dict = {"a": 1, "b": "test", "c": {"nested": True}} + + # Test that it returns a UUID + result = hash_dict(test_dict) + assert str(result).count("-") == 4 + + +def test_stable_hash(): + # Test that same input gives same output + assert stable_hash("test") == stable_hash("test") + + # Test that different inputs give different outputs + assert stable_hash("test1") != stable_hash("test2") + + # Test with different types + assert isinstance(stable_hash(42), int) + assert isinstance(stable_hash("string"), int) + assert isinstance(stable_hash([1, 2, 3]), int)