Skip to content

Comments

Add L2 hinge loss#398

Merged
longjon merged 8 commits intoBVLC:devfrom
sguada:L2_hinge_loss
Jun 21, 2014
Merged

Add L2 hinge loss#398
longjon merged 8 commits intoBVLC:devfrom
sguada:L2_hinge_loss

Conversation

@sguada
Copy link
Contributor

@sguada sguada commented May 7, 2014

This PR extends #303 adding C param to the loss and by adding a L2 norm to hinge_loss. These should help to implement L1 (L2) -SVMs with L1 (L2) regularization.

@longjon could you take a look, and let me know if your examples are still working.

By default it uses C=1.0 and hinge_norm=L1, so it behaves as hinge loss.

@longjon
Copy link
Contributor

longjon commented May 8, 2014

Cool. Isn't C redundant with the weight_decay multipliers? Have you tried training LeNet with L2 hinge?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a big fan of the 1/2 on the regularizer, but I don't usually see it on the squared hinge term (cf. liblinear, for example).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it will be great to have L1, L2 and L1/L2 regularizers in caffe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Er, not sure I communicated what I meant to here... what I mean is, I usually see the L2 hinge loss formulated as (lambda/2) ||w||^2 + sum_i xi_i^2, but you seem to have implemented (lambda/2) ||w||^2 + (1/2) sum_i xi_i^2, (where the regularization is implicit in the weight decay, and sometimes C is used in place of lambda, and sometimes the 1/2 is omitted on the first term, although I don't like that), so (unless you are following a convention I'm not aware of), I would drop the multiplication by Dtype(0.5), (and add a factor of 2 to the gradient).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to separate the L2 hinge loss from the L2 weight regularization. One could use L2 hinge loss without any weight regularization or with L1 weight regularization. So I was trying to separate C from lambda, although it is not clear if it is the right approach.

The idea of multiplying the loss by 0.5 instead of multiplying the gradient by 2 was to simplify the gradient computation and reduce the rounding errors, (this is similar to the caffe implementation of dropout). But I would be happy to change it to a more common formulation.
I was following this implementation of L2svmloss

Y = bsxfun(@(y,ypos) 2*(y==ypos)-1, y, 1:K);

margin = max(0, 1 - Y .* (X*theta));
if reg
    loss = (0.5 * sum(theta.^2)) + C*mean(margin.^2);
else
    loss = C*mean(margin.^2);
end
loss = sum(loss);
if reg
    g = theta - 2*C/M * (X' * (margin .* Y));
else
    g = - 2*C/M * (X' * (margin .* Y));
end

@sguada
Copy link
Contributor Author

sguada commented May 8, 2014

I'm not sure if C and weight_decay behave in the same way here. Since it is not regularizing the weights yet, just scaling the gradients.

@longjon many thanks for your comments. I will incorporate them.

I haven't try LeNet yet, but I should, as a sanity check.

@s9xie
Copy link

s9xie commented May 8, 2014

@sguada I think C and weight_decay won't work as the same role in general cases, but you can always set per-layer learning rate (as in my own implementation and experiments) to have exact the same effect of C. But anyway, a separate C parameter would be more favorable to avoid misunderstanding.

@longjon
Copy link
Contributor

longjon commented May 8, 2014

@s9xie (cc @sguada):
As you note, without the explicit tradeoff parameter C, you can always achieve whatever objective you want by setting weight decays (note that these can be per-layer) and learning rates. In fact, this is how all existing loss layers are implemented (see, e.g., EuclideanLossLayer). I think adding an explicit tradeoff parameter will actually muddle things up, because

  • objectives will not have a unique specification in terms of C and weight decay, whereas before they had a unique specification in terms of weight decay only, and
  • the C set here will not be the usual SVM C, since that will also depend on the weight decay, whereas omitting C makes the usual SVM lambda simply the weight decay.

There certainly could be improvements made to the way parameters are regularized (e.g., weight decay could be upgraded to a general function of parameters, thus supporting things like L1 regularization), but I think this should be done uniformly across layers.

@s9xie
Copy link

s9xie commented May 8, 2014

@longjon Totally agreed.
You are right, I didn't realize adding C will affect lambda, which again will be problematic.

My point was when seeing a loss function with specific parameters, it would be great if one can easily write it down in a prototxt without reading the code and struggling to figure out the mess of setting lr, lambda, weight decay, C etc... But I do agree this decoupling process could be easily achieved after we have an explicit regularizer layer.

@sguada
Copy link
Contributor Author

sguada commented May 8, 2014

@longjon If let C=1.0 with the current random weights initialization the loss and the gradients become very big and don't converge. I guess the possible solutions will be reduce the learning rate of the innerproduct layer, initialize the weights to 0 or reduce C=0.1. The first two solutions imply changing a different layer which works fine for other losses, or just change new loss.

@sguada
Copy link
Contributor Author

sguada commented May 9, 2014

@longjon I have moved the 0.5 from the loss into the 2 for the gradient, but for now I'm going to leave the C, since even if one don't want to regularize the weights, one can still specify it.
I would let the users decide how they want to use it.

@winstywang
Copy link

Any successful attempt on imagenet dataset? It seems the parameters for softmax and hinge loss differ greatly

@sguada
Copy link
Contributor Author

sguada commented May 9, 2014

@winstywang try with C=0.1 to scale down the gradients, initially they loss is too big.
Or initialize the weights in the previous fully connected to constant 0.

@winstywang
Copy link

@sguada It seems scaling down it by a factor of number of classes makes more sense. I'm still experimenting it...

@sguada
Copy link
Contributor Author

sguada commented May 10, 2014

@winstywang is it start converging? The training loss should start going down after few thousand iterations.

@winstywang
Copy link

@sguada Still running... only 1000 iters for now. I manually checked the initial gradient of softmax and l2 hinge loss. I find that this setting could at least match the gradient magnitudes. This setting works on CIFAR-10. Let's wait and see...

@sergeyk
Copy link
Contributor

sergeyk commented May 21, 2014

Let's get this in?

@sguada
Copy link
Contributor Author

sguada commented May 21, 2014

Yeah, I don't like to merge my own PR.
But if tested it then go ahead and merge it.

On Tuesday, May 20, 2014, Sergey Karayev notifications@github.com wrote:

Let's get this in?


Reply to this email directly or view it on GitHubhttps://github.com//pull/398#issuecomment-43712109
.

Sergio

@shelhamer
Copy link
Member

Wait was @longjon's concern about C ever resolved?

@sergeyk
Copy link
Contributor

sergeyk commented May 21, 2014

Well the last few comments are kind of unresolved. Did stuff converge?

On Tuesday, May 20, 2014, Sergio Guadarrama notifications@github.com
wrote:

Yeah, I don't like to merge my own PR.
But if tested it then go ahead and merge it.

On Tuesday, May 20, 2014, Sergey Karayev <notifications@github.comjavascript:_e(%7B%7D,'cvml','notifications@github.com');>
wrote:

Let's get this in?


Reply to this email directly or view it on GitHub<
https://github.com/BVLC/caffe/pull/398#issuecomment-43712109>
.

Sergio


Reply to this email directly or view it on GitHubhttps://github.com//pull/398#issuecomment-43712912
.

@winstywang
Copy link

@sergeyk I could get a reasonable result on cifar, but not imagenet. For imagenet, the training loss could decrease, but with a significantly slower rate. BTW, I always set C to the proportion of number of positive samples and negative samples.

@sguada
Copy link
Contributor Author

sguada commented May 21, 2014

Yeah, it converges if one sets C=0.1. With C=1, the default value it does
not.
So I think we leave the C parameter and just documented.

I think

On Tuesday, May 20, 2014, Sergey Karayev notifications@github.com wrote:

Well the last few comments are kind of unresolved. Did stuff converge?

On Tuesday, May 20, 2014, Sergio Guadarrama <notifications@github.comjavascript:_e(%7B%7D,'cvml','notifications@github.com');>

wrote:

Yeah, I don't like to merge my own PR.
But if tested it then go ahead and merge it.

On Tuesday, May 20, 2014, Sergey Karayev <notifications@github.comjavascript:_e(%7B%7D,'cvml','notifications@github.com');
<javascript:_e(%7B%7D,'cvml','notifications@github.comjavascript:_e(%7B%7D,'cvml','notifications@github.com');');>>

wrote:

Let's get this in?


Reply to this email directly or view it on GitHub<
https://github.com/BVLC/caffe/pull/398#issuecomment-43712109>
.

Sergio


Reply to this email directly or view it on GitHub<
https://github.com/BVLC/caffe/pull/398#issuecomment-43712912>
.


Reply to this email directly or view it on GitHubhttps://github.com//pull/398#issuecomment-43713081
.

Sergio

@shelhamer
Copy link
Member

Better to finish the conversation and rebase for a clean merge than just cross our fingers it's ok.

@sergeyk
Copy link
Contributor

sergeyk commented May 21, 2014

I would vote for no C, and @sguada should add documentation of this layer
in the new convention, then we can merge

On Tuesday, May 20, 2014, Evan Shelhamer notifications@github.com wrote:

Better to clear up the conversation and rebase for a clean merge then
cross our fingers in the dark.


Reply to this email directly or view it on GitHubhttps://github.com//pull/398#issuecomment-43713319
.

@longjon
Copy link
Contributor

longjon commented May 21, 2014

To be clear, the gradients we compute are momentum * (last gradient) - lr * weight_decay * w - lr * d(loss)/dw. Thus, (I can't check this right now, but) scaling up the loss should be exactly the same as scaling up base_lr and scaling down weight_decay by the same factor.

So, if we do include scaling on (any) loss, it is for the convenience of adjusting one parameter in the net prototxt instead of two in the solver prototxt, and comes at the cost of a bit of redundancy in the parameter space. (If we do decide we want that, I would be in favor of (@sguada's offline suggestion of) naming the scale factor something other than C, to avoid confusion with SVM C values, and I would also be in favor of an implementation that works across all loss layers (though not in this PR).)

@sguada
Copy link
Contributor Author

sguada commented May 21, 2014

@longjon agreed, I would renameC to loss_scale and rebase again, so other loss layers could incorporate that if needed.

Actually one should change the weight_decay and blob_lr in the previous layer and not in the solver.prototxt to get the same behaviour. Although in that case one could still get overflows on the loss.
I don't like having to change the parameters of other layers, just because I changed the loss function, but could be an option.

@longjon
Copy link
Contributor

longjon commented May 21, 2014

@sguada, is what I said above not true? I think this confusion is precisely why I don't find omitting the loss scaling parameter as onerous as you do. Isn't it true that if you scale L, you scale dL/dw for all w (i.e., the scale parameter propagates all the way down in backprop), so that it suffices to adjust exactly base_lr and weight_decay in the solver prototxt to achieve the same effect?

I agree that having to finagle layer multipliers to get the right tradeoff is rather obnoxious, but I don't see why that should be necessary here. I'll check this to be sure when I have a chance...

@sergeyk
Copy link
Contributor

sergeyk commented Jun 20, 2014

@longjon true, but when the learning rate is updated, the factor by which weight_decay was initially updated to scale the loss becomes incorrect, no?

@longjon
Copy link
Contributor

longjon commented Jun 20, 2014

@sergeyk, unless I've missed something, I see no problem with updating the learning rate. Learning rate schedules always have the form lr = base_lr * (some function of iteration number), so scaling base_lr should scale the learning rate by the same factor at every iteration (and weight_decay may be inversely scaled to scale only the loss).

Viewed another way, weight_decay picks the amount of L2 regularization, and base_lr sets the scale of the gradients; there are no more degrees of freedom.

@sguada
Copy link
Contributor Author

sguada commented Jun 20, 2014

@longjon Agreed, one would need to keep updating the weight_decay at the same time it is updating lr to keep things aligned with the intended loss scale.
At this point there is no simple way to update weight_decay in Caffe, but maybe it will be a good idea to add it.
So if you don't oppose, I will just change C to loss_scale, rebase it and hope to merge it soon.
I still think that having a way to scale the losses, specially when one could have different losses computed at the same time, i.e. multi-task problem, would be useful.

@longjon
Copy link
Contributor

longjon commented Jun 20, 2014

@sguada: No, I am saying exactly the opposite. Setting weight_decay and base_lr once at the beginning should have exactly the same effect as scaling the loss, learning rate schedules notwithstanding.

I'll try this empirically by tomorrow and post the results here.

You're quite right that we need a way to trade off multiple losses. But I think it would be better to do that generically in another PR, since all losses will need that option. In the end, the user will have the option to not use the loss scale with a single loss, gaining unique specification of the objective, or to scale the loss directly just as you've done here.

@sguada
Copy link
Contributor Author

sguada commented Jun 21, 2014

@longjon okay, then I will remove the loss_scale from the L2_hinge_loss and leave it for a more general PR about scaling losses.

sguada added 6 commits June 20, 2014 18:26
Conflicts:
	src/caffe/layers/loss_layer.cpp
	src/caffe/proto/caffe.proto
	src/caffe/test/test_l2_hinge_loss_layer.cpp
Conflicts:
	include/caffe/vision_layers.hpp
	src/caffe/layers/loss_layer.cpp
	src/caffe/proto/caffe.proto
	src/caffe/test/test_l2_hinge_loss_layer.cpp
Conflicts:
	src/caffe/layers/loss_layer.cpp
	src/caffe/test/test_hinge_loss_layer.cpp
Conflicts:
	src/caffe/layers/loss_layer.cpp
	src/caffe/test/test_hinge_loss_layer.cpp
Conflicts:
	src/caffe/layers/loss_layer.cpp
@sguada
Copy link
Contributor Author

sguada commented Jun 21, 2014

@longjon could you take a look and merge?

@longjon
Copy link
Contributor

longjon commented Jun 21, 2014

For the record, I can confirm empirically what I said above: one gets the same performance on LeNet either by setting (the old) c_param, or by scaling base_lr up and weight_decay down by the same factor, to within random error. Since we've converged for now, I won't worry about fixing the seed and checking that the weights are the same, unless someone wants me to.

I'll review the code as it is now and make sure it works for me by tomorrow.

@sguada
Copy link
Contributor Author

sguada commented Jun 21, 2014

@longjon many thanks for verifying empirically that scaling the loss has the same effect as changing the base_lr and weight_decay appropriately. This will be very useful for cases where the L2 diverges.

@longjon
Copy link
Contributor

longjon commented Jun 21, 2014

Okay, one big thing:

  • it seems not to build, due to lingering references to C_

and a few finicky things:

@sguada
Copy link
Contributor Author

sguada commented Jun 21, 2014

@longjon It seems that I run the tests in a different branch. Now should be okay.

@longjon longjon merged commit 60ff8ba into BVLC:dev Jun 21, 2014
longjon added a commit that referenced this pull request Jun 21, 2014
@longjon
Copy link
Contributor

longjon commented Jun 21, 2014

Merged. Thanks @sguada!

There was a lint error (trailing whitespace), but I went ahead and fixed it myself to avoid another iteration. I also added "L1" to the names of the original tests to be explicit.

@sguada
Copy link
Contributor Author

sguada commented Jun 21, 2014

Thanks @longjon for fixing the last bits. :)

On Saturday, June 21, 2014, longjon notifications@github.com wrote:

Merged. Thanks @sguada https://github.com/sguada!

There was a lint error (trailing whitespace), but I went ahead and fixed
it myself to avoid another iteration. I also added "L1" to the names of the
original tests to be explicit.


Reply to this email directly or view it on GitHub
#398 (comment).

Sergio

@HoldenCaulfieldRye
Copy link

Has anyone managed to obtain good performance on ImageNet using this L2 SVM as the final layer? Most applications of transfer learning in computer vision replace the top layer in AlexNet with a L2 SVM to achieve state of the art, so would be awesome to have this work on caffe!

@sguada in your first post you say that hinge_loss should help to implement L2 SVM. Do you mean that one can set the hyper-parameters of a hinge_loss layer to achieve a L2 SVM, or do you mean that the code could be re-used by someone wishing to write an implementation of L2 SVM?

@s9xie
Copy link

s9xie commented Aug 13, 2014

@HoldenCaulfieldRye

In your citation (as well as many others) they trained a linear svm on dumped CNN layer features. This is of course not the same as replacing the softmax with hingeloss layer. Though SVM is actually a shallow network that can be trained with Caffe.

People have discussions about this (Charlie Tang's paper and R-CNN). Based on my experience on Imagenet-2012, one-vs-all svm cannot work very well for datasets with larger number of classes (say, 100 and above) and it is extremely difficult to tune the hyper parameters. This might be due to the class imbalance problem.

If there is any one who tried one-vs-all hingeloss instead of softmax on Imagenet and saw some performance boost, I'd be appreciated to know that.

@HoldenCaulfieldRye
Copy link

@s9xie thanks for the reply, that's taught me a few things!

These points are off-topic, but:

  1. how would one emulate a linear SVM on Caffe? I'd say maybe a single layer on dumped CNN layer features, with Euclidean loss, but what activation functions? If by any chance you have such a prototxt lying around, would be great to check it out.

  2. can't the class imbalance problem be deftly tackled by renormalising the cost function by the prior?

mitmul pushed a commit to mitmul/caffe that referenced this pull request Sep 30, 2014
@jainanshul
Copy link

I have been looking for a prototxt file myself explaining how to use caffe to train linear SVM. Does anyone have an example that I could use?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants