-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfully_connect_filter.hpp
More file actions
40 lines (33 loc) · 973 Bytes
/
fully_connect_filter.hpp
File metadata and controls
40 lines (33 loc) · 973 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#ifndef FULLYCONNECT_HPP
#define FULLYCONNECT_HPP
#include "block.hpp"
#include "math.h"
#include "filter.hpp"
#include "filler_filter.hpp"
namespace fool {
template<typename Dtype>
class FullyConnectFilter: public Filter<Dtype>{
public:
explicit FullyConnectFilter(const vector<vector<int>>& blob_shapes)
: Filter<Dtype>(blob_shapes){
m_K = blob_shapes[0][0];
m_N = blob_shapes[1][0];
}
virtual void Reshape(const std::vector<Block<Dtype>*>& inputs,
const std::vector<Block<Dtype>*>& outputs);
virtual void FilterInitialize();
virtual void Forward_cpu(const std::vector<Block<Dtype>*>& inputs,
const std::vector<Block<Dtype>*>& outputs);
virtual void Backward_cpu(const std::vector<Block<Dtype>*>& outputs,
const std::vector<Block<Dtype>*>& inputs);
// input C*H*W
int m_K;
// output C*H*W
int m_N;
// batch size
int m_M;
// bias_term in output
Block<Dtype> m_output_bias;
};
}
#endif // FULLYCONNECT_HPP