Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 59 additions & 9 deletions include/caffe/filler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,18 @@ class PositiveUnitballFiller : public Filler<Dtype> {
};

/**
* @brief Fills a Blob with values @f$ x \sim U(-a, +a) @f$ where @f$ a @f$
* is set inversely proportional to the number of incoming nodes.
* @brief Fills a Blob with values @f$ x \sim U(-a, +a) @f$ where @f$ a @f$ is
* set inversely proportional to number of incoming nodes, outgoing
* nodes, or their average.
*
* A Filler based on the paper [Bengio and Glorot 2010]: Understanding
* the difficulty of training deep feedforward neuralnetworks, but does not
* use the fan_out value.
* the difficulty of training deep feedforward neuralnetworks.
*
* It fills the incoming matrix by randomly sampling uniform data from
* [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
* of input nodes. You should make sure the input blob has shape (num, a, b, c)
* where a * b * c = fan_in.
* It fills the incoming matrix by randomly sampling uniform data from [-scale,
* scale] where scale = sqrt(3 / n) where n is the fan_in, fan_out, or their
* average, depending on the variance_norm option. You should make sure the
* input blob has shape (num, a, b, c) where a * b * c = fan_in and num * b * c
* = fan_out. Note that this is currently not the case for inner product layers.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1970 is in so this filler is now right for InnerProduct layers too.

*
* TODO(dox): make notation in above comment consistent with rest & use LaTeX.
*/
Expand All @@ -149,14 +150,61 @@ class XavierFiller : public Filler<Dtype> {
virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
int fan_in = blob->count() / blob->num();
Dtype scale = sqrt(Dtype(3) / fan_in);
int fan_out = blob->count() / blob->channels();
Dtype n = fan_in; // default to fan_in
if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_AVERAGE) {
n = (fan_in + fan_out) / Dtype(2);
} else if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_FAN_OUT) {
n = fan_out;
}
Dtype scale = sqrt(Dtype(3) / n);
caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,
blob->mutable_cpu_data());
CHECK_EQ(this->filler_param_.sparse(), -1)
<< "Sparsity not supported by this Filler.";
}
};

/**
* @brief Fills a Blob with values @f$ x \sim N(0, \sigma^2) @f$ where
* @f$ \sigma^2 @f$ is set inversely proportional to number of incoming
* nodes, outgoing nodes, or their average.
*
* A Filler based on the paper [He, Zhang, Ren and Sun 2015]: Specifically
* accounts for ReLU nonlinearities.
*
* It fills the incoming matrix by randomly sampling Gaussian data with std =
* sqrt(2 / n) where n is the fan_in, fan_out, or their average, depending on
* the variance_norm option. You should make sure the input blob has shape (num,
* a, b, c) where a * b * c = fan_in and num * b * c = fan_out. Note that this
* is currently not the case for inner product layers.
*/
template <typename Dtype>
class MSRAFiller : public Filler<Dtype> {
public:
explicit MSRAFiller(const FillerParameter& param)
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
int fan_in = blob->count() / blob->num();
int fan_out = blob->count() / blob->channels();
Dtype n = fan_in; // default to fan_in
if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_AVERAGE) {
n = (fan_in + fan_out) / Dtype(2);
} else if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_FAN_OUT) {
n = fan_out;
}
Dtype std = sqrt(Dtype(2) / n);
caffe_rng_gaussian<Dtype>(blob->count(), Dtype(0), std,
blob->mutable_cpu_data());
CHECK_EQ(this->filler_param_.sparse(), -1)
<< "Sparsity not supported by this Filler.";
}
};

/**
* @brief Get a specific filler from the specification given in FillerParameter.
Expand All @@ -177,6 +225,8 @@ Filler<Dtype>* GetFiller(const FillerParameter& param) {
return new UniformFiller<Dtype>(param);
} else if (type == "xavier") {
return new XavierFiller<Dtype>(param);
} else if (type == "msra") {
return new MSRAFiller<Dtype>(param);
} else {
CHECK(false) << "Unknown filler name: " << param.type();
}
Expand Down
8 changes: 8 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ message FillerParameter {
// The expected number of non-zero output weights for a given input in
// Gaussian filler -- the default -1 means don't perform sparsification.
optional int32 sparse = 7 [default = -1];
// Normalize the filler variance by fan_in, fan_out, or their average.
// Applies to 'xavier' and 'msra' fillers.
enum VarianceNorm {
FAN_IN = 0;
FAN_OUT = 1;
AVERAGE = 2;
}
optional VarianceNorm variance_norm = 8 [default = FAN_IN];
}

message NetParameter {
Expand Down
98 changes: 98 additions & 0 deletions src/caffe/test/test_filler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,102 @@ TYPED_TEST(GaussianFillerTest, TestFill) {
EXPECT_LE(var, target_var * 5.);
}

template <typename Dtype>
class XavierFillerTest : public ::testing::Test {
protected:
XavierFillerTest()
: blob_(new Blob<Dtype>(1000, 2, 4, 5)),
filler_param_() {
}
virtual void test_params(FillerParameter_VarianceNorm variance_norm,
Dtype n) {
this->filler_param_.set_variance_norm(variance_norm);
this->filler_.reset(new XavierFiller<Dtype>(this->filler_param_));
this->filler_->Fill(blob_);
EXPECT_TRUE(this->blob_);
const int count = this->blob_->count();
const Dtype* data = this->blob_->cpu_data();
Dtype mean = 0.;
Dtype ex2 = 0.;
for (int i = 0; i < count; ++i) {
mean += data[i];
ex2 += data[i] * data[i];
}
mean /= count;
ex2 /= count;
Dtype std = sqrt(ex2 - mean*mean);
Dtype target_std = sqrt(2.0 / n);
EXPECT_NEAR(mean, 0.0, 0.1);
EXPECT_NEAR(std, target_std, 0.1);
}
virtual ~XavierFillerTest() { delete blob_; }
Blob<Dtype>* const blob_;
FillerParameter filler_param_;
shared_ptr<XavierFiller<Dtype> > filler_;
};

TYPED_TEST_CASE(XavierFillerTest, TestDtypes);

TYPED_TEST(XavierFillerTest, TestFillFanIn) {
TypeParam n = 2*4*5;
this->test_params(FillerParameter_VarianceNorm_FAN_IN, n);
}
TYPED_TEST(XavierFillerTest, TestFillFanOut) {
TypeParam n = 1000*4*5;
this->test_params(FillerParameter_VarianceNorm_FAN_OUT, n);
}
TYPED_TEST(XavierFillerTest, TestFillAverage) {
TypeParam n = (2*4*5 + 1000*4*5) / 2.0;
this->test_params(FillerParameter_VarianceNorm_AVERAGE, n);
}

template <typename Dtype>
class MSRAFillerTest : public ::testing::Test {
protected:
MSRAFillerTest()
: blob_(new Blob<Dtype>(1000, 2, 4, 5)),
filler_param_() {
}
virtual void test_params(FillerParameter_VarianceNorm variance_norm,
Dtype n) {
this->filler_param_.set_variance_norm(variance_norm);
this->filler_.reset(new MSRAFiller<Dtype>(this->filler_param_));
this->filler_->Fill(blob_);
EXPECT_TRUE(this->blob_);
const int count = this->blob_->count();
const Dtype* data = this->blob_->cpu_data();
Dtype mean = 0.;
Dtype ex2 = 0.;
for (int i = 0; i < count; ++i) {
mean += data[i];
ex2 += data[i] * data[i];
}
mean /= count;
ex2 /= count;
Dtype std = sqrt(ex2 - mean*mean);
Dtype target_std = sqrt(2.0 / n);
EXPECT_NEAR(mean, 0.0, 0.1);
EXPECT_NEAR(std, target_std, 0.1);
}
virtual ~MSRAFillerTest() { delete blob_; }
Blob<Dtype>* const blob_;
FillerParameter filler_param_;
shared_ptr<MSRAFiller<Dtype> > filler_;
};

TYPED_TEST_CASE(MSRAFillerTest, TestDtypes);

TYPED_TEST(MSRAFillerTest, TestFillFanIn) {
TypeParam n = 2*4*5;
this->test_params(FillerParameter_VarianceNorm_FAN_IN, n);
}
TYPED_TEST(MSRAFillerTest, TestFillFanOut) {
TypeParam n = 1000*4*5;
this->test_params(FillerParameter_VarianceNorm_FAN_OUT, n);
}
TYPED_TEST(MSRAFillerTest, TestFillAverage) {
TypeParam n = (2*4*5 + 1000*4*5) / 2.0;
this->test_params(FillerParameter_VarianceNorm_AVERAGE, n);
}

} // namespace caffe