I see a couple of issues with the way MLflow logging is set up. It doesn't seem to happen with the provided examples, but they do with the example set up in the slac-fel branch (maybe due to the higher number of global steps and inner CL loops?)
In that branch I committed a fix that needed to be in there to be able to see drift.detected on MLflow/WandB: logging them as ints not bools.
It's also possible drift.detected should use drift.step (the per-stage step) as its x-axis, not the global step. This way, the chart shows consistent progression independent of how many eval/cl steps happen in between. I did not implement that because I am not sure it resolves the full issue I see below.
Main issue
During a run, the dashboard does not show all the expected drift.detected events (should be 4):
Moreover, after the run ends, MLflow seems to remove some rows from the data, and since we have very few 1s vs 0s, we end up with a flat 0 line:
Exporting the CSV data for drift.detected, for this example, we expected a 1 around step 4500 but it's been removed.
I haven't had the time to create a simple reproducer to put here. Perhaps I or @anagainaru can share the data with whoever is assigned to this (I am also going to upload it to Perlmutter tomorrow).
FWIW, the saved CSV looks fine (4 drifts detected), attached here.
slac-fel.csv
I see a couple of issues with the way MLflow logging is set up. It doesn't seem to happen with the provided examples, but they do with the example set up in the slac-fel branch (maybe due to the higher number of global steps and inner CL loops?)
In that branch I committed a fix that needed to be in there to be able to see drift.detected on MLflow/WandB: logging them as ints not bools.
It's also possible drift.detected should use drift.step (the per-stage step) as its x-axis, not the global step. This way, the chart shows consistent progression independent of how many eval/cl steps happen in between. I did not implement that because I am not sure it resolves the full issue I see below.
Main issue
During a run, the dashboard does not show all the expected drift.detected events (should be 4):
Moreover, after the run ends, MLflow seems to remove some rows from the data, and since we have very few 1s vs 0s, we end up with a flat 0 line:
Exporting the CSV data for drift.detected, for this example, we expected a 1 around step 4500 but it's been removed.
I haven't had the time to create a simple reproducer to put here. Perhaps I or @anagainaru can share the data with whoever is assigned to this (I am also going to upload it to Perlmutter tomorrow).
FWIW, the saved CSV looks fine (4 drifts detected), attached here.
slac-fel.csv