diff --git a/.gitignore b/.gitignore index 3089ebe..88e6d11 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,5 @@ build correlation.egg-info logs checkpoints* +log_dir_negroni +wandb \ No newline at end of file diff --git a/configs/default.py b/configs/default.py index 2cfe165..656fe66 100644 --- a/configs/default.py +++ b/configs/default.py @@ -5,9 +5,9 @@ _CN.suffix ='arxiv2' _CN.gamma = 0.8 _CN.max_flow = 400 -_CN.batch_size = 8 -_CN.sum_freq = 100 -_CN.val_freq = 5000 +_CN.batch_size = 12 +_CN.sum_freq = 50 +_CN.val_freq = 1500 _CN.image_size = [368, 496] _CN.add_noise = True _CN.critical_params = [] @@ -59,7 +59,7 @@ _CN.trainer.canonical_lr = 25e-5 _CN.trainer.adamw_decay = 1e-4 _CN.trainer.clip = 1.0 -_CN.trainer.num_steps = 120000 +_CN.trainer.num_steps = 50000 _CN.trainer.epsilon = 1e-8 _CN.trainer.anneal_strategy = 'linear' def get_cfg(): diff --git a/core/FlowFormer/LatentCostFormer/gma.py b/core/FlowFormer/LatentCostFormer/gma.py index dd712e8..cee57ce 100644 --- a/core/FlowFormer/LatentCostFormer/gma.py +++ b/core/FlowFormer/LatentCostFormer/gma.py @@ -49,7 +49,7 @@ def __init__( self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) - self.pos_emb = RelPosEmb(max_pos_size, dim_head) + # self.pos_emb = RelPosEmb(max_pos_size, dim_head) def forward(self, fmap): heads, b, c, h, w = self.heads, *fmap.shape diff --git a/core/FlowFormer/encoders.py b/core/FlowFormer/encoders.py index f132aaf..c83865e 100644 --- a/core/FlowFormer/encoders.py +++ b/core/FlowFormer/encoders.py @@ -15,6 +15,7 @@ def __init__(self, pretrained=True): del self.svt.blocks[2] del self.svt.pos_block[2] del self.svt.pos_block[2] + del self.svt.norm def forward(self, x, data=None, layer=2): B = x.shape[0] diff --git a/core/datasets.py b/core/datasets.py index 5fca782..09b85f3 100644 --- a/core/datasets.py +++ b/core/datasets.py @@ -229,7 +229,48 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): train_dataset = KITTI(aug_params, split='training') train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, - pin_memory=False, shuffle=True, num_workers=128, drop_last=True) + pin_memory=False, shuffle=True, num_workers=4, drop_last=True) print('Training with %d image pairs' % len(train_dataset)) return train_loader + + + + +def fetch_dataset(args, TRAIN_DS='C+T+K+S+H'): + """ Create the data loader for the corresponding trainign set """ + + if args.stage == 'chairs': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} + train_dataset = FlyingChairs(aug_params, split='training') + + elif args.stage == 'things': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} + clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') + final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') + train_dataset = clean_dataset + final_dataset + + elif args.stage == 'sintel': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} + things = FlyingThings3D(aug_params, dstype='frames_cleanpass') + sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') + sintel_final = MpiSintel(aug_params, split='training', dstype='final') + + if TRAIN_DS == 'C+T+K+S+H': + kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) + hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) + train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things + + elif TRAIN_DS == 'C+T+K/S': + train_dataset = 100*sintel_clean + 100*sintel_final + things + + elif args.stage == 'kitti': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} + train_dataset = KITTI(aug_params, split='training') + + # train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, + # pin_memory=False, shuffle=True, num_workers=4, drop_last=True) + + print('Training with %d image pairs' % len(train_dataset)) + return train_dataset + diff --git a/core/utils/logger.py b/core/utils/logger.py index 64c1c5f..da19062 100644 --- a/core/utils/logger.py +++ b/core/utils/logger.py @@ -1,6 +1,10 @@ from torch.utils.tensorboard import SummaryWriter from loguru import logger as loguru_logger +import os +import wandb +import numpy as np + class Logger: def __init__(self, model, scheduler, cfg): self.model = model @@ -10,6 +14,17 @@ def __init__(self, model, scheduler, cfg): self.writer = None self.cfg = cfg + self.tboard_log_dir = os.path.join(cfg.log_dir, 'tboard_logs') + os.makedirs(self.tboard_log_dir,exist_ok=True) + + def _log_metric(self, tag, scalar_value, global_step, tboard_writer, cfg=None): + if cfg is None or cfg.log_in_wandb: + try: + wandb.log({tag: scalar_value}, step=global_step) + except Exception as e: + print(f"WandB log failed for tag '{tag}': {e}") + tboard_writer.add_scalar(tag=tag, scalar_value=scalar_value, global_step=global_step) + def _print_training_status(self): metrics_data = [self.running_loss[k]/self.cfg.sum_freq for k in sorted(self.running_loss.keys())] training_str = "[{:6d}, {}] ".format(self.total_steps+1, self.scheduler.get_last_lr()) @@ -25,10 +40,12 @@ def _print_training_status(self): self.writer = SummaryWriter(self.cfg.log_dir) for k in self.running_loss: - self.writer.add_scalar(k, self.running_loss[k]/self.cfg.sum_freq, self.total_steps) + # self.writer.add_scalar(k, self.running_loss[k]/self.cfg.sum_freq, self.total_steps) + self._log_metric(tag=k, scalar_value=self.running_loss[k]/self.cfg.sum_freq, global_step=self.total_steps, tboard_writer=self.writer, cfg=self.cfg) + self.running_loss[k] = 0.0 - def push(self, metrics): + def push(self, metrics, model, current_lr, train_loss, curr_epoch): self.total_steps += 1 for key in metrics: @@ -41,12 +58,27 @@ def push(self, metrics): self._print_training_status() self.running_loss = {} - def write_dict(self, results): + var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad] + var_cnt = len(var_sum) + var_sum = np.sum(var_sum) + var_avg = var_sum.item()/var_cnt + # by suraj + var_norm = [var.norm().item() for var in model.parameters() if var.requires_grad] + + self._log_metric(tag="Training loss", scalar_value=train_loss, global_step=self.total_steps, tboard_writer=self.writer, cfg=self.cfg) + self._log_metric(tag='Learning Rate', scalar_value=current_lr, global_step=self.total_steps, tboard_writer=self.writer, cfg=self.cfg) + self._log_metric(tag='var sum average', scalar_value=var_avg, global_step=self.total_steps, tboard_writer=self.writer, cfg=self.cfg) + self._log_metric(tag='var norm average', scalar_value=np.mean(var_norm), global_step=self.total_steps, tboard_writer=self.writer, cfg=self.cfg) + self._log_metric(tag='Epoch', scalar_value=curr_epoch, global_step=self.total_steps, tboard_writer=self.writer, cfg=self.cfg) + + def write_dict(self, results, val_dataset, cfg): if self.writer is None: self.writer = SummaryWriter() for key in results: - self.writer.add_scalar(key, results[key], self.total_steps) + # self.writer.add_scalar(key, results[key], self.total_steps) + self._log_metric(tag=f"Val_{val_dataset}_{key}", scalar_value=results[key], global_step=self.total_steps, tboard_writer=self.writer, cfg=self.cfg) + def close(self): self.writer.close() diff --git a/evaluate_FlowFormer.py b/evaluate_FlowFormer.py index 792b322..a96b6e3 100644 --- a/evaluate_FlowFormer.py +++ b/evaluate_FlowFormer.py @@ -14,57 +14,94 @@ from configs.small_things_eval import get_cfg as get_small_things_cfg from core.utils.misc import process_cfg import datasets -from utils import flow_viz -from utils import frame_utils +from core.utils import flow_viz +from core.utils import frame_utils # from FlowFormer import FlowFormer from core.FlowFormer import build_flowformer -from raft import RAFT +from core.raft import RAFT +from core.datasets import FlyingChairs, MpiSintel -from utils.utils import InputPadder, forward_interpolate + +from core.utils.utils import InputPadder, forward_interpolate +from tqdm import tqdm +import vpd_utils @torch.no_grad() -def validate_chairs(model): +def validate_chairs(model, args=None): """ Perform evaluation on the FlyingChairs (test) split """ + args.val_workers=2 model.eval() - epe_list = [] - val_dataset = datasets.FlyingChairs(split='validation') - for val_id in range(len(val_dataset)): - image1, image2, flow_gt, _ = val_dataset[val_id] - image1 = image1[None].cuda() - image2 = image2[None].cuda() + epe_list = [] + ddp_logger = vpd_utils.MetricLogger() + + val_dataset = FlyingChairs(split='validation') + sampler_val = torch.utils.data.DistributedSampler( + val_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=False) + val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, sampler=sampler_val, num_workers=args.val_workers, pin_memory=True) + device = torch.device(args.rank) + + # for val_id in range(len(val_dataset)): + for batch in tqdm(val_loader): + image1, image2, flow_gt, _ = batch + image1 = image1.to(device) + image2 = image2.to(device) flow_pre, _ = model(image1, image2) epe = torch.sum((flow_pre[0].cpu() - flow_gt)**2, dim=0).sqrt() epe_list.append(epe.view(-1).numpy()) - epe = np.mean(np.concatenate(epe_list)) + # epe = np.mean(np.concatenate(epe_list)) + epe_all = np.concatenate(epe_list) + epe = np.mean(epe_all) + px1 = np.mean(epe_all<1) + px3 = np.mean(epe_all<3) + px5 = np.mean(epe_all<5) + + if args.rank == 0: + print("Validation:- EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (epe, px1, px3, px5)) + + ddp_logger.update(epe=float(epe)) + ddp_logger.synchronize_between_processes() + epe = ddp_logger.meters['epe'].global_avg + print("Validation Chairs EPE: %f" % epe) return {'chairs': epe} - @torch.no_grad() -def validate_sintel(model): +def validate_sintel(model, args=None): + """ Peform validation using the Sintel (train) split """ + args.val_workers=2 model.eval() results = {} for dstype in ['clean', 'final']: - val_dataset = datasets.MpiSintel(split='training', dstype=dstype) + print("Validating on %s" % dstype) + val_dataset = MpiSintel(split='training', dstype=dstype) + + + sampler_val = torch.utils.data.DistributedSampler( + val_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=False) + val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, sampler=sampler_val, num_workers=args.val_workers, pin_memory=True) + device = torch.device(args.rank) + epe_list = [] + ddp_logger = vpd_utils.MetricLogger() + + for batch in tqdm(val_loader): + # for val_id in tqdm(range(len(val_dataset))): + image1, image2, flow_gt, _ = batch + image1 = image1.to(device) + image2 = image2.to(device) - for val_id in range(len(val_dataset)): - image1, image2, flow_gt, _ = val_dataset[val_id] - image1 = image1[None].cuda() - image2 = image2[None].cuda() padder = InputPadder(image1.shape) image1, image2 = padder.pad(image1, image2) - flow_pre = model(image1, image2) - - flow_pre = padder.unpad(flow_pre[0]).cpu()[0] + flow_pr = model(image1, image2) + flow = padder.unpad(flow_pr[0]).cpu()[0] - epe = torch.sum((flow_pre - flow_gt)**2, dim=0).sqrt() + epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() epe_list.append(epe.view(-1).numpy()) epe_all = np.concatenate(epe_list) @@ -73,8 +110,24 @@ def validate_sintel(model): px3 = np.mean(epe_all<3) px5 = np.mean(epe_all<5) - print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5)) - results[dstype] = np.mean(epe_list) + ddp_logger.update(epe=float(epe)) + ddp_logger.update(px1=float(px1)) + ddp_logger.update(px3=float(px3)) + ddp_logger.update(px5=float(px5)) + ddp_logger.synchronize_between_processes() + epe = ddp_logger.meters['epe'].global_avg + px1 = ddp_logger.meters['px1'].global_avg + px3 = ddp_logger.meters['px3'].global_avg + px5 = ddp_logger.meters['px5'].global_avg + + if args.rank == 0: + print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5)) + # import ipdb;ipdb.set_trace() + # results[dstype] = np.mean(epe_list) + results[dstype+"_epe"] = epe + results[dstype+"_1px"] = px1 + results[dstype+"_3px"] = px3 + results[dstype+"_5px"] = px5 return results diff --git a/run_train.sh b/run_train.sh index 06ab841..ebc2005 100755 --- a/run_train.sh +++ b/run_train.sh @@ -1,5 +1,9 @@ mkdir -p checkpoints -python -u train_FlowFormer.py --name chairs --stage chairs --validation chairs -python -u train_FlowFormer.py --name things --stage things --validation sintel -python -u train_FlowFormer.py --name sintel --stage sintel --validation sintel -python -u train_FlowFormer.py --name kitti --stage kitti --validation kitti \ No newline at end of file + +NUM_GPUS=6 +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 + +python -u train_FlowFormer.py --name chairs --stage chairs --validation chairs --num_gpus=${NUM_GPUS} +python -u train_FlowFormer.py --name things --stage things --validation sintel --num_gpus=${NUM_GPUS} +python -u train_FlowFormer.py --name sintel --stage sintel --validation sintel --num_gpus=${NUM_GPUS} +python -u train_FlowFormer.py --name kitti --stage kitti --validation kitti --num_gpus=${NUM_GPUS} \ No newline at end of file diff --git a/runs/Feb23_11-09-54_user-DIT400TR-48RL/events.out.tfevents.1708666794.user-DIT400TR-48RL.2619153.0 b/runs/Feb23_11-09-54_user-DIT400TR-48RL/events.out.tfevents.1708666794.user-DIT400TR-48RL.2619153.0 new file mode 100644 index 0000000..35b80ab Binary files /dev/null and b/runs/Feb23_11-09-54_user-DIT400TR-48RL/events.out.tfevents.1708666794.user-DIT400TR-48RL.2619153.0 differ diff --git a/runs/Feb23_11-53-53_user-DIT400TR-48RL/events.out.tfevents.1708669433.user-DIT400TR-48RL.2712718.0 b/runs/Feb23_11-53-53_user-DIT400TR-48RL/events.out.tfevents.1708669433.user-DIT400TR-48RL.2712718.0 new file mode 100644 index 0000000..837459b Binary files /dev/null and b/runs/Feb23_11-53-53_user-DIT400TR-48RL/events.out.tfevents.1708669433.user-DIT400TR-48RL.2712718.0 differ diff --git a/runs/Feb23_11-59-40_user-DIT400TR-48RL/events.out.tfevents.1708669780.user-DIT400TR-48RL.2727667.0 b/runs/Feb23_11-59-40_user-DIT400TR-48RL/events.out.tfevents.1708669780.user-DIT400TR-48RL.2727667.0 new file mode 100644 index 0000000..262c6df Binary files /dev/null and b/runs/Feb23_11-59-40_user-DIT400TR-48RL/events.out.tfevents.1708669780.user-DIT400TR-48RL.2727667.0 differ diff --git a/runs/Feb23_13-10-56_user-DIT400TR-48RL/events.out.tfevents.1708674056.user-DIT400TR-48RL.3185432.0 b/runs/Feb23_13-10-56_user-DIT400TR-48RL/events.out.tfevents.1708674056.user-DIT400TR-48RL.3185432.0 new file mode 100644 index 0000000..d308adf Binary files /dev/null and b/runs/Feb23_13-10-56_user-DIT400TR-48RL/events.out.tfevents.1708674056.user-DIT400TR-48RL.3185432.0 differ diff --git a/runs/Feb23_13-10-56_user-DIT400TR-48RL/events.out.tfevents.1708674056.user-DIT400TR-48RL.3185433.0 b/runs/Feb23_13-10-56_user-DIT400TR-48RL/events.out.tfevents.1708674056.user-DIT400TR-48RL.3185433.0 new file mode 100644 index 0000000..be230a9 Binary files /dev/null and b/runs/Feb23_13-10-56_user-DIT400TR-48RL/events.out.tfevents.1708674056.user-DIT400TR-48RL.3185433.0 differ diff --git a/runs/Feb23_13-19-39_user-DIT400TR-48RL/events.out.tfevents.1708674579.user-DIT400TR-48RL.3194750.0 b/runs/Feb23_13-19-39_user-DIT400TR-48RL/events.out.tfevents.1708674579.user-DIT400TR-48RL.3194750.0 new file mode 100644 index 0000000..3417913 Binary files /dev/null and b/runs/Feb23_13-19-39_user-DIT400TR-48RL/events.out.tfevents.1708674579.user-DIT400TR-48RL.3194750.0 differ diff --git a/runs/Feb23_13-19-39_user-DIT400TR-48RL/events.out.tfevents.1708674579.user-DIT400TR-48RL.3194751.0 b/runs/Feb23_13-19-39_user-DIT400TR-48RL/events.out.tfevents.1708674579.user-DIT400TR-48RL.3194751.0 new file mode 100644 index 0000000..fccc5f8 Binary files /dev/null and b/runs/Feb23_13-19-39_user-DIT400TR-48RL/events.out.tfevents.1708674579.user-DIT400TR-48RL.3194751.0 differ diff --git a/runs/Feb23_13-37-23_user-DIT400TR-48RL/events.out.tfevents.1708675643.user-DIT400TR-48RL.3215433.0 b/runs/Feb23_13-37-23_user-DIT400TR-48RL/events.out.tfevents.1708675643.user-DIT400TR-48RL.3215433.0 new file mode 100644 index 0000000..d19bf99 Binary files /dev/null and b/runs/Feb23_13-37-23_user-DIT400TR-48RL/events.out.tfevents.1708675643.user-DIT400TR-48RL.3215433.0 differ diff --git a/runs/Feb23_13-37-23_user-DIT400TR-48RL/events.out.tfevents.1708675643.user-DIT400TR-48RL.3215434.0 b/runs/Feb23_13-37-23_user-DIT400TR-48RL/events.out.tfevents.1708675643.user-DIT400TR-48RL.3215434.0 new file mode 100644 index 0000000..3b48b52 Binary files /dev/null and b/runs/Feb23_13-37-23_user-DIT400TR-48RL/events.out.tfevents.1708675643.user-DIT400TR-48RL.3215434.0 differ diff --git a/runs/Feb23_13-56-36_user-DIT400TR-48RL/events.out.tfevents.1708676796.user-DIT400TR-48RL.3229568.0 b/runs/Feb23_13-56-36_user-DIT400TR-48RL/events.out.tfevents.1708676796.user-DIT400TR-48RL.3229568.0 new file mode 100644 index 0000000..0de125f Binary files /dev/null and b/runs/Feb23_13-56-36_user-DIT400TR-48RL/events.out.tfevents.1708676796.user-DIT400TR-48RL.3229568.0 differ diff --git a/runs/Feb23_13-56-36_user-DIT400TR-48RL/events.out.tfevents.1708676796.user-DIT400TR-48RL.3229569.0 b/runs/Feb23_13-56-36_user-DIT400TR-48RL/events.out.tfevents.1708676796.user-DIT400TR-48RL.3229569.0 new file mode 100644 index 0000000..e972302 Binary files /dev/null and b/runs/Feb23_13-56-36_user-DIT400TR-48RL/events.out.tfevents.1708676796.user-DIT400TR-48RL.3229569.0 differ diff --git a/train_FlowFormer.py b/train_FlowFormer.py index 366a8a1..572aeef 100644 --- a/train_FlowFormer.py +++ b/train_FlowFormer.py @@ -8,6 +8,8 @@ import time import numpy as np import matplotlib.pyplot as plt +import wandb +from tqdm import tqdm from pathlib import Path import torch @@ -31,6 +33,29 @@ # from core.FlowFormer import FlowFormer from core.FlowFormer import build_flowformer +# DDP training +import torch.multiprocessing as mp +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +import vpd_utils + + +def get_lr(optimizer): + for param_group in optimizer.param_groups: + return param_group['lr'] + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + try: from torch.cuda.amp import GradScaler except: @@ -53,30 +78,64 @@ def update(self): def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) -def train(cfg): - model = nn.DataParallel(build_flowformer(cfg)) - loguru_logger.info("Parameter Count: %d" % count_parameters(model)) +def setup_ddp(gpu, args): + dist.init_process_group( + backend='nccl', + init_method='env://', + world_size=args.world_size, + rank=gpu) + + torch.manual_seed(0) + torch.cuda.set_device(gpu) + +def train(gpu, cfg): + + cfg.rank = gpu + vpd_utils.setup_for_distributed(cfg.rank == 0) + print("gpu = ",gpu) + + # coordinate multiple GPUs + setup_ddp(gpu, cfg) + rng = np.random.default_rng(12345) + + model = build_flowformer(cfg) + device = torch.device(gpu) + loguru_logger.info("Parameter Count: %d" % count_parameters(model)) + if cfg.restore_ckpt is not None: print("[Loading ckpt from {}]".format(cfg.restore_ckpt)) model.load_state_dict(torch.load(cfg.restore_ckpt), strict=True) - model.cuda() + model.to(device) model.train() - train_loader = datasets.fetch_dataloader(cfg) + model = DDP(model, device_ids=[gpu], find_unused_parameters=False) + + # train_loader = datasets.fetch_dataloader(cfg) + db = datasets.fetch_dataset(cfg) + train_sampler = torch.utils.data.distributed.DistributedSampler( + db, shuffle=True, num_replicas=cfg.world_size, rank=cfg.rank) + + train_loader = DataLoader(db, batch_size=cfg.batch_size, sampler=train_sampler, num_workers=2) + + optimizer, scheduler = fetch_optimizer(model, cfg.trainer) total_steps = 0 scaler = GradScaler(enabled=cfg.mixed_precision) - logger = Logger(model, scheduler, cfg) + if cfg.rank == 0: + logger = Logger(model, scheduler, cfg) add_noise = False + if cfg.log_in_wandb and cfg.rank==0: + wandb.init(project='1. Improving the ConvGRU of RAFT', name ="") + should_keep_training = True while should_keep_training: - for i_batch, data_blob in enumerate(train_loader): + for i_batch, data_blob in enumerate(tqdm(train_loader)): optimizer.zero_grad() image1, image2, flow, valid = [x.cuda() for x in data_blob] @@ -97,24 +156,35 @@ def train(cfg): scaler.update() metrics.update(output) - logger.push(metrics) + # logger.push(metrics) + curr_epoch = total_steps // len(train_loader) + 1 + if cfg.rank == 0: + logger.push(metrics, model, get_lr(optimizer), loss.item(), curr_epoch) + ### change evaluate to functions if total_steps % cfg.val_freq == cfg.val_freq - 1: - PATH = '%s/%d_%s.pth' % (cfg.log_dir, total_steps+1, cfg.name) - # torch.save(model.state_dict(), PATH) + if cfg.rank == 0: + print("Doing validation: ") + PATH = '%s/%d_%s.pth' % (cfg.log_dir, total_steps+1, cfg.name) + torch.save(model.state_dict(), PATH) + all_results={} results = {} - for val_dataset in cfg.validation: - if val_dataset == 'chairs': - results.update(evaluate.validate_chairs(model.module)) - elif val_dataset == 'sintel': - results.update(evaluate.validate_sintel(model.module)) - elif val_dataset == 'kitti': - results.update(evaluate.validate_kitti(model.module)) - - logger.write_dict(results) + # for val_dataset in args.validation: + if cfg.validation == 'chairs': + results.update(evaluate.validate_chairs(model.module, args=cfg)) + elif cfg.validation == 'sintel': + results.update(evaluate.validate_sintel(model.module, args=cfg)) + elif cfg.validation == 'kitti': + results.update(evaluate.validate_kitti(model.module)) + + all_results[cfg.validation] = results + if cfg.rank == 0: + logger.write_dict(results, cfg.validation,cfg) + + # logger.write_dict(results) model.train() @@ -123,25 +193,37 @@ def train(cfg): if total_steps > cfg.trainer.num_steps: should_keep_training = False break + + if cfg.rank == 0: + logger.close() - logger.close() - PATH = cfg.log_dir + '/final' - torch.save(model.state_dict(), PATH) + PATH = cfg.log_dir + '/final' + os.makedirs(PATH,exist_ok=True) + torch.save(model.state_dict(), PATH) - PATH = f'checkpoints/{cfg.stage}.pth' - torch.save(model.state_dict(), PATH) + PATH = f'{cfg.log_dir}/checkpoints/{cfg.stage}.pth' + os.makedirs(PATH,exist_ok=True) + torch.save(model.state_dict(), PATH) return PATH if __name__ == '__main__': + os.environ["CUDA_VISIBLE_DEVICES"]="0,1" parser = argparse.ArgumentParser() - parser.add_argument('--name', default='flowformer', help="name your experiment") - parser.add_argument('--stage', help="determines which dataset to use for training") - parser.add_argument('--validation', type=str, nargs='+') + + parser.add_argument('--name', default='chairs', help="name your experiment") + parser.add_argument('--stage', default='chairs', help="determines which dataset to use for training") + parser.add_argument('--validation', default='chairs', type=str, nargs='+') parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') + parser.add_argument('--log_in_wandb', default="true", type=str2bool) + parser.add_argument('--num_gpus', type=int, default=2) + parser.add_argument('--port', type=str, default="29500") + parser.add_argument('--rank', type=int) + args = parser.parse_args() + args.world_size = args.num_gpus if args.stage == 'chairs': from configs.default import get_cfg @@ -157,6 +239,7 @@ def train(cfg): cfg = get_cfg() cfg.update(vars(args)) process_cfg(cfg) + cfg.log_dir = "log_dir_negroni" loguru_logger.add(str(Path(cfg.log_dir) / 'log.txt'), encoding="utf8") loguru_logger.info(cfg) @@ -166,4 +249,7 @@ def train(cfg): if not os.path.isdir('checkpoints'): os.mkdir('checkpoints') - train(cfg) + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = args.port + mp.spawn(train, nprocs=args.num_gpus, args=(cfg,)) + # train(cfg) diff --git a/vpd_utils.py b/vpd_utils.py new file mode 100755 index 0000000..9b33cd9 --- /dev/null +++ b/vpd_utils.py @@ -0,0 +1,589 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import math +import time +from collections import defaultdict, deque +import datetime +import numpy as np +from timm.utils import get_state_dict + +from pathlib import Path + +import torch +import torch.distributed as dist +# from torch._six import inf +from torch import inf +from tensorboardX import SummaryWriter +from torchvision import transforms +import cv2 +import random +import os +import numpy as np +import torch +import matplotlib.pyplot as plt +import matplotlib.cm as cm + +def seed_everything_for_reproducibility(seed=42): + """ + For REPRODUCIBILITY + Official source: https://pytorch.org/docs/stable/notes/randomness.html#reproducibility + """ + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + # to + torch.backends.cudnn.deterministic = True + # To NOT randomly choose which algo to use for CUDNN operations like convolution,etc. + # benchmark=True will improve training performance but will loose reproducibility. + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True, warn_only=True) + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +class TensorboardLogger(object): + def __init__(self, log_dir): + self.writer = SummaryWriter(logdir=log_dir) + self.step = 0 + + def set_step(self, step=None): + if step is not None: + self.step = step + else: + self.step += 1 + + def update(self, head='scalar', step=None, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step) + + def flush(self): + self.writer.flush() + + +class WandbLogger(object): + def __init__(self, args): + self.args = args + + try: + import wandb + self._wandb = wandb + except ImportError: + raise ImportError( + "To use the Weights and Biases Logger please install wandb." + "Run `pip install wandb` to install it." + ) + + # Initialize a W&B run + if self._wandb.run is None: + self._wandb.init( + project=args.project, + config=args + ) + + def log_epoch_metrics(self, metrics, commit=True): + """ + Log train/test metrics onto W&B. + """ + # Log number of model parameters as W&B summary + self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None) + metrics.pop('n_parameters', None) + + # Log current epoch + self._wandb.log({'epoch': metrics.get('epoch')}, commit=False) + metrics.pop('epoch') + + for k, v in metrics.items(): + if 'train' in k: + self._wandb.log({f'Global Train/{k}': v}, commit=False) + elif 'test' in k: + self._wandb.log({f'Global Test/{k}': v}, commit=False) + + self._wandb.log({}) + + def log_checkpoints(self): + output_dir = self.args.output_dir + model_artifact = self._wandb.Artifact( + self._wandb.run.id + "_model", type="model" + ) + + model_artifact.add_dir(output_dir) + self._wandb.log_artifact(model_artifact, aliases=["latest", "best"]) + + def set_steps(self): + # Set global training step + self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step') + # Set epoch-wise step + self._wandb.define_metric('Global Train/*', step_metric='epoch') + self._wandb.define_metric('Global Test/*', step_metric='epoch') + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + + if args.dist_on_itp: + args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) + os.environ['LOCAL_RANK'] = str(args.gpu) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + + os.environ['RANK'] = str(args.rank) + os.environ['LOCAL_RANK'] = str(args.gpu) + os.environ['WORLD_SIZE'] = str(args.world_size) + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}, gpu {}'.format( + args.rank, args.dist_url, args.gpu), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def init_distributed_mode_simple(args): + + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + args.dist_url = 'env://' + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}, gpu {}'.format( + args.rank, args.dist_url, args.gpu), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + +def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(model, prefix=prefix) + + warn_missing_keys = [] + ignore_missing_keys = [] + for key in missing_keys: + keep_flag = True + for ignore_key in ignore_missing.split('|'): + if ignore_key in key: + keep_flag = False + break + if keep_flag: + warn_missing_keys.append(key) + else: + ignore_missing_keys.append(key) + + missing_keys = warn_missing_keys + + if len(missing_keys) > 0: + print("Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, missing_keys)) + if len(unexpected_keys) > 0: + print("Weights from pretrained model not used in {}: {}".format( + model.__class__.__name__, unexpected_keys)) + if len(ignore_missing_keys) > 0: + print("Ignored weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, ignore_missing_keys)) + if len(error_msgs) > 0: + print('\n'.join(error_msgs)) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + return total_norm + + +def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, + start_warmup_value=0, warmup_steps=-1): + warmup_schedule = np.array([]) + warmup_iters = warmup_epochs * niter_per_ep + if warmup_steps > 0: + warmup_iters = warmup_steps + print("Set warmup steps = %d" % warmup_iters) + if warmup_epochs > 0: + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(epochs * niter_per_ep - warmup_iters) + schedule = np.array( + [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) + + schedule = np.concatenate((warmup_schedule, schedule)) + + assert len(schedule) == epochs * niter_per_ep + return schedule + +def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): + output_dir = Path(args.output_dir) + epoch_name = str(epoch) + checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] + for checkpoint_path in checkpoint_paths: + to_save = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch, + 'scaler': loss_scaler.state_dict(), + 'args': args, + } + + if model_ema is not None: + to_save['model_ema'] = get_state_dict(model_ema) + + save_on_master(to_save, checkpoint_path) + + if is_main_process() and isinstance(epoch, int): + to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq + old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del) + if os.path.exists(old_ckpt): + os.remove(old_ckpt) + + +def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): + output_dir = Path(args.output_dir) + if args.auto_resume and len(args.resume) == 0: + import glob + all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) + latest_ckpt = -1 + for ckpt in all_checkpoints: + t = ckpt.split('-')[-1].split('.')[0] + if t.isdigit(): + latest_ckpt = max(int(t), latest_ckpt) + if latest_ckpt >= 0: + args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) + print("Auto resume checkpoint: %s" % args.resume) + + if args.resume: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model']) + print("Resume checkpoint %s" % args.resume) + if 'optimizer' in checkpoint and 'epoch' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema' + args.start_epoch = checkpoint['epoch'] + 1 + else: + assert args.eval, 'Does not support resuming with checkpoint-best' + if hasattr(args, 'model_ema') and args.model_ema: + if 'model_ema' in checkpoint.keys(): + model_ema.ema.load_state_dict(checkpoint['model_ema']) + else: + model_ema.ema.load_state_dict(checkpoint['model']) + if 'scaler' in checkpoint: + loss_scaler.load_state_dict(checkpoint['scaler']) + print("With optim & sched!") + + +def colorize_depth(depth, cmap="magma_r", vmin=None,vmax=None): + + vmin = np.min(depth) if vmin is None else vmin + vmax = np.max(depth) if vmax is None else vmax + depth = np.clip(depth,vmin,vmax) + depth[0,0] = vmin + depth[-1,-1] = vmax + colormap = cm.get_cmap(cmap) + colorized_depth = colormap((depth-vmin)/(vmax-vmin)) #first convert from 0-1 then from 0-255 in below line. + #colored_depth = colormap(depth) + colorized_depth = (colorized_depth[:,:,:3]*255).astype(np.uint8) + return colorized_depth + +def cosine_annealing(global_step, tot_iterations, n_cycles, max_lr, suraj_modification=False): + """cosine annealing learning rate schedule. + this one is in terms of steps(or iterations)""" + iters_per_cycle = math.floor(tot_iterations/n_cycles) + cos_inner = (math.pi * (global_step % iters_per_cycle)) / (iters_per_cycle) + if suraj_modification: + curr_cycle = global_step // iters_per_cycle + max_lr = max_lr * (0.5 ** curr_cycle) + return max_lr/2 * (math.cos(cos_inner) + 1) + +def visualize_garg_crop_rectangle(depth): + # Define the slice coordinates + gt_height, gt_width, _ = depth.shape + # garg crop + slice_top = int(0.40810811 * gt_height) + slice_bottom = int(0.99189189 * gt_height) + slice_left = int(0.03594771 * gt_width) + slice_right = int(0.96405229 * gt_width) + + # Draw a rectangle on the image + # import ipdb;ipdb.set_trace() + depth = cv2.rectangle(depth, (slice_left, slice_top), (slice_right, slice_bottom), color=(250, 250, 0), thickness=2) + return depth \ No newline at end of file