Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions src/driver/continuous_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ def _handle_drift(self, drift_signal: DriftSignal) -> None:
self.logger.info(
f"==== DRIFT DETECTED (Event #{self.drift_event_count})! ====", level=0
)
# Log data timestamp range if the harness tracks it
timerange = getattr(self.modelHarness, "current_window_timerange", None)
if timerange is not None:
self.logger.info(
f"\tData time range: {timerange[0]} → {timerange[1]}", level=1
)
self.logger.info(
f"\tRegime: {drift_signal.regime.value if drift_signal.regime else 'N/A'}",
level=1,
Expand Down Expand Up @@ -328,17 +334,31 @@ def _log_metrics(self, drift_signal: DriftSignal, metric_value: float) -> None:
"""
flops_perf = self.flops_profiler.get_performance()

# Log all drift metrics including performance in a single call
# Log all drift metrics in a single call.
# detected=0 is sampled at 10% to reduce log volume; detected=1 is
# always included so no true drift events are dropped.
log_detected = drift_signal.drift_detected or (np.random.random() <= 0.1)
self.logger.stage("drift")
# Include data timestamp range if available from the harness
timerange = getattr(self.modelHarness, "current_window_timerange", None)
ts_fields = {}
if timerange is not None:
ts_fields["data_time_start"] = timerange[0]
ts_fields["data_time_end"] = timerange[1]
self.logger.log(
{
"detected": drift_signal.drift_detected,
**(
{"detected": int(drift_signal.drift_detected)}
if log_detected
else {}
),
"score": drift_signal.drift_score,
"regime": (drift_signal.regime.value if drift_signal.regime else "N/A"),
"confidence": (
drift_signal.confidence if drift_signal.confidence else "N/A"
),
f"metric_{self.metric_idx}": metric_value,
**ts_fields,
**{f"cperf_{k}": v for k, v in flops_perf.items()},
},
)
23 changes: 20 additions & 3 deletions src/logger/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def log(

# Add step metrics
prefixed_metrics["step"] = current_step
if self._current_stage:
if self._current_stage and increment:
prefixed_metrics[f"{self._current_stage}.step"] = self._stage_steps[
self._current_stage
]
Expand All @@ -187,13 +187,30 @@ def log(
if self.enabled and self.run:
mlflow = self._get_mlflow()
# Filter to only numeric values (MLflow only accepts numbers for metrics)
# Exclude bools: bool is a subclass of int in Python, but casting
# True/False to float(1.0/0.0) corrupts categorical semantics.
# Instead, booleans are converted to int (0/1) separately.
numeric_metrics = {
k: float(v)
for k, v in prefixed_metrics.items()
if isinstance(v, (int, float)) and k != "step"
if isinstance(v, (int, float))
and not isinstance(v, bool)
and k != "step"
}
# Convert booleans to int (0/1) so they remain discrete metrics
bool_metrics = {
k: int(v) for k, v in prefixed_metrics.items() if isinstance(v, bool)
}
numeric_metrics.update(bool_metrics)
if numeric_metrics:
mlflow.log_metrics(numeric_metrics, step=current_step)
# WandB uses the global step as x-axis, however, MLflow uses the per-stage step so that
# sparse global-step values do not cause MLflow to downsample away the rare detected=1 events.
log_step = (
self._stage_steps[self._current_stage]
if self._current_stage
else current_step
)
mlflow.log_metrics(numeric_metrics, step=log_step)

def save(
self,
Expand Down