Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ build
correlation.egg-info
logs
checkpoints*
log_dir_negroni
wandb
8 changes: 4 additions & 4 deletions configs/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
_CN.suffix ='arxiv2'
_CN.gamma = 0.8
_CN.max_flow = 400
_CN.batch_size = 8
_CN.sum_freq = 100
_CN.val_freq = 5000
_CN.batch_size = 12
_CN.sum_freq = 50
_CN.val_freq = 1500
_CN.image_size = [368, 496]
_CN.add_noise = True
_CN.critical_params = []
Expand Down Expand Up @@ -59,7 +59,7 @@
_CN.trainer.canonical_lr = 25e-5
_CN.trainer.adamw_decay = 1e-4
_CN.trainer.clip = 1.0
_CN.trainer.num_steps = 120000
_CN.trainer.num_steps = 50000
_CN.trainer.epsilon = 1e-8
_CN.trainer.anneal_strategy = 'linear'
def get_cfg():
Expand Down
2 changes: 1 addition & 1 deletion core/FlowFormer/LatentCostFormer/gma.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(

self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)

self.pos_emb = RelPosEmb(max_pos_size, dim_head)
# self.pos_emb = RelPosEmb(max_pos_size, dim_head)

def forward(self, fmap):
heads, b, c, h, w = self.heads, *fmap.shape
Expand Down
1 change: 1 addition & 0 deletions core/FlowFormer/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self, pretrained=True):
del self.svt.blocks[2]
del self.svt.pos_block[2]
del self.svt.pos_block[2]
del self.svt.norm

def forward(self, x, data=None, layer=2):
B = x.shape[0]
Expand Down
43 changes: 42 additions & 1 deletion core/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,48 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
train_dataset = KITTI(aug_params, split='training')

train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=False, shuffle=True, num_workers=128, drop_last=True)
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)

print('Training with %d image pairs' % len(train_dataset))
return train_loader




def fetch_dataset(args, TRAIN_DS='C+T+K+S+H'):
""" Create the data loader for the corresponding trainign set """

if args.stage == 'chairs':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
train_dataset = FlyingChairs(aug_params, split='training')

elif args.stage == 'things':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
train_dataset = clean_dataset + final_dataset

elif args.stage == 'sintel':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
sintel_final = MpiSintel(aug_params, split='training', dstype='final')

if TRAIN_DS == 'C+T+K+S+H':
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things

elif TRAIN_DS == 'C+T+K/S':
train_dataset = 100*sintel_clean + 100*sintel_final + things

elif args.stage == 'kitti':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
train_dataset = KITTI(aug_params, split='training')

# train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
# pin_memory=False, shuffle=True, num_workers=4, drop_last=True)

print('Training with %d image pairs' % len(train_dataset))
return train_dataset

40 changes: 36 additions & 4 deletions core/utils/logger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from torch.utils.tensorboard import SummaryWriter
from loguru import logger as loguru_logger

import os
import wandb
import numpy as np

class Logger:
def __init__(self, model, scheduler, cfg):
self.model = model
Expand All @@ -10,6 +14,17 @@ def __init__(self, model, scheduler, cfg):
self.writer = None
self.cfg = cfg

self.tboard_log_dir = os.path.join(cfg.log_dir, 'tboard_logs')
os.makedirs(self.tboard_log_dir,exist_ok=True)

def _log_metric(self, tag, scalar_value, global_step, tboard_writer, cfg=None):
if cfg is None or cfg.log_in_wandb:
try:
wandb.log({tag: scalar_value}, step=global_step)
except Exception as e:
print(f"WandB log failed for tag '{tag}': {e}")
tboard_writer.add_scalar(tag=tag, scalar_value=scalar_value, global_step=global_step)

def _print_training_status(self):
metrics_data = [self.running_loss[k]/self.cfg.sum_freq for k in sorted(self.running_loss.keys())]
training_str = "[{:6d}, {}] ".format(self.total_steps+1, self.scheduler.get_last_lr())
Expand All @@ -25,10 +40,12 @@ def _print_training_status(self):
self.writer = SummaryWriter(self.cfg.log_dir)

for k in self.running_loss:
self.writer.add_scalar(k, self.running_loss[k]/self.cfg.sum_freq, self.total_steps)
# self.writer.add_scalar(k, self.running_loss[k]/self.cfg.sum_freq, self.total_steps)
self._log_metric(tag=k, scalar_value=self.running_loss[k]/self.cfg.sum_freq, global_step=self.total_steps, tboard_writer=self.writer, cfg=self.cfg)

self.running_loss[k] = 0.0

def push(self, metrics):
def push(self, metrics, model, current_lr, train_loss, curr_epoch):
self.total_steps += 1

for key in metrics:
Expand All @@ -41,12 +58,27 @@ def push(self, metrics):
self._print_training_status()
self.running_loss = {}

def write_dict(self, results):
var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad]
var_cnt = len(var_sum)
var_sum = np.sum(var_sum)
var_avg = var_sum.item()/var_cnt
# by suraj
var_norm = [var.norm().item() for var in model.parameters() if var.requires_grad]

self._log_metric(tag="Training loss", scalar_value=train_loss, global_step=self.total_steps, tboard_writer=self.writer, cfg=self.cfg)
self._log_metric(tag='Learning Rate', scalar_value=current_lr, global_step=self.total_steps, tboard_writer=self.writer, cfg=self.cfg)
self._log_metric(tag='var sum average', scalar_value=var_avg, global_step=self.total_steps, tboard_writer=self.writer, cfg=self.cfg)
self._log_metric(tag='var norm average', scalar_value=np.mean(var_norm), global_step=self.total_steps, tboard_writer=self.writer, cfg=self.cfg)
self._log_metric(tag='Epoch', scalar_value=curr_epoch, global_step=self.total_steps, tboard_writer=self.writer, cfg=self.cfg)

def write_dict(self, results, val_dataset, cfg):
if self.writer is None:
self.writer = SummaryWriter()

for key in results:
self.writer.add_scalar(key, results[key], self.total_steps)
# self.writer.add_scalar(key, results[key], self.total_steps)
self._log_metric(tag=f"Val_{val_dataset}_{key}", scalar_value=results[key], global_step=self.total_steps, tboard_writer=self.writer, cfg=self.cfg)


def close(self):
self.writer.close()
Expand Down
103 changes: 78 additions & 25 deletions evaluate_FlowFormer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,57 +14,94 @@
from configs.small_things_eval import get_cfg as get_small_things_cfg
from core.utils.misc import process_cfg
import datasets
from utils import flow_viz
from utils import frame_utils
from core.utils import flow_viz
from core.utils import frame_utils

# from FlowFormer import FlowFormer
from core.FlowFormer import build_flowformer
from raft import RAFT
from core.raft import RAFT
from core.datasets import FlyingChairs, MpiSintel

from utils.utils import InputPadder, forward_interpolate

from core.utils.utils import InputPadder, forward_interpolate
from tqdm import tqdm
import vpd_utils

@torch.no_grad()
def validate_chairs(model):
def validate_chairs(model, args=None):
""" Perform evaluation on the FlyingChairs (test) split """
args.val_workers=2
model.eval()
epe_list = []

val_dataset = datasets.FlyingChairs(split='validation')
for val_id in range(len(val_dataset)):
image1, image2, flow_gt, _ = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
epe_list = []
ddp_logger = vpd_utils.MetricLogger()

val_dataset = FlyingChairs(split='validation')
sampler_val = torch.utils.data.DistributedSampler(
val_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=False)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, sampler=sampler_val, num_workers=args.val_workers, pin_memory=True)
device = torch.device(args.rank)

# for val_id in range(len(val_dataset)):
for batch in tqdm(val_loader):
image1, image2, flow_gt, _ = batch
image1 = image1.to(device)
image2 = image2.to(device)
flow_pre, _ = model(image1, image2)

epe = torch.sum((flow_pre[0].cpu() - flow_gt)**2, dim=0).sqrt()
epe_list.append(epe.view(-1).numpy())

epe = np.mean(np.concatenate(epe_list))
# epe = np.mean(np.concatenate(epe_list))
epe_all = np.concatenate(epe_list)
epe = np.mean(epe_all)
px1 = np.mean(epe_all<1)
px3 = np.mean(epe_all<3)
px5 = np.mean(epe_all<5)

if args.rank == 0:
print("Validation:- EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (epe, px1, px3, px5))

ddp_logger.update(epe=float(epe))
ddp_logger.synchronize_between_processes()
epe = ddp_logger.meters['epe'].global_avg

print("Validation Chairs EPE: %f" % epe)
return {'chairs': epe}


@torch.no_grad()
def validate_sintel(model):
def validate_sintel(model, args=None):

""" Peform validation using the Sintel (train) split """
args.val_workers=2
model.eval()
results = {}
for dstype in ['clean', 'final']:
val_dataset = datasets.MpiSintel(split='training', dstype=dstype)
print("Validating on %s" % dstype)
val_dataset = MpiSintel(split='training', dstype=dstype)


sampler_val = torch.utils.data.DistributedSampler(
val_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=False)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, sampler=sampler_val, num_workers=args.val_workers, pin_memory=True)
device = torch.device(args.rank)

epe_list = []
ddp_logger = vpd_utils.MetricLogger()

for batch in tqdm(val_loader):
# for val_id in tqdm(range(len(val_dataset))):
image1, image2, flow_gt, _ = batch
image1 = image1.to(device)
image2 = image2.to(device)

for val_id in range(len(val_dataset)):
image1, image2, flow_gt, _ = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)

flow_pre = model(image1, image2)

flow_pre = padder.unpad(flow_pre[0]).cpu()[0]
flow_pr = model(image1, image2)
flow = padder.unpad(flow_pr[0]).cpu()[0]

epe = torch.sum((flow_pre - flow_gt)**2, dim=0).sqrt()
epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
epe_list.append(epe.view(-1).numpy())

epe_all = np.concatenate(epe_list)
Expand All @@ -73,8 +110,24 @@ def validate_sintel(model):
px3 = np.mean(epe_all<3)
px5 = np.mean(epe_all<5)

print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
results[dstype] = np.mean(epe_list)
ddp_logger.update(epe=float(epe))
ddp_logger.update(px1=float(px1))
ddp_logger.update(px3=float(px3))
ddp_logger.update(px5=float(px5))
ddp_logger.synchronize_between_processes()
epe = ddp_logger.meters['epe'].global_avg
px1 = ddp_logger.meters['px1'].global_avg
px3 = ddp_logger.meters['px3'].global_avg
px5 = ddp_logger.meters['px5'].global_avg

if args.rank == 0:
print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
# import ipdb;ipdb.set_trace()
# results[dstype] = np.mean(epe_list)
results[dstype+"_epe"] = epe
results[dstype+"_1px"] = px1
results[dstype+"_3px"] = px3
results[dstype+"_5px"] = px5

return results

Expand Down
12 changes: 8 additions & 4 deletions run_train.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
mkdir -p checkpoints
python -u train_FlowFormer.py --name chairs --stage chairs --validation chairs
python -u train_FlowFormer.py --name things --stage things --validation sintel
python -u train_FlowFormer.py --name sintel --stage sintel --validation sintel
python -u train_FlowFormer.py --name kitti --stage kitti --validation kitti

NUM_GPUS=6
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5

python -u train_FlowFormer.py --name chairs --stage chairs --validation chairs --num_gpus=${NUM_GPUS}
python -u train_FlowFormer.py --name things --stage things --validation sintel --num_gpus=${NUM_GPUS}
python -u train_FlowFormer.py --name sintel --stage sintel --validation sintel --num_gpus=${NUM_GPUS}
python -u train_FlowFormer.py --name kitti --stage kitti --validation kitti --num_gpus=${NUM_GPUS}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading