Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def compute_loss(
dice_map = torch.softmax(out_map, dim=1)
else:
# compute binary map instead
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment]
# Class reduced
inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
Expand Down
8 changes: 6 additions & 2 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import psutil
import torch
import wandb
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torchvision.transforms.v2 import Compose, GaussianBlur, Normalize, RandomGrayscale, RandomPhotometricDistort
from tqdm.auto import tqdm
Expand Down Expand Up @@ -335,6 +335,8 @@ def main(args):
scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4)
elif args.sched == "onecycle":
scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader))
elif args.sched == "poly":
scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader))

# Training monitoring
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
Expand Down Expand Up @@ -442,7 +444,9 @@ def parse_args():
action="store_true",
help="metrics evaluation with straight boxes instead of polygons to save time + memory",
)
parser.add_argument("--sched", type=str, default="onecycle", help="scheduler to use")
parser.add_argument(
"--sched", type=str, default="poly", choices=["cosine", "onecycle", "poly"], help="scheduler to use"
)
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
parser.add_argument("--find-lr", action="store_true", help="Gridsearch the optimal LR")
parser.add_argument("--early-stop", action="store_true", help="Enable early stopping")
Expand Down
28 changes: 21 additions & 7 deletions references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,25 @@ def main(args):
plot_samples(x, target)
return

# Scheduler
if args.sched == "exponential":
scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
args.lr,
decay_steps=args.epochs * len(train_loader),
decay_rate=1 / (25e4), # final lr as a fraction of initial lr
staircase=False,
name="ExponentialDecay",
)
elif args.sched == "poly":
scheduler = tf.keras.optimizers.schedules.PolynomialDecay(
args.lr,
decay_steps=args.epochs * len(train_loader),
end_learning_rate=1e-7,
power=1.0,
cycle=False,
name="PolynomialDecay",
)
# Optimizer
scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
args.lr,
decay_steps=args.epochs * len(train_loader),
decay_rate=1 / (25e4), # final lr as a fraction of initial lr
staircase=False,
name="ExponentialDecay",
)
optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler, beta_1=0.95, beta_2=0.99, epsilon=1e-6, clipnorm=5)
if args.amp:
optimizer = mixed_precision.LossScaleOptimizer(optimizer)
Expand Down Expand Up @@ -409,6 +420,9 @@ def parse_args():
action="store_true",
help="metrics evaluation with straight boxes instead of polygons to save time + memory",
)
parser.add_argument(
"--sched", type=str, default="poly", choices=["exponential", "poly"], help="scheduler to use"
)
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
parser.add_argument("--find-lr", action="store_true", help="Gridsearch the optimal LR")
parser.add_argument("--early-stop", action="store_true", help="Enable early stopping")
Expand Down