diff --git a/gluonfr/loss.py b/gluonfr/loss.py index 678d40b..68583e8 100644 --- a/gluonfr/loss.py +++ b/gluonfr/loss.py @@ -23,6 +23,7 @@ import math import numpy as np from mxnet import nd, init +from mxnet.gluon.nn import HybridBlock from mxnet.gluon.loss import Loss, SoftmaxCrossEntropyLoss __all__ = ["get_loss", "SoftmaxCrossEntropyLoss", "ArcLoss", "TripletLoss", "RingLoss", @@ -794,6 +795,10 @@ def hybrid_forward(self, F, pred, label, sample_weight=None): return super().hybrid_forward(F, pred=fc, label=label, sample_weight=sample_weight) +class CircleLoss(HybridBlock): + pass + + _losses = { 'softmax': SoftmaxCrossEntropyLoss, 'arcface': ArcLoss,