diff --git a/Training_&_Testing_InfNet_&_SemiInfNet.ipynb b/Training_&_Testing_InfNet_&_SemiInfNet.ipynb new file mode 100644 index 00000000..760ac323 --- /dev/null +++ b/Training_&_Testing_InfNet_&_SemiInfNet.ipynb @@ -0,0 +1,500 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "**Pooyan Rezaeipour Lasaki**\n", + "\n", + "e-mails:\n", + "rezaeipourpooyan@gmail.com &\n", + "pooyan_rezaeipour@elec.iust.ac.ir\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "NBzk2gT4cbWc" + } + }, + { + "cell_type": "markdown", + "source": [ + "**Mohsen Safaei**\n", + "\n", + "e-mails: mfsafaei78@gmail.com & mo_safaei@elec.iust.ac.ir" + ], + "metadata": { + "id": "4sF28w9ooSmx" + } + }, + { + "cell_type": "markdown", + "source": [ + "**Saeed Chamani**\n", + "\n", + "e-mails: saeed.chamani10@gmail.com\n", + "& saeed_chamani@elec.iust.ac.ir\n" + ], + "metadata": { + "id": "2DbUtbJ2ocOP" + } + }, + { + "cell_type": "markdown", + "source": [ + "\n", + "*Biomedical Engineering Department, School of Electrical Engineering, \"Iran University of Science and Technology\", Tehran*" + ], + "metadata": { + "id": "6Bz-krJCoYBE" + } + }, + { + "cell_type": "markdown", + "source": [ + "*To find the files which you have uploaded in your \"Google Drive\"*" + ], + "metadata": { + "id": "Nfvxi52_PdCu" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GCH1VY7EPBeM" + }, + "outputs": [], + "source": [ + "from google.colab import drive\n", + "drive.mount('/content/gdrive')" + ] + }, + { + "cell_type": "markdown", + "source": [ + "*Unzipping the mentioned file which is named \"Inf-Net-master\"*" + ], + "metadata": { + "id": "lrOeyY2zPzx5" + } + }, + { + "cell_type": "code", + "source": [ + "!unzip \"/content/gdrive/MyDrive/Inf-Net-master.zip\"" + ], + "metadata": { + "id": "Y60Fu4r3Pxh3" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "*Install the package named \"thop\"*" + ], + "metadata": { + "id": "DlXTRYWLQVp5" + } + }, + { + "cell_type": "code", + "source": [ + "pip install thop" + ], + "metadata": { + "id": "7wKqn7SMQRo-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "*For recognizing the module named \"Code\"*" + ], + "metadata": { + "id": "nBLraA_eQy4q" + } + }, + { + "cell_type": "code", + "source": [ + "cd '/content/Inf-Net-master'" + ], + "metadata": { + "id": "3KWrSufZQvXw" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import sys\n", + "sys.path.append('/content/Inf-Net-master')" + ], + "metadata": { + "id": "EaTK03_mQv99" + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "*For solving the error of \"ipykernel_launcher.py: error: unrecognized arguments\"*" + ], + "metadata": { + "id": "GQCjBsepRp3m" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install ipykernel" + ], + "metadata": { + "id": "yY9orVlvRlZU" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import sys\n", + "sys.argv=['-f']\n", + "del sys" + ], + "metadata": { + "id": "cDl9trNnRoIt" + }, + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "*Training \"MyTrain_LungInf.py\" for Inf-Net & Semi-Inf-Net*\n", + "\n", + "*The following code is adjusted for Inf-Net. If you want to change it to Semi-Inf-Net, follow the [README.md](https://github.com/DengPingFan/Inf-Net/blob/master/README.md)*" + ], + "metadata": { + "id": "FLBiAGX4SHdy" + } + }, + { + "cell_type": "code", + "source": [ + "# -*- coding: utf-8 -*-\n", + "\n", + "\"\"\"Preview\n", + "Code for 'Inf-Net: Automatic COVID-19 Lung Infection Segmentation from CT Scans'\n", + "submit to Transactions on Medical Imaging, 2020.\n", + "\n", + "1st Version: Created on 2020-05-13 (@author: Ge-Peng Ji)\n", + "2nd Version: Fix some bugs caused by THOP on 2020-06-10 (@author: Ge-Peng Ji)\n", + "\"\"\"\n", + "\n", + "import torch\n", + "from torch.autograd import Variable\n", + "import os\n", + "import argparse\n", + "from datetime import datetime\n", + "from Code.utils.dataloader_LungInf import get_loader\n", + "from Code.utils.utils import clip_gradient, adjust_lr, AvgMeter\n", + "import torch.nn.functional as F\n", + "\n", + "\n", + "def joint_loss(pred, mask):\n", + " weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)\n", + " wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')\n", + " wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))\n", + "\n", + " pred = torch.sigmoid(pred)\n", + " inter = ((pred * mask)*weit).sum(dim=(2, 3))\n", + " union = ((pred + mask)*weit).sum(dim=(2, 3))\n", + " wiou = 1 - (inter + 1)/(union - inter+1)\n", + " return (wbce + wiou).mean()\n", + "\n", + "\n", + "def train(train_loader, model, optimizer, epoch, train_save):\n", + " model.train()\n", + " # ---- multi-scale training ----\n", + " size_rates = [0.75, 1, 1.25] # replace your desired scale, try larger scale for better accuracy in small object\n", + " loss_record1, loss_record2, loss_record3, loss_record4, loss_record5 = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()\n", + " for i, pack in enumerate(train_loader, start=1):\n", + " for rate in size_rates:\n", + " optimizer.zero_grad()\n", + " # ---- data prepare ----\n", + " images, gts, edges = pack\n", + " images = Variable(images).cuda()\n", + " gts = Variable(gts).cuda()\n", + " edges = Variable(edges).cuda()\n", + " # ---- rescaling the inputs (img/gt/edge) ----\n", + " trainsize = int(round(opt.trainsize*rate/32)*32)\n", + " if rate != 1:\n", + " images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True)\n", + " gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True)\n", + " edges = F.upsample(edges, size=(trainsize, trainsize), mode='bilinear', align_corners=True)\n", + "\n", + " # ---- forward ----\n", + " lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2, lateral_edge = model(images)\n", + " # ---- loss function ----\n", + " loss5 = joint_loss(lateral_map_5, gts)\n", + " loss4 = joint_loss(lateral_map_4, gts)\n", + " loss3 = joint_loss(lateral_map_3, gts)\n", + " loss2 = joint_loss(lateral_map_2, gts)\n", + " loss1 = BCE(lateral_edge, edges)\n", + " loss = loss1 + loss2 + loss3 + loss4 + loss5\n", + " # ---- backward ----\n", + " loss.backward()\n", + " clip_gradient(optimizer, opt.clip)\n", + " optimizer.step()\n", + " # ---- recording loss ----\n", + " if rate == 1:\n", + " loss_record1.update(loss1.data, opt.batchsize)\n", + " loss_record2.update(loss2.data, opt.batchsize)\n", + " loss_record3.update(loss3.data, opt.batchsize)\n", + " loss_record4.update(loss4.data, opt.batchsize)\n", + " loss_record5.update(loss5.data, opt.batchsize)\n", + " # ---- train logging ----\n", + " if i % 20 == 0 or i == total_step:\n", + " print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], [lateral-edge: {:.4f}, '\n", + " 'lateral-2: {:.4f}, lateral-3: {:0.4f}, lateral-4: {:0.4f}, lateral-5: {:0.4f}]'.\n", + " format(datetime.now(), epoch, opt.epoch, i, total_step, loss_record1.show(),\n", + " loss_record2.show(), loss_record3.show(), loss_record4.show(), loss_record5.show()))\n", + " # ---- save model_lung_infection ----\n", + " save_path = './Snapshots/save_weights/{}/'.format(train_save)\n", + " os.makedirs(save_path, exist_ok=True)\n", + "\n", + " if (epoch+1) % 10 == 0:\n", + " torch.save(model.state_dict(), save_path + 'Inf-Net-%d.pth' % (epoch+1))\n", + " print('[Saving Snapshot:]', save_path + 'Inf-Net-%d.pth' % (epoch+1))\n", + "\n", + "\n", + "if __name__ == '__main__':\n", + " parser = argparse.ArgumentParser()\n", + " # hyper-parameters\n", + " parser.add_argument('--epoch', type=int, default=100,\n", + " help='epoch number')\n", + " parser.add_argument('--lr', type=float, default=1e-4,\n", + " help='learning rate')\n", + " parser.add_argument('--batchsize', type=int, default=24,\n", + " help='training batch size')\n", + " parser.add_argument('--trainsize', type=int, default=352,\n", + " help='set the size of training sample')\n", + " parser.add_argument('--clip', type=float, default=0.5,\n", + " help='gradient clipping margin')\n", + " parser.add_argument('--decay_rate', type=float, default=0.1,\n", + " help='decay rate of learning rate')\n", + " parser.add_argument('--decay_epoch', type=int, default=50,\n", + " help='every n epochs decay learning rate')\n", + " parser.add_argument('--is_thop', type=bool, default=False,\n", + " help='whether calculate FLOPs/Params (Thop)')\n", + " parser.add_argument('--gpu_device', type=int, default=0,\n", + " help='choose which GPU device you want to use')\n", + " parser.add_argument('--num_workers', type=int, default=8,\n", + " help='number of workers in dataloader. In windows, set num_workers=0')\n", + " # model_lung_infection parameters\n", + " parser.add_argument('--net_channel', type=int, default=32,\n", + " help='internal channel numbers in the Inf-Net, default=32, try larger for better accuracy')\n", + " parser.add_argument('--n_classes', type=int, default=1,\n", + " help='binary segmentation when n_classes=1')\n", + " parser.add_argument('--backbone', type=str, default='Res2Net50',\n", + " help='change different backbone, choice: VGGNet16, ResNet50, Res2Net50')\n", + " # training dataset\n", + " parser.add_argument('--train_path', type=str,\n", + " default='./Dataset/COVID-SemiSeg/Dataset/TrainingSet/LungInfection-Train/Doctor-label')\n", + " parser.add_argument('--is_semi', type=bool, default=False,\n", + " help='if True, you will turn on the mode of `Semi-Inf-Net`')\n", + " parser.add_argument('--is_pseudo', type=bool, default=False,\n", + " help='if True, you will train the model on pseudo-label')\n", + " parser.add_argument('--train_save', type=str, default=None,\n", + " help='If you use custom save path, please edit `--is_semi=True` and `--is_pseudo=True`')\n", + "\n", + " opt = parser.parse_args()\n", + "\n", + " # ---- build models ----\n", + " torch.cuda.set_device(opt.gpu_device)\n", + " # - please asign your prefer backbone in opt.\n", + " if opt.backbone == 'Res2Net50':\n", + " print('Backbone loading: Res2Net50')\n", + " from Code.model_lung_infection.InfNet_Res2Net import Inf_Net\n", + " elif opt.backbone == 'ResNet50':\n", + " print('Backbone loading: ResNet50')\n", + " from Code.model_lung_infection.InfNet_ResNet import Inf_Net\n", + " elif opt.backbone == 'VGGNet16':\n", + " print('Backbone loading: VGGNet16')\n", + " from Code.model_lung_infection.InfNet_VGGNet import Inf_Net\n", + " else:\n", + " raise ValueError('Invalid backbone parameters: {}'.format(opt.backbone))\n", + " model = Inf_Net(channel=opt.net_channel, n_class=opt.n_classes).cuda()\n", + "\n", + " # ---- load pre-trained weights (mode=Semi-Inf-Net) ----\n", + " # - See Sec.2.3 of `README.md` to learn how to generate your own img/pseudo-label from scratch.\n", + " if opt.is_semi and opt.backbone == 'Res2Net50':\n", + " print('Loading weights from weights file trained on pseudo label')\n", + " model.load_state_dict(torch.load('./Snapshots/save_weights/Inf-Net_Pseduo/Inf-Net_pseudo_100.pth'))\n", + " else:\n", + " print('Not loading weights from weights file')\n", + "\n", + " # weights file save path\n", + " if opt.is_pseudo and (not opt.is_semi):\n", + " train_save = 'Inf-Net_Pseudo'\n", + " elif (not opt.is_pseudo) and opt.is_semi:\n", + " train_save = 'Semi-Inf-Net'\n", + " elif (not opt.is_pseudo) and (not opt.is_semi):\n", + " train_save = 'Inf-Net'\n", + " else:\n", + " print('Use custom save path')\n", + " train_save = opt.train_save\n", + "\n", + " # ---- calculate FLOPs and Params ----\n", + " if opt.is_thop:\n", + " from Code.utils.utils import CalParams\n", + " x = torch.randn(1, 3, opt.trainsize, opt.trainsize).cuda()\n", + " CalParams(model, x)\n", + "\n", + " # ---- load training sub-modules ----\n", + " BCE = torch.nn.BCEWithLogitsLoss()\n", + "\n", + " params = model.parameters()\n", + " optimizer = torch.optim.Adam(params, opt.lr)\n", + "\n", + " image_root = '{}/Imgs/'.format(opt.train_path)\n", + " gt_root = '{}/GT/'.format(opt.train_path)\n", + " edge_root = '{}/Edge/'.format(opt.train_path)\n", + "\n", + " train_loader = get_loader(image_root, gt_root, edge_root,\n", + " batchsize=opt.batchsize, trainsize=opt.trainsize, num_workers=opt.num_workers)\n", + " total_step = len(train_loader)\n", + "\n", + " # ---- start !! -----\n", + " print(\"#\"*20, \"\\nStart Training (Inf-Net-{})\\n{}\\nThis code is written for 'Inf-Net: Automatic COVID-19 Lung \"\n", + " \"Infection Segmentation from CT Scans', 2020, TMI.\\n\"\n", + " \"----\\nPlease cite the paper if you use this code and dataset. \"\n", + " \"And any questions feel free to contact me \"\n", + " \"via E-mail (gepengai.ji@gmail.com)\\n----\\n\".format(opt.backbone, opt), \"#\"*20)\n", + "\n", + " for epoch in range(1, opt.epoch):\n", + " adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch)\n", + " train(train_loader, model, optimizer, epoch, train_save)\n" + ], + "metadata": { + "id": "NVj5QcZwQev5" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "*Testing \"MyTest_LungInf.py\" for Inf-Net & Semi-Inf-Net*\n", + "\n", + "*The following code is adjusted for Inf-Net. If you want to change it to Semi-Inf-Net, follow the [README.md](https://github.com/DengPingFan/Inf-Net/blob/master/README.md)*" + ], + "metadata": { + "id": "UnzEsna6UHe3" + } + }, + { + "cell_type": "code", + "source": [ + "# -*- coding: utf-8 -*-\n", + "\n", + "\"\"\"Preview\n", + "Code for 'Inf-Net: Automatic COVID-19 Lung Infection Segmentation from CT Scans'\n", + "submit to Transactions on Medical Imaging, 2020.\n", + "\n", + "First Version: Created on 2020-05-13 (@author: Ge-Peng Ji)\n", + "\"\"\"\n", + "\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import numpy as np\n", + "import os\n", + "import argparse\n", + "from scipy import misc\n", + "from Code.model_lung_infection.InfNet_Res2Net import Inf_Net as Network\n", + "from Code.utils.dataloader_LungInf import test_dataset\n", + "\n", + "\n", + "def inference():\n", + " parser = argparse.ArgumentParser()\n", + " parser.add_argument('--testsize', type=int, default=352, help='testing size')\n", + " parser.add_argument('--data_path', type=str, default='./Dataset/COVID-SemiSeg/Dataset/TestingSet/LungInfection-Test/',\n", + " help='Path to test data')\n", + " parser.add_argument('--pth_path', type=str, default='./Snapshots/save_weights/Inf-Net/Inf-Net-100.pth',\n", + " help='Path to weights file. If `semi-sup`, edit it to `Semi-Inf-Net/Semi-Inf-Net-100.pth`')\n", + " parser.add_argument('--save_path', type=str, default='./Results/Lung infection segmentation/Inf-Net/',\n", + " help='Path to save the predictions. if `semi-sup`, edit it to `Semi-Inf-Net`')\n", + " opt = parser.parse_args()\n", + "\n", + " print(\"#\" * 20, \"\\nStart Testing (Inf-Net)\\n{}\\nThis code is written for 'Inf-Net: Automatic COVID-19 Lung \"\n", + " \"Infection Segmentation from CT Scans', 2020, TMI.\\n\"\n", + " \"----\\nPlease cite the paper if you use this code and dataset. \"\n", + " \"And any questions feel free to contact me \"\n", + " \"via E-mail (gepengai.ji@gamil.com)\\n----\\n\".format(opt), \"#\" * 20)\n", + "\n", + " model = Network()\n", + " # model = torch.nn.DataParallel(model, device_ids=[0, 1]) # uncomment it if you have multiply GPUs.\n", + " model.load_state_dict(torch.load(opt.pth_path, map_location={'cuda:1':'cuda:0'}))\n", + " model.cuda()\n", + " model.eval()\n", + "\n", + " image_root = '{}/Imgs/'.format(opt.data_path)\n", + " # gt_root = '{}/GT/'.format(opt.data_path)\n", + " test_loader = test_dataset(image_root, opt.testsize)\n", + " os.makedirs(opt.save_path, exist_ok=True)\n", + "\n", + " for i in range(test_loader.size):\n", + " image, name = test_loader.load_data()\n", + "\n", + " image = image.cuda()\n", + "\n", + " lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2, lateral_edge = model(image)\n", + "\n", + " res = lateral_map_2\n", + " # res = F.upsample(res, size=(ori_size[1],ori_size[0]), mode='bilinear', align_corners=False)\n", + " res = res.sigmoid().data.cpu().numpy().squeeze()\n", + " res = (res - res.min()) / (res.max() - res.min() + 1e-8)\n", + " #misc.imsave(opt.save_path + name, res)\n", + " import imageio\n", + " imageio.imwrite(opt.save_path + name, res)\n", + "\n", + " print('Test Done!')\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " inference()\n" + ], + "metadata": { + "id": "8EUSmTJwUPI2" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/Training_&_Testing_SemiInfNet_+_MulticlassUNet.ipynb b/Training_&_Testing_SemiInfNet_+_MulticlassUNet.ipynb new file mode 100644 index 00000000..885701a6 --- /dev/null +++ b/Training_&_Testing_SemiInfNet_+_MulticlassUNet.ipynb @@ -0,0 +1,359 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "**Pooyan Rezaeipour Lasaki**\n", + "\n", + "e-mails:\n", + "rezaeipourpooyan@gmail.com &\n", + "pooyan_rezaeipour@elec.iust.ac.ir\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "metadata": { + "id": "NBzk2gT4cbWc" + } + }, + { + "cell_type": "markdown", + "source": [ + "**Mohsen Safaei**\n", + "\n", + "e-mails: mfsafaei78@gmail.com & mo_safaei@elec.iust.ac.ir" + ], + "metadata": { + "id": "4sF28w9ooSmx" + } + }, + { + "cell_type": "markdown", + "source": [ + "**Saeed Chamani**\n", + "\n", + "e-mails: saeed.chamani10@gmail.com\n", + "& saeed_chamani@elec.iust.ac.ir\n" + ], + "metadata": { + "id": "2DbUtbJ2ocOP" + } + }, + { + "cell_type": "markdown", + "source": [ + "\n", + "*Biomedical Engineering Department, School of Electrical Engineering, \"Iran University of Science and Technology\", Tehran*" + ], + "metadata": { + "id": "6Bz-krJCoYBE" + } + }, + { + "cell_type": "markdown", + "source": [ + "*To find the files which you have uploaded in your \"Google Drive\"*" + ], + "metadata": { + "id": "Q3zewxs-bj_v" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ixa4_nkha_gh" + }, + "outputs": [], + "source": [ + "from google.colab import drive\n", + "drive.mount('/content/gdrive')" + ] + }, + { + "cell_type": "markdown", + "source": [ + "*Unzipping the mentioned file which is named \"Inf-Net-master\"*" + ], + "metadata": { + "id": "_K4pW3h1bnkX" + } + }, + { + "cell_type": "code", + "source": [ + "!unzip \"/content/gdrive/MyDrive/Inf-Net-master.zip\"" + ], + "metadata": { + "id": "1YG1syJcbfBI" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "*For recognizing the module named \"Code\"*" + ], + "metadata": { + "id": "nT6Yvu7_cT-7" + } + }, + { + "cell_type": "code", + "source": [ + "cd '/content/Inf-Net-master'" + ], + "metadata": { + "id": "q9A5tTYGb55F" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import sys\n", + "sys.path.append('/content/Inf-Net-master')" + ], + "metadata": { + "id": "_CjMZsNTb9TS" + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "*Training \"MyTrain_MulClsLungInf_UNet\" for Semi-Inf-Net + Multi-class UNet*" + ], + "metadata": { + "id": "IF1yiJRNcuBa" + } + }, + { + "cell_type": "markdown", + "source": [ + "**Note** that if you run this file uninterruptedly after Semi-Inf-net(second method), you do not face a problem because the folder named \"Snapshots\" was created automatically.\n", + "\n", + " But for running this file seperately(which we do below), you must create a folder in \"Snapshots\" which is named \"save_weights\",then create a folder in \"save_weights\" which is named \"{}\".\n", + "\n", + "Also change the last line which is \" save_path='Semi-Inf-Net_UNet' \" to \" save_path='{}' \"" + ], + "metadata": { + "id": "QkjXlXO8dJJJ" + } + }, + { + "cell_type": "code", + "source": [ + "# -*- coding: utf-8 -*-\n", + "\n", + "\"\"\"Preview\n", + "Code for 'Inf-Net: Automatic COVID-19 Lung Infection Segmentation from CT Scans'\n", + "submit to Transactions on Medical Imaging, 2020.\n", + "\n", + "First Version: Created on 2020-05-13 (@author: Ge-Peng Ji)\n", + "\"\"\"\n", + "\n", + "import os\n", + "import numpy as np\n", + "import torch.optim as optim\n", + "from Code.utils.dataloader_MulClsLungInf_UNet import LungDataset\n", + "from torchvision import transforms\n", + "# from LungData import test_dataloader, train_dataloader # pls change batch_size\n", + "from torch.utils.data import DataLoader\n", + "from Code.model_lung_infection.InfNet_UNet import *\n", + "\n", + "\n", + "def train(epo_num, num_classes, input_channels, batch_size, lr, save_path):\n", + " train_dataset = LungDataset(\n", + " imgs_path='./Dataset/COVID-SemiSeg/Dataset/TrainingSet/MultiClassInfection-Train/Imgs/',\n", + " # NOTES: prior is borrowed from the object-level label of train split\n", + " pseudo_path='./Dataset/COVID-SemiSeg/Dataset/TrainingSet/MultiClassInfection-Train/Prior/',\n", + " label_path='./Dataset/COVID-SemiSeg/Dataset/TrainingSet/MultiClassInfection-Train/GT/',\n", + " transform=transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))\n", + " train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)\n", + "\n", + " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "\n", + " lung_model = Inf_Net_UNet(input_channels, num_classes) # input_channels=3, n_class=3\n", + " print(lung_model)\n", + " lung_model = lung_model.to(device)\n", + "\n", + " criterion = nn.BCELoss().to(device)\n", + " optimizer = optim.SGD(lung_model.parameters(), lr=lr, momentum=0.7)\n", + "\n", + " print(\"#\" * 20, \"\\nStart Training (Inf-Net)\\nThis code is written for 'Inf-Net: Automatic COVID-19 Lung \"\n", + " \"Infection Segmentation from CT Scans', 2020, TMI.\\n\"\n", + " \"----\\nPlease cite the paper if you use this code and dataset. \"\n", + " \"And any questions feel free to contact me \"\n", + " \"via E-mail (gepengai.ji@gmail.com)\\n----\\n\", \"#\" * 20)\n", + "\n", + " for epo in range(epo_num):\n", + "\n", + " train_loss = 0\n", + " lung_model.train()\n", + "\n", + " for index, (img, pseudo, img_mask, _) in enumerate(train_dataloader):\n", + "\n", + " img = img.to(device)\n", + " pseudo = pseudo.to(device)\n", + " img_mask = img_mask.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " output = lung_model(torch.cat((img, pseudo), dim=1))\n", + "\n", + " output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160])\n", + " loss = criterion(output, img_mask)\n", + "\n", + " loss.backward()\n", + " iter_loss = loss.item()\n", + " train_loss += iter_loss\n", + " optimizer.step()\n", + "\n", + " if np.mod(index, 20) == 0:\n", + " print('Epoch: {}/{}, Step: {}/{}, Train loss is {}'.format(epo, epo_num, index, len(train_dataloader), iter_loss))\n", + "\n", + " os.makedirs('./checkpoints//UNet_Multi-Class-Semi', exist_ok=True)\n", + " if np.mod(epo+1, 10) == 0:\n", + " torch.save(lung_model.state_dict(),\n", + " './Snapshots/save_weights/{}/unet_model_{}.pkl'.format(save_path, epo+1))\n", + " print('Saving checkpoints: unet_model_{}.pkl'.format(epo+1))\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " train(epo_num=200,\n", + " num_classes=3,\n", + " input_channels=6,\n", + " batch_size=16,\n", + " lr=1e-2,\n", + " save_path='{}')\n" + ], + "metadata": { + "id": "qsEQFhNebvFK" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "*Testing \"MyTest_MulClsLungInf_UNet\" for Semi-Inf-Net + Multi-class UNet*" + ], + "metadata": { + "id": "dww4tHi5hGdS" + } + }, + { + "cell_type": "markdown", + "source": [ + "For running this file seperately from second method(which we do below), change the one before the last line which is \" snapshot_dir='./Snapshots/save_weights/Semi-Inf-Net_UNet/unet_model_200.pkl' \" to \n", + " \" snapshot_dir='./Snapshots/save_weights/{}/unet_model_200.pkl' \"" + ], + "metadata": { + "id": "Wq3I-cmghukX" + } + }, + { + "cell_type": "code", + "source": [ + "# -*- coding: utf-8 -*-\n", + "\n", + "\"\"\"Preview\n", + "Code for 'Inf-Net: Automatic COVID-19 Lung Infection Segmentation from CT Scans'\n", + "submit to Transactions on Medical Imaging, 2020.\n", + "\n", + "First Version: Created on 2020-05-13 (@author: Ge-Peng Ji)\n", + "\"\"\"\n", + "\n", + "import os\n", + "import numpy as np\n", + "from Code.utils.dataloader_MulClsLungInf_UNet import LungDataset\n", + "from torchvision import transforms\n", + "from torch.utils.data import DataLoader\n", + "from Code.model_lung_infection.InfNet_UNet import * # use U-Net for multi-class segmentation\n", + "from scipy import misc\n", + "from Code.utils.split_class import split_class\n", + "import shutil\n", + "\n", + "\n", + "def inference(num_classes, input_channels, snapshot_dir, save_path):\n", + " test_dataset = LungDataset(\n", + " imgs_path='./Dataset/COVID-SemiSeg/Dataset/TestingSet/MultiClassInfection-Test/Imgs/',\n", + " pseudo_path='./Dataset/COVID-SemiSeg/Results/Lung infection segmentation/Semi-Inf-Net/', # NOTES: generated from `Semi-Inf-Net`\n", + " label_path='./Dataset/COVID-SemiSeg/Dataset/TestingSet/MultiClassInfection-Test/GT/',\n", + " transform=transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.485, 0.456, 0.406], \n", + " std=[0.229, 0.224, 0.225])]),\n", + " is_test=True\n", + " )\n", + " test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)\n", + "\n", + " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "\n", + " lung_model = Inf_Net_UNet(input_channels, num_classes).cuda()\n", + " print(lung_model)\n", + " lung_model.load_state_dict(torch.load(snapshot_dir))\n", + " lung_model.eval()\n", + "\n", + " for index, (img, pseudo, img_mask, name) in enumerate(test_dataloader):\n", + " img = img.to(device)\n", + " pseudo = pseudo.to(device)\n", + " img_mask = img_mask.to(device)\n", + "\n", + " output = lung_model(torch.cat((img, pseudo), dim=1))\n", + " output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160])\n", + " b, _, w, h = output.size()\n", + " _, _, w_gt, h_gt = img_mask.size()\n", + "\n", + " # output b*n_class*h*w -- > b*h*w\n", + " pred = output.cpu().permute(0, 2, 3, 1).contiguous().view(-1, num_classes).max(1)[1].view(b, w, h).numpy().squeeze()\n", + " print('Class numbers of prediction in total:', np.unique(pred))\n", + " # pred = misc.imresize(pred, size=(w_gt, h_gt))\n", + " os.makedirs(save_path, exist_ok=True)\n", + " #misc.imsave(save_path + name[0].replace('.jpg', '.png'), pred)\n", + " import imageio\n", + " imageio.imwrite(save_path + name[0].replace('.jpg', '.png'), pred)\n", + " split_class(save_path, name[0].replace('.jpg', '.png'), w_gt, h_gt)\n", + "\n", + " shutil.rmtree(save_path)\n", + " print('Test done!')\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " inference(num_classes=3,\n", + " input_channels=6,\n", + " snapshot_dir='./Snapshots/save_weights/{}/unet_model_200.pkl',\n", + " save_path='./Results/Multi-class lung infection segmentation/class_12/'\n", + " )\n" + ], + "metadata": { + "id": "o9b4KsLBhRqi" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/training_&_testing_infnet_&_semiinfnet.py b/training_&_testing_infnet_&_semiinfnet.py new file mode 100644 index 00000000..f01bc044 --- /dev/null +++ b/training_&_testing_infnet_&_semiinfnet.py @@ -0,0 +1,326 @@ +# -*- coding: utf-8 -*- +"""Training_&_Testing_InfNet_&_SemiInfNet.ipynb + +Automatically generated by Colaboratory. + +Original file is located at + https://colab.research.google.com/drive/1ja8JaidGb1zgkL_aTEnpwEAN7VRBo0CI + +**Pooyan Rezaeipour Lasaki** + +e-mails: +rezaeipourpooyan@gmail.com & +pooyan_rezaeipour@elec.iust.ac.ir + +**Mohsen Safaei** + +e-mails: mfsafaei78@gmail.com & mo_safaei@elec.iust.ac.ir + +**Saeed Chamani** + +e-mails: saeed.chamani10@gmail.com +& saeed_chamani@elec.iust.ac.ir + +*Biomedical Engineering Department, School of Electrical Engineering, "Iran University of Science and Technology", Tehran* + +*To find the files which you have uploaded in your "Google Drive"* +""" + +from google.colab import drive +drive.mount('/content/gdrive') + +"""*Unzipping the mentioned file which is named "Inf-Net-master"*""" + +!unzip "/content/gdrive/MyDrive/Inf-Net-master.zip" + +"""*Install the package named "thop"*""" + +pip install thop + +"""*For recognizing the module named "Code"*""" + +cd '/content/Inf-Net-master' + +import sys +sys.path.append('/content/Inf-Net-master') + +"""*For solving the error of "ipykernel_launcher.py: error: unrecognized arguments"*""" + +!pip install ipykernel + +import sys +sys.argv=['-f'] +del sys + +"""*Training "MyTrain_LungInf.py" for Inf-Net & Semi-Inf-Net* + +*The following code is adjusted for Inf-Net. If you want to change it to Semi-Inf-Net, follow the [README.md](https://github.com/DengPingFan/Inf-Net/blob/master/README.md)* +""" + +# -*- coding: utf-8 -*- + +"""Preview +Code for 'Inf-Net: Automatic COVID-19 Lung Infection Segmentation from CT Scans' +submit to Transactions on Medical Imaging, 2020. + +1st Version: Created on 2020-05-13 (@author: Ge-Peng Ji) +2nd Version: Fix some bugs caused by THOP on 2020-06-10 (@author: Ge-Peng Ji) +""" + +import torch +from torch.autograd import Variable +import os +import argparse +from datetime import datetime +from Code.utils.dataloader_LungInf import get_loader +from Code.utils.utils import clip_gradient, adjust_lr, AvgMeter +import torch.nn.functional as F + + +def joint_loss(pred, mask): + weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) + wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none') + wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3)) + + pred = torch.sigmoid(pred) + inter = ((pred * mask)*weit).sum(dim=(2, 3)) + union = ((pred + mask)*weit).sum(dim=(2, 3)) + wiou = 1 - (inter + 1)/(union - inter+1) + return (wbce + wiou).mean() + + +def train(train_loader, model, optimizer, epoch, train_save): + model.train() + # ---- multi-scale training ---- + size_rates = [0.75, 1, 1.25] # replace your desired scale, try larger scale for better accuracy in small object + loss_record1, loss_record2, loss_record3, loss_record4, loss_record5 = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() + for i, pack in enumerate(train_loader, start=1): + for rate in size_rates: + optimizer.zero_grad() + # ---- data prepare ---- + images, gts, edges = pack + images = Variable(images).cuda() + gts = Variable(gts).cuda() + edges = Variable(edges).cuda() + # ---- rescaling the inputs (img/gt/edge) ---- + trainsize = int(round(opt.trainsize*rate/32)*32) + if rate != 1: + images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True) + gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True) + edges = F.upsample(edges, size=(trainsize, trainsize), mode='bilinear', align_corners=True) + + # ---- forward ---- + lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2, lateral_edge = model(images) + # ---- loss function ---- + loss5 = joint_loss(lateral_map_5, gts) + loss4 = joint_loss(lateral_map_4, gts) + loss3 = joint_loss(lateral_map_3, gts) + loss2 = joint_loss(lateral_map_2, gts) + loss1 = BCE(lateral_edge, edges) + loss = loss1 + loss2 + loss3 + loss4 + loss5 + # ---- backward ---- + loss.backward() + clip_gradient(optimizer, opt.clip) + optimizer.step() + # ---- recording loss ---- + if rate == 1: + loss_record1.update(loss1.data, opt.batchsize) + loss_record2.update(loss2.data, opt.batchsize) + loss_record3.update(loss3.data, opt.batchsize) + loss_record4.update(loss4.data, opt.batchsize) + loss_record5.update(loss5.data, opt.batchsize) + # ---- train logging ---- + if i % 20 == 0 or i == total_step: + print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], [lateral-edge: {:.4f}, ' + 'lateral-2: {:.4f}, lateral-3: {:0.4f}, lateral-4: {:0.4f}, lateral-5: {:0.4f}]'. + format(datetime.now(), epoch, opt.epoch, i, total_step, loss_record1.show(), + loss_record2.show(), loss_record3.show(), loss_record4.show(), loss_record5.show())) + # ---- save model_lung_infection ---- + save_path = './Snapshots/save_weights/{}/'.format(train_save) + os.makedirs(save_path, exist_ok=True) + + if (epoch+1) % 10 == 0: + torch.save(model.state_dict(), save_path + 'Inf-Net-%d.pth' % (epoch+1)) + print('[Saving Snapshot:]', save_path + 'Inf-Net-%d.pth' % (epoch+1)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + # hyper-parameters + parser.add_argument('--epoch', type=int, default=100, + help='epoch number') + parser.add_argument('--lr', type=float, default=1e-4, + help='learning rate') + parser.add_argument('--batchsize', type=int, default=24, + help='training batch size') + parser.add_argument('--trainsize', type=int, default=352, + help='set the size of training sample') + parser.add_argument('--clip', type=float, default=0.5, + help='gradient clipping margin') + parser.add_argument('--decay_rate', type=float, default=0.1, + help='decay rate of learning rate') + parser.add_argument('--decay_epoch', type=int, default=50, + help='every n epochs decay learning rate') + parser.add_argument('--is_thop', type=bool, default=False, + help='whether calculate FLOPs/Params (Thop)') + parser.add_argument('--gpu_device', type=int, default=0, + help='choose which GPU device you want to use') + parser.add_argument('--num_workers', type=int, default=8, + help='number of workers in dataloader. In windows, set num_workers=0') + # model_lung_infection parameters + parser.add_argument('--net_channel', type=int, default=32, + help='internal channel numbers in the Inf-Net, default=32, try larger for better accuracy') + parser.add_argument('--n_classes', type=int, default=1, + help='binary segmentation when n_classes=1') + parser.add_argument('--backbone', type=str, default='Res2Net50', + help='change different backbone, choice: VGGNet16, ResNet50, Res2Net50') + # training dataset + parser.add_argument('--train_path', type=str, + default='./Dataset/COVID-SemiSeg/Dataset/TrainingSet/LungInfection-Train/Doctor-label') + parser.add_argument('--is_semi', type=bool, default=False, + help='if True, you will turn on the mode of `Semi-Inf-Net`') + parser.add_argument('--is_pseudo', type=bool, default=False, + help='if True, you will train the model on pseudo-label') + parser.add_argument('--train_save', type=str, default=None, + help='If you use custom save path, please edit `--is_semi=True` and `--is_pseudo=True`') + + opt = parser.parse_args() + + # ---- build models ---- + torch.cuda.set_device(opt.gpu_device) + # - please asign your prefer backbone in opt. + if opt.backbone == 'Res2Net50': + print('Backbone loading: Res2Net50') + from Code.model_lung_infection.InfNet_Res2Net import Inf_Net + elif opt.backbone == 'ResNet50': + print('Backbone loading: ResNet50') + from Code.model_lung_infection.InfNet_ResNet import Inf_Net + elif opt.backbone == 'VGGNet16': + print('Backbone loading: VGGNet16') + from Code.model_lung_infection.InfNet_VGGNet import Inf_Net + else: + raise ValueError('Invalid backbone parameters: {}'.format(opt.backbone)) + model = Inf_Net(channel=opt.net_channel, n_class=opt.n_classes).cuda() + + # ---- load pre-trained weights (mode=Semi-Inf-Net) ---- + # - See Sec.2.3 of `README.md` to learn how to generate your own img/pseudo-label from scratch. + if opt.is_semi and opt.backbone == 'Res2Net50': + print('Loading weights from weights file trained on pseudo label') + model.load_state_dict(torch.load('./Snapshots/save_weights/Inf-Net_Pseduo/Inf-Net_pseudo_100.pth')) + else: + print('Not loading weights from weights file') + + # weights file save path + if opt.is_pseudo and (not opt.is_semi): + train_save = 'Inf-Net_Pseudo' + elif (not opt.is_pseudo) and opt.is_semi: + train_save = 'Semi-Inf-Net' + elif (not opt.is_pseudo) and (not opt.is_semi): + train_save = 'Inf-Net' + else: + print('Use custom save path') + train_save = opt.train_save + + # ---- calculate FLOPs and Params ---- + if opt.is_thop: + from Code.utils.utils import CalParams + x = torch.randn(1, 3, opt.trainsize, opt.trainsize).cuda() + CalParams(model, x) + + # ---- load training sub-modules ---- + BCE = torch.nn.BCEWithLogitsLoss() + + params = model.parameters() + optimizer = torch.optim.Adam(params, opt.lr) + + image_root = '{}/Imgs/'.format(opt.train_path) + gt_root = '{}/GT/'.format(opt.train_path) + edge_root = '{}/Edge/'.format(opt.train_path) + + train_loader = get_loader(image_root, gt_root, edge_root, + batchsize=opt.batchsize, trainsize=opt.trainsize, num_workers=opt.num_workers) + total_step = len(train_loader) + + # ---- start !! ----- + print("#"*20, "\nStart Training (Inf-Net-{})\n{}\nThis code is written for 'Inf-Net: Automatic COVID-19 Lung " + "Infection Segmentation from CT Scans', 2020, TMI.\n" + "----\nPlease cite the paper if you use this code and dataset. " + "And any questions feel free to contact me " + "via E-mail (gepengai.ji@gmail.com)\n----\n".format(opt.backbone, opt), "#"*20) + + for epoch in range(1, opt.epoch): + adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch) + train(train_loader, model, optimizer, epoch, train_save) + +"""*Testing "MyTest_LungInf.py" for Inf-Net & Semi-Inf-Net* + +*The following code is adjusted for Inf-Net. If you want to change it to Semi-Inf-Net, follow the [README.md](https://github.com/DengPingFan/Inf-Net/blob/master/README.md)* +""" + +# -*- coding: utf-8 -*- + +"""Preview +Code for 'Inf-Net: Automatic COVID-19 Lung Infection Segmentation from CT Scans' +submit to Transactions on Medical Imaging, 2020. + +First Version: Created on 2020-05-13 (@author: Ge-Peng Ji) +""" + +import torch +import torch.nn.functional as F +import numpy as np +import os +import argparse +from scipy import misc +from Code.model_lung_infection.InfNet_Res2Net import Inf_Net as Network +from Code.utils.dataloader_LungInf import test_dataset + + +def inference(): + parser = argparse.ArgumentParser() + parser.add_argument('--testsize', type=int, default=352, help='testing size') + parser.add_argument('--data_path', type=str, default='./Dataset/COVID-SemiSeg/Dataset/TestingSet/LungInfection-Test/', + help='Path to test data') + parser.add_argument('--pth_path', type=str, default='./Snapshots/save_weights/Inf-Net/Inf-Net-100.pth', + help='Path to weights file. If `semi-sup`, edit it to `Semi-Inf-Net/Semi-Inf-Net-100.pth`') + parser.add_argument('--save_path', type=str, default='./Results/Lung infection segmentation/Inf-Net/', + help='Path to save the predictions. if `semi-sup`, edit it to `Semi-Inf-Net`') + opt = parser.parse_args() + + print("#" * 20, "\nStart Testing (Inf-Net)\n{}\nThis code is written for 'Inf-Net: Automatic COVID-19 Lung " + "Infection Segmentation from CT Scans', 2020, TMI.\n" + "----\nPlease cite the paper if you use this code and dataset. " + "And any questions feel free to contact me " + "via E-mail (gepengai.ji@gamil.com)\n----\n".format(opt), "#" * 20) + + model = Network() + # model = torch.nn.DataParallel(model, device_ids=[0, 1]) # uncomment it if you have multiply GPUs. + model.load_state_dict(torch.load(opt.pth_path, map_location={'cuda:1':'cuda:0'})) + model.cuda() + model.eval() + + image_root = '{}/Imgs/'.format(opt.data_path) + # gt_root = '{}/GT/'.format(opt.data_path) + test_loader = test_dataset(image_root, opt.testsize) + os.makedirs(opt.save_path, exist_ok=True) + + for i in range(test_loader.size): + image, name = test_loader.load_data() + + image = image.cuda() + + lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2, lateral_edge = model(image) + + res = lateral_map_2 + # res = F.upsample(res, size=(ori_size[1],ori_size[0]), mode='bilinear', align_corners=False) + res = res.sigmoid().data.cpu().numpy().squeeze() + res = (res - res.min()) / (res.max() - res.min() + 1e-8) + #misc.imsave(opt.save_path + name, res) + import imageio + imageio.imwrite(opt.save_path + name, res) + + print('Test Done!') + + +if __name__ == "__main__": + inference() \ No newline at end of file diff --git a/training_&_testing_semiinfnet_+_multiclassunet.py b/training_&_testing_semiinfnet_+_multiclassunet.py new file mode 100644 index 00000000..03535d2d --- /dev/null +++ b/training_&_testing_semiinfnet_+_multiclassunet.py @@ -0,0 +1,212 @@ +# -*- coding: utf-8 -*- +"""Training_&_Testing SemiInfNet + MulticlassUNet.ipynb + +Automatically generated by Colaboratory. + +Original file is located at + https://colab.research.google.com/drive/1bmII5H1AHL4MQDGr6vsOCDi-BFtD_j1m + +**Pooyan Rezaeipour Lasaki** + +e-mails: +rezaeipourpooyan@gmail.com & +pooyan_rezaeipour@elec.iust.ac.ir + +**Mohsen Safaei** + +e-mails: mfsafaei78@gmail.com & mo_safaei@elec.iust.ac.ir + +**Saeed Chamani** + +e-mails: saeed.chamani10@gmail.com +& saeed_chamani@elec.iust.ac.ir + +*Biomedical Engineering Department, School of Electrical Engineering, "Iran University of Science and Technology", Tehran* + +*To find the files which you have uploaded in your "Google Drive"* +""" + +from google.colab import drive +drive.mount('/content/gdrive') + +"""*Unzipping the mentioned file which is named "Inf-Net-master"*""" + +!unzip "/content/gdrive/MyDrive/Inf-Net-master.zip" + +"""*For recognizing the module named "Code"*""" + +cd '/content/Inf-Net-master' + +import sys +sys.path.append('/content/Inf-Net-master') + +"""*Training "MyTrain_MulClsLungInf_UNet" for Semi-Inf-Net + Multi-class UNet* + +**Note** that if you run this file uninterruptedly after Semi-Inf-net(second method), you do not face a problem because the folder named "Snapshots" was created automatically. + + But for running this file seperately(which we do below), you must create a folder in "Snapshots" which is named "save_weights",then create a folder in "save_weights" which is named "{}". + +Also change the last line which is " save_path='Semi-Inf-Net_UNet' " to " save_path='{}' " +""" + +# -*- coding: utf-8 -*- + +"""Preview +Code for 'Inf-Net: Automatic COVID-19 Lung Infection Segmentation from CT Scans' +submit to Transactions on Medical Imaging, 2020. + +First Version: Created on 2020-05-13 (@author: Ge-Peng Ji) +""" + +import os +import numpy as np +import torch.optim as optim +from Code.utils.dataloader_MulClsLungInf_UNet import LungDataset +from torchvision import transforms +# from LungData import test_dataloader, train_dataloader # pls change batch_size +from torch.utils.data import DataLoader +from Code.model_lung_infection.InfNet_UNet import * + + +def train(epo_num, num_classes, input_channels, batch_size, lr, save_path): + train_dataset = LungDataset( + imgs_path='./Dataset/COVID-SemiSeg/Dataset/TrainingSet/MultiClassInfection-Train/Imgs/', + # NOTES: prior is borrowed from the object-level label of train split + pseudo_path='./Dataset/COVID-SemiSeg/Dataset/TrainingSet/MultiClassInfection-Train/Prior/', + label_path='./Dataset/COVID-SemiSeg/Dataset/TrainingSet/MultiClassInfection-Train/GT/', + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])) + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + lung_model = Inf_Net_UNet(input_channels, num_classes) # input_channels=3, n_class=3 + print(lung_model) + lung_model = lung_model.to(device) + + criterion = nn.BCELoss().to(device) + optimizer = optim.SGD(lung_model.parameters(), lr=lr, momentum=0.7) + + print("#" * 20, "\nStart Training (Inf-Net)\nThis code is written for 'Inf-Net: Automatic COVID-19 Lung " + "Infection Segmentation from CT Scans', 2020, TMI.\n" + "----\nPlease cite the paper if you use this code and dataset. " + "And any questions feel free to contact me " + "via E-mail (gepengai.ji@gmail.com)\n----\n", "#" * 20) + + for epo in range(epo_num): + + train_loss = 0 + lung_model.train() + + for index, (img, pseudo, img_mask, _) in enumerate(train_dataloader): + + img = img.to(device) + pseudo = pseudo.to(device) + img_mask = img_mask.to(device) + + optimizer.zero_grad() + output = lung_model(torch.cat((img, pseudo), dim=1)) + + output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160]) + loss = criterion(output, img_mask) + + loss.backward() + iter_loss = loss.item() + train_loss += iter_loss + optimizer.step() + + if np.mod(index, 20) == 0: + print('Epoch: {}/{}, Step: {}/{}, Train loss is {}'.format(epo, epo_num, index, len(train_dataloader), iter_loss)) + + os.makedirs('./checkpoints//UNet_Multi-Class-Semi', exist_ok=True) + if np.mod(epo+1, 10) == 0: + torch.save(lung_model.state_dict(), + './Snapshots/save_weights/{}/unet_model_{}.pkl'.format(save_path, epo+1)) + print('Saving checkpoints: unet_model_{}.pkl'.format(epo+1)) + + +if __name__ == "__main__": + train(epo_num=200, + num_classes=3, + input_channels=6, + batch_size=16, + lr=1e-2, + save_path='{}') + +"""*Testing "MyTest_MulClsLungInf_UNet" for Semi-Inf-Net + Multi-class UNet* + +For running this file seperately from second method(which we do below), change the one before the last line which is " snapshot_dir='./Snapshots/save_weights/Semi-Inf-Net_UNet/unet_model_200.pkl' " to + " snapshot_dir='./Snapshots/save_weights/{}/unet_model_200.pkl' " +""" + +# -*- coding: utf-8 -*- + +"""Preview +Code for 'Inf-Net: Automatic COVID-19 Lung Infection Segmentation from CT Scans' +submit to Transactions on Medical Imaging, 2020. + +First Version: Created on 2020-05-13 (@author: Ge-Peng Ji) +""" + +import os +import numpy as np +from Code.utils.dataloader_MulClsLungInf_UNet import LungDataset +from torchvision import transforms +from torch.utils.data import DataLoader +from Code.model_lung_infection.InfNet_UNet import * # use U-Net for multi-class segmentation +from scipy import misc +from Code.utils.split_class import split_class +import shutil + + +def inference(num_classes, input_channels, snapshot_dir, save_path): + test_dataset = LungDataset( + imgs_path='./Dataset/COVID-SemiSeg/Dataset/TestingSet/MultiClassInfection-Test/Imgs/', + pseudo_path='./Dataset/COVID-SemiSeg/Results/Lung infection segmentation/Semi-Inf-Net/', # NOTES: generated from `Semi-Inf-Net` + label_path='./Dataset/COVID-SemiSeg/Dataset/TestingSet/MultiClassInfection-Test/GT/', + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])]), + is_test=True + ) + test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + lung_model = Inf_Net_UNet(input_channels, num_classes).cuda() + print(lung_model) + lung_model.load_state_dict(torch.load(snapshot_dir)) + lung_model.eval() + + for index, (img, pseudo, img_mask, name) in enumerate(test_dataloader): + img = img.to(device) + pseudo = pseudo.to(device) + img_mask = img_mask.to(device) + + output = lung_model(torch.cat((img, pseudo), dim=1)) + output = torch.sigmoid(output) # output.shape is torch.Size([4, 2, 160, 160]) + b, _, w, h = output.size() + _, _, w_gt, h_gt = img_mask.size() + + # output b*n_class*h*w -- > b*h*w + pred = output.cpu().permute(0, 2, 3, 1).contiguous().view(-1, num_classes).max(1)[1].view(b, w, h).numpy().squeeze() + print('Class numbers of prediction in total:', np.unique(pred)) + # pred = misc.imresize(pred, size=(w_gt, h_gt)) + os.makedirs(save_path, exist_ok=True) + #misc.imsave(save_path + name[0].replace('.jpg', '.png'), pred) + import imageio + imageio.imwrite(save_path + name[0].replace('.jpg', '.png'), pred) + split_class(save_path, name[0].replace('.jpg', '.png'), w_gt, h_gt) + + shutil.rmtree(save_path) + print('Test done!') + + +if __name__ == "__main__": + inference(num_classes=3, + input_channels=6, + snapshot_dir='./Snapshots/save_weights/{}/unet_model_200.pkl', + save_path='./Results/Multi-class lung infection segmentation/class_12/' + ) \ No newline at end of file