Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions example-get-started-experiments/code/notebooks/TrainSegModel.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,28 @@
" intersection = 2.0 * np.sum(y_true * y_pred)\n",
" dice = intersection / (np.sum(y_true) + np.sum(y_pred) + eps)\n",
" dice_list.append(dice)\n",
" return np.mean(dice_list)"
" return np.mean(dice_list)\n",
"\n",
"def evaluate(learn):\n",
" test_img_fpaths = get_files(DATA / \"test_data\", extensions=\".jpg\")\n",
" test_dl = learn.dls.test_dl(test_img_fpaths)\n",
" preds, _ = learn.get_preds(dl=test_dl)\n",
" masks_pred = np.array(preds[:, 1, :] > 0.5, dtype=np.uint8)\n",
" test_mask_fpaths = [\n",
" get_mask_path(fpath, DATA / \"test_data\") for fpath in test_img_fpaths\n",
" ]\n",
" masks_true = [Image.open(mask_path) for mask_path in test_mask_fpaths]\n",
" dice_multi = 0.0\n",
" for ii in range(len(masks_true)):\n",
" mask_pred, mask_true = masks_pred[ii], masks_true[ii]\n",
" width, height = mask_true.shape[1], mask_true.shape[0]\n",
" mask_pred = np.array(\n",
" Image.fromarray(mask_pred).resize((width, height)),\n",
" dtype=int\n",
" )\n",
" mask_true = np.array(mask_true, dtype=int)\n",
" dice_multi += dice(mask_true, mask_pred) / len(masks_true)\n",
" return dice_multi"
]
},
{
Expand All @@ -168,6 +189,8 @@
"outputs": [],
"source": [
"train_arch = 'resnet18'\n",
"models_dir = ROOT / \"models\"\n",
"models_dir.mkdir(exist_ok=True)\n",
"\n",
"for base_lr in [0.001, 0.005, 0.01]:\n",
" with Live(str(ROOT / \"results\" / \"train\"), save_dvc_exp=True) as live:\n",
Expand All @@ -185,26 +208,11 @@
" **fine_tune_args,\n",
" cbs=[DVCLiveCallback(live=live)])\n",
"\n",
" learn.export(fname=(models_dir / \"model.pkl\").absolute())\n",
"\n",
" live.summary[\"evaluate/dice_multi\"] = evaluate(learn)\n",
"\n",
" test_img_fpaths = get_files(DATA / \"test_data\", extensions=\".jpg\")\n",
" test_dl = learn.dls.test_dl(test_img_fpaths)\n",
" preds, _ = learn.get_preds(dl=test_dl)\n",
" masks_pred = np.array(preds[:, 1, :] > 0.5, dtype=np.uint8)\n",
" test_mask_fpaths = [\n",
" get_mask_path(fpath, DATA / \"test_data\") for fpath in test_img_fpaths\n",
" ]\n",
" masks_true = [Image.open(mask_path) for mask_path in test_mask_fpaths]\n",
" dice_multi = 0.0\n",
" for ii in range(len(masks_true)):\n",
" mask_pred, mask_true = masks_pred[ii], masks_true[ii]\n",
" width, height = mask_true.shape[1], mask_true.shape[0]\n",
" mask_pred = np.array(\n",
" Image.fromarray(mask_pred).resize((width, height)),\n",
" dtype=int\n",
" )\n",
" mask_true = np.array(mask_true, dtype=int)\n",
" dice_multi += dice(mask_true, mask_pred) / len(masks_true)\n",
" live.summary[\"evaluate/dice_multi\"] = dice_multi"
" live.log_artifact(str(models_dir / \"model.pkl\"))"
]
},
{
Expand Down
1 change: 1 addition & 0 deletions example-get-started-experiments/generate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ cp -r $HERE/code/src .
cp $HERE/code/params.yaml .
sed -e "s/base_lr: 0.01/base_lr: $BEST_EXP_BASE_LR/" -i".bkp" params.yaml
rm params.yaml.bkp
dvc remove models/model.pkl.dvc
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I understand the logic here, could you describe the workflow? do we transition from dvc add to log_artifact at some stage?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We transition from log_artifact which is dvc add to stage output in the dvc.yaml.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it ... and what it was before the log_artifact where and how did we save the model?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what it was before the log_artifact where and how did we save the model?

We didn't save it at all during the notebook state, that was the motivation for the P.R.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it!


dvc stage add -n data_split \
-p base,data_split \
Expand Down