From b73326b772e181ad19c2e18652f570843c30bd7c Mon Sep 17 00:00:00 2001 From: hoangpdh Date: Wed, 11 Jun 2025 18:43:15 +0700 Subject: [PATCH] Change XLM-R, text embeddings --- main_supcon.py | 275 +++++++++++----------------------------- networks/xlmr_supcon.py | 26 ++++ util.py | 16 ++- 3 files changed, 114 insertions(+), 203 deletions(-) create mode 100644 networks/xlmr_supcon.py diff --git a/main_supcon.py b/main_supcon.py index ea6a625f..0f18f600 100644 --- a/main_supcon.py +++ b/main_supcon.py @@ -1,245 +1,138 @@ from __future__ import print_function - import os import sys import argparse import time import math - -import tensorboard_logger as tb_logger import torch import torch.backends.cudnn as cudnn -from torchvision import transforms, datasets - -from util import TwoCropTransform, AverageMeter -from util import adjust_learning_rate, warmup_learning_rate -from util import set_optimizer, save_model -from networks.resnet_big import SupConResNet +from datasets import load_dataset +from transformers import AutoTokenizer +from torch.utils.data import Dataset, DataLoader +import tensorboard_logger as tb_logger +from util import TextAugment, AverageMeter, adjust_learning_rate, warmup_learning_rate, set_optimizer, save_model +from networks.xlmr_supcon import SupConXLMRLarge from losses import SupConLoss -try: - import apex - from apex import amp, optimizers -except ImportError: - pass - - def parse_option(): parser = argparse.ArgumentParser('argument for training') - - parser.add_argument('--print_freq', type=int, default=10, - help='print frequency') - parser.add_argument('--save_freq', type=int, default=50, - help='save frequency') - parser.add_argument('--batch_size', type=int, default=256, - help='batch_size') - parser.add_argument('--num_workers', type=int, default=16, - help='num of workers to use') - parser.add_argument('--epochs', type=int, default=1000, - help='number of training epochs') - - # optimization - parser.add_argument('--learning_rate', type=float, default=0.05, - help='learning rate') - parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900', - help='where to decay lr, can be a list') - parser.add_argument('--lr_decay_rate', type=float, default=0.1, - help='decay rate for learning rate') - parser.add_argument('--weight_decay', type=float, default=1e-4, - help='weight decay') - parser.add_argument('--momentum', type=float, default=0.9, - help='momentum') - - # model dataset - parser.add_argument('--model', type=str, default='resnet50') - parser.add_argument('--dataset', type=str, default='cifar10', - choices=['cifar10', 'cifar100', 'path'], help='dataset') - parser.add_argument('--mean', type=str, help='mean of dataset in path in form of str tuple') - parser.add_argument('--std', type=str, help='std of dataset in path in form of str tuple') + parser.add_argument('--print_freq', type=int, default=10, help='print frequency') + parser.add_argument('--save_freq', type=int, default=50, help='save frequency') + parser.add_argument('--batch_size', type=int, default=64, help='batch_size') + parser.add_argument('--num_workers', type=int, default=4, help='num of workers to use') + parser.add_argument('--epochs', type=int, default=100, help='number of training epochs') + parser.add_argument('--learning_rate', type=float, default=5e-5, help='learning rate') + parser.add_argument('--lr_decay_epochs', type=str, default='60,80', help='where to decay lr') + parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate') + parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay') + parser.add_argument('--momentum', type=float, default=0.9, help='momentum') + parser.add_argument('--dataset', type=str, default='path', help='dataset name or path') parser.add_argument('--data_folder', type=str, default=None, help='path to custom dataset') - parser.add_argument('--size', type=int, default=32, help='parameter for RandomResizedCrop') - - # method - parser.add_argument('--method', type=str, default='SupCon', - choices=['SupCon', 'SimCLR'], help='choose method') - - # temperature - parser.add_argument('--temp', type=float, default=0.07, - help='temperature for loss function') - - # other setting - parser.add_argument('--cosine', action='store_true', - help='using cosine annealing') - parser.add_argument('--syncBN', action='store_true', - help='using synchronized batch normalization') - parser.add_argument('--warm', action='store_true', - help='warm-up for large batch training') - parser.add_argument('--trial', type=str, default='0', - help='id for recording multiple runs') + parser.add_argument('--method', type=str, default='SupCon', choices=['SupCon', 'SimCLR'], help='choose method') + parser.add_argument('--temp', type=float, default=0.07, help='temperature for loss function') + parser.add_argument('--cosine', action='store_true', help='using cosine annealing') + parser.add_argument('--syncBN', action='store_true', help='using synchronized batch normalization') + parser.add_argument('--warm', action='store_true', help='warm-up for large batch training') + parser.add_argument('--trial', type=str, default='0', help='id for recording multiple runs') opt = parser.parse_args() - - # check if dataset is path that passed required arguments - if opt.dataset == 'path': - assert opt.data_folder is not None \ - and opt.mean is not None \ - and opt.std is not None - - # set the path according to the environment - if opt.data_folder is None: - opt.data_folder = './datasets/' + opt.data_folder = './datasets/' if opt.data_folder is None else opt.data_folder opt.model_path = './save/SupCon/{}_models'.format(opt.dataset) opt.tb_path = './save/SupCon/{}_tensorboard'.format(opt.dataset) - iterations = opt.lr_decay_epochs.split(',') - opt.lr_decay_epochs = list([]) - for it in iterations: - opt.lr_decay_epochs.append(int(it)) - - opt.model_name = '{}_{}_{}_lr_{}_decay_{}_bsz_{}_temp_{}_trial_{}'.\ - format(opt.method, opt.dataset, opt.model, opt.learning_rate, - opt.weight_decay, opt.batch_size, opt.temp, opt.trial) - + opt.lr_decay_epochs = [int(it) for it in iterations] + opt.model_name = '{}_{}_lr_{}_decay_{}_bsz_{}_temp_{}_trial_{}'.format( + opt.method, opt.dataset, opt.learning_rate, opt.weight_decay, opt.batch_size, opt.temp, opt.trial) if opt.cosine: opt.model_name = '{}_cosine'.format(opt.model_name) - - # warm-up for large-batch training, - if opt.batch_size > 256: + if opt.batch_size > 64: opt.warm = True if opt.warm: opt.model_name = '{}_warm'.format(opt.model_name) - opt.warmup_from = 0.01 - opt.warm_epochs = 10 + opt.warmup_from = 1e-6 + opt.warm_epochs = 5 if opt.cosine: eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3) - opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * ( - 1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 + opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2 else: opt.warmup_to = opt.learning_rate - opt.tb_folder = os.path.join(opt.tb_path, opt.model_name) if not os.path.isdir(opt.tb_folder): os.makedirs(opt.tb_folder) - opt.save_folder = os.path.join(opt.model_path, opt.model_name) if not os.path.isdir(opt.save_folder): os.makedirs(opt.save_folder) - return opt +class TextDataset(Dataset): + def __init__(self, dataset, transform=None): + self.dataset = dataset + self.transform = transform + self.tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large") -def set_loader(opt): - # construct data loader - if opt.dataset == 'cifar10': - mean = (0.4914, 0.4822, 0.4465) - std = (0.2023, 0.1994, 0.2010) - elif opt.dataset == 'cifar100': - mean = (0.5071, 0.4867, 0.4408) - std = (0.2675, 0.2565, 0.2761) - elif opt.dataset == 'path': - mean = eval(opt.mean) - std = eval(opt.std) - else: - raise ValueError('dataset not supported: {}'.format(opt.dataset)) - normalize = transforms.Normalize(mean=mean, std=std) - - train_transform = transforms.Compose([ - transforms.RandomResizedCrop(size=opt.size, scale=(0.2, 1.)), - transforms.RandomHorizontalFlip(), - transforms.RandomApply([ - transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) - ], p=0.8), - transforms.RandomGrayscale(p=0.2), - transforms.ToTensor(), - normalize, - ]) - - if opt.dataset == 'cifar10': - train_dataset = datasets.CIFAR10(root=opt.data_folder, - transform=TwoCropTransform(train_transform), - download=True) - elif opt.dataset == 'cifar100': - train_dataset = datasets.CIFAR100(root=opt.data_folder, - transform=TwoCropTransform(train_transform), - download=True) - elif opt.dataset == 'path': - train_dataset = datasets.ImageFolder(root=opt.data_folder, - transform=TwoCropTransform(train_transform)) - else: - raise ValueError(opt.dataset) + def __len__(self): + return len(self.dataset) - train_sampler = None - train_loader = torch.utils.data.DataLoader( - train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None), - num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler) + def __getitem__(self, idx): + text = self.dataset[idx]['text'] + label = self.dataset[idx]['label'] + if self.transform: + text1, text2 = self.transform(text) + else: + text1, text2 = text, text + encoding1 = self.tokenizer(text1, padding='max_length', truncation=True, max_length=128, return_tensors='pt') + encoding2 = self.tokenizer(text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt') + return { + 'input_ids': [encoding1['input_ids'].squeeze(), encoding2['input_ids'].squeeze()], + 'attention_mask': [encoding1['attention_mask'].squeeze(), encoding2['attention_mask'].squeeze()], + 'labels': torch.tensor(label, dtype=torch.long) + } +def set_loader(opt): + train_dataset = load_dataset(opt.dataset, split='train') if opt.dataset != 'path' else load_dataset('csv', data_files=opt.data_folder)['train'] + train_transform = TextAugment() + train_dataset = TextDataset(train_dataset, transform=train_transform) + train_loader = DataLoader( + train_dataset, batch_size=opt.batch_size, shuffle=True, + num_workers=opt.num_workers, pin_memory=True) return train_loader - def set_model(opt): - model = SupConResNet(name=opt.model) + model = SupConXLMRLarge() criterion = SupConLoss(temperature=opt.temp) - - # enable synchronized Batch Normalization - if opt.syncBN: - model = apex.parallel.convert_syncbn_model(model) - if torch.cuda.is_available(): if torch.cuda.device_count() > 1: model.encoder = torch.nn.DataParallel(model.encoder) model = model.cuda() criterion = criterion.cuda() cudnn.benchmark = True - return model, criterion - def train(train_loader, model, criterion, optimizer, epoch, opt): - """one epoch training""" model.train() - batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() - end = time.time() - for idx, (images, labels) in enumerate(train_loader): + for idx, batch in enumerate(train_loader): data_time.update(time.time() - end) - - images = torch.cat([images[0], images[1]], dim=0) - if torch.cuda.is_available(): - images = images.cuda(non_blocking=True) - labels = labels.cuda(non_blocking=True) + input_ids = torch.stack(batch['input_ids'], dim=1).cuda(non_blocking=True) # [bsz, 2, seq_len] + attention_mask = torch.stack(batch['attention_mask'], dim=1).cuda(non_blocking=True) + labels = batch['labels'].cuda(non_blocking=True) bsz = labels.shape[0] - - # warm-up learning rate warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer) - - # compute loss - features = model(images) - f1, f2 = torch.split(features, [bsz, bsz], dim=0) - features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) - if opt.method == 'SupCon': - loss = criterion(features, labels) - elif opt.method == 'SimCLR': - loss = criterion(features) - else: - raise ValueError('contrastive method not supported: {}'. - format(opt.method)) - - # update metric + features = [] + for i in range(2): # Process two views + feat = model(input_ids[:, i], attention_mask[:, i]) + features.append(feat) + features = torch.stack(features, dim=1) # [bsz, 2, feat_dim] + loss = criterion(features, labels) if opt.method == 'SupCon' else criterion(features) losses.update(loss.item(), bsz) - - # SGD optimizer.zero_grad() loss.backward() optimizer.step() - - # measure elapsed time batch_time.update(time.time() - end) end = time.time() - - # print info if (idx + 1) % opt.print_freq == 0: print('Train: [{0}][{1}/{2}]\t' 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' @@ -248,49 +141,27 @@ def train(train_loader, model, criterion, optimizer, epoch, opt): epoch, idx + 1, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses)) sys.stdout.flush() - return losses.avg - def main(): opt = parse_option() - - # build data loader train_loader = set_loader(opt) - - # build model and criterion model, criterion = set_model(opt) - - # build optimizer optimizer = set_optimizer(opt, model) - - # tensorboard logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) - - # training routine for epoch in range(1, opt.epochs + 1): adjust_learning_rate(opt, optimizer, epoch) - - # train for one epoch time1 = time.time() loss = train(train_loader, model, criterion, optimizer, epoch, opt) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) - - # tensorboard logger logger.log_value('loss', loss, epoch) logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) - if epoch % opt.save_freq == 0: - save_file = os.path.join( - opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) + save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) save_model(model, optimizer, opt, epoch, save_file) - - # save the last model - save_file = os.path.join( - opt.save_folder, 'last.pth') + save_file = os.path.join(opt.save_folder, 'last.pth') save_model(model, optimizer, opt, opt.epochs, save_file) - if __name__ == '__main__': - main() + main() \ No newline at end of file diff --git a/networks/xlmr_supcon.py b/networks/xlmr_supcon.py new file mode 100644 index 00000000..057bba0a --- /dev/null +++ b/networks/xlmr_supcon.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn +from transformers import AutoModel + +class SupConXLMRLarge(nn.Module): + """XLM-RoBERTa-large backbone + projection head""" + def __init__(self, head='mlp', feat_dim=128): + super(SupConXLMRLarge, self).__init__() + self.encoder = AutoModel.from_pretrained("xlm-roberta-large") + dim_in = 1024 # XLM-RoBERTa-large embedding size + if head == 'linear': + self.head = nn.Linear(dim_in, feat_dim) + elif head == 'mlp': + self.head = nn.Sequential( + nn.Linear(dim_in, dim_in), + nn.ReLU(inplace=True), + nn.Linear(dim_in, feat_dim) + ) + else: + raise NotImplementedError('head not supported: {}'.format(head)) + + def forward(self, input_ids, attention_mask=None): + outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) + feat = outputs.last_hidden_state[:, 0, :] # Use [CLS] token + feat = nn.functional.normalize(self.head(feat), dim=1) # L2 normalize + return feat \ No newline at end of file diff --git a/util.py b/util.py index b6323530..e9e28d01 100644 --- a/util.py +++ b/util.py @@ -4,6 +4,8 @@ import numpy as np import torch import torch.optim as optim +import random +import nlpaug.augmenter.word as naw class TwoCropTransform: @@ -14,7 +16,19 @@ def __init__(self, transform): def __call__(self, x): return [self.transform(x), self.transform(x)] - +class TextAugment: + """Create two augmented versions of the same text""" + def __init__(self): + self.aug = naw.WordEmbsAug( + model_type='fasttext', + model_path='cc.de.300.bin', + action="substitute", + aug_p=0.3 + ) # Synonym replacement augmentation + + def __call__(self, text): + return [self.aug.augment(text)[0], self.aug.augment(text)[0]] + class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self):