Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions ingestify/application/dataset_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,15 +497,19 @@ def update_dataset(

def invalidate_revision(self, dataset: Dataset, reason: str = ""):
"""Mark the current revision as VALIDATION_FAILED and reset
last_modified_at so the dataset is refetched on the next run.

Args:
dataset: Dataset whose current revision should be invalidated
reason: Human-readable reason for invalidation
"""
self.dataset_repository.invalidate_revision(dataset)

self.dispatch(RevisionInvalidated(dataset=dataset, reason=reason))
last_modified_at so the dataset is refetched on the next run."""
self.invalidate_revisions([dataset], reason=reason)

def invalidate_revisions(self, datasets: list, reason: str = ""):
"""Batch invalidate revisions. Batches DB updates and event writes
per 1000 datasets for efficiency."""
batch_size = 1000
for i in range(0, len(datasets), batch_size):
batch = datasets[i : i + batch_size]
self.dataset_repository.invalidate_revisions(batch)
self.event_bus.dispatch_many(
[RevisionInvalidated(dataset=ds, reason=reason) for ds in batch]
)

def destroy_dataset(self, dataset: Dataset):
# TODO: remove files. Now we leave some orphaned files around
Expand Down
7 changes: 6 additions & 1 deletion ingestify/domain/models/dataset/dataset_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,15 @@ def get_dataset_last_modified_at_map(
dataset+revision+file graph."""
return {}

@abstractmethod
def invalidate_revision(self, dataset: Dataset):
"""Mark the current revision as VALIDATION_FAILED and reset
last_modified_at on the dataset."""
self.invalidate_revisions([dataset])

@abstractmethod
def invalidate_revisions(self, datasets: list[Dataset]):
"""Batch invalidate: mark current revisions as VALIDATION_FAILED
and reset last_modified_at on the datasets."""
pass

@abstractmethod
Expand Down
3 changes: 3 additions & 0 deletions ingestify/domain/models/event/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@
class Dispatcher(Protocol):
def dispatch(self, event: DomainEvent):
pass

def dispatch_many(self, events: list[DomainEvent]):
pass
12 changes: 12 additions & 0 deletions ingestify/domain/models/event/event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ def __init__(self, queue):
def dispatch(self, event):
self.queue.put(event)

def dispatch_many(self, events):
for event in events:
self.queue.put(event)


class EventBus:
def __init__(self):
Expand All @@ -37,3 +41,11 @@ def dispatch(self, event):
except Exception as e:
logger.exception(f"Failed to handle {event}")
raise Exception(f"Failed to handle {event}") from e

def dispatch_many(self, events):
for dispatcher in self.dispatchers:
try:
dispatcher.dispatch_many(events)
except Exception as e:
logger.exception(f"Failed to handle {len(events)} events")
raise Exception(f"Failed to handle {len(events)} events") from e
9 changes: 9 additions & 0 deletions ingestify/domain/models/event/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,14 @@ def dispatch(self, event: DomainEvent):
except Exception:
logger.exception(f"Failed to handle {event} by {subscriber}")

def dispatch_many(self, events: list[DomainEvent]):
for subscriber in self.subscribers:
try:
subscriber.handle_many(events)
except Exception:
logger.exception(
f"Failed to handle {len(events)} events by {subscriber}"
)

def add_subscriber(self, subscriber: Subscriber):
self.subscribers.append(subscriber)
5 changes: 5 additions & 0 deletions ingestify/domain/models/event/subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,8 @@ def handle(self, event: DomainEvent):
self.on_revision_added(event)
elif isinstance(event, RevisionInvalidated):
self.on_revision_invalidated(event)

def handle_many(self, events: list[DomainEvent]):
"""Handle a batch of events. Override for efficient bulk writes."""
for event in events:
self.handle(event)
26 changes: 17 additions & 9 deletions ingestify/infra/event_log/event_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,24 @@ def __init__(self, engine, table_prefix: str = ""):
self._table.create(engine, checkfirst=True)

def write(self, event: DomainEvent) -> None:
self.write_many([event])

def write_many(self, events: list[DomainEvent]) -> None:
if not events:
return
now = utcnow()
rows = [
{
"event_type": type(event).event_type,
"payload_json": event.model_dump(mode="json"),
"source": event.dataset.provider,
"dataset_id": event.dataset.dataset_id,
"created_at": now,
}
for event in events
]
with self._engine.connect() as conn:
conn.execute(
self._table.insert().values(
event_type=type(event).event_type,
payload_json=event.model_dump(mode="json"),
source=event.dataset.provider,
dataset_id=event.dataset.dataset_id,
created_at=utcnow(),
)
)
conn.execute(self._table.insert(), rows)
conn.commit()

def fetch_batch(self, last_event_id: int, batch_size: int) -> list:
Expand Down
12 changes: 12 additions & 0 deletions ingestify/infra/event_log/subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def _write(self, event) -> None:
event.dataset.dataset_id,
)

def _write_many(self, events) -> None:
try:
self._event_log.write_many(events)
except Exception:
logger.exception(
"EventLogSubscriber: failed to write %d events",
len(events),
)

def on_dataset_created(self, event) -> None:
self._write(event)

Expand All @@ -45,3 +54,6 @@ def on_revision_added(self, event) -> None:

def on_revision_invalidated(self, event) -> None:
self._write(event)

def handle_many(self, events) -> None:
self._write_many(events)
26 changes: 16 additions & 10 deletions ingestify/infra/store/dataset/sqlalchemy/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,28 +677,34 @@ def _save(self, datasets: list[Dataset]):
connection.commit()

def invalidate_revision(self, dataset: Dataset):
current_revision = dataset.current_revision
self.invalidate_revisions([dataset])

def invalidate_revisions(self, datasets: list[Dataset]):
if not datasets:
return

dataset_ids = [d.dataset_id for d in datasets]

with self.connect() as connection:
# Set revision state to VALIDATION_FAILED
# Batch update revision state
connection.execute(
self.revision_table.update()
.where(self.revision_table.c.dataset_id == dataset.dataset_id)
.where(
self.revision_table.c.revision_id == current_revision.revision_id
)
.where(self.revision_table.c.dataset_id.in_(dataset_ids))
.values(state=RevisionState.VALIDATION_FAILED)
)
# Reset last_modified_at so the pre-check cache doesn't skip it
# Batch reset last_modified_at
connection.execute(
self.dataset_table.update()
.where(self.dataset_table.c.dataset_id == dataset.dataset_id)
.where(self.dataset_table.c.dataset_id.in_(dataset_ids))
.values(last_modified_at=None)
)
connection.commit()

# Update in-memory state
current_revision.state = RevisionState.VALIDATION_FAILED
dataset.last_modified_at = None
for dataset in datasets:
if dataset.current_revision:
dataset.current_revision.state = RevisionState.VALIDATION_FAILED
dataset.last_modified_at = None

def destroy(self, dataset: Dataset):
with self.connect() as connection:
Expand Down
75 changes: 62 additions & 13 deletions ingestify/tests/test_refetch_validation_failed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,27 @@ def counting_loader(file_resource, current_file, **kwargs):
class SimpleSource(Source):
provider = "test_provider"

def __init__(self, name, n_datasets=1):
super().__init__(name)
self.n_datasets = n_datasets

def find_datasets(
self, dataset_type, data_spec_versions, dataset_collection_metadata, **kwargs
):
r = DatasetResource(
dataset_resource_id={"item_id": 1},
provider=self.provider,
dataset_type="test",
name="item-1",
)
r.add_file(
last_modified=FIXED_TIME,
data_feed_key="f1",
data_spec_version="v1",
file_loader=counting_loader,
)
yield r
for i in range(self.n_datasets):
r = DatasetResource(
dataset_resource_id={"item_id": i},
provider=self.provider,
dataset_type="test",
name=f"item-{i}",
)
r.add_file(
last_modified=FIXED_TIME,
data_feed_key="f1",
data_spec_version="v1",
file_loader=counting_loader,
)
yield r


def _setup(engine):
Expand Down Expand Up @@ -99,3 +104,47 @@ def test_invalidate_revision_triggers_refetch(engine):
# Second run: should refetch
engine.run()
assert call_count == 2, "Dataset with invalidated revision should be refetched"


def test_invalidate_revisions_batch(engine):
"""invalidate_revisions works on multiple datasets at once."""
global call_count
call_count = 0

dsv = DataSpecVersionCollection.from_dict({"default": {"v1"}})
engine.add_ingestion_plan(
IngestionPlan(
source=SimpleSource("s", n_datasets=5),
fetch_policy=FetchPolicy(),
dataset_type="test",
selectors=[Selector.build({}, data_spec_versions=dsv)],
data_spec_versions=dsv,
)
)

# First run: creates 5 datasets
engine.run()
assert call_count == 5

# Batch invalidate all 5
datasets = list(
engine.store.get_dataset_collection(
provider="test_provider", dataset_type="test"
)
)
assert len(datasets) == 5
engine.store.invalidate_revisions(datasets, reason="Batch test")

# Verify all invalidated
datasets = list(
engine.store.get_dataset_collection(
provider="test_provider", dataset_type="test"
)
)
for ds in datasets:
assert ds.current_revision.state == RevisionState.VALIDATION_FAILED
assert ds.last_modified_at is None

# Second run: should refetch all 5
engine.run()
assert call_count == 10, "All 5 invalidated datasets should be refetched"
Loading