From 3f1a3039ff5bfcac08182422fe6a42b8a6f20d9b Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 30 Jan 2024 11:33:20 +0100 Subject: [PATCH 1/4] add poly scheduler for detection --- references/detection/train_pytorch.py | 4 +++- references/detection/train_tensorflow.py | 26 +++++++++++++++++------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index 6f01f6e9e3..30ed6f403c 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -17,7 +17,7 @@ import psutil import torch import wandb -from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR +from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torchvision.transforms.v2 import Compose, GaussianBlur, Normalize, RandomGrayscale, RandomPhotometricDistort from tqdm.auto import tqdm @@ -335,6 +335,8 @@ def main(args): scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4) elif args.sched == "onecycle": scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader)) + elif args.sched == "poly": + scheduler = PolynomialLR(optimizer, args.epochs * len(train_loader)) # Training monitoring current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index a4652affe5..def87ecf75 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -268,14 +268,25 @@ def main(args): plot_samples(x, target) return + # Scheduler + if args.sched == "exponential": + scheduler = tf.keras.optimizers.schedules.ExponentialDecay( + args.lr, + decay_steps=args.epochs * len(train_loader), + decay_rate=1 / (25e4), # final lr as a fraction of initial lr + staircase=False, + name="ExponentialDecay", + ) + elif args.sched == "poly": + scheduler = tf.keras.optimizers.schedules.PolynomialDecay( + args.lr, + decay_steps=args.epochs * len(train_loader), + end_learning_rate=1e-7, + power=1.0, + cycle=False, + name="PolynomialDecay", + ) # Optimizer - scheduler = tf.keras.optimizers.schedules.ExponentialDecay( - args.lr, - decay_steps=args.epochs * len(train_loader), - decay_rate=1 / (25e4), # final lr as a fraction of initial lr - staircase=False, - name="ExponentialDecay", - ) optimizer = tf.keras.optimizers.Adam(learning_rate=scheduler, beta_1=0.95, beta_2=0.99, epsilon=1e-6, clipnorm=5) if args.amp: optimizer = mixed_precision.LossScaleOptimizer(optimizer) @@ -409,6 +420,7 @@ def parse_args(): action="store_true", help="metrics evaluation with straight boxes instead of polygons to save time + memory", ) + parser.add_argument("--sched", type=str, default="exponential", help="scheduler to use") parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") parser.add_argument("--find-lr", action="store_true", help="Gridsearch the optimal LR") parser.add_argument("--early-stop", action="store_true", help="Enable early stopping") From 159b01ae08bc916188aa1be9e5bcf1bac379ac37 Mon Sep 17 00:00:00 2001 From: felix Date: Tue, 30 Jan 2024 11:38:01 +0100 Subject: [PATCH 2/4] mypy --- doctr/models/detection/differentiable_binarization/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index f11408bd3d..80f3d64bed 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -268,7 +268,7 @@ def compute_loss( dice_map = torch.softmax(out_map, dim=1) else: # compute binary map instead - dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) + dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment] # Class reduced inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3)) cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3)) From 591641e872bd1fd5f6489bd1e85c542d3fa8494a Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 2 Feb 2024 08:29:32 +0100 Subject: [PATCH 3/4] choices for scheduler --- references/detection/train_pytorch.py | 4 +++- references/detection/train_tensorflow.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index 30ed6f403c..53e336762f 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -444,7 +444,9 @@ def parse_args(): action="store_true", help="metrics evaluation with straight boxes instead of polygons to save time + memory", ) - parser.add_argument("--sched", type=str, default="onecycle", help="scheduler to use") + parser.add_argument( + "--sched", type=str, default="poly", choices=["cosine", "onecycle", "poly"], help="scheduler to use" + ) parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") parser.add_argument("--find-lr", action="store_true", help="Gridsearch the optimal LR") parser.add_argument("--early-stop", action="store_true", help="Enable early stopping") diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index def87ecf75..59d7c2761b 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -420,7 +420,9 @@ def parse_args(): action="store_true", help="metrics evaluation with straight boxes instead of polygons to save time + memory", ) - parser.add_argument("--sched", type=str, default="exponential", help="scheduler to use") + parser.add_argument( + "--sched", type=str, default="exponential", choices=["exponential", "poly"], help="scheduler to use" + ) parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") parser.add_argument("--find-lr", action="store_true", help="Gridsearch the optimal LR") parser.add_argument("--early-stop", action="store_true", help="Enable early stopping") From c996c1bf3bc1999900cc5a0e8df61340500c8c07 Mon Sep 17 00:00:00 2001 From: felix Date: Fri, 2 Feb 2024 08:31:39 +0100 Subject: [PATCH 4/4] update --- references/detection/train_tensorflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index 59d7c2761b..1204f42d24 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -421,7 +421,7 @@ def parse_args(): help="metrics evaluation with straight boxes instead of polygons to save time + memory", ) parser.add_argument( - "--sched", type=str, default="exponential", choices=["exponential", "poly"], help="scheduler to use" + "--sched", type=str, default="poly", choices=["exponential", "poly"], help="scheduler to use" ) parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") parser.add_argument("--find-lr", action="store_true", help="Gridsearch the optimal LR")