Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
275 changes: 73 additions & 202 deletions main_supcon.py
Original file line number Diff line number Diff line change
@@ -1,245 +1,138 @@
from __future__ import print_function

import os
import sys
import argparse
import time
import math

import tensorboard_logger as tb_logger
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms, datasets

from util import TwoCropTransform, AverageMeter
from util import adjust_learning_rate, warmup_learning_rate
from util import set_optimizer, save_model
from networks.resnet_big import SupConResNet
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import tensorboard_logger as tb_logger
from util import TextAugment, AverageMeter, adjust_learning_rate, warmup_learning_rate, set_optimizer, save_model
from networks.xlmr_supcon import SupConXLMRLarge
from losses import SupConLoss

try:
import apex
from apex import amp, optimizers
except ImportError:
pass


def parse_option():
parser = argparse.ArgumentParser('argument for training')

parser.add_argument('--print_freq', type=int, default=10,
help='print frequency')
parser.add_argument('--save_freq', type=int, default=50,
help='save frequency')
parser.add_argument('--batch_size', type=int, default=256,
help='batch_size')
parser.add_argument('--num_workers', type=int, default=16,
help='num of workers to use')
parser.add_argument('--epochs', type=int, default=1000,
help='number of training epochs')

# optimization
parser.add_argument('--learning_rate', type=float, default=0.05,
help='learning rate')
parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900',
help='where to decay lr, can be a list')
parser.add_argument('--lr_decay_rate', type=float, default=0.1,
help='decay rate for learning rate')
parser.add_argument('--weight_decay', type=float, default=1e-4,
help='weight decay')
parser.add_argument('--momentum', type=float, default=0.9,
help='momentum')

# model dataset
parser.add_argument('--model', type=str, default='resnet50')
parser.add_argument('--dataset', type=str, default='cifar10',
choices=['cifar10', 'cifar100', 'path'], help='dataset')
parser.add_argument('--mean', type=str, help='mean of dataset in path in form of str tuple')
parser.add_argument('--std', type=str, help='std of dataset in path in form of str tuple')
parser.add_argument('--print_freq', type=int, default=10, help='print frequency')
parser.add_argument('--save_freq', type=int, default=50, help='save frequency')
parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
parser.add_argument('--num_workers', type=int, default=4, help='num of workers to use')
parser.add_argument('--epochs', type=int, default=100, help='number of training epochs')
parser.add_argument('--learning_rate', type=float, default=5e-5, help='learning rate')
parser.add_argument('--lr_decay_epochs', type=str, default='60,80', help='where to decay lr')
parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--dataset', type=str, default='path', help='dataset name or path')
parser.add_argument('--data_folder', type=str, default=None, help='path to custom dataset')
parser.add_argument('--size', type=int, default=32, help='parameter for RandomResizedCrop')

# method
parser.add_argument('--method', type=str, default='SupCon',
choices=['SupCon', 'SimCLR'], help='choose method')

# temperature
parser.add_argument('--temp', type=float, default=0.07,
help='temperature for loss function')

# other setting
parser.add_argument('--cosine', action='store_true',
help='using cosine annealing')
parser.add_argument('--syncBN', action='store_true',
help='using synchronized batch normalization')
parser.add_argument('--warm', action='store_true',
help='warm-up for large batch training')
parser.add_argument('--trial', type=str, default='0',
help='id for recording multiple runs')
parser.add_argument('--method', type=str, default='SupCon', choices=['SupCon', 'SimCLR'], help='choose method')
parser.add_argument('--temp', type=float, default=0.07, help='temperature for loss function')
parser.add_argument('--cosine', action='store_true', help='using cosine annealing')
parser.add_argument('--syncBN', action='store_true', help='using synchronized batch normalization')
parser.add_argument('--warm', action='store_true', help='warm-up for large batch training')
parser.add_argument('--trial', type=str, default='0', help='id for recording multiple runs')

opt = parser.parse_args()

# check if dataset is path that passed required arguments
if opt.dataset == 'path':
assert opt.data_folder is not None \
and opt.mean is not None \
and opt.std is not None

# set the path according to the environment
if opt.data_folder is None:
opt.data_folder = './datasets/'
opt.data_folder = './datasets/' if opt.data_folder is None else opt.data_folder
opt.model_path = './save/SupCon/{}_models'.format(opt.dataset)
opt.tb_path = './save/SupCon/{}_tensorboard'.format(opt.dataset)

iterations = opt.lr_decay_epochs.split(',')
opt.lr_decay_epochs = list([])
for it in iterations:
opt.lr_decay_epochs.append(int(it))

opt.model_name = '{}_{}_{}_lr_{}_decay_{}_bsz_{}_temp_{}_trial_{}'.\
format(opt.method, opt.dataset, opt.model, opt.learning_rate,
opt.weight_decay, opt.batch_size, opt.temp, opt.trial)

opt.lr_decay_epochs = [int(it) for it in iterations]
opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}_temp_{}_trial_{}'.format(
opt.method, opt.dataset, opt.learning_rate, opt.weight_decay, opt.batch_size, opt.temp, opt.trial)
if opt.cosine:
opt.model_name = '{}_cosine'.format(opt.model_name)

# warm-up for large-batch training,
if opt.batch_size > 256:
if opt.batch_size > 64:
opt.warm = True
if opt.warm:
opt.model_name = '{}_warm'.format(opt.model_name)
opt.warmup_from = 0.01
opt.warm_epochs = 10
opt.warmup_from = 1e-6
opt.warm_epochs = 5
if opt.cosine:
eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
else:
opt.warmup_to = opt.learning_rate

opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
if not os.path.isdir(opt.tb_folder):
os.makedirs(opt.tb_folder)

opt.save_folder = os.path.join(opt.model_path, opt.model_name)
if not os.path.isdir(opt.save_folder):
os.makedirs(opt.save_folder)

return opt

class TextDataset(Dataset):
def __init__(self, dataset, transform=None):
self.dataset = dataset
self.transform = transform
self.tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large")

def set_loader(opt):
# construct data loader
if opt.dataset == 'cifar10':
mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)
elif opt.dataset == 'cifar100':
mean = (0.5071, 0.4867, 0.4408)
std = (0.2675, 0.2565, 0.2761)
elif opt.dataset == 'path':
mean = eval(opt.mean)
std = eval(opt.std)
else:
raise ValueError('dataset not supported: {}'.format(opt.dataset))
normalize = transforms.Normalize(mean=mean, std=std)

train_transform = transforms.Compose([
transforms.RandomResizedCrop(size=opt.size, scale=(0.2, 1.)),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
normalize,
])

if opt.dataset == 'cifar10':
train_dataset = datasets.CIFAR10(root=opt.data_folder,
transform=TwoCropTransform(train_transform),
download=True)
elif opt.dataset == 'cifar100':
train_dataset = datasets.CIFAR100(root=opt.data_folder,
transform=TwoCropTransform(train_transform),
download=True)
elif opt.dataset == 'path':
train_dataset = datasets.ImageFolder(root=opt.data_folder,
transform=TwoCropTransform(train_transform))
else:
raise ValueError(opt.dataset)
def __len__(self):
return len(self.dataset)

train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None),
num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler)
def __getitem__(self, idx):
text = self.dataset[idx]['text']
label = self.dataset[idx]['label']
if self.transform:
text1, text2 = self.transform(text)
else:
text1, text2 = text, text
encoding1 = self.tokenizer(text1, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
encoding2 = self.tokenizer(text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
return {
'input_ids': [encoding1['input_ids'].squeeze(), encoding2['input_ids'].squeeze()],
'attention_mask': [encoding1['attention_mask'].squeeze(), encoding2['attention_mask'].squeeze()],
'labels': torch.tensor(label, dtype=torch.long)
}

def set_loader(opt):
train_dataset = load_dataset(opt.dataset, split='train') if opt.dataset != 'path' else load_dataset('csv', data_files=opt.data_folder)['train']
train_transform = TextAugment()
train_dataset = TextDataset(train_dataset, transform=train_transform)
train_loader = DataLoader(
train_dataset, batch_size=opt.batch_size, shuffle=True,
num_workers=opt.num_workers, pin_memory=True)
return train_loader


def set_model(opt):
model = SupConResNet(name=opt.model)
model = SupConXLMRLarge()
criterion = SupConLoss(temperature=opt.temp)

# enable synchronized Batch Normalization
if opt.syncBN:
model = apex.parallel.convert_syncbn_model(model)

if torch.cuda.is_available():
if torch.cuda.device_count() > 1:
model.encoder = torch.nn.DataParallel(model.encoder)
model = model.cuda()
criterion = criterion.cuda()
cudnn.benchmark = True

return model, criterion


def train(train_loader, model, criterion, optimizer, epoch, opt):
"""one epoch training"""
model.train()

batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()

end = time.time()
for idx, (images, labels) in enumerate(train_loader):
for idx, batch in enumerate(train_loader):
data_time.update(time.time() - end)

images = torch.cat([images[0], images[1]], dim=0)
if torch.cuda.is_available():
images = images.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
input_ids = torch.stack(batch['input_ids'], dim=1).cuda(non_blocking=True) # [bsz, 2, seq_len]
attention_mask = torch.stack(batch['attention_mask'], dim=1).cuda(non_blocking=True)
labels = batch['labels'].cuda(non_blocking=True)
bsz = labels.shape[0]

# warm-up learning rate
warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)

# compute loss
features = model(images)
f1, f2 = torch.split(features, [bsz, bsz], dim=0)
features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
if opt.method == 'SupCon':
loss = criterion(features, labels)
elif opt.method == 'SimCLR':
loss = criterion(features)
else:
raise ValueError('contrastive method not supported: {}'.
format(opt.method))

# update metric
features = []
for i in range(2): # Process two views
feat = model(input_ids[:, i], attention_mask[:, i])
features.append(feat)
features = torch.stack(features, dim=1) # [bsz, 2, feat_dim]
loss = criterion(features, labels) if opt.method == 'SupCon' else criterion(features)
losses.update(loss.item(), bsz)

# SGD
optimizer.zero_grad()
loss.backward()
optimizer.step()

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

# print info
if (idx + 1) % opt.print_freq == 0:
print('Train: [{0}][{1}/{2}]\t'
'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
Expand All @@ -248,49 +141,27 @@ def train(train_loader, model, criterion, optimizer, epoch, opt):
epoch, idx + 1, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses))
sys.stdout.flush()

return losses.avg


def main():
opt = parse_option()

# build data loader
train_loader = set_loader(opt)

# build model and criterion
model, criterion = set_model(opt)

# build optimizer
optimizer = set_optimizer(opt, model)

# tensorboard
logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

# training routine
for epoch in range(1, opt.epochs + 1):
adjust_learning_rate(opt, optimizer, epoch)

# train for one epoch
time1 = time.time()
loss = train(train_loader, model, criterion, optimizer, epoch, opt)
time2 = time.time()
print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

# tensorboard logger
logger.log_value('loss', loss, epoch)
logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)

if epoch % opt.save_freq == 0:
save_file = os.path.join(
opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
save_model(model, optimizer, opt, epoch, save_file)

# save the last model
save_file = os.path.join(
opt.save_folder, 'last.pth')
save_file = os.path.join(opt.save_folder, 'last.pth')
save_model(model, optimizer, opt, opt.epochs, save_file)


if __name__ == '__main__':
main()
main()
26 changes: 26 additions & 0 deletions networks/xlmr_supcon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import torch.nn as nn
from transformers import AutoModel

class SupConXLMRLarge(nn.Module):
"""XLM-RoBERTa-large backbone + projection head"""
def __init__(self, head='mlp', feat_dim=128):
super(SupConXLMRLarge, self).__init__()
self.encoder = AutoModel.from_pretrained("xlm-roberta-large")
dim_in = 1024 # XLM-RoBERTa-large embedding size
if head == 'linear':
self.head = nn.Linear(dim_in, feat_dim)
elif head == 'mlp':
self.head = nn.Sequential(
nn.Linear(dim_in, dim_in),
nn.ReLU(inplace=True),
nn.Linear(dim_in, feat_dim)
)
else:
raise NotImplementedError('head not supported: {}'.format(head))

def forward(self, input_ids, attention_mask=None):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
feat = outputs.last_hidden_state[:, 0, :] # Use [CLS] token
feat = nn.functional.normalize(self.head(feat), dim=1) # L2 normalize
return feat
Loading