diff --git a/include/af/autograd/Functions.hpp b/include/af/autograd/Functions.hpp index 5ca6ff7..e272f60 100644 --- a/include/af/autograd/Functions.hpp +++ b/include/af/autograd/Functions.hpp @@ -17,16 +17,30 @@ namespace af { Variable operator *(const Variable &lhs, const Variable &rhs); Variable operator -(const Variable &lhs, const Variable &rhs); Variable operator /(const Variable &lhs, const Variable &rhs); + Variable operator >(const Variable &lhs, const Variable &rhs); + Variable operator <(const Variable &lhs, const Variable &rhs); + Variable operator >=(const Variable &lhs, const Variable &rhs); + Variable operator <=(const Variable &lhs, const Variable &rhs); Variable operator +(const double &lhs, const Variable &rhs); Variable operator *(const double &lhs, const Variable &rhs); Variable operator -(const double &lhs, const Variable &rhs); Variable operator /(const double &lhs, const Variable &rhs); + Variable operator >(const double &lhs, const Variable &rhs); + Variable operator <(const double &lhs, const Variable &rhs); + Variable operator >=(const double &lhs, const Variable &rhs); + Variable operator <=(const double &lhs, const Variable &rhs); Variable operator +(const Variable &lhs, const double &rhs); Variable operator *(const Variable &lhs, const double &rhs); Variable operator -(const Variable &lhs, const double &rhs); Variable operator /(const Variable &lhs, const double &rhs); + Variable operator >(const Variable &lhs, const double &rhs); + Variable operator <(const Variable &lhs, const double &rhs); + Variable operator >=(const Variable &lhs, const double &rhs); + Variable operator <=(const Variable &lhs, const double &rhs); + + Variable operator !(const Variable &input); Variable negate(const Variable &input); Variable reciprocal(const Variable &input); @@ -41,6 +55,10 @@ namespace af { Variable max(const Variable &lhs, const double &rhs); Variable max(const double &lhs, const Variable &rhs); + Variable min(const Variable &lhs, const Variable &rhs); + Variable min(const Variable &lhs, const double &rhs); + Variable min(const double &lhs, const Variable &rhs); + Variable transpose(const Variable &input); Variable expandAs(const Variable &input, const Variable &reference); Variable reduceAs(const Variable &input, const Variable &reference); diff --git a/include/af/nn/Modules/Activations.hpp b/include/af/nn/Modules/Activations.hpp index 95beab9..2d00a90 100644 --- a/include/af/nn/Modules/Activations.hpp +++ b/include/af/nn/Modules/Activations.hpp @@ -48,5 +48,37 @@ namespace af autograd::Variable forward(const autograd::Variable &input); }; + + class PReLU : public Module + { + public: + PReLU(int size, double spread = 1.0); + PReLU(const autograd::Variable &w); + + autograd::Variable forward(const autograd::Variable &input); + }; + + class ELU : public Module + { + private: + double m_alpha; + public: + ELU(double alpha = 1.0); + + autograd::Variable forward(const autograd::Variable &input); + }; + + class ThresholdReLU : public Module + { + private: + double m_threshold; + public: + ThresholdReLU(double threshold = 1.0); + + autograd::Variable forward(const autograd::Variable &input); + }; + + + } } diff --git a/src/autograd/Functions.cpp b/src/autograd/Functions.cpp index 87e0f7f..2be0d86 100644 --- a/src/autograd/Functions.cpp +++ b/src/autograd/Functions.cpp @@ -61,12 +61,26 @@ namespace af { return Variable(result, false); } + Variable operator <(const Variable &lhs, const Variable &rhs) + { + auto result = lhs.array() < rhs.array(); + return Variable(result, false); + } + + Variable operator >=(const Variable &lhs, const Variable &rhs) + { + auto result = lhs.array() >= rhs.array(); + return Variable(result, false); + } + Variable operator <=(const Variable &lhs, const Variable &rhs) { auto result = lhs.array() <= rhs.array(); return Variable(result, false); } + + #define INSTANTIATE_OPERATOR(OP) \ Variable operator OP(const double &lhs_val, const Variable &rhs) \ { \ @@ -91,6 +105,8 @@ namespace af { INSTANTIATE_OPERATOR(*) INSTANTIATE_OPERATOR(/) INSTANTIATE_OPERATOR(>) + INSTANTIATE_OPERATOR(<) + INSTANTIATE_OPERATOR(>=) INSTANTIATE_OPERATOR(<=) #undef INSTANTIATE_OPERATOR @@ -103,14 +119,26 @@ namespace af { Variable max(const Variable &lhs, const Variable &rhs) { - auto mask = lhs > rhs; - auto result = max(lhs.array(), rhs.array()); - - auto grad_func = [](std::vector &inputs, const Variable &grad_output) { - inputs[0].addGrad( inputs[2] * grad_output); - inputs[1].addGrad(!inputs[2] * grad_output); - }; - return Variable(result, {lhs, rhs, mask}, grad_func); + auto mask = lhs > rhs; + auto result = max(lhs.array(), rhs.array()); + + auto grad_func = [](std::vector &inputs, const Variable &grad_output) { + inputs[0].addGrad( inputs[2] * grad_output); + inputs[1].addGrad(!inputs[2] * grad_output); + }; + return Variable(result, {lhs, rhs, mask}, grad_func); + } + + Variable min(const Variable &lhs, const Variable &rhs) + { + auto mask = lhs < rhs; + auto result = min(lhs.array(), rhs.array()); + + auto grad_func = [](std::vector &inputs, const Variable &grad_output) { + inputs[0].addGrad( inputs[2] * grad_output); + inputs[1].addGrad(!inputs[2] * grad_output); + }; + return Variable(result, {lhs, rhs, mask}, grad_func); } #define INSTANTIATE_FUNCTION(FN) \ @@ -134,6 +162,7 @@ namespace af { INSTANTIATE_FUNCTION(max); + INSTANTIATE_FUNCTION(min); #undef INSTANTIATE_FUNCTION diff --git a/src/nn/Modules/Activations.cpp b/src/nn/Modules/Activations.cpp index 05b9510..adedf26 100644 --- a/src/nn/Modules/Activations.cpp +++ b/src/nn/Modules/Activations.cpp @@ -9,7 +9,7 @@ #include #include - +#include namespace af { namespace nn @@ -46,5 +46,45 @@ namespace af { return max(input, m_slope * input); } + + PReLU::PReLU(int size, double spread) + { + auto w = nn::weight(size, 1, spread); + setParams({w}); + } + + PReLU::PReLU(const Variable &w) : + Module({w}) + { + } + + Variable PReLU::forward(const Variable &input) + { + auto mask = input >= 0.0; + return (input * mask) + (input * !mask * expandAs(m_parameters[0],input)); + } + + ELU::ELU(double alpha) : + m_alpha(alpha) + { + } + + Variable ELU::forward(const Variable &input) + { + auto mask = input >= 0.0; + return (mask * input) + (!mask * m_alpha * (exp(input)-1)); + } + + ThresholdReLU::ThresholdReLU(double threshold) : + m_threshold(threshold) + { + } + + Variable ThresholdReLU::forward(const Variable &input) + { + auto mask = input >= m_threshold; + return input * mask; + } + } }