diff --git a/recognition/ISICs_UNet/README.md b/recognition/ISICs_UNet/README.md
index f2c009212e..788ea17b79 100644
--- a/recognition/ISICs_UNet/README.md
+++ b/recognition/ISICs_UNet/README.md
@@ -1,52 +1,101 @@
-# Segmenting ISICs with U-Net
+# Segment the ISICs data set with the U-net
-COMP3710 Report recognition problem 3 (Segmenting ISICs data set with U-Net) solved in TensorFlow
+## Project Overview
+This project aim to solve the segmentation of skin lesian (ISIC2018 data set) using the U-net, with all labels having a minimum Dice similarity coefficient of 0.7 on the test set[Task 3].
-Created by Christopher Bailey (45576430)
+## ISIC2018
+
-## The problem and algorithm
-The problem solved by this program is binary segmentation of the ISICs skin lesion data set. Segmentation is a way to label pixels in an image according to some grouping, in this case lesion or non-lesion. This translates images of skin to masks representing areas of concern for skin lesions.
+Skin Lesion Analysis towards Melanoma Detection
-U-Net is a form of autoencoder where the downsampling path is expected to learn the features of the image and the upsampling path learns how to recreate the masks. Long skip connections between downpooling and upsampling layers are utilised to overcome the bottleneck in traditional autoencoders allowing feature representations to be recreated.
+Task found in https://challenge2018.isic-archive.com/
-## How it works
-A four layer padded U-Net is used, preserving skin features and mask resolution. The implementation utilises Adam as the optimizer and implements Dice distance as the loss function as this appeared to give quicker convergence than other methods (eg. binary cross-entropy).
-The utilised metric is a Dice coefficient implementation. My initial implementation appeared faulty and was replaced with a 3rd party implementation which appears correct. 3 epochs was observed to be generally sufficient to observe Dice coefficients of 0.8+ on test datasets but occasional non-convergence was observed and could be curbed by increasing the number of epochs. Visualisation of predictions is also implemented and shows reasonable correspondence. Orange bandaids represent an interesting challenge for the implementation as presented.
+## U-net
+
-### Training, validation and testing split
-Training, validation and testing uses a respective 60:20:20 split, a commonly assumed starting point suggested by course staff. U-Net in particular was developed to work "with very few training images" (Ronneberger et al, 2015) The input data for this problem consists of 2594 images and masks. This split appears to provide satisfactory results.
+U-net is one of the popular image segmentation architectures used mostly in biomedical purposes. The name UNet is because it’s architecture contains a compressive path and an expansive path which can be viewed as a U shape. This architecture is built in such a way that it could generate better results even for a less number of training data sets.
-## Using the model
-### Dependencies required
-* Python3 (tested with 3.8)
-* TensorFlow 2.x (tested with 2.3)
-* glob (used to load filenames)
-* matplotlib (used for visualisations, tested with 3.3)
+## Data Set Structure
-### Parameter tuning
-The model was developed on a GTX 1660 TI (6GB VRAM) and certain values (notably batch size and image resolution) were set lower than might otherwise be ideal on more capable hardware. This is commented in the relevant code.
+data set folder need to be stored in same directory with structure same as below
+```bash
+ISIC2018
+ |_ ISIC2018_Task1-2_Training_Input_x2
+ |_ ISIC_0000000
+ |_ ISIC_0000001
+ |_ ...
+ |_ ISIC2018_Task1_Training_GroundTruth_x2
+ |_ ISIC_0000000_segmentation
+ |_ ISIC_0000001_segmentation
+ |_ ...
+```
-### Running the model
-The model is executed via the main.py script.
+## Dice Coefficient
-### Example output
-Given a batch size of 1 and 3 epochs the following output was observed on a single run:
-Era | Loss | Dice coefficient
---- | ---- | ----------------
-Epoch 1 | 0.7433 | 0.2567
-Epoch 2 | 0.3197 | 0.6803
-Epoch 3 | 0.2657 | 0.7343
-Testing | 0.1820 | 0.8180
+The Sørensen–Dice coefficient is a statistic used to gauge the similarity of two samples.
+Further information in https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
-### Figure 1 - example visualisation plot
-Skin images in left column, true mask middle, predicted mask right column
-
+## Dependencies
-## References
-Segments of code in this assignment were used from or based on the following sources:
-1. COMP3710-demo-code.ipynb from Guest Lecture
-1. https://www.tensorflow.org/tutorials/load_data/images
-1. https://www.tensorflow.org/guide/gpu
-1. Karan Jakhar (2019) https://medium.com/@karan_jakhar/100-days-of-code-day-7-84e4918cb72c
+- python 3
+- tensorflow 2.1.0
+- pandas 1.1.4
+- numpy 1.19.2
+- matplotlib 3.3.2
+- scikit-learn 0.23.2
+- pillow 8.0.1
+
+
+## Usages
+
+- Run `train.py` for training the UNet on ISIC data.
+- Run `evaluation.py` for evaluation and case present.
+
+## Advance
+
+- Modify `setting.py` for custom setting, such as different batch size.
+- Modify `unet.py` for custom UNet, such as different kernel size.
+
+## Algorithm
+
+- data set:
+ - The data set we used is the training set of ISIC 2018 challenge data which has segmentation labels.
+ - Training: Validation: Test = 1660: 415: 519 = 0.64: 0.16 : 0.2 (Training: Test = 4: 1 and in Training, further split 4: 1 for Training: Validation)
+ - Training data augmentations: rescale, rotate, shift, zoom, grayscale
+- model:
+ - Original UNet with padding which can keep the shape of input and output same.
+ - The first convolutional layers has 16 output channels.
+ - The activation function of all convolutional layers is ELU.
+ - Without batch normalization layers.
+ - The inputs is (384, 512, 1)
+ - The output is (384, 512, 1) after sigmoid activation.
+ - Optimizer: Adam, lr = 1e-4
+ - Loss: dice coefficient loss
+ - Metrics: accuracy & dice coefficient
+
+## Results
+
+Evaluation dice coefficient is 0.805256724357605.
+
+plot of train/valid Dice coefficient:
+
+
+
+case present:
+
+
+
+## Reference
+Manna, S. (2020). K-Fold Cross Validation for Deep Learning using Keras. [online] Medium. Available at: https://medium.com/the-owl/k-fold-cross-validation-in-keras-3ec4a3a00538 [Accessed 24 Nov. 2020].
+
+zhixuhao (2020). zhixuhao/unet. [online] GitHub. Available at: https://github.com/zhixuhao/unet.
+
+GitHub. (n.d.). NifTK/NiftyNet. [online] Available at: https://github.com/NifTK/NiftyNet/blob/a383ba342e3e38a7ad7eed7538bfb34960f80c8d/niftynet/layer/loss_segmentation.py [Accessed 24 Nov. 2020].
+
+Team, K. (n.d.). Keras documentation: Losses. [online] keras.io. Available at: https://keras.io/api/losses/#creating-custom-losses [Accessed 24 Nov. 2020].
+
+262588213843476 (n.d.). unet.py. [online] Gist. Available at: https://gist.github.com/abhinavsagar/fe0c900133cafe93194c069fe655ef6e [Accessed 24 Nov. 2020].
+
+Stack Overflow. (n.d.). python - Disable Tensorflow debugging information. [online] Available at: https://stackoverflow.com/questions/35911252/disable-tensorflow-debugging-information [Accessed 24 Nov. 2020].
diff --git a/recognition/SG_45762402/Dataset.py b/recognition/SG_45762402/Dataset.py
new file mode 100644
index 0000000000..7666d0c855
--- /dev/null
+++ b/recognition/SG_45762402/Dataset.py
@@ -0,0 +1,66 @@
+#!/usr/bin/env python
+# coding: utf-8
+
+# In[3]:
+
+
+#!pip install lmdb
+
+
+# In[4]:
+
+
+from io import BytesIO
+
+import lmdb
+from PIL import Image
+from torch.utils.data import Dataset
+
+
+
+
+'''
+Create lmdb Dataset for training:
+
+path:Path of image data
+
+'''
+class MultiResolutionDataset(Dataset):
+ def __init__(self, path, transform, resolution=8):
+ self.env = lmdb.open(
+ path,
+ max_readers=32,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False,
+ )
+
+ if not self.env:
+ raise IOError('Cannot open lmdb dataset', path)
+
+ with self.env.begin(write=False) as txn:
+ self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
+
+ self.resolution = resolution
+ self.transform = transform
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, index):
+ with self.env.begin(write=False) as txn:
+ key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
+ img_bytes = txn.get(key)
+
+ buffer = BytesIO(img_bytes)
+ img = Image.open(buffer)
+ img = self.transform(img)
+
+ return img
+
+
+
+
+
+
diff --git a/recognition/SG_45762402/Images/ADAIN.png b/recognition/SG_45762402/Images/ADAIN.png
new file mode 100644
index 0000000000..56c69ccef3
Binary files /dev/null and b/recognition/SG_45762402/Images/ADAIN.png differ
diff --git a/recognition/SG_45762402/Images/Structure.png b/recognition/SG_45762402/Images/Structure.png
new file mode 100644
index 0000000000..95d0ab3c84
Binary files /dev/null and b/recognition/SG_45762402/Images/Structure.png differ
diff --git a/recognition/SG_45762402/Images/sample-iter20500.png b/recognition/SG_45762402/Images/sample-iter20500.png
new file mode 100644
index 0000000000..28a22d92b9
Binary files /dev/null and b/recognition/SG_45762402/Images/sample-iter20500.png differ
diff --git a/recognition/SG_45762402/Images/sample-iter32500.png b/recognition/SG_45762402/Images/sample-iter32500.png
new file mode 100644
index 0000000000..e14537bc2a
Binary files /dev/null and b/recognition/SG_45762402/Images/sample-iter32500.png differ
diff --git a/recognition/SG_45762402/Images/sample-iter5000.png b/recognition/SG_45762402/Images/sample-iter5000.png
new file mode 100644
index 0000000000..964a45863d
Binary files /dev/null and b/recognition/SG_45762402/Images/sample-iter5000.png differ
diff --git a/recognition/SG_45762402/Images/sample-iter55000.png b/recognition/SG_45762402/Images/sample-iter55000.png
new file mode 100644
index 0000000000..a8e5a1d26f
Binary files /dev/null and b/recognition/SG_45762402/Images/sample-iter55000.png differ
diff --git a/recognition/SG_45762402/Images/size_128.png b/recognition/SG_45762402/Images/size_128.png
new file mode 100644
index 0000000000..07ecc891ef
Binary files /dev/null and b/recognition/SG_45762402/Images/size_128.png differ
diff --git a/recognition/SG_45762402/Images/size_256.png b/recognition/SG_45762402/Images/size_256.png
new file mode 100644
index 0000000000..b33dc70ed3
Binary files /dev/null and b/recognition/SG_45762402/Images/size_256.png differ
diff --git a/recognition/SG_45762402/Images/size_64.png b/recognition/SG_45762402/Images/size_64.png
new file mode 100644
index 0000000000..51f7af6f8d
Binary files /dev/null and b/recognition/SG_45762402/Images/size_64.png differ
diff --git a/recognition/SG_45762402/Images/size_8.png b/recognition/SG_45762402/Images/size_8.png
new file mode 100644
index 0000000000..e98d855f52
Binary files /dev/null and b/recognition/SG_45762402/Images/size_8.png differ
diff --git a/recognition/SG_45762402/Model1.py b/recognition/SG_45762402/Model1.py
new file mode 100644
index 0000000000..6fad041a8e
--- /dev/null
+++ b/recognition/SG_45762402/Model1.py
@@ -0,0 +1,357 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Sat Oct 30 21:28:48 2021
+
+
+reference: https://github.com/rosinality/style-based-gan-pytorch.git
+"""
+
+import torch
+
+from torch import nn
+from torch.nn import init
+from torch.nn import functional as F
+from torch.autograd import Function
+
+from math import sqrt
+
+import random
+
+
+def init_linear(linear):
+ init.xavier_normal(linear.weight)
+ linear.bias.data.zero_()
+
+
+def init_conv(conv, glu=True):
+ init.kaiming_normal(conv.weight)
+ if conv.bias is not None:
+ conv.bias.data.zero_()
+
+'''
+The equalized learning rate involves:
+
+1.Initializing all weights (linear and conv) from regular normal distribution, no fancy init
+2.Scaling all weights by the per-layer normalization constant from the Kaiming He initialization.
+
+'''
+
+class EqualLR:
+ def __init__(self, name):
+ self.name = name
+
+ def compute_weight(self, module):
+ weight = getattr(module, self.name + '_orig')
+ fan_in = weight.data.size(1) * weight.data[0][0].numel()
+
+ return weight * sqrt(2 / fan_in)
+
+ @staticmethod
+ def apply(module, name):
+ fn = EqualLR(name)
+
+ weight = getattr(module, name)
+ del module._parameters[name]
+ module.register_parameter(name + '_orig', nn.Parameter(weight.data))
+ module.register_forward_pre_hook(fn)
+
+ return fn
+
+ def __call__(self, module, input):
+ weight = self.compute_weight(module)
+ setattr(module, self.name, weight)
+
+
+def equal_lr(module, name='weight'):
+ EqualLR.apply(module, name)
+
+ return module
+
+
+'''
+The upsampling is achieved by deconvolution:
+F.conv_transpose2d(input, weight, self.bias, stride=2, padding=self.pad)
+and the downsampling is achieved by ordinary convolution:
+out = F. conv2d(input, weight, self.bias, stride=2, padding=self.pad).
+
+'''
+
+
+class FusedUpsample(nn.Module):
+ def __init__(self, in_channel, out_channel, kernel_size, padding=0):
+ super().__init__()
+
+ weight = torch.randn(in_channel, out_channel, kernel_size, kernel_size)
+ bias = torch.zeros(out_channel)
+
+ fan_in = in_channel * kernel_size * kernel_size
+ self.multiplier = sqrt(2 / fan_in)
+
+ self.weight = nn.Parameter(weight)
+ self.bias = nn.Parameter(bias)
+
+ self.pad = padding
+
+ def forward(self, input):
+ weight = F.pad(self.weight * self.multiplier, [1, 1, 1, 1])
+ weight = (
+ weight[:, :, 1:, 1:]
+ + weight[:, :, :-1, 1:]
+ + weight[:, :, 1:, :-1]
+ + weight[:, :, :-1, :-1]
+ ) / 4
+
+ out = F.conv_transpose2d(input, weight, self.bias, stride=2, padding=self.pad)
+
+ return out
+
+
+class FusedDownsample(nn.Module):
+ def __init__(self, in_channel, out_channel, kernel_size, padding=0):
+ super().__init__()
+
+ weight = torch.randn(out_channel, in_channel, kernel_size, kernel_size)
+ bias = torch.zeros(out_channel)
+
+ fan_in = in_channel * kernel_size * kernel_size
+ self.multiplier = sqrt(2 / fan_in)
+
+ self.weight = nn.Parameter(weight)
+ self.bias = nn.Parameter(bias)
+
+ self.pad = padding
+
+ def forward(self, input):
+ weight = F.pad(self.weight * self.multiplier, [1, 1, 1, 1])
+ weight = (
+ weight[:, :, 1:, 1:]
+ + weight[:, :, :-1, 1:]
+ + weight[:, :, 1:, :-1]
+ + weight[:, :, :-1, :-1]
+ ) / 4
+
+ out = F.conv2d(input, weight, self.bias, stride=2, padding=self.pad)
+
+ return out
+
+'''
+class PixelNorm(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input):
+ return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
+'''
+
+class PixelNorm(nn.Module):
+ def __init__(self, epsilon=1e-8):
+ super().__init__()
+ self.epsilon = epsilon
+ def forward(self, x):
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)
+
+
+class BlurFunctionBackward(Function):
+ @staticmethod
+ def forward(ctx, grad_output, kernel, kernel_flip):
+ ctx.save_for_backward(kernel, kernel_flip)
+
+ grad_input = F.conv2d(
+ grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]
+ )
+
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, gradgrad_output):
+ kernel, kernel_flip = ctx.saved_tensors
+
+ grad_input = F.conv2d(
+ gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]
+ )
+
+ return grad_input, None, None
+
+
+class BlurFunction(Function):
+ @staticmethod
+ def forward(ctx, input, kernel, kernel_flip):
+ ctx.save_for_backward(kernel, kernel_flip)
+
+ output = F.conv2d(input, kernel, padding=1, groups=input.shape[1])
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ kernel, kernel_flip = ctx.saved_tensors
+
+ grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)
+
+ return grad_input, None, None
+
+
+blur = BlurFunction.apply
+
+
+class Blur(nn.Module):
+ def __init__(self, channel):
+ super().__init__()
+
+ weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
+ weight = weight.view(1, 1, 3, 3)
+ weight = weight / weight.sum()
+ weight_flip = torch.flip(weight, [2, 3])
+
+ self.register_buffer('weight', weight.repeat(channel, 1, 1, 1))
+ self.register_buffer('weight_flip', weight_flip.repeat(channel, 1, 1, 1))
+
+ def forward(self, input):
+ return blur(input, self.weight, self.weight_flip)
+ # return F.conv2d(input, self.weight, padding=1, groups=input.shape[1])
+
+
+class EqualConv2d(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+
+ conv = nn.Conv2d(*args, **kwargs)
+ conv.weight.data.normal_()
+ conv.bias.data.zero_()
+ self.conv = equal_lr(conv)
+
+ def forward(self, input):
+ return self.conv(input)
+
+
+class EqualLinear(nn.Module):
+ def __init__(self, in_dim, out_dim):
+ super().__init__()
+
+ linear = nn.Linear(in_dim, out_dim)
+ linear.weight.data.normal_()
+ linear.bias.data.zero_()
+
+ self.linear = equal_lr(linear)
+
+ def forward(self, input):
+ return self.linear(input)
+
+
+class ConvBlock(nn.Module):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ padding,
+ kernel_size2=None,
+ padding2=None,
+ downsample=False,
+ fused=False,
+ ):
+ super().__init__()
+
+ pad1 = padding
+ pad2 = padding
+ if padding2 is not None:
+ pad2 = padding2
+
+ kernel1 = kernel_size
+ kernel2 = kernel_size
+ if kernel_size2 is not None:
+ kernel2 = kernel_size2
+
+ self.conv1 = nn.Sequential(
+ EqualConv2d(in_channel, out_channel, kernel1, padding=pad1),
+ nn.LeakyReLU(0.2),
+ )
+
+ if downsample:
+ if fused:
+ self.conv2 = nn.Sequential(
+ Blur(out_channel),
+ FusedDownsample(out_channel, out_channel, kernel2, padding=pad2),
+ nn.LeakyReLU(0.2),
+ )
+
+ else:
+ self.conv2 = nn.Sequential(
+ Blur(out_channel),
+ EqualConv2d(out_channel, out_channel, kernel2, padding=pad2),
+ nn.AvgPool2d(2),
+ nn.LeakyReLU(0.2),
+ )
+
+ else:
+ self.conv2 = nn.Sequential(
+ EqualConv2d(out_channel, out_channel, kernel2, padding=pad2),
+ nn.LeakyReLU(0.2),
+ )
+
+ def forward(self, input):
+ out = self.conv1(input)
+ out = self.conv2(out)
+
+ return out
+
+# AdaIn (AdaptiveInstanceNorm)
+class AdaptiveInstanceNorm(nn.Module):
+ def __init__(self, in_channel, style_dim):
+ super().__init__()
+
+ self.norm = nn.InstanceNorm2d(in_channel)
+ self.style = EqualLinear(style_dim, in_channel * 2)
+
+ self.style.linear.bias.data[:in_channel] = 1
+ self.style.linear.bias.data[in_channel:] = 0
+
+ def forward(self, input, style):
+ style = self.style(style).unsqueeze(2).unsqueeze(3)
+ gamma, beta = style.chunk(2, 1)
+
+ out = self.norm(input)
+ out = gamma * out + beta
+
+ return out
+
+
+class NoiseInjection(nn.Module):
+ def __init__(self, channel):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
+
+ def forward(self, image, noise):
+ return image + self.weight * noise
+
+
+'''
+class NoiseInjection(nn.Module):
+ #adds noise. noise is per pixel (constant over channels) with per-channel weight
+ def __init__(self, channel):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
+
+ def forward(self, image, noise=None):
+ # here is a little trick: if you get all the noiselayers and set each
+ # modules .noise attribute, you can have pre-defined noise.
+ # Very useful for analysis
+ if noise is None and self.noise is None:
+ noise = torch.randn(image.size(0), 1, image.size(2), image.size(3), device=image.device, dtype=image.dtype)
+ elif noise is None:
+ return image + self.weight * noise
+'''
+
+class ConstantInput(nn.Module):
+ def __init__(self, channel, size=4):
+ super().__init__()
+
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
+
+ def forward(self, input):
+ batch = input.shape[0]
+ out = self.input.repeat(batch, 1, 1, 1)
+
+ return out
diff --git a/recognition/SG_45762402/Model2.py b/recognition/SG_45762402/Model2.py
new file mode 100644
index 0000000000..056b389ed9
--- /dev/null
+++ b/recognition/SG_45762402/Model2.py
@@ -0,0 +1,288 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Sun Oct 31 01:04:17 2021
+
+"""
+from Model1 import *
+
+'''
+The first convolutional layer, if upsampling is required, this layer is responsible;
+The first noise injection layer is used to inject random factors;
+LeakyReLU is activated;
+The first adaptive instance normalization layer is used to inject styles;
+The second convolutional layer;
+The second noise injection layer;
+LeakyReLU is activated;
+The second adaptive instance normalization layer;
+
+'''
+class StyledConvBlock(nn.Module):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size=3,
+ padding=1,
+ style_dim=512,
+ initial=False,
+ upsample=False,
+ fused=False,
+ ):
+ super().__init__()
+
+ if initial:
+ self.conv1 = ConstantInput(in_channel)
+
+ else:
+ if upsample:
+ if fused:
+ self.conv1 = nn.Sequential(
+ FusedUpsample(
+ in_channel, out_channel, kernel_size, padding=padding
+ ),
+ Blur(out_channel),
+ )
+
+ else:
+ self.conv1 = nn.Sequential(
+ nn.Upsample(scale_factor=2, mode='nearest'),
+ EqualConv2d(
+ in_channel, out_channel, kernel_size, padding=padding
+ ),
+ Blur(out_channel),
+ )
+
+ else:
+ self.conv1 = EqualConv2d(
+ in_channel, out_channel, kernel_size, padding=padding
+ )
+
+ self.noise1 = equal_lr(NoiseInjection(out_channel))
+ self.adain1 = AdaptiveInstanceNorm(out_channel, style_dim)
+ self.lrelu1 = nn.LeakyReLU(0.2)
+
+ self.conv2 = EqualConv2d(out_channel, out_channel, kernel_size, padding=padding)
+ self.noise2 = equal_lr(NoiseInjection(out_channel))
+ self.adain2 = AdaptiveInstanceNorm(out_channel, style_dim)
+ self.lrelu2 = nn.LeakyReLU(0.2)
+
+ def forward(self, input, style, noise):
+ out = self.conv1(input)
+ out = self.noise1(out, noise)
+ out = self.lrelu1(out)
+ out = self.adain1(out, style)
+
+ out = self.conv2(out)
+ out = self.noise2(out, noise)
+ out = self.lrelu2(out)
+ out = self.adain2(out, style)
+
+ return out
+
+
+class Generator(nn.Module):
+ def __init__(self, code_dim, fused=True):
+ super().__init__()
+
+ self.progression = nn.ModuleList(
+ [
+ StyledConvBlock(512, 512, 3, 1, initial=True), # 4
+ StyledConvBlock(512, 512, 3, 1, upsample=True), # 8
+ StyledConvBlock(512, 512, 3, 1, upsample=True), # 16
+ StyledConvBlock(512, 512, 3, 1, upsample=True), # 32
+ StyledConvBlock(512, 256, 3, 1, upsample=True), # 64
+ StyledConvBlock(256, 128, 3, 1, upsample=True, fused=fused), # 128
+ StyledConvBlock(128, 64, 3, 1, upsample=True, fused=fused), # 256
+ StyledConvBlock(64, 32, 3, 1, upsample=True, fused=fused), # 512
+ StyledConvBlock(32, 16, 3, 1, upsample=True, fused=fused), # 1024
+ ]
+ )
+
+ self.to_rgb = nn.ModuleList(
+ [
+ EqualConv2d(512, 3, 1),
+ EqualConv2d(512, 3, 1),
+ EqualConv2d(512, 3, 1),
+ EqualConv2d(512, 3, 1),
+ EqualConv2d(256, 3, 1),
+ EqualConv2d(128, 3, 1),
+ EqualConv2d(64, 3, 1),
+ EqualConv2d(32, 3, 1),
+ EqualConv2d(16, 3, 1),
+ ]
+ )
+
+ # self.blur = Blur()
+
+ def forward(self, style, noise, step=0, alpha=-1, mixing_range=(-1, -1)):
+ out = noise[0]
+
+ if len(style) < 2:
+ inject_index = [len(self.progression) + 1]
+
+ else:
+ inject_index = sorted(random.sample(list(range(step)), len(style) - 1))
+
+ crossover = 0
+
+ for i, (conv, to_rgb) in enumerate(zip(self.progression, self.to_rgb)):
+ if mixing_range == (-1, -1):
+ if crossover < len(inject_index) and i > inject_index[crossover]:
+ crossover = min(crossover + 1, len(style))
+
+ style_step = style[crossover]
+
+ else:
+ if mixing_range[0] <= i <= mixing_range[1]:
+ style_step = style[1]
+
+ else:
+ style_step = style[0]
+
+ if i > 0 and step > 0:
+ out_prev = out
+
+ out = conv(out, style_step, noise[i])
+
+ if i == step:
+ out = to_rgb(out)
+
+ if i > 0 and 0 <= alpha < 1:
+ skip_rgb = self.to_rgb[i - 1](out_prev)
+ skip_rgb = F.interpolate(skip_rgb, scale_factor=2, mode='nearest')
+ out = (1 - alpha) * skip_rgb + alpha * out
+
+ break
+
+ return out
+
+
+class StyledGenerator(nn.Module):
+ def __init__(self, code_dim=512, n_mlp=8):
+ super().__init__()
+
+ self.generator = Generator(code_dim)
+
+ layers = [PixelNorm()]
+ for i in range(n_mlp):
+ layers.append(EqualLinear(code_dim, code_dim))
+ layers.append(nn.LeakyReLU(0.2))
+
+ self.style = nn.Sequential(*layers)
+
+ def forward(
+ self,
+ input,
+ noise=None,# TODO: support input noise
+ step=0, # Step means how many layers (count from 4 x 4) are used to train
+ alpha=-1,# alpha is the parameter of smooth conversion of resolution):
+ mean_style=None,# TODO: support mean_style
+ style_weight=0,
+ mixing_range=(-1, -1),
+ ):
+ styles = []
+ if type(input) not in (list, tuple):
+ input = [input]
+
+ for i in input:
+ styles.append(self.style(i))
+
+ batch = input[0].shape[0]
+
+ if noise is None:
+ noise = []
+
+ for i in range(step + 1):
+ size = 4 * 2 ** i
+ noise.append(torch.randn(batch, 1, size, size, device=input[0].device))
+
+ if mean_style is not None:
+ styles_norm = []
+
+ for style in styles:
+ styles_norm.append(mean_style + style_weight * (style - mean_style))
+
+ styles = styles_norm
+
+ return self.generator(styles, noise, step, alpha, mixing_range=mixing_range)
+
+ def mean_style(self, input):
+ style = self.style(input).mean(0, keepdim=True)
+
+ return style
+
+
+class Discriminator(nn.Module):
+ def __init__(self, fused=True, from_rgb_activate=False):
+ super().__init__()
+
+ self.progression = nn.ModuleList(
+ [
+ ConvBlock(16, 32, 3, 1, downsample=True, fused=fused), # 512
+ ConvBlock(32, 64, 3, 1, downsample=True, fused=fused), # 256
+ ConvBlock(64, 128, 3, 1, downsample=True, fused=fused), # 128
+ ConvBlock(128, 256, 3, 1, downsample=True, fused=fused), # 64
+ ConvBlock(256, 512, 3, 1, downsample=True), # 32
+ ConvBlock(512, 512, 3, 1, downsample=True), # 16
+ ConvBlock(512, 512, 3, 1, downsample=True), # 8
+ ConvBlock(512, 512, 3, 1, downsample=True), # 4
+ ConvBlock(513, 512, 3, 1, 4, 0),
+ ]
+ )
+
+ def make_from_rgb(out_channel):
+ if from_rgb_activate:
+ return nn.Sequential(EqualConv2d(3, out_channel, 1), nn.LeakyReLU(0.2))
+
+ else:
+ return EqualConv2d(3, out_channel, 1)
+
+ self.from_rgb = nn.ModuleList(
+ [
+ make_from_rgb(16),
+ make_from_rgb(32),
+ make_from_rgb(64),
+ make_from_rgb(128),
+ make_from_rgb(256),
+ make_from_rgb(512),
+ make_from_rgb(512),
+ make_from_rgb(512),
+ make_from_rgb(512),
+ ]
+ )
+
+ # self.blur = Blur()
+
+ self.n_layer = len(self.progression)
+
+ self.linear = EqualLinear(512, 1)
+
+ def forward(self, input, step=0, alpha=-1):
+ for i in range(step, -1, -1):
+ index = self.n_layer - i - 1
+
+ if i == step:
+ out = self.from_rgb[index](input)
+
+ if i == 0:
+ out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8)
+ mean_std = out_std.mean()
+ mean_std = mean_std.expand(out.size(0), 1, 4, 4)
+ out = torch.cat([out, mean_std], 1)
+
+ out = self.progression[index](out)
+
+ if i > 0:
+ if i == step and 0 <= alpha < 1:
+ skip_rgb = F.avg_pool2d(input, 2)
+ skip_rgb = self.from_rgb[index + 1](skip_rgb)
+
+ out = (1 - alpha) * skip_rgb + alpha * out
+
+ out = out.squeeze(2).squeeze(2)
+ # print(input.size(), out.size(), step)
+ out = self.linear(out)
+
+ return out
+
+
diff --git a/recognition/SG_45762402/README.md b/recognition/SG_45762402/README.md
new file mode 100644
index 0000000000..bdd5ff2cd1
--- /dev/null
+++ b/recognition/SG_45762402/README.md
@@ -0,0 +1,126 @@
+# StyleGan_Oasis
+
+This project is an unofficial implementation of StyleGan based on the Oasis brain dataset.
+
+Reference: https://github.com/rosinality/style-based-gan-pytorch.git , https://github.com/SiskonEmilia/StyleGAN-PyTorch.git
+
+Original Paper: https://arxiv.org/abs/1812.04948
+
+## Introduction
+
+"Style" in StyleGAN here refers to the main attributes of the brain region in the data set, where Style refers to the style of the brain structure, such as brain gray matter, brain lesions, etc.
+
+StyleGAN uses **style** to affect rough information such as brain size and shape, and uses **noise** to affect the details of the cross-section of the brain.
+
+## Dataset
+
+The OASIS datasets hosted by central.xnat.org provide the community with open access to a significant database of neuroimaging and processed imaging data across a broad demographic, cognitive, and genetic spectrum an easily accessible platform for use in neuroimaging, clinical, and cognitive research on normal aging and cognitive decline. All data is available via [www.oasis-brains.org](https://www.oasis-brains.org/).s
+
+## Requirements
+
+- Python3
+
+- Pytorch >= 1.0.0
+- lmdb
+- tqdm
+
+## Usage
+
+- ### Prepare the data
+
+ ```
+ !python prepare_data.py --out LMDB_PATH --n_worker N DATAPATH
+ ```
+
+ This step will generate a LMDB Dataset for training
+
+- ### Training StyleGan(Driver.script)
+
+ To start train, use `path=LMDB_PATH`in train_stylegan.py.
+
+ ```
+ !python train_stylegan.py
+ ```
+
+ ### Generate Pictures
+
+ Run generate_mixing.py , note that default parameters should be changed
+
+ `size=64`
+ `n_row=1`
+ `n_col=5`
+ `path=model.path`
+ `mixing=True`
+ `num_mixing=20`
+
+ ```
+ !python generate_mixing.py
+ ```
+
+- ### Train from pre-trained model
+
+ To continue training from previous training -model ,change default `ckpt=pre-trained.model` in train_stylegan.py.
+
+- ### During Training Process--examples
+
+
+
+ 


+
+
+
+- ### Some Generate Samples --style mixing
+
+ #### 64*64 images
+
+
+
+ #### 128*128 images
+
+
+
+
+
+ #### 256*256 images
+
+
+
+
+
+## Model Structure
+
+### Overview
+
+
+
+**Mapping network** --- **latent code**
+
+Mapping network changes the latent code **z** into **w**. In GAN, z is a random vector that conforms to a uniform distribution or a Gaussian distribution. The Mapping network consists of 8 fully connected layers. Through a series of affine transformations, w is obtained from z, and this w is converted into style Y=(Ys,Yb), combined with the AdaIN (adaptive instance normalization) style transformation method:
+
+
+
+**Style-mixing**
+
+The style-mixing part is the biggest difference compared to GAN. The specific method of style-mixing is to input two different latent codes ***z1*** and ***z2*** into the mapping network, and get ***w1*** and ***w2*** respectively, which represent two different styles, and then randomly select an intermediate intersection point in the synthesis network. Use ***w1*** for the part before the intersection and ***w2*** for the part after the intersection. The generated image should have both source A and source B features.
+
+### Parameters
+
+| Parameter | Description |
+| :-------------: | :----------------------------------------------------------: |
+| n_gpu | number of GPUs used to train the model |
+| device | default device to create and save tensors |
+| learning_rate | a dict to indicate learning rate at different stage of training |
+| batch_size | a dict to indicate batch size at different stage of training |
+| mini_batch_size | minimal batch size |
+| n_fc | number of layers in the full-connected mapping network |
+| dim_latent | dimension of latent space |
+| dim_input | size of the first layer of generator |
+| n_sample | how many samples will be used to train a single layer |
+| step | which layer to start training |
+| ckpt | checkpoint model file |
+| Path | Data file path --LMDB |
+| max_step | maximum resolution of images is 2 ^ (max_step + 2) |
+| loss | Options: wgan-gp,r1,Default=wgan-gp |
+| mixing | Whether to use mixing regularization, Default=True |
+| n_critic | How many times the discriminator is updated every time the generator is updated |
+
diff --git a/recognition/SG_45762402/generate_mixing.py b/recognition/SG_45762402/generate_mixing.py
new file mode 100644
index 0000000000..1e2f6f927f
--- /dev/null
+++ b/recognition/SG_45762402/generate_mixing.py
@@ -0,0 +1,152 @@
+# -*- coding: utf-8 -*-
+"""Generate_mixing.ipynb
+
+Automatically generated by Colaboratory.
+
+Original file is located at
+ https://colab.research.google.com/drive/1fLvHVr-RXe6X3aYCR7XIAgth-c54qP2_
+
+StyleGan_pytorch_generator
+"""
+
+#!cd /content/drive/MyDrive/Stylegan_shang/Model2.py
+
+from Model2 import StyledGenerator
+
+#import os
+
+#os.chdir('/content/drive/MyDrive/Stylegan_shang') #修改当前工作目录
+
+import argparse
+import math
+
+import torch
+from torchvision import utils
+
+device= torch.device('cuda:0')
+#generator = StyledGenerator().to(device)
+
+def sample(generator, step, mean_style, n_sample):
+ image = generator(
+ torch.randn(n_sample, 512).to(device),
+ step=step,
+ alpha=1,
+ mean_style=mean_style,
+ style_weight=0.7,
+ )
+
+ return image
+
+def get_mean_style(generator, device):
+ mean_style = None
+
+ for i in range(10):
+ style = generator.mean_style(torch.randn(1024, 512).to(device))
+
+ if mean_style is None:
+ mean_style = style
+
+ else:
+ mean_style += style
+
+ mean_style /= 10
+ return mean_style
+
+def style_mixing(generator, step, mean_style, n_source, n_target, device):
+ source_code = torch.randn(n_source, 512).to(device)
+ target_code = torch.randn(n_target, 512).to(device)
+
+ shape = 4 * 2 ** step
+ alpha = 1
+
+ images = [torch.ones(1, 3, shape, shape).to(device) * -1]
+
+ source_image = generator(
+ source_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7
+ )
+ target_image = generator(
+ target_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7
+ )
+
+ images.append(source_image)
+
+ for i in range(n_target):
+ image = generator(
+ [target_code[i].unsqueeze(0).repeat(n_source, 1), source_code],
+ step=step,
+ alpha=alpha,
+ mean_style=mean_style,
+ style_weight=0.7,
+ mixing_range=(0, 1),
+ )
+ images.append(target_image[i].unsqueeze(0))
+ images.append(image)
+
+ images = torch.cat(images, 0)
+
+ return images
+
+'''
+Note that:
+train_step-2: Genarate size=8
+train_step-3:Genarate size=16
+train_step-4:Genarate size=32
+train_step-5:Genarate size=64
+and so on
+'''
+size=64
+n_row=1
+n_col=5
+path='/content/drive/MyDrive/StyleGAN_rosin/checkpoint/train_step-5.model'
+mixing=True
+num_mixing=20
+
+
+
+from train_stylegan import imshow
+
+if __name__ == '__main__':
+
+
+ #512
+ generator = StyledGenerator().to(device)
+ generator.load_state_dict(torch.load(path)['g_running'])
+
+ #generator.eval()
+ generator.train()
+
+ mean_style = get_mean_style(generator, device)
+
+ step = int(math.log(size, 2)) - 2
+
+ img = sample(generator, step, mean_style, n_row * n_col)
+ utils.save_image(img, 'g_sample.png', nrow=n_col, normalize=True, range=(-1, 1))
+
+
+ if mixing==True:
+ for j in range(num_mixing):
+ img = style_mixing(generator, step, mean_style, n_col, n_row, device)
+ #imshow(img,j)
+ utils.save_image(
+ img, f'sample_mixing_{j}.png', nrow=n_col + 1, normalize=True, range=(-1, 1)
+ )
+
+
+
+
+
+
+
+
+
+
+
+ '''
+
+ for j in range(20):
+ img = style_mixing(generator, step, mean_style, args.n_col, args.n_row, device)
+ utils.save_image(
+ img, f'sample_mixing_{j}.png', nrow=args.n_col + 1, normalize=True, range=(-1, 1)
+ )
+
+ '''
\ No newline at end of file
diff --git a/recognition/SG_45762402/prepare_data.py b/recognition/SG_45762402/prepare_data.py
new file mode 100644
index 0000000000..6e617452b0
--- /dev/null
+++ b/recognition/SG_45762402/prepare_data.py
@@ -0,0 +1,76 @@
+# -*- coding: utf-8 -*-
+
+
+
+import argparse
+from io import BytesIO
+import multiprocessing
+from functools import partial
+
+from PIL import Image
+import lmdb
+from tqdm import tqdm
+from torchvision import datasets
+from torchvision.transforms import functional as trans_fn
+
+
+def resize_and_convert(img, size, quality=100):
+ img = trans_fn.resize(img, size, Image.LANCZOS)
+ img = trans_fn.center_crop(img, size)
+ buffer = BytesIO()
+ img.save(buffer, format='jpeg', quality=quality)
+ val = buffer.getvalue()
+
+ return val
+
+
+def resize_multiple(img, sizes=(8, 16, 32, 64, 128, 256, 512, 1024), quality=100):
+ imgs = []
+
+ for size in sizes:
+ imgs.append(resize_and_convert(img, size, quality))
+
+ return imgs
+
+
+def resize_worker(img_file, sizes):
+ i, file = img_file
+ img = Image.open(file)
+ img = img.convert('RGB')
+ out = resize_multiple(img, sizes=sizes)
+
+ return i, out
+
+
+def prepare(transaction, dataset, n_worker, sizes=(8, 16, 32, 64, 128, 256, 512, 1024)):
+ resize_fn = partial(resize_worker, sizes=sizes)
+
+ files = sorted(dataset.imgs, key=lambda x: x[0])
+ files = [(i, file) for i, (file, label) in enumerate(files)]
+ total = 0
+
+ with multiprocessing.Pool(n_worker) as pool:
+ for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
+ for size, img in zip(sizes, imgs):
+ key = f'{size}-{str(i).zfill(5)}'.encode('utf-8')
+ transaction.put(key, img)
+
+ total += 1
+
+ transaction.put('length'.encode('utf-8'), str(total).encode('utf-8'))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--out', type=str)
+ parser.add_argument('--n_worker', type=int, default=8)
+ parser.add_argument('path', type=str)
+
+ args = parser.parse_args()
+
+ imgset = datasets.ImageFolder(args.path)
+
+ with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
+ with env.begin(write=True) as txn:
+ prepare(txn, imgset, args.n_worker)
+
diff --git a/recognition/SG_45762402/sample/sample b/recognition/SG_45762402/sample/sample
new file mode 100644
index 0000000000..fd452327cc
--- /dev/null
+++ b/recognition/SG_45762402/sample/sample
@@ -0,0 +1 @@
+#sample
diff --git a/recognition/SG_45762402/train_stylegan.py b/recognition/SG_45762402/train_stylegan.py
new file mode 100644
index 0000000000..0b21a60135
--- /dev/null
+++ b/recognition/SG_45762402/train_stylegan.py
@@ -0,0 +1,421 @@
+# -*- coding: utf-8 -*-
+"""Stylegan_rosinality
+
+Automatically generated by Colaboratory.
+
+Original file is located at
+ https://colab.research.google.com/drive/1z2Vt8eaB3pyMR9xx6spcpFYFaCk1XLdz
+"""
+
+
+
+"""Test"""
+
+#!python train.py --ckpt /content/drive/MyDrive/StyleGAN_rosin/checkpoint/train_step-4.model --mixing /content/drive/MyDrive/StyleGAN_rosin/LMDB_PATH
+
+
+#os.chdir('/content/drive/MyDrive/Stylegan_shang') #修改当前工作目录
+
+#!python train_SG.py --mixing /content/drive/MyDrive/StyleGAN_rosin/LMDB_PATH
+
+import argparse
+import random
+import math
+
+from tqdm import tqdm
+import numpy as np
+from PIL import Image
+
+import torch
+from torch import nn, optim
+from torch.nn import functional as F
+from torch.autograd import Variable, grad
+from torch.utils.data import DataLoader
+from torchvision import datasets, transforms, utils
+
+from Dataset import MultiResolutionDataset
+from Model2 import StyledGenerator, Discriminator
+import matplotlib.pyplot as plt
+
+# use idel gpu
+# it's better to use enviroment variable
+# if using multiple gpus, please
+# modify hyperparameters at the same time
+# And Make Sure Your Pytorch Version >= 1.0.1
+import os
+os.environ['CUDA_VISIBLE_DEVICES']='0'
+n_gpu = 1
+device = torch.device('cuda:0')
+Path='/content/drive/MyDrive/StyleGAN_rosin/LMDB_PATH'
+ckpt=None
+
+
+
+
+#learning_rate = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
+batch_size_1gpu = {4: 128, 8: 128, 16: 64, 32: 32, 64: 16, 128: 16}
+mini_batch_size_1 = 8
+#batch_size = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
+mini_batch_size = 8
+batch_size_4gpus = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}
+mini_batch_size_4 = 16
+batch_size_8gpus = {4: 512, 8: 256, 16: 128, 32: 64}
+mini_batch_size_8 = 32
+n_fc = 8
+dim_latent = 512
+dim_input = 4
+n_sample = 120000 #number of samples used for each training phases
+DGR = 1
+n_show_loss = 500
+step = 1 # Train from (8 * 8)
+max_step = 8 # Maximum step (8 for 1024^2)
+#style_mixing = [] # Waiting to implement
+image_folder_path = '/content/drive/MyDrive/Dataset_brain/keras_png_slices_data'
+save_folder_path = '/content/drive/MyDrive/Stylegan_shang/results'
+
+low_steps = [0, 1, 2]
+# style_mixing += low_steps
+mid_steps = [3, 4, 5]
+# style_mixing += mid_steps
+hig_steps = [6, 7, 8]
+# style_mixing += hig_steps
+
+# Used to continue training from last checkpoint
+startpoint = 0
+used_sample = 0
+alpha = 0
+
+# Mode: Evaluate? Train?
+is_train = True
+
+# How to start training?
+# True for start from saved model
+# False for retrain from the very beginning
+is_continue = True
+d_losses = [float('inf')]
+g_losses = [float('inf')]
+inputs, outputs = [], []
+
+def set_grad_flag(module, flag=True):
+ for p in module.parameters():
+ p.requires_grad = flag
+
+def reset_LR(optimizer, lr):
+ for pam_group in optimizer.param_groups:
+ mul = pam_group.get('mul', 1)
+ pam_group['lr'] = lr * mul
+
+
+def accumulate(model1, model2, decay=0.999):
+ par1 = dict(model1.named_parameters())
+ par2 = dict(model2.named_parameters())
+
+ for k in par1.keys():
+ par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)
+
+
+# Gain sample
+def gain_sample(dataset, batch_size, image_size=4):
+ dataset.resolution = image_size
+ loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=1, drop_last=True)
+
+ return loader
+
+def imshow(tensor, i):
+ grid = tensor[0]
+ grid.clamp_(-1, 1).add_(1).div_(2)
+ # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
+ ndarr = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
+ img = Image.fromarray(ndarr)
+ #img.save(f'{save_folder_path}sample-iter{i}.png')
+ plt.imshow(img)
+ plt.show()
+
+init_size=8 #Initial image size,default=8
+batch_default=32
+max_size=1024 #Max image size,default=1024
+ckpt=None
+loss='wgan-gp' #options:wgan-gp,r1
+#gen_sample = {512: (8, 4), 1024: (4, 2)}
+mixing=True
+no_from_rgb_activate=True
+n_critic=1
+
+#learning_rate.get(8,0.001)
+
+def train(dataset, generator, discriminator,loss):
+ step = int(math.log2(init_size)) - 2
+
+ resolution = 4 * 2 ** step
+ loader = gain_sample(
+ dataset, batch_size.get(resolution, batch_default), resolution
+ )
+ data_loader = iter(loader)
+
+ reset_LR(g_optimizer, learning_rate.get(resolution, 0.001))
+ reset_LR(d_optimizer, learning_rate.get(resolution, 0.001))
+
+ #Epoch=1,000,000
+
+ #pbar = tqdm(range(1000000))
+ pbar = tqdm(range(startpoint + 1, n_sample * 5))
+
+ set_grad_flag(generator, False)
+ set_grad_flag(discriminator, True)
+
+ #Initializing
+ disc_loss_val = 0
+ gen_loss_val = 0
+ grad_loss_val = 0
+
+ alpha = 0
+ used_sample = 0
+
+ max_step = int(math.log2(max_size)) - 2
+ final_progress = False
+
+ for i in pbar:
+ discriminator.zero_grad()
+
+ #alpha = min(1, 1 / n_sample * (used_sample + 1))
+ alpha = min(1, alpha + batch_size.get(resolution, mini_batch_size) / (n_sample * 2))
+
+ if (resolution == init_size and ckpt is None) or final_progress:
+ alpha = 1
+
+ if used_sample > n_sample * 2:
+ used_sample = 0
+ step += 1
+
+ if step > max_step:
+ step = max_step
+ final_progress = True
+ ckpt_step = step + 1
+
+ else:
+ alpha = 0
+ ckpt_step = step
+
+ resolution = 4 * 2 ** step
+
+ loader = gain_sample(
+ dataset, batch_size.get(resolution, batch_default), resolution
+ )
+ data_loader = iter(loader)
+
+
+ torch.save(
+ {
+ 'generator': generator.module.state_dict(),
+ 'discriminator': discriminator.module.state_dict(),
+ 'g_optimizer': g_optimizer.state_dict(),
+ 'd_optimizer': d_optimizer.state_dict(),
+ 'g_running': g_running.state_dict(),
+ },
+ f'checkpoint/train_step-{ckpt_step}.pth',
+ )
+
+ reset_LR(g_optimizer, learning_rate.get(resolution, 0.001))
+ reset_LR(d_optimizer, learning_rate.get(resolution, 0.001))
+
+ try:
+ real_image = next(data_loader)
+
+ except (OSError, StopIteration):
+ data_loader = iter(loader)
+ real_image = next(data_loader)
+
+ used_sample += real_image.shape[0]
+
+ b_size = real_image.size(0)
+ real_image = real_image.cuda()
+
+ #Loss function of discriminator
+ if loss == 'wgan-gp':
+ real_predict = discriminator(real_image, step=step, alpha=alpha)
+ real_predict = real_predict.mean() - 0.001 * (real_predict ** 2).mean()
+ (-real_predict).backward()
+
+ elif loss == 'r1':
+ real_image.requires_grad= True
+ real_scores = discriminator(real_image, step=step, alpha=alpha)
+ real_predict = F.softplus(-real_scores).mean()
+ real_predict.backward(retain_graph=True)
+
+ grad_real = grad(
+ outputs=real_scores.sum(), inputs=real_image, create_graph=True
+ )[0]
+ grad_penalty = (
+ grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2
+ ).mean()
+ grad_penalty = 10 / 2 * grad_penalty
+ grad_penalty.backward()
+ if i%10 == 0:
+ grad_loss_val = grad_penalty.item()
+
+ if mixing==True and random.random() < 0.9:
+ gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn(
+ 4, b_size, dim_latent, device='cuda'
+ ).chunk(4, 0)
+ gen_in1 = [gen_in11.squeeze(0), gen_in12.squeeze(0)]
+ gen_in2 = [gen_in21.squeeze(0), gen_in22.squeeze(0)]
+
+ else:
+ gen_in1, gen_in2 = torch.randn(2, b_size, dim_latent, device='cuda').chunk(
+ 2, 0
+ )
+ gen_in1 = gen_in1.squeeze(0)
+ gen_in2 = gen_in2.squeeze(0)
+
+ fake_image = generator(gen_in1, step=step, alpha=alpha)
+ fake_predict = discriminator(fake_image, step=step, alpha=alpha)
+
+ if loss == 'wgan-gp':
+ fake_predict = fake_predict.mean()
+ fake_predict.backward()
+
+ eps = torch.rand(b_size, 1, 1, 1).cuda()
+ x_hat = eps * real_image.data + (1 - eps) * fake_image.data
+ x_hat.requires_grad= True
+ hat_predict = discriminator(x_hat, step=step, alpha=alpha)
+ grad_x_hat = grad(
+ outputs=hat_predict.sum(), inputs=x_hat, create_graph=True
+ )[0]
+ grad_penalty = (
+ (grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1) ** 2
+ ).mean()
+ grad_penalty = 10 * grad_penalty
+ grad_penalty.backward()
+ if i%10 == 0:
+ grad_loss_val = grad_penalty.item()
+ disc_loss_val = (-real_predict + fake_predict).item()
+
+ elif loss == 'r1':
+ fake_predict = F.softplus(fake_predict).mean()
+ fake_predict.backward()
+ if i%10 == 0:
+ disc_loss_val = (real_predict + fake_predict).item()
+
+ d_optimizer.step()
+
+ #Loss function of generator
+ if (i + 1) % n_critic == 0:
+ generator.zero_grad()
+
+ set_grad_flag(generator, True)
+ set_grad_flag(discriminator, False)
+
+ fake_image = generator(gen_in2, step=step, alpha=alpha)
+
+ predict = discriminator(fake_image, step=step, alpha=alpha)
+
+ if loss == 'wgan-gp':
+ loss = -predict.mean()
+
+ elif loss == 'r1':
+ loss = F.softplus(-predict).mean()
+
+ if i%10 == 0:
+ gen_loss_val = loss.item()
+
+ loss.backward(retain_graph=True)
+ g_optimizer.step()
+ accumulate(g_running, generator.module)
+
+ set_grad_flag(generator, False)
+ set_grad_flag(discriminator, True)
+
+ if (i + 1) % 100 == 0:
+ images = []
+
+ gen_i, gen_j = gen_sample.get(resolution, (10, 5))
+
+ with torch.no_grad():
+ for _ in range(gen_i):
+ images.append(
+ g_running(
+ torch.randn(gen_j, dim_latent).cuda(), step=step, alpha=alpha
+ ).data.cpu()
+ )
+
+ utils.save_image(
+ torch.cat(images, 0),
+ f'sample/{str(i + 1).zfill(6)}.png',
+ nrow=gen_i,
+ normalize=True,
+ range=(-1, 1),
+ )
+ imshow(torch.cat(images, 0), i)
+
+ if (i + 1) % 10000 == 0:
+ torch.save(
+ g_running.state_dict(), f'checkpoint/{str(i + 1).zfill(6)}.pth'
+ )
+
+ state_msg = (
+ f'Size: {4 * 2 ** step}; G: {gen_loss_val:.3f}; D: {disc_loss_val:.3f};'
+ f' Grad: {grad_loss_val:.3f}; Alpha: {alpha:.5f}'
+ )
+
+ pbar.set_description(state_msg)
+
+no_from_rgb_activate=True
+
+if __name__ == '__main__':
+
+ generator = nn.DataParallel(StyledGenerator(dim_latent)).cuda()
+ discriminator = nn.DataParallel(
+ Discriminator(from_rgb_activate=not no_from_rgb_activate)
+ ).cuda()
+
+ g_running = StyledGenerator(dim_latent).cuda()
+ g_running.train(False)
+
+ g_optimizer = optim.Adam(
+ generator.module.generator.parameters(), lr=0.001, betas=(0.0, 0.99)
+ )
+
+ g_optimizer.add_param_group(
+ {
+ 'params': generator.module.style.parameters(),
+ 'lr': 0.001 * 0.01,
+ 'mult': 0.01,
+ }
+ )
+
+ d_optimizer = optim.Adam(discriminator.parameters(), lr=0.001, betas=(0.0, 0.99))
+ accumulate(g_running, generator.module, 0)
+
+
+#Load pre-trained models
+ if ckpt is not None:
+ ckpt = torch.load(ckpt)
+
+ generator.module.load_state_dict(ckpt['generator'])
+ discriminator.module.load_state_dict(ckpt['discriminator'])
+ g_running.load_state_dict(ckpt['g_running'])
+ g_optimizer.load_state_dict(ckpt['g_optimizer'])
+ d_optimizer.load_state_dict(ckpt['d_optimizer'])
+
+ transform = transforms.Compose(
+ [
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
+ ]
+ )
+
+ dataset = MultiResolutionDataset(Path, transform)
+
+
+ learning_rate = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
+ batch_size = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
+
+
+ gen_sample = {512: (8, 4), 1024: (4, 2)}
+
+ batch_default = 32
+
+ loss='wgan-gp'
+
+ train(dataset, generator, discriminator,loss)