Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions src/operator/nn/dnnl/dnnl_fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

#if MXNET_USE_ONEDNN == 1

#include <memory>
#include <unordered_map>
#include <string>
#include <vector>

Expand All @@ -41,6 +43,8 @@ struct DNNLFCParam : public dmlc::Parameter<DNNLFCParam> {
bool quantized;
bool enable_float_output;
bool with_eltwise;
bool with_sum;
bool first_quantization_pass; // True for operator created during first quantization pass
dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset
dmlc::optional<bool> channel_wise_quantize;
Expand All @@ -54,6 +58,10 @@ struct DNNLFCParam : public dmlc::Parameter<DNNLFCParam> {
DMLC_DECLARE_FIELD(with_eltwise)
.set_default(false)
.describe("Whether there's a post with_eltwise after FullyConnected operator");
DMLC_DECLARE_FIELD(with_sum).set_default(false).describe("Add post sum");
DMLC_DECLARE_FIELD(first_quantization_pass)
.set_default(false)
.describe("True for first quantization pass");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
.describe(
Expand All @@ -76,9 +84,86 @@ struct DNNLFCFullParam {
FullyConnectedParam default_param;
DNNLFCParam dnnl_param;
DNNLPostEltwiseParam eltwise_param;
float sum_scale = {1.0f};
std::vector<float> output_scales = {0.0f};
};

static inline size_t GetInSumIndex(const DNNLFCFullParam& param) {
assert(param.dnnl_param.with_sum);
return fullc::kWeight + 1 + (param.default_param.no_bias ? 0 : 1);
}

class FCInputIndex {
public:
explicit FCInputIndex(const DNNLFCFullParam full_param) {
auto& dnnl_param = full_param.dnnl_param;
const bool has_bias = !full_param.default_param.no_bias;
const bool quantized = dnnl_param.quantized;
const bool sum_input_quantized =
quantized && dnnl_param.with_sum && !dnnl_param.enable_float_output;
const bool channel_wise = quantized && dnnl_param.channel_wise_quantize.has_value() &&
dnnl_param.channel_wise_quantize.value();

// Calculate position of particular input in the input vector:
int index = 0;
data = index++;
weight = index++;
bias = has_bias ? index++ : 0;
sum = dnnl_param.with_sum ? index++ : 0;
num_base = index; // note number of base inputs

data_min = quantized ? index++ : 0;
data_max = quantized ? index++ : 0;
weight_min = (quantized && !channel_wise) ? index++ : 0;
weight_max = (quantized && !channel_wise) ? index++ : 0;
bias_min = (quantized && !channel_wise && has_bias) ? index++ : 0;
bias_max = (quantized && !channel_wise && has_bias) ? index++ : 0;
sum_min = sum_input_quantized ? index++ : 0;
sum_max = sum_input_quantized ? index++ : 0;
num_total = index; // note number of total inputs
}

// Returns true if sum input exists
bool IsSumExist() const {
return sum;
}

// Returns true if bias input exists
bool IsBiasExist() const {
return bias;
}

// Returns true if sum input exists and it is float number
bool IsSumInputFloat() const {
return (sum && !sum_min);
}
int GetTotal() const {
return num_total;
}
int GetBase() const {
return num_base;
}

// Represent index of particular input in the input vector:
int data;
int weight;
int bias;
int sum;
int data_min;
int data_max;
int weight_min;
int weight_max;
int bias_min;
int bias_max;
int sum_min;
int sum_max;

private:
int num_base; // Number of standard inputs
int num_total; // Number of total inputs: standard + additional needed for
// quantization
};

dnnl::inner_product_forward::primitive_desc GetFCFwdImpl(const DNNLFCFullParam& full_param,
const bool is_train,
const NDArray& data,
Expand Down
3 changes: 3 additions & 0 deletions src/operator/nn/dnnl/dnnl_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ dnnl::inner_product_forward::primitive_desc GetFCFwdImpl(const DNNLFCFullParam&
full_param.eltwise_param.alpha,
full_param.eltwise_param.beta);
}
if (full_param.dnnl_param.with_sum) {
ops.append_sum(full_param.sum_scale);
}
attr.set_post_ops(ops);

if (full_param.dnnl_param.quantized && full_param.output_scales.size()) {
Expand Down
11 changes: 11 additions & 0 deletions src/operator/subgraph/build_subgraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,17 @@ void CreateSubgraphNode(nnvm::Graph* g,
for (BiDirectedNode* dest_node : subgraph_nodes) {
sn->outputs.erase(dest_node->node);
}
}
}

// Set outputs according to current inputs
for (size_t i = 0; i < n->inputs.size(); ++i) {
auto& e = n->inputs[i];
// update input entries' source simple nodes' outputs map
nnvm::Node* node = e.node.get();
if (indexed_graph.exist(node)) {
const auto nid = indexed_graph.node_id(node);
BiDirectedNode* sn = simple_nodes[nid].get();
sn->outputs[n.get()].push_back(i);
}
}
Expand Down
Loading