Skip to content

HingeLossLayer#303

Merged
jeffdonahue merged 3 commits intoBVLC:devfrom
longjon:hinge-loss-layer
Apr 26, 2014
Merged

HingeLossLayer#303
jeffdonahue merged 3 commits intoBVLC:devfrom
longjon:hinge-loss-layer

Conversation

@longjon
Copy link
Contributor

@longjon longjon commented Apr 8, 2014

  • Currently, HingeLossLayer only provides a CPU implementation
  • Tests are not included
  • I intend eventually to rectify the above, but these commits should be usable now
  • Adding L2 regularization directly to InnerProductLayer might seem heavy-handed (although the implementation is simple)
  • AFAICT Implement regularizers #258 does not address regularization of parameters, so it will not make 86ef499 go away, but future work might provide a more general solution
  • Although this means one can use caffe to train linear SVMs, this is clumsy without the convex optimization smarts provided by SVM packages; however, caffe can do end-to-end training of a nonconvex network with a max-margin objective

@jeffdonahue
Copy link
Contributor

Why L2 regularization in InnerProductLayer? Should be equivalent to weight decay, no? (Though your implementation does save an axpy if using lambda instead of weight_decay, with weight_decay set to 0, but seems potentially hazardous if we're not going to remove weight_decay altogether and do something similar in all layers with parameters imho.)

@longjon longjon changed the title HingeLossLayer and L2 regularization in InnerProductLayer HingeLossLayer Apr 8, 2014
@longjon
Copy link
Contributor Author

longjon commented Apr 8, 2014

Good point, wasn't thinking about that. One might conceivably want a different tradeoff parameter at the HingeLossLayer, but that could be dealt with in a different way. In fact lambda = 0 does very well on LeNet/hinge, so I'm removing that commit.

@jeffdonahue
Copy link
Contributor

Inside each LayerParameter you can specify a weight_decay multiplier to the global weight_decay in SolverParameter.

HingeLossLayer looks good to me, would merge with basic unit tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

caffe_copy(count, bottom_data, bottom_diff)

@s9xie
Copy link

s9xie commented Apr 10, 2014

@longjon can you share the prototxt for LeNet/hinge? I've got numerical overflow on gradient computations with your loss...

@longjon
Copy link
Contributor Author

longjon commented Apr 10, 2014

My apologies, @s9xie, I accidentally clobbered the working commit with a broken one (an errant minus sign). I've put up a fixed version that gets (e.g.) 0.9921 accuracy after 10k iterations. The only change to the prototxt is to replace SOFTMAX_LOSS with HINGE_LOSS. (The default learning rate is fine.)

@zgxiangyang
Copy link

I'm new to hinge loss, how can it be applied to a multi-class problem?
l = max(0, 1-xy) with x = -1/ 1 for a single node
then what is the whole Loss

@longjon
Copy link
Contributor Author

longjon commented Apr 13, 2014

@zgxiangyang, this layer implements one-vs-all hinge loss, so the loss for each example is the hinge loss for the binary problem of separating the true class of that example from all other classes. There is also (not implemented here) a different multiclass hinge loss, the Crammer and Singer version, that some feel is more natural (and that extends naturally to structured prediction problems); one-vs-all is, however, more common in practice.

@zgxiangyang
Copy link

@longjon thanks!

longjon added 3 commits April 25, 2014 16:38
This layer implements a "one-vs-all" hinge loss, (1/n) sum_ij max(0, 1 -
y_ij x_ij), with bottom blob x_ij (i ranging over examples and j over
classes), and y_ij = +1/-1 indicating the label. No regularization is
included, since regularization is done via weight decay or using the
parameters of another layer. The gradient is taken to be zero at the
hinge point. This commit only provides the CPU implementation.
In theory, layer functions could be nonsmooth anywhere; in all cases in
use so far, they are nonsmooth at either zero or +1 and -1. In the
future, it might be necessary to generalize the kink mechanism beyond
this stopgap measure.
Based on SoftmaxWithLossLayerTest.
@longjon
Copy link
Contributor Author

longjon commented Apr 26, 2014

Now using caffe_copy, tests added and passing (note the change I made to how kink works), lint passes.

Can someone confirm that is okay to store intermediate computations in diff during Forward?

Other than that, this is ready for review.

@jeffdonahue
Copy link
Contributor

looks great, thanks Jon!

jeffdonahue added a commit that referenced this pull request Apr 26, 2014
@jeffdonahue jeffdonahue merged commit 4a31964 into BVLC:dev Apr 26, 2014
@sguada sguada mentioned this pull request May 7, 2014
@sguada
Copy link
Contributor

sguada commented May 8, 2014

@longjon I think in the long run we probably don't want to use diff to store intermediate results, but it is fine for now.

@shelhamer shelhamer mentioned this pull request May 20, 2014
mitmul pushed a commit to mitmul/caffe that referenced this pull request Sep 30, 2014
@longjon longjon deleted the hinge-loss-layer branch December 30, 2014 04:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants

Comments