Skip to content

MLflow logging of drift metrics #85

@pluflou

Description

@pluflou

I see a couple of issues with the way MLflow logging is set up. It doesn't seem to happen with the provided examples, but they do with the example set up in the slac-fel branch (maybe due to the higher number of global steps and inner CL loops?)

In that branch I committed a fix that needed to be in there to be able to see drift.detected on MLflow/WandB: logging them as ints not bools.

It's also possible drift.detected should use drift.step (the per-stage step) as its x-axis, not the global step. This way, the chart shows consistent progression independent of how many eval/cl steps happen in between. I did not implement that because I am not sure it resolves the full issue I see below.

Main issue
During a run, the dashboard does not show all the expected drift.detected events (should be 4):

Image

Moreover, after the run ends, MLflow seems to remove some rows from the data, and since we have very few 1s vs 0s, we end up with a flat 0 line:

Image

Exporting the CSV data for drift.detected, for this example, we expected a 1 around step 4500 but it's been removed.

Image

I haven't had the time to create a simple reproducer to put here. Perhaps I or @anagainaru can share the data with whoever is assigned to this (I am also going to upload it to Perlmutter tomorrow).

FWIW, the saved CSV looks fine (4 drifts detected), attached here.
slac-fel.csv

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions