Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion alphatrion/experiment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ async def __aenter__(self):
if exp is None:
raise RuntimeError(f"Experiment {self._id} not found in the database.")

# Use weakref to avoid circular reference
self._runtime.current_exp = self
return self

Expand Down
12 changes: 12 additions & 0 deletions alphatrion/metadata/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Experiment,
Metrics,
Model,
Run,
Trial,
TrialStatus,
)
Expand Down Expand Up @@ -200,6 +201,17 @@ def update_trial(self, trial_id: uuid.UUID, **kwargs):
session.commit()
session.close()

def create_run(self, trial_id: uuid.UUID) -> uuid.UUID:
session = self._session()
new_run = Run(
trial_id=trial_id,
)
session.add(new_run)
session.commit()
run_id = new_run.uuid
session.close()
return run_id

def create_metric(self, trial_id: uuid.UUID, key: str, value: float, step: int):
session = self._session()
new_metric = Metrics(
Expand Down
25 changes: 13 additions & 12 deletions alphatrion/metadata/sql_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ class Trial(Base):
)


class Run(Base):
__tablename__ = "runs"

uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
trial_id = Column(UUID(as_uuid=True), nullable=False)
# artifact_path = Column(String, nullable=False)

created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
updated_at = Column(
DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC)
)


class Model(Base):
__tablename__ = "models"

Expand Down Expand Up @@ -88,15 +101,3 @@ class Metrics(Base):
# TODO: do we need?
step = Column(Integer, nullable=False, default=0)
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))


# class Traces(Base):
# __tablename__ = "traces"

# uuid = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# trial_id = Column(UUID(as_uuid=True), nullable=False)
# run_id = Column(UUID(as_uuid=True), nullable=False)
# message = Column(String, nullable=False)
# level = Column(String, nullable=False, default="INFO")
# trial_id = Column(Integer, nullable=False)
# created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
Empty file added alphatrion/run/__init__.py
Empty file.
16 changes: 16 additions & 0 deletions alphatrion/run/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import uuid

from alphatrion.runtime.runtime import global_runtime


class Run:
def __init__(self, trial_id: uuid.UUID):
self._runtime = global_runtime()
self._trial_id = trial_id

@property
def id(self) -> uuid.UUID:
return self._id

def _start(self):
self._id = self._runtime._metadb.create_run(trial_id=self._trial_id)
43 changes: 36 additions & 7 deletions alphatrion/trial/trial.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import contextvars
import os
import uuid
Expand All @@ -6,6 +7,7 @@
from pydantic import BaseModel, Field, model_validator

from alphatrion.metadata.sql_models import COMPLETED_STATUS, TrialStatus
from alphatrion.run.run import Run
from alphatrion.runtime.runtime import global_runtime
from alphatrion.utils.context import Context

Expand Down Expand Up @@ -83,6 +85,8 @@ class Trial:
"_context",
"_token",
"_meta",
"_runs",
"_running_tasks",
)

def __init__(self, exp_id: int, config: TrialConfig | None = None):
Expand All @@ -98,12 +102,20 @@ def __init__(self, exp_id: int, config: TrialConfig | None = None):
# _meta stores the runtime meta information of the trial,
# like the metric max/min values.
self._construct_meta()
self._runs = dict()
self._running_tasks = dict()

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
self.cancel()
if self._token:
current_trial_id.reset(self._token)

@property
def id(self) -> uuid.UUID:
return self._id

def _construct_meta(self):
self._meta = dict()
Expand Down Expand Up @@ -168,14 +180,17 @@ def stopped(self) -> bool:
return self._context.cancelled()

async def wait(self):
await self._context.wait_cancelled()
await self._context.wait()

def cancelled(self) -> bool:
return self._context.cancelled()

def _start(
self,
description: str | None = None,
meta: dict | None = None,
params: dict | None = None,
) -> uuid.UUID:
):
self._id = self._runtime._metadb.create_trial(
exp_id=self._exp_id,
description=description,
Expand All @@ -188,11 +203,6 @@ def _start(
# each trial runs in its own context.
self._token = current_trial_id.set(self._id)
self._context.start()
return self._id

@property
def id(self) -> uuid.UUID:
return self._id

# cancel function should be called manually as a pair of start
def cancel(self):
Expand All @@ -209,10 +219,29 @@ def _stop(self):
)

self._runtime.current_exp.unregister_trial(self._id)
self._runs.clear()
for task in self._running_tasks.values():
task.cancel()
self._running_tasks.clear()

def _get_obj(self):
return self._runtime._metadb.get_trial(trial_id=self._id)

def increment_step(self) -> int:
self._step += 1
return self._step

# start_run should accept a lambda function to create the run task.
def start_run(self, call_func: callable) -> Run:
run = Run(trial_id=self._id)
run._start()
self._runs[run.id] = run

# the created task will also inherit the current context,
# including the current_trial_id context var.
task = asyncio.create_task(call_func())
self._running_tasks[run.id] = task
task.add_done_callback(lambda t: self._running_tasks.pop(run.id, None))
task.add_done_callback(lambda t: self._runs.pop(run.id, None))

return run
4 changes: 3 additions & 1 deletion alphatrion/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,7 @@ def cancel(self):
def cancelled(self):
return self._cancel_event.is_set()

async def wait_cancelled(self):
# TODO: wait will not wait for all the coroutines to finish,
# it will return as soon as the context is cancelled.
async def wait(self):
await self._cancel_event.wait()
4 changes: 0 additions & 4 deletions tests/integration/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,3 @@ def test_workflow():
pirate_joke, signature = joke_workflow()
assert pirate_joke is not None
assert signature is not None


if __name__ == "__main__":
test_workflow()
29 changes: 28 additions & 1 deletion tests/unit/experiment/test_craft_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ async def fake_work(trial: Trial):
async with CraftExperiment.start(name="context_exp") as exp:
async with exp.start_trial(description="First trial") as trial:
trial_id = current_trial_id.get()

start_time = datetime.now()

asyncio.create_task(fake_work(trial))
Expand All @@ -76,6 +75,34 @@ async def fake_work(trial: Trial):
assert trial_obj.status == TrialStatus.FINISHED


@pytest.mark.asyncio
async def test_create_experiment_with_run():
init(project_id="test_project", artifact_insecure=True)

async def fake_work(cancel_func: callable):
await asyncio.sleep(3)
cancel_func()

async with (
CraftExperiment.start(name="context_exp") as exp,
exp.start_trial(description="First trial") as trial,
):
start_time = datetime.now()

trial.start_run(lambda: fake_work(trial.cancel))
assert len(trial._running_tasks) == 1
assert len(trial._runs) == 1

trial.start_run(lambda: fake_work(trial.cancel))
assert len(trial._running_tasks) == 2
assert len(trial._runs) == 2

await trial.wait()
assert datetime.now() - start_time >= timedelta(seconds=3)
assert len(trial._running_tasks) == 0
assert len(trial._runs) == 0


@pytest.mark.asyncio
async def test_craft_experiment_with_context():
init(project_id="test_project", artifact_insecure=True)
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/utils/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def test_context_no_timeout():
# double cancel should be no-op
ctx.cancel()
assert ctx.cancelled()
await ctx.wait_cancelled()
await ctx.wait()


@pytest.mark.asyncio
Expand All @@ -24,7 +24,7 @@ async def test_context_with_timeout():
assert not ctx.cancelled()
await asyncio.sleep(0.2)
assert ctx.cancelled()
await ctx.wait_cancelled()
await ctx.wait()


@pytest.mark.asyncio
Expand All @@ -34,7 +34,7 @@ async def test_context_manual_cancel():
assert not ctx.cancelled()
ctx.cancel()
assert ctx.cancelled()
await ctx.wait_cancelled()
await ctx.wait()


@pytest.mark.asyncio
Expand All @@ -43,7 +43,7 @@ async def test_context_wait_cancelled():
ctx.start()

async def waiter():
await ctx.wait_cancelled()
await ctx.wait()
return True

task = asyncio.create_task(waiter())
Expand All @@ -62,7 +62,7 @@ async def test_context_multiple_waiters():
results = []

async def waiter(idx):
await ctx.wait_cancelled()
await ctx.wait()
results.append(idx)

tasks = [asyncio.create_task(waiter(i)) for i in range(5)]
Expand Down
Loading