From 3cfb9aa2633e7fd001f7faf04bb3c32b7cb97b5f Mon Sep 17 00:00:00 2001 From: David de la Iglesia Castro Date: Thu, 13 Apr 2023 15:59:14 +0000 Subject: [PATCH] Use log_artifact in notebook. --- .../code/notebooks/TrainSegModel.ipynb | 48 +++++++++++-------- example-get-started-experiments/generate.sh | 1 + 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/example-get-started-experiments/code/notebooks/TrainSegModel.ipynb b/example-get-started-experiments/code/notebooks/TrainSegModel.ipynb index ce5a6eb9..091df3fd 100644 --- a/example-get-started-experiments/code/notebooks/TrainSegModel.ipynb +++ b/example-get-started-experiments/code/notebooks/TrainSegModel.ipynb @@ -158,7 +158,28 @@ " intersection = 2.0 * np.sum(y_true * y_pred)\n", " dice = intersection / (np.sum(y_true) + np.sum(y_pred) + eps)\n", " dice_list.append(dice)\n", - " return np.mean(dice_list)" + " return np.mean(dice_list)\n", + "\n", + "def evaluate(learn):\n", + " test_img_fpaths = get_files(DATA / \"test_data\", extensions=\".jpg\")\n", + " test_dl = learn.dls.test_dl(test_img_fpaths)\n", + " preds, _ = learn.get_preds(dl=test_dl)\n", + " masks_pred = np.array(preds[:, 1, :] > 0.5, dtype=np.uint8)\n", + " test_mask_fpaths = [\n", + " get_mask_path(fpath, DATA / \"test_data\") for fpath in test_img_fpaths\n", + " ]\n", + " masks_true = [Image.open(mask_path) for mask_path in test_mask_fpaths]\n", + " dice_multi = 0.0\n", + " for ii in range(len(masks_true)):\n", + " mask_pred, mask_true = masks_pred[ii], masks_true[ii]\n", + " width, height = mask_true.shape[1], mask_true.shape[0]\n", + " mask_pred = np.array(\n", + " Image.fromarray(mask_pred).resize((width, height)),\n", + " dtype=int\n", + " )\n", + " mask_true = np.array(mask_true, dtype=int)\n", + " dice_multi += dice(mask_true, mask_pred) / len(masks_true)\n", + " return dice_multi" ] }, { @@ -168,6 +189,8 @@ "outputs": [], "source": [ "train_arch = 'resnet18'\n", + "models_dir = ROOT / \"models\"\n", + "models_dir.mkdir(exist_ok=True)\n", "\n", "for base_lr in [0.001, 0.005, 0.01]:\n", " with Live(str(ROOT / \"results\" / \"train\"), save_dvc_exp=True) as live:\n", @@ -185,26 +208,11 @@ " **fine_tune_args,\n", " cbs=[DVCLiveCallback(live=live)])\n", "\n", + " learn.export(fname=(models_dir / \"model.pkl\").absolute())\n", + "\n", + " live.summary[\"evaluate/dice_multi\"] = evaluate(learn)\n", "\n", - " test_img_fpaths = get_files(DATA / \"test_data\", extensions=\".jpg\")\n", - " test_dl = learn.dls.test_dl(test_img_fpaths)\n", - " preds, _ = learn.get_preds(dl=test_dl)\n", - " masks_pred = np.array(preds[:, 1, :] > 0.5, dtype=np.uint8)\n", - " test_mask_fpaths = [\n", - " get_mask_path(fpath, DATA / \"test_data\") for fpath in test_img_fpaths\n", - " ]\n", - " masks_true = [Image.open(mask_path) for mask_path in test_mask_fpaths]\n", - " dice_multi = 0.0\n", - " for ii in range(len(masks_true)):\n", - " mask_pred, mask_true = masks_pred[ii], masks_true[ii]\n", - " width, height = mask_true.shape[1], mask_true.shape[0]\n", - " mask_pred = np.array(\n", - " Image.fromarray(mask_pred).resize((width, height)),\n", - " dtype=int\n", - " )\n", - " mask_true = np.array(mask_true, dtype=int)\n", - " dice_multi += dice(mask_true, mask_pred) / len(masks_true)\n", - " live.summary[\"evaluate/dice_multi\"] = dice_multi" + " live.log_artifact(str(models_dir / \"model.pkl\"))" ] }, { diff --git a/example-get-started-experiments/generate.sh b/example-get-started-experiments/generate.sh index bf6848be..3787b405 100755 --- a/example-get-started-experiments/generate.sh +++ b/example-get-started-experiments/generate.sh @@ -95,6 +95,7 @@ cp -r $HERE/code/src . cp $HERE/code/params.yaml . sed -e "s/base_lr: 0.01/base_lr: $BEST_EXP_BASE_LR/" -i".bkp" params.yaml rm params.yaml.bkp +dvc remove models/model.pkl.dvc dvc stage add -n data_split \ -p base,data_split \