diff --git a/src/caffelib.cc b/src/caffelib.cc index aa1757945..774aa3a84 100644 --- a/src/caffelib.cc +++ b/src/caffelib.cc @@ -160,6 +160,7 @@ namespace dd { std::vector layers = {50}; std::string activation = "ReLU"; + double elu_alpha = 1.0; double dropout = 0.5; if (ad.has("layers")) layers = ad.get("layers").get>(); @@ -170,6 +171,12 @@ namespace dd activation = "ReLU"; else if (dd_utils::iequals(activation,"prelu")) activation = "PReLU"; + else if (dd_utils::iequals(activation,"elu")) + { + activation = "ELU"; + if (ad.has("elu_alpha")) + elu_alpha = ad.get("elu_alpha").get(); + } else if (dd_utils::iequals(activation,"sigmoid")) activation = "Sigmoid"; else if (dd_utils::iequals(activation,"tanh")) @@ -308,6 +315,8 @@ namespace dd else lparam = net_param.add_layer(); lparam->set_name("act"+std::to_string(l)); lparam->set_type(activation); + if (activation == "ELU" && elu_alpha != 1.0) + lparam->mutable_elu_param()->set_alpha(elu_alpha); lparam->add_bottom(last_ip); lparam->add_top(last_ip); ++rl; @@ -325,6 +334,8 @@ namespace dd else dlparam = deploy_net_param.add_layer(); dlparam->set_name("act"+std::to_string(l)); dlparam->set_type(activation); + if (activation == "ELU" && elu_alpha != 1.0) + lparam->mutable_elu_param()->set_alpha(elu_alpha); dlparam->add_bottom(last_ip); dlparam->add_top(last_ip); ++drl; @@ -478,6 +489,7 @@ namespace dd //- get relevant configuration elements std::vector layers; std::string activation = "ReLU"; + double elu_alpha = 1.0; double dropout = 0.5; if (ad.has("layers")) try @@ -495,6 +507,12 @@ namespace dd activation = "ReLU"; else if (dd_utils::iequals(activation,"prelu")) activation = "PReLU"; + else if (dd_utils::iequals(activation,"elu")) + { + activation = "ELU"; + if (ad.has("elu_alpha")) + elu_alpha = ad.get("elu_alpha").get(); + } else if (dd_utils::iequals(activation,"sigmoid")) activation = "Sigmoid"; else if (dd_utils::iequals(activation,"tanh")) @@ -680,6 +698,8 @@ namespace dd else lparam = net_param.add_layer(); lparam->set_name("act"+std::to_string(ccount)); lparam->set_type(activation); + if (activation == "ELU" && elu_alpha != 1.0) + lparam->mutable_elu_param()->set_alpha(elu_alpha); lparam->add_bottom("conv"+std::to_string(ccount)); lparam->add_top("conv"+std::to_string(ccount)); ++rl; @@ -697,6 +717,8 @@ namespace dd else dlparam = deploy_net_param.add_layer(); dlparam->set_name("act"+std::to_string(ccount)); dlparam->set_type(activation); + if (activation == "ELU" && elu_alpha != 1.0) + lparam->mutable_elu_param()->set_alpha(elu_alpha); dlparam->add_bottom("conv"+std::to_string(ccount)); dlparam->add_top("conv"+std::to_string(ccount)); ++drl; @@ -817,6 +839,8 @@ namespace dd else lparam = net_param.add_layer(); lparam->set_name("act"+std::to_string(cact)); lparam->set_type(activation); + if (activation == "ELU" && elu_alpha != 1.0) + lparam->mutable_elu_param()->set_alpha(elu_alpha); lparam->add_bottom(last_ip); lparam->add_top(last_ip); ++rl; @@ -833,6 +857,8 @@ namespace dd else dlparam = deploy_net_param.add_layer(); dlparam->set_name("act"+std::to_string(cact)); dlparam->set_type(activation); + if (activation == "ELU" && elu_alpha != 1.0) + lparam->mutable_elu_param()->set_alpha(elu_alpha); dlparam->add_bottom(last_ip); dlparam->add_top(last_ip); ++drl;