diff --git a/alphatrion/experiment/base.py b/alphatrion/experiment/base.py index b8e6e6b..1daf057 100644 --- a/alphatrion/experiment/base.py +++ b/alphatrion/experiment/base.py @@ -37,7 +37,6 @@ async def __aenter__(self): if exp is None: raise RuntimeError(f"Experiment {self._id} not found in the database.") - # Use weakref to avoid circular reference self._runtime.current_exp = self return self diff --git a/alphatrion/metadata/sql.py b/alphatrion/metadata/sql.py index 0638d86..3369658 100644 --- a/alphatrion/metadata/sql.py +++ b/alphatrion/metadata/sql.py @@ -9,6 +9,7 @@ Experiment, Metrics, Model, + Run, Trial, TrialStatus, ) @@ -200,6 +201,17 @@ def update_trial(self, trial_id: uuid.UUID, **kwargs): session.commit() session.close() + def create_run(self, trial_id: uuid.UUID) -> uuid.UUID: + session = self._session() + new_run = Run( + trial_id=trial_id, + ) + session.add(new_run) + session.commit() + run_id = new_run.uuid + session.close() + return run_id + def create_metric(self, trial_id: uuid.UUID, key: str, value: float, step: int): session = self._session() new_metric = Metrics( diff --git a/alphatrion/metadata/sql_models.py b/alphatrion/metadata/sql_models.py index 5cd6b1d..8b53824 100644 --- a/alphatrion/metadata/sql_models.py +++ b/alphatrion/metadata/sql_models.py @@ -61,6 +61,19 @@ class Trial(Base): ) +class Run(Base): + __tablename__ = "runs" + + uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + trial_id = Column(UUID(as_uuid=True), nullable=False) + # artifact_path = Column(String, nullable=False) + + created_at = Column(DateTime(timezone=True), default=datetime.now(UTC)) + updated_at = Column( + DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC) + ) + + class Model(Base): __tablename__ = "models" @@ -88,15 +101,3 @@ class Metrics(Base): # TODO: do we need? step = Column(Integer, nullable=False, default=0) created_at = Column(DateTime(timezone=True), default=datetime.now(UTC)) - - -# class Traces(Base): -# __tablename__ = "traces" - -# uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) -# trial_id = Column(UUID(as_uuid=True), nullable=False) -# run_id = Column(UUID(as_uuid=True), nullable=False) -# message = Column(String, nullable=False) -# level = Column(String, nullable=False, default="INFO") -# trial_id = Column(Integer, nullable=False) -# created_at = Column(DateTime(timezone=True), default=datetime.now(UTC)) diff --git a/alphatrion/run/__init__.py b/alphatrion/run/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/alphatrion/run/run.py b/alphatrion/run/run.py new file mode 100644 index 0000000..148597e --- /dev/null +++ b/alphatrion/run/run.py @@ -0,0 +1,16 @@ +import uuid + +from alphatrion.runtime.runtime import global_runtime + + +class Run: + def __init__(self, trial_id: uuid.UUID): + self._runtime = global_runtime() + self._trial_id = trial_id + + @property + def id(self) -> uuid.UUID: + return self._id + + def _start(self): + self._id = self._runtime._metadb.create_run(trial_id=self._trial_id) diff --git a/alphatrion/trial/trial.py b/alphatrion/trial/trial.py index 70db666..4e16300 100644 --- a/alphatrion/trial/trial.py +++ b/alphatrion/trial/trial.py @@ -1,3 +1,4 @@ +import asyncio import contextvars import os import uuid @@ -6,6 +7,7 @@ from pydantic import BaseModel, Field, model_validator from alphatrion.metadata.sql_models import COMPLETED_STATUS, TrialStatus +from alphatrion.run.run import Run from alphatrion.runtime.runtime import global_runtime from alphatrion.utils.context import Context @@ -83,6 +85,8 @@ class Trial: "_context", "_token", "_meta", + "_runs", + "_running_tasks", ) def __init__(self, exp_id: int, config: TrialConfig | None = None): @@ -98,12 +102,20 @@ def __init__(self, exp_id: int, config: TrialConfig | None = None): # _meta stores the runtime meta information of the trial, # like the metric max/min values. self._construct_meta() + self._runs = dict() + self._running_tasks = dict() async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): self.cancel() + if self._token: + current_trial_id.reset(self._token) + + @property + def id(self) -> uuid.UUID: + return self._id def _construct_meta(self): self._meta = dict() @@ -168,14 +180,17 @@ def stopped(self) -> bool: return self._context.cancelled() async def wait(self): - await self._context.wait_cancelled() + await self._context.wait() + + def cancelled(self) -> bool: + return self._context.cancelled() def _start( self, description: str | None = None, meta: dict | None = None, params: dict | None = None, - ) -> uuid.UUID: + ): self._id = self._runtime._metadb.create_trial( exp_id=self._exp_id, description=description, @@ -188,11 +203,6 @@ def _start( # each trial runs in its own context. self._token = current_trial_id.set(self._id) self._context.start() - return self._id - - @property - def id(self) -> uuid.UUID: - return self._id # cancel function should be called manually as a pair of start def cancel(self): @@ -209,6 +219,10 @@ def _stop(self): ) self._runtime.current_exp.unregister_trial(self._id) + self._runs.clear() + for task in self._running_tasks.values(): + task.cancel() + self._running_tasks.clear() def _get_obj(self): return self._runtime._metadb.get_trial(trial_id=self._id) @@ -216,3 +230,18 @@ def _get_obj(self): def increment_step(self) -> int: self._step += 1 return self._step + + # start_run should accept a lambda function to create the run task. + def start_run(self, call_func: callable) -> Run: + run = Run(trial_id=self._id) + run._start() + self._runs[run.id] = run + + # the created task will also inherit the current context, + # including the current_trial_id context var. + task = asyncio.create_task(call_func()) + self._running_tasks[run.id] = task + task.add_done_callback(lambda t: self._running_tasks.pop(run.id, None)) + task.add_done_callback(lambda t: self._runs.pop(run.id, None)) + + return run diff --git a/alphatrion/utils/context.py b/alphatrion/utils/context.py index 0f946ff..4a898a7 100644 --- a/alphatrion/utils/context.py +++ b/alphatrion/utils/context.py @@ -33,5 +33,7 @@ def cancel(self): def cancelled(self): return self._cancel_event.is_set() - async def wait_cancelled(self): + # TODO: wait will not wait for all the coroutines to finish, + # it will return as soon as the context is cancelled. + async def wait(self): await self._cancel_event.wait() diff --git a/tests/integration/test_tracing.py b/tests/integration/test_tracing.py index 7f4c405..c0ed450 100644 --- a/tests/integration/test_tracing.py +++ b/tests/integration/test_tracing.py @@ -56,7 +56,3 @@ def test_workflow(): pirate_joke, signature = joke_workflow() assert pirate_joke is not None assert signature is not None - - -if __name__ == "__main__": - test_workflow() diff --git a/tests/unit/experiment/test_craft_exp.py b/tests/unit/experiment/test_craft_exp.py index b3f6831..9f98bd8 100644 --- a/tests/unit/experiment/test_craft_exp.py +++ b/tests/unit/experiment/test_craft_exp.py @@ -63,7 +63,6 @@ async def fake_work(trial: Trial): async with CraftExperiment.start(name="context_exp") as exp: async with exp.start_trial(description="First trial") as trial: trial_id = current_trial_id.get() - start_time = datetime.now() asyncio.create_task(fake_work(trial)) @@ -76,6 +75,34 @@ async def fake_work(trial: Trial): assert trial_obj.status == TrialStatus.FINISHED +@pytest.mark.asyncio +async def test_create_experiment_with_run(): + init(project_id="test_project", artifact_insecure=True) + + async def fake_work(cancel_func: callable): + await asyncio.sleep(3) + cancel_func() + + async with ( + CraftExperiment.start(name="context_exp") as exp, + exp.start_trial(description="First trial") as trial, + ): + start_time = datetime.now() + + trial.start_run(lambda: fake_work(trial.cancel)) + assert len(trial._running_tasks) == 1 + assert len(trial._runs) == 1 + + trial.start_run(lambda: fake_work(trial.cancel)) + assert len(trial._running_tasks) == 2 + assert len(trial._runs) == 2 + + await trial.wait() + assert datetime.now() - start_time >= timedelta(seconds=3) + assert len(trial._running_tasks) == 0 + assert len(trial._runs) == 0 + + @pytest.mark.asyncio async def test_craft_experiment_with_context(): init(project_id="test_project", artifact_insecure=True) diff --git a/tests/unit/utils/test_context.py b/tests/unit/utils/test_context.py index 6ff8c1b..ff4adde 100644 --- a/tests/unit/utils/test_context.py +++ b/tests/unit/utils/test_context.py @@ -14,7 +14,7 @@ async def test_context_no_timeout(): # double cancel should be no-op ctx.cancel() assert ctx.cancelled() - await ctx.wait_cancelled() + await ctx.wait() @pytest.mark.asyncio @@ -24,7 +24,7 @@ async def test_context_with_timeout(): assert not ctx.cancelled() await asyncio.sleep(0.2) assert ctx.cancelled() - await ctx.wait_cancelled() + await ctx.wait() @pytest.mark.asyncio @@ -34,7 +34,7 @@ async def test_context_manual_cancel(): assert not ctx.cancelled() ctx.cancel() assert ctx.cancelled() - await ctx.wait_cancelled() + await ctx.wait() @pytest.mark.asyncio @@ -43,7 +43,7 @@ async def test_context_wait_cancelled(): ctx.start() async def waiter(): - await ctx.wait_cancelled() + await ctx.wait() return True task = asyncio.create_task(waiter()) @@ -62,7 +62,7 @@ async def test_context_multiple_waiters(): results = [] async def waiter(idx): - await ctx.wait_cancelled() + await ctx.wait() results.append(idx) tasks = [asyncio.create_task(waiter(i)) for i in range(5)]