Skip to content

Commit a158e17

Browse files
committed
feat: add more operators
1 parent 4c397d6 commit a158e17

File tree

7 files changed

+459
-32
lines changed

7 files changed

+459
-32
lines changed

src/orcapod/core/operators/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,23 @@
22
from .semijoin import SemiJoin
33
from .mappers import MapTags, MapPackets
44
from .batch import Batch
5-
from .column_selection import DropTagColumns, DropPacketColumns
5+
from .column_selection import (
6+
SelectTagColumns,
7+
SelectPacketColumns,
8+
DropTagColumns,
9+
DropPacketColumns,
10+
)
11+
from .filters import PolarsFilter
612

713
__all__ = [
814
"Join",
915
"SemiJoin",
1016
"MapTags",
1117
"MapPackets",
1218
"Batch",
19+
"SelectTagColumns",
20+
"SelectPacketColumns",
1321
"DropTagColumns",
1422
"DropPacketColumns",
23+
"PolarsFilter",
1524
]

src/orcapod/core/operators/column_selection.py

Lines changed: 161 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,155 @@
1818
logger = logging.getLogger(__name__)
1919

2020

21+
class SelectTagColumns(UnaryOperator):
22+
"""
23+
Operator that selects specified columns from a stream.
24+
"""
25+
26+
def __init__(self, columns: str | Collection[str], strict: bool = True, **kwargs):
27+
if isinstance(columns, str):
28+
columns = [columns]
29+
self.columns = columns
30+
self.strict = strict
31+
super().__init__(**kwargs)
32+
33+
def op_forward(self, stream: cp.Stream) -> cp.Stream:
34+
tag_columns, packet_columns = stream.keys()
35+
tags_to_drop = [c for c in tag_columns if c not in self.columns]
36+
new_tag_columns = [c for c in tag_columns if c not in tags_to_drop]
37+
38+
if len(new_tag_columns) == len(tag_columns):
39+
logger.info("All tag columns are selected. Returning stream unaltered.")
40+
return stream
41+
42+
table = stream.as_table(
43+
include_source=True, include_system_tags=True, sort_by_tags=False
44+
)
45+
46+
modified_table = table.drop_columns(list(tags_to_drop))
47+
48+
return TableStream(
49+
modified_table,
50+
tag_columns=new_tag_columns,
51+
source=self,
52+
upstreams=(stream,),
53+
)
54+
55+
def op_validate_inputs(self, stream: cp.Stream) -> None:
56+
"""
57+
This method should be implemented by subclasses to validate the inputs to the operator.
58+
It takes two streams as input and raises an error if the inputs are not valid.
59+
"""
60+
# TODO: remove redundant logic
61+
tag_columns, packet_columns = stream.keys()
62+
columns_to_select = self.columns
63+
missing_columns = set(columns_to_select) - set(tag_columns)
64+
if missing_columns and self.strict:
65+
raise InputValidationError(
66+
f"Missing tag columns: {missing_columns}. Make sure all specified columns to select are present or use strict=False to ignore missing columns"
67+
)
68+
69+
def op_output_types(
70+
self, stream: cp.Stream, include_system_tags: bool = False
71+
) -> tuple[PythonSchema, PythonSchema]:
72+
tag_schema, packet_schema = stream.types(
73+
include_system_tags=include_system_tags
74+
)
75+
tag_columns, _ = stream.keys()
76+
tags_to_drop = [tc for tc in tag_columns if tc not in self.columns]
77+
78+
# this ensures all system tag columns are preserved
79+
new_tag_schema = {k: v for k, v in tag_schema.items() if k not in tags_to_drop}
80+
81+
return new_tag_schema, packet_schema
82+
83+
def op_identity_structure(self, stream: cp.Stream | None = None) -> Any:
84+
return (
85+
self.__class__.__name__,
86+
self.columns,
87+
self.strict,
88+
) + ((stream,) if stream is not None else ())
89+
90+
91+
class SelectPacketColumns(UnaryOperator):
92+
"""
93+
Operator that selects specified columns from a stream.
94+
"""
95+
96+
def __init__(self, columns: str | Collection[str], strict: bool = True, **kwargs):
97+
if isinstance(columns, str):
98+
columns = [columns]
99+
self.columns = columns
100+
self.strict = strict
101+
super().__init__(**kwargs)
102+
103+
def op_forward(self, stream: cp.Stream) -> cp.Stream:
104+
tag_columns, packet_columns = stream.keys()
105+
packet_columns_to_drop = [c for c in packet_columns if c not in self.columns]
106+
new_packet_columns = [
107+
c for c in packet_columns if c not in packet_columns_to_drop
108+
]
109+
110+
if len(new_packet_columns) == len(packet_columns):
111+
logger.info("All packet columns are selected. Returning stream unaltered.")
112+
return stream
113+
114+
table = stream.as_table(
115+
include_source=True, include_system_tags=True, sort_by_tags=False
116+
)
117+
# make sure to drop associated source fields
118+
associated_source_fields = [
119+
f"{constants.SOURCE_PREFIX}{c}" for c in packet_columns_to_drop
120+
]
121+
packet_columns_to_drop.extend(associated_source_fields)
122+
123+
modified_table = table.drop_columns(packet_columns_to_drop)
124+
125+
return TableStream(
126+
modified_table,
127+
tag_columns=tag_columns,
128+
source=self,
129+
upstreams=(stream,),
130+
)
131+
132+
def op_validate_inputs(self, stream: cp.Stream) -> None:
133+
"""
134+
This method should be implemented by subclasses to validate the inputs to the operator.
135+
It takes two streams as input and raises an error if the inputs are not valid.
136+
"""
137+
# TODO: remove redundant logic
138+
tag_columns, packet_columns = stream.keys()
139+
columns_to_select = self.columns
140+
missing_columns = set(columns_to_select) - set(packet_columns)
141+
if missing_columns and self.strict:
142+
raise InputValidationError(
143+
f"Missing packet columns: {missing_columns}. Make sure all specified columns to select are present or use strict=False to ignore missing columns"
144+
)
145+
146+
def op_output_types(
147+
self, stream: cp.Stream, include_system_tags: bool = False
148+
) -> tuple[PythonSchema, PythonSchema]:
149+
tag_schema, packet_schema = stream.types(
150+
include_system_tags=include_system_tags
151+
)
152+
_, packet_columns = stream.keys()
153+
packets_to_drop = [pc for pc in packet_columns if pc not in self.columns]
154+
155+
# this ensures all system tag columns are preserved
156+
new_packet_schema = {
157+
k: v for k, v in packet_schema.items() if k not in packets_to_drop
158+
}
159+
160+
return tag_schema, new_packet_schema
161+
162+
def op_identity_structure(self, stream: cp.Stream | None = None) -> Any:
163+
return (
164+
self.__class__.__name__,
165+
self.columns,
166+
self.strict,
167+
) + ((stream,) if stream is not None else ())
168+
169+
21170
class DropTagColumns(UnaryOperator):
22171
"""
23172
Operator that drops specified columns from a stream.
@@ -64,11 +213,10 @@ def op_validate_inputs(self, stream: cp.Stream) -> None:
64213
tag_columns, packet_columns = stream.keys()
65214
columns_to_drop = self.columns
66215
missing_columns = set(columns_to_drop) - set(tag_columns)
67-
if missing_columns:
68-
if self.strict:
69-
raise InputValidationError(
70-
f"Missing tag columns: {missing_columns}. Make sure all specified columns to drop are present or use strict=False to ignore missing columns"
71-
)
216+
if missing_columns and self.strict:
217+
raise InputValidationError(
218+
f"Missing tag columns: {missing_columns}. Make sure all specified columns to drop are present or use strict=False to ignore missing columns"
219+
)
72220

73221
def op_output_types(
74222
self, stream: cp.Stream, include_system_tags: bool = False
@@ -105,19 +253,25 @@ def __init__(self, columns: str | Collection[str], strict: bool = True, **kwargs
105253

106254
def op_forward(self, stream: cp.Stream) -> cp.Stream:
107255
tag_columns, packet_columns = stream.keys()
108-
columns_to_drop = self.columns
256+
columns_to_drop = list(self.columns)
109257
if not self.strict:
110258
columns_to_drop = [c for c in columns_to_drop if c in packet_columns]
111259

112260
if len(columns_to_drop) == 0:
113261
logger.info("No packet columns to drop. Returning stream unaltered.")
114262
return stream
115263

264+
# make sure all associated source columns are dropped too
265+
associated_source_columns = [
266+
f"{constants.SOURCE_PREFIX}{c}" for c in columns_to_drop
267+
]
268+
columns_to_drop.extend(associated_source_columns)
269+
116270
table = stream.as_table(
117271
include_source=True, include_system_tags=True, sort_by_tags=False
118272
)
119273

120-
modified_table = table.drop_columns(list(columns_to_drop))
274+
modified_table = table.drop_columns(columns_to_drop)
121275

122276
return TableStream(
123277
modified_table,
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from orcapod.protocols import core_protocols as cp
2+
from orcapod.core.streams import TableStream
3+
from orcapod.types import PythonSchema
4+
from typing import Any, TYPE_CHECKING, TypeAlias
5+
from orcapod.utils.lazy_module import LazyModule
6+
from collections.abc import Collection, Mapping
7+
from orcapod.errors import InputValidationError
8+
from orcapod.core.system_constants import constants
9+
from orcapod.core.operators.base import UnaryOperator
10+
import logging
11+
from collections.abc import Iterable
12+
13+
14+
if TYPE_CHECKING:
15+
import pyarrow as pa
16+
import polars as pl
17+
import polars._typing as pl_type
18+
import numpy as np
19+
else:
20+
pa = LazyModule("pyarrow")
21+
pl = LazyModule("polars")
22+
pl_type = LazyModule("polars._typing")
23+
24+
logger = logging.getLogger(__name__)
25+
26+
polars_predicate: TypeAlias = "pl_type.IntoExprColumn| Iterable[pl_type.IntoExprColumn]| bool| list[bool]| np.ndarray[Any, Any]"
27+
28+
29+
class PolarsFilter(UnaryOperator):
30+
"""
31+
Operator that applies Polars filtering to a stream
32+
"""
33+
34+
def __init__(
35+
self,
36+
predicates: Collection[
37+
"pl_type.IntoExprColumn| Iterable[pl_type.IntoExprColumn]| bool| list[bool]| np.ndarray[Any, Any]"
38+
] = (),
39+
constraints: Mapping[str, Any] | None = None,
40+
**kwargs,
41+
):
42+
self.predicates = predicates
43+
self.constraints = constraints if constraints is not None else {}
44+
super().__init__(**kwargs)
45+
46+
def op_forward(self, stream: cp.Stream) -> cp.Stream:
47+
if len(self.predicates) == 0 and len(self.constraints) == 0:
48+
logger.info(
49+
"No predicates or constraints specified. Returning stream unaltered."
50+
)
51+
return stream
52+
53+
# TODO: improve efficiency here...
54+
table = stream.as_table(
55+
include_source=True, include_system_tags=True, sort_by_tags=False
56+
)
57+
df = pl.DataFrame(table)
58+
filtered_table = df.filter(*self.predicates, **self.constraints).to_arrow()
59+
60+
return TableStream(
61+
filtered_table,
62+
tag_columns=stream.tag_keys(),
63+
source=self,
64+
upstreams=(stream,),
65+
)
66+
67+
def op_validate_inputs(self, stream: cp.Stream) -> None:
68+
"""
69+
This method should be implemented by subclasses to validate the inputs to the operator.
70+
It takes two streams as input and raises an error if the inputs are not valid.
71+
"""
72+
73+
# Any valid stream would work
74+
return
75+
76+
def op_output_types(
77+
self, stream: cp.Stream, include_system_tags: bool = False
78+
) -> tuple[PythonSchema, PythonSchema]:
79+
# data types are not modified
80+
return stream.types(include_system_tags=include_system_tags)
81+
82+
def op_identity_structure(self, stream: cp.Stream | None = None) -> Any:
83+
return (
84+
self.__class__.__name__,
85+
self.predicates,
86+
self.constraints,
87+
) + ((stream,) if stream is not None else ())
88+
89+
90+
class SelectPacketColumns(UnaryOperator):
91+
"""
92+
Operator that selects specified columns from a stream.
93+
"""
94+
95+
def __init__(self, columns: str | Collection[str], strict: bool = True, **kwargs):
96+
if isinstance(columns, str):
97+
columns = [columns]
98+
self.columns = columns
99+
self.strict = strict
100+
super().__init__(**kwargs)
101+
102+
def op_forward(self, stream: cp.Stream) -> cp.Stream:
103+
tag_columns, packet_columns = stream.keys()
104+
packet_columns_to_drop = [c for c in packet_columns if c not in self.columns]
105+
new_packet_columns = [
106+
c for c in packet_columns if c not in packet_columns_to_drop
107+
]
108+
109+
if len(new_packet_columns) == len(packet_columns):
110+
logger.info("All packet columns are selected. Returning stream unaltered.")
111+
return stream
112+
113+
table = stream.as_table(
114+
include_source=True, include_system_tags=True, sort_by_tags=False
115+
)
116+
# make sure to drop associated source fields
117+
associated_source_fields = [
118+
f"{constants.SOURCE_PREFIX}{c}" for c in packet_columns_to_drop
119+
]
120+
packet_columns_to_drop.extend(associated_source_fields)
121+
122+
modified_table = table.drop_columns(packet_columns_to_drop)
123+
124+
return TableStream(
125+
modified_table,
126+
tag_columns=tag_columns,
127+
source=self,
128+
upstreams=(stream,),
129+
)
130+
131+
def op_validate_inputs(self, stream: cp.Stream) -> None:
132+
"""
133+
This method should be implemented by subclasses to validate the inputs to the operator.
134+
It takes two streams as input and raises an error if the inputs are not valid.
135+
"""
136+
# TODO: remove redundant logic
137+
tag_columns, packet_columns = stream.keys()
138+
columns_to_select = self.columns
139+
missing_columns = set(columns_to_select) - set(packet_columns)
140+
if missing_columns and self.strict:
141+
raise InputValidationError(
142+
f"Missing packet columns: {missing_columns}. Make sure all specified columns to select are present or use strict=False to ignore missing columns"
143+
)
144+
145+
def op_output_types(
146+
self, stream: cp.Stream, include_system_tags: bool = False
147+
) -> tuple[PythonSchema, PythonSchema]:
148+
tag_schema, packet_schema = stream.types(
149+
include_system_tags=include_system_tags
150+
)
151+
_, packet_columns = stream.keys()
152+
packets_to_drop = [pc for pc in packet_columns if pc not in self.columns]
153+
154+
# this ensures all system tag columns are preserved
155+
new_packet_schema = {
156+
k: v for k, v in packet_schema.items() if k not in packets_to_drop
157+
}
158+
159+
return tag_schema, new_packet_schema
160+
161+
def op_identity_structure(self, stream: cp.Stream | None = None) -> Any:
162+
return (
163+
self.__class__.__name__,
164+
self.columns,
165+
self.strict,
166+
) + ((stream,) if stream is not None else ())

0 commit comments

Comments
 (0)