Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/relay/qnn/op/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ bool QuantizeRel(const Array<Type>& types,
const auto* quantize_attrs = attrs.as<QuantizeAttrs>();
const Array<tvm::Expr> oshape = data->shape;
const DataType out_dtype = quantize_attrs->out_dtype;
CHECK(out_dtype == Int(8) || out_dtype == UInt(8))
<< "Output type should be one of [int8, unit8 ] but was " << out_dtype;
CHECK(out_dtype == Int(8) || out_dtype == UInt(8) || out_dtype == Int(32))
<< "Output type should be one of [int8, unit8, int32] but was " << out_dtype;
// assign output type
reporter->Assign(types[1], TensorTypeNode::make(oshape, out_dtype));
return true;
Expand All @@ -72,12 +72,12 @@ Expr MakeQuantize(Expr data,
Expr QuantizeLower(const Expr& input_tensor,
const QuantizeAttrs* attrs) {
const auto out_dtype = attrs->out_dtype;
const auto output_zero_point = MakeConstantScalar(Int(32), attrs->output_zero_point);
const auto output_zero_point = MakeConstantScalar(Float(32), attrs->output_zero_point);
const auto scale = MakeConstantScalar(Float(32), attrs->output_scale);
const int32_t min_val = GetQmin(out_dtype);
const int32_t max_val = GetQmax(out_dtype);
auto scale_data = Cast(Round(Divide(input_tensor, scale)), Int(32));
auto add_zero_point = Add(scale_data, output_zero_point);
auto scale_data = Divide(input_tensor, scale);
auto add_zero_point = Cast(Round(Add(scale_data, output_zero_point)), Int(32));
auto clamped_output = Clip(add_zero_point, min_val, max_val);
auto clamp_out_dtype = Cast(clamped_output, out_dtype);
return clamp_out_dtype;
Expand Down