Skip to content

Conversation

@eqy
Copy link
Contributor

@eqy eqy commented Jun 5, 2019

Support for features described in #2651. Currently this PR focuses mainly on vision models where there is a well known dataset (ImageNet).

@ZihengJiang ZihengJiang self-assigned this Jun 6, 2019
@eqy eqy mentioned this pull request Jun 6, 2019
return graph

def _evaluate(val_data, batch_fn, graph, lib, params, ctx, free_vars=[], config=[], num_classes=1000, early_stopping=32, log_iter=2):
import mxnet as mx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not assume MXNET is installed.

# setup evaluaiton metric
acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)
val_data.reset()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not assume calibration data comes from MXNET iterator. I think a Python list of tuple is enough, as we only need around thousand images for calibration.

Copy link
Contributor Author

@eqy eqy Jun 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case we only need ~32 images for calibration, but I agree that passing lists or tuples of NDArrays can work.

if data_scale.shape[0] == 1:
assert weight_scale.shape[0] == 1
else:
assert weight_scale.shape[0] == data_scale.shape[0] or\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assert is wrong under group convolution. reproduce by quantizing a ResNeXt.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the tip; can you point to an off-the-shelf definition of ResNeXT? I'm currently using gluon model zoo to easily import graphs and parameters; it would be great if something similar exists for ResNeXT.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def load_gluon_cv():
  #block = gluoncv.model_zoo.resnet.resnet18_v1(pretrained=True)
  block = gluoncv.model_zoo.resnext.resnext50_32x4d(pretrained=True)
  #block = gluoncv.model_zoo.mobilenet.mobilenet_v2_1_0(pretrained=True)
  net, params = relay.frontend.from_mxnet(block,
                                          shape={"data": (1, 3, 224, 224)})
  return net, params

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I can reproduce this issue.
It seems handling general grouped convolutions are a little trickier than the pure depthwise case, because we currently view per-channel scales from the "input" perspective, but it should be doable.
Vanilla conv2d:
During calibration data and weight scales must be matched so that accumulation scales match.
During realize we can just multiply the scale tensors (vectors) together to get the result.
Depthwise conv2d:
During calibration data and weight scales do not need to be matched because there is no accumulation across channels.
During realize we can just multiply the scales together and broadcasting will yield the correct result.
Grouped conv2d:
During calibration data and weight scales must be matched locally within each group, but we can only manipulate the data scales to match the accumulation scale if we wish to avoid any sophisticated co-optimization between both data and weight scales.
During realize we cannot just broadcast using scale tensors themselves, but we might be able to do something like data_scale * weight_scale[0] -> downsample to get the output scale.

I will try this to see if it works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a temp solution running antinucleon@2035899

The ResNext will drop 1% (which I am not too worried yet because once full pipeline is able to run, we can tune it later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I have something similar implemented and I am evaluating the current accuracy. Looks like I get 76.4% for per-channel with 64 calibration samples and per-layer will be > 77% (still running).

else:
assert weight_scale.shape[0] == data_scale.shape[0] or\
weight_scale.shape[0] == 1 # depthwise, no need to unify scales
if weight_scale.shape[0] != 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant condition.


CHECK(dom_scale.as<ConstantNode>());
CHECK(dom_scale.as<ConstantNode>()->tensor_type()->shape.size() == 1);
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add support of group conv

float cur_output_scale =\
reinterpret_cast<float*>(data_scale_tensor->data->data)[i]*\
reinterpret_cast<float*>(weight_scale_tensor->data->data)[i];
CHECK(cur_output_scale == max_output_scale);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't compare two float number equal in this way. Also what it meaning of this check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really to compare two float values are strictly bitwise equal (or all the bytes here are equal), with the assumption that each scale is a power of two so that multiplication happens perfectly. You are right that this is not generally valid if we allow free non-power of 2 scales.
The meaning of the check is to compare that the scales match in the middle of the accumulation during convolution. max_output_scale is just a crude heuristic; we use the max of possible accumulation scales across all the channels during calibration. This check tests whether the match_scales pass during calibration worked correctly. There is likely more fine-tuning potential here, but we found that max gives the most consistent results.

@antinucleon
Copy link
Contributor

antinucleon commented Jun 10, 2019

I hit a max_pool2D realize issue with this IR:

fn (%data: Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 1000), float32] {
  %0 = relay.op.annotation.simulated_quantize(%data, meta[relay.Constant][0], -127f, 127f, kind=1, passthrough=1, granularity="channel", layout="NCHW")
  %1 = relay.op.annotation.simulated_quantize(meta[relay.Constant][1] /* ty=Tensor[(64, 3, 7, 7), float32] */ /* ty=Tensor[(64, 3, 7, 7), float32] */, meta[relay.Constant][2], -127f, 127f, kind=2, granularity="channel", layout="OIHW")
  %2 = nn.conv2d(%0, %1, strides=[2, 2], padding=[3, 3], channels=64, kernel_size=[7, 7])
  %3 = relay.op.annotation.simulated_quantize(meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */ /* ty=Tensor[(64, 1, 1), float32] */, meta[relay.Constant][4], -127f, 127f, kind=2, granularity="channel", layout="NCHW", op_hint="broadcastable_add")
  %4 = add(%2, %3)
  %5 = nn.relu(%4)
  %6 = relay.op.annotation.simulated_quantize(%5, meta[relay.Constant][5], -127f, 127f, kind=1, passthrough=1, granularity="channel", layout="NCHW")
  %7 = nn.max_pool2d(%6, pool_size=[3, 3], strides=[2, 2], padding=[1, 1])
  %8 = relay.op.annotation.simulated_quantize(%7, meta[relay.Constant][6], -127f, 127f, kind=1, passthrough=1, granularity="channel", layout="NCHW", op_hint="nn.max_pool2d")
  %9 = relay.op.annotation.simulated_quantize(meta[relay.Constant][7] /* ty=Tensor[(64, 1, 1), float32] */ /* ty=Tensor[(64, 1, 1), float32] */, meta[relay.Constant][8], -127f, 127f, kind=2, granularity="channel", layout="NCHW", op_hint="broadcastable_mul")
  %10 = multiply(%8, %9)

When realize %10, %8 is int8 while %9 is float32, so realize will fail.

@eqy
Copy link
Contributor Author

eqy commented Jun 10, 2019

I hit a max_pool2D realize issue with this IR:

fn (%data: Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 1000), float32] {
  %0 = relay.op.annotation.simulated_quantize(%data, meta[relay.Constant][0], -127f, 127f, kind=1, passthrough=1, granularity="channel", layout="NCHW")
  %1 = relay.op.annotation.simulated_quantize(meta[relay.Constant][1] /* ty=Tensor[(64, 3, 7, 7), float32] */ /* ty=Tensor[(64, 3, 7, 7), float32] */, meta[relay.Constant][2], -127f, 127f, kind=2, granularity="channel", layout="OIHW")
  %2 = nn.conv2d(%0, %1, strides=[2, 2], padding=[3, 3], channels=64, kernel_size=[7, 7])
  %3 = relay.op.annotation.simulated_quantize(meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */ /* ty=Tensor[(64, 1, 1), float32] */, meta[relay.Constant][4], -127f, 127f, kind=2, granularity="channel", layout="NCHW", op_hint="broadcastable_add")
  %4 = add(%2, %3)
  %5 = nn.relu(%4)
  %6 = relay.op.annotation.simulated_quantize(%5, meta[relay.Constant][5], -127f, 127f, kind=1, passthrough=1, granularity="channel", layout="NCHW")
  %7 = nn.max_pool2d(%6, pool_size=[3, 3], strides=[2, 2], padding=[1, 1])
  %8 = relay.op.annotation.simulated_quantize(%7, meta[relay.Constant][6], -127f, 127f, kind=1, passthrough=1, granularity="channel", layout="NCHW", op_hint="nn.max_pool2d")
  %9 = relay.op.annotation.simulated_quantize(meta[relay.Constant][7] /* ty=Tensor[(64, 1, 1), float32] */ /* ty=Tensor[(64, 1, 1), float32] */, meta[relay.Constant][8], -127f, 127f, kind=2, granularity="channel", layout="NCHW", op_hint="broadcastable_mul")
  %10 = multiply(%8, %9)

When realize %10, %8 is int8 while %9 is int32, so realize will fail.

Do you know why %9 is int32 in this case? It looks like its (KIND) is weight, which we usually set to int8 quantization.

@antinucleon
Copy link
Contributor

antinucleon commented Jun 10, 2019

I hit a max_pool2D realize issue with this IR:

fn (%data: Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 1000), float32] {
  %0 = relay.op.annotation.simulated_quantize(%data, meta[relay.Constant][0], -127f, 127f, kind=1, passthrough=1, granularity="channel", layout="NCHW")
  %1 = relay.op.annotation.simulated_quantize(meta[relay.Constant][1] /* ty=Tensor[(64, 3, 7, 7), float32] */ /* ty=Tensor[(64, 3, 7, 7), float32] */, meta[relay.Constant][2], -127f, 127f, kind=2, granularity="channel", layout="OIHW")
  %2 = nn.conv2d(%0, %1, strides=[2, 2], padding=[3, 3], channels=64, kernel_size=[7, 7])
  %3 = relay.op.annotation.simulated_quantize(meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */ /* ty=Tensor[(64, 1, 1), float32] */, meta[relay.Constant][4], -127f, 127f, kind=2, granularity="channel", layout="NCHW", op_hint="broadcastable_add")
  %4 = add(%2, %3)
  %5 = nn.relu(%4)
  %6 = relay.op.annotation.simulated_quantize(%5, meta[relay.Constant][5], -127f, 127f, kind=1, passthrough=1, granularity="channel", layout="NCHW")
  %7 = nn.max_pool2d(%6, pool_size=[3, 3], strides=[2, 2], padding=[1, 1])
  %8 = relay.op.annotation.simulated_quantize(%7, meta[relay.Constant][6], -127f, 127f, kind=1, passthrough=1, granularity="channel", layout="NCHW", op_hint="nn.max_pool2d")
  %9 = relay.op.annotation.simulated_quantize(meta[relay.Constant][7] /* ty=Tensor[(64, 1, 1), float32] */ /* ty=Tensor[(64, 1, 1), float32] */, meta[relay.Constant][8], -127f, 127f, kind=2, granularity="channel", layout="NCHW", op_hint="broadcastable_mul")
  %10 = multiply(%8, %9)

When realize %10, %8 is int8 while %9 is int32, so realize will fail.

Do you know why %9 is int32 in this case? It looks like its (KIND) is weight, which we usually set to int8 quantization.

I am also trying to debug it and no clue yet. I double checked %9 is float32, but %8 is int8, so it failed in the check at quantize.cc:715 (dtype is int32)

    if (lhs->dtype == Float(32)) {
      ldata = Cast(ldata, dtype);
    } else {
      CHECK_EQ(lhs->dtype, dtype);
    }

It will be great that we can solve it together.

Reproduce steps:

  1. Download pretrained model from: https://dl.fbaipublicfiles.com/octconv/ablation/a02_resnet-26_alpha-0.500.tar
  2. untar
  3. loading code:
def load(prefix, epoch=0):
  sym, arg_params, aux_params = \
    mx.model.load_checkpoint(prefix, 0)
  net, params = relay.frontend.from_mxnet(symbol=sym,
                          shape={"data": (1, 3, 224, 224)},
                          arg_params=arg_params,
                          aux_params=aux_params)
  return net, params

net, params = load("./model/a02_resnet-26_alpha-0.500/checkpoint-0", 0)

@antinucleon
Copy link
Contributor

I hit a max_pool2D realize issue with this IR:

fn (%data: Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 1000), float32] {
  %0 = relay.op.annotation.simulated_quantize(%data, meta[relay.Constant][0], -127f, 127f, kind=1, passthrough=1, granularity="channel", layout="NCHW")
  %1 = relay.op.annotation.simulated_quantize(meta[relay.Constant][1] /* ty=Tensor[(64, 3, 7, 7), float32] */ /* ty=Tensor[(64, 3, 7, 7), float32] */, meta[relay.Constant][2], -127f, 127f, kind=2, granularity="channel", layout="OIHW")
  %2 = nn.conv2d(%0, %1, strides=[2, 2], padding=[3, 3], channels=64, kernel_size=[7, 7])
  %3 = relay.op.annotation.simulated_quantize(meta[relay.Constant][3] /* ty=Tensor[(64, 1, 1), float32] */ /* ty=Tensor[(64, 1, 1), float32] */, meta[relay.Constant][4], -127f, 127f, kind=2, granularity="channel", layout="NCHW", op_hint="broadcastable_add")
  %4 = add(%2, %3)
  %5 = nn.relu(%4)
  %6 = relay.op.annotation.simulated_quantize(%5, meta[relay.Constant][5], -127f, 127f, kind=1, passthrough=1, granularity="channel", layout="NCHW")
  %7 = nn.max_pool2d(%6, pool_size=[3, 3], strides=[2, 2], padding=[1, 1])
  %8 = relay.op.annotation.simulated_quantize(%7, meta[relay.Constant][6], -127f, 127f, kind=1, passthrough=1, granularity="channel", layout="NCHW", op_hint="nn.max_pool2d")
  %9 = relay.op.annotation.simulated_quantize(meta[relay.Constant][7] /* ty=Tensor[(64, 1, 1), float32] */ /* ty=Tensor[(64, 1, 1), float32] */, meta[relay.Constant][8], -127f, 127f, kind=2, granularity="channel", layout="NCHW", op_hint="broadcastable_mul")
  %10 = multiply(%8, %9)

When realize %10, %8 is int8 while %9 is int32, so realize will fail.

Do you know why %9 is int32 in this case? It looks like its (KIND) is weight, which we usually set to int8 quantization.

I am also trying to debug it and no clue yet. I double checked %9 is float32, but %8 is int8, so it failed in the check at quantize.cc:715 (dtype is int32)

    if (lhs->dtype == Float(32)) {
      ldata = Cast(ldata, dtype);
    } else {
      CHECK_EQ(lhs->dtype, dtype);
    }

It will be great that we can solve it together.

Reproduce steps:

  1. Download pretrained model from: https://dl.fbaipublicfiles.com/octconv/ablation/a02_resnet-26_alpha-0.500.tar
  2. untar
  3. loading code:
def load(prefix, epoch=0):
  sym, arg_params, aux_params = \
    mx.model.load_checkpoint(prefix, 0)
  net, params = relay.frontend.from_mxnet(symbol=sym,
                          shape={"data": (1, 3, 224, 224)},
                          arg_params=arg_params,
                          aux_params=aux_params)
  return net, params

net, params = load("./model/a02_resnet-26_alpha-0.500/checkpoint-0", 0)

Thanks to @ajtulloch finding a solution to this bug.
Problem: when a bn follows with max_pool, MulRealize will fail to handle int8 output from max_pool

diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc
index 3a2e54c8..4059dc3a 100644
--- a/src/relay/pass/quantize.cc
+++ b/src/relay/pass/quantize.cc
@@ -340,18 +340,9 @@ Expr MulRealize(const Call& ref_call,
     const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
     Expr ldata = lhs->data;
     Expr rdata = rhs->data;
-
     DataType dtype = cfg->dtype_activation;
-    if (lhs->dtype == Float(32)) {
-      ldata = Cast(ldata, dtype);
-    } else {
-      CHECK_EQ(lhs->dtype, dtype);
-    }
-    if (rhs->dtype == Float(32)) {
-      rdata = Cast(rdata, dtype);
-    } else {
-      CHECK_EQ(rhs->dtype, dtype);
-    }
+    ldata = Cast(ldata, dtype);
+    rdata = Cast(rdata, dtype);
 
     Expr ret = ForwardOp(ref_call, {ldata, rdata});
     Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale));

CHECK_EQ(rhs->dtype, dtype);
}
ldata = Cast(ldata, dtype);
rdata = Cast(rdata, dtype);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I discussed with @ajtulloch offline. We may have overflow issue when case fp32 directly. So maybe we should add range check before cast.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the failure mode in this case (e.g., if there is overflow, is it considered fatal or do we replace with a clamp)?

@antinucleon
Copy link
Contributor

In order to make quantized model running with AVX512, we have to set data_type to uint8. Currently the type_solver is not able to handle this: int32 = conv(uint8, int8)

A dirty way to make the pipeline running is:

diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc
index 84f72e0d..c083f1e1 100644
--- a/src/relay/pass/type_solver.cc
+++ b/src/relay/pass/type_solver.cc
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -116,6 +116,9 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
       return lhs->resolved_type;
     } else {
       Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type);
+      if (!resolved.defined()) {
+        return rhs->resolved_type;
+      }
       CHECK(resolved.defined())
         << "Unable to unify parent types: "
         << lhs->resolved_type << " and " << rhs->resolved_type;

We have to make type_solver working in clean way to handle quantized type solver.

@tqchen
Copy link
Member

tqchen commented Jul 10, 2019

Given the importance of quantization, it would be great if we can work to merge one version of the PR before we iterate.

@eqy can you work to get one version of the PR ready so we can merge it soon. Then @vinx13 @antinucleon @ZihengJiang can help iterate.

Thanks

@XiaotaoChen
Copy link

XiaotaoChen commented Jul 12, 2019

Hi @tqchen @eqy @ZihengJiang @vinx13 , I tried this PR for channel quantization. It works in resnet(according to the code of @antinucleon as before). But it failed on my own detection model, which works on master or my old version. the error infos as below. Do you have any idea about this error? Under what cases will cause this error? wish you can provide some clues. Thx.

[11:58:41] /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/src/relay/pass/pass_manager.cc:377: Executing module pass : InferType with opt level: 0

[11:58:41] /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/src/relay/pass/pass_manager.cc:377: Executing module pass : InferType with opt level: 0

Traceback (most recent call last):
  File "core/test.py", line 76, in <module>
    quantize_method=pFramework.quantize_method)
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_usage_summary/core/TVMPipeline.py", line 68, in __init__
    granularity='layer', gpuid=self.device_id)
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/python/tvm/relay/quantize/quantize.py", line 909, in autoquantize
    graph, lib, params = relay.build(graph, target)
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/python/tvm/relay/build_module.py", line 196, in build
    params)
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/python/tvm/relay/build_module.py", line 107, in build
    self._build(func, target, target_host)
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/python/tvm/_ffi/_ctypes/function.py", line 209, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbcfe6b) [0x7f6589e4ae6b]
  [bt] (7) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbcebab) [0x7f6589e49bab]
  [bt] (6) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbb5ebd) [0x7f6589e30ebd]
  [bt] (5) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbafa3e) [0x7f6589e2aa3e]
  [bt] (4) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbb4fcf) [0x7f6589e2ffcf]
  [bt] (3) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbb1908) [0x7f6589e2c908]
  [bt] (2) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xc43e73) [0x7f6589ebee73]
  [bt] (1) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbb84d0) [0x7f6589e334d0]
  [bt] (0) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbb3db9) [0x7f6589e2edb9]
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/src/relay/backend/graph_plan_memory.cc", line 87
TVMError: Check failed: tok.size() == 1U (3 vs. 1) : 

@vinx13
Copy link
Member

vinx13 commented Jul 12, 2019

Hi @tqchen @eqy @ZihengJiang @vinx13 , I tried this PR for channel quantization. It works in resnet(according to the code of @antinucleon as before). But it failed on my own detection model, which works on master or my old version. the error infos as below. Do you have any idea about this error? Under what cases will cause this error? wish you can provide some clues. Thx.

[11:58:41] /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/src/relay/pass/pass_manager.cc:377: Executing module pass : InferType with opt level: 0

[11:58:41] /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/src/relay/pass/pass_manager.cc:377: Executing module pass : InferType with opt level: 0

Traceback (most recent call last):
  File "core/test.py", line 76, in <module>
    quantize_method=pFramework.quantize_method)
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_usage_summary/core/TVMPipeline.py", line 68, in __init__
    granularity='layer', gpuid=self.device_id)
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/python/tvm/relay/quantize/quantize.py", line 909, in autoquantize
    graph, lib, params = relay.build(graph, target)
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/python/tvm/relay/build_module.py", line 196, in build
    params)
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/python/tvm/relay/build_module.py", line 107, in build
    self._build(func, target, target_host)
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/python/tvm/_ffi/_ctypes/function.py", line 209, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbcfe6b) [0x7f6589e4ae6b]
  [bt] (7) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbcebab) [0x7f6589e49bab]
  [bt] (6) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbb5ebd) [0x7f6589e30ebd]
  [bt] (5) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbafa3e) [0x7f6589e2aa3e]
  [bt] (4) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbb4fcf) [0x7f6589e2ffcf]
  [bt] (3) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbb1908) [0x7f6589e2c908]
  [bt] (2) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xc43e73) [0x7f6589ebee73]
  [bt] (1) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbb84d0) [0x7f6589e334d0]
  [bt] (0) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbb3db9) [0x7f6589e2edb9]
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/src/relay/backend/graph_plan_memory.cc", line 87
TVMError: Check failed: tok.size() == 1U (3 vs. 1) : 

looks like you have nested tuple that is not yet supported by graph memory planer

@XiaotaoChen
Copy link

XiaotaoChen commented Jul 12, 2019

Hi @tqchen @eqy @ZihengJiang @vinx13 , I tried this PR for channel quantization. It works in resnet(according to the code of @antinucleon as before). But it failed on my own detection model, which works on master or my old version. the error infos as below. Do you have any idea about this error? Under what cases will cause this error? wish you can provide some clues. Thx.

[11:58:41] /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/src/relay/pass/pass_manager.cc:377: Executing module pass : InferType with opt level: 0

[11:58:41] /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/src/relay/pass/pass_manager.cc:377: Executing module pass : InferType with opt level: 0

Traceback (most recent call last):
  File "core/test.py", line 76, in <module>
    quantize_method=pFramework.quantize_method)
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_usage_summary/core/TVMPipeline.py", line 68, in __init__
    granularity='layer', gpuid=self.device_id)
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/python/tvm/relay/quantize/quantize.py", line 909, in autoquantize
    graph, lib, params = relay.build(graph, target)
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/python/tvm/relay/build_module.py", line 196, in build
    params)
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/python/tvm/relay/build_module.py", line 107, in build
    self._build(func, target, target_host)
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/python/tvm/_ffi/_ctypes/function.py", line 209, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbcfe6b) [0x7f6589e4ae6b]
  [bt] (7) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbcebab) [0x7f6589e49bab]
  [bt] (6) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbb5ebd) [0x7f6589e30ebd]
  [bt] (5) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbafa3e) [0x7f6589e2aa3e]
  [bt] (4) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbb4fcf) [0x7f6589e2ffcf]
  [bt] (3) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbb1908) [0x7f6589e2c908]
  [bt] (2) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xc43e73) [0x7f6589ebee73]
  [bt] (1) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbb84d0) [0x7f6589e334d0]
  [bt] (0) /mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/build/libtvm.so(+0xbb3db9) [0x7f6589e2edb9]
  File "/mnt/tscpfs/xiaotao.chen/Repositories/tvm_channel/src/relay/backend/graph_plan_memory.cc", line 87
TVMError: Check failed: tok.size() == 1U (3 vs. 1) : 

looks like you have nested tuple that is not yet supported by graph memory planer

Thanks for your reply. the detection model has 3 outputs. That grouped by mxnet.sym.group() . Maybe the error caused by the below code in this PR, which tuple the graph.body. The model can work normally with graph memory planner on master or old version. Do you have any suggestion to fix this in my case? Or how to avoid this by modify my model ? :

        graph = relay.Function(graph.params,
                                relay.expr.Tuple([graph.body]+additional_outputs))
        target = 'llvm -mcpu=core-avx2'
        #target = 'cuda'
        with relay.build_config(opt_level=0):
            graph, lib, params = relay.build(graph, target)
            ctx = tvm.nd.context(target)

By the way, The target of above code is to collect all internal outputs for activation calibration. building in the below way can work normally in my case. Can this error be fixed by replacing the code with below's ?

    quantize_op = _op.get("relay.op.annotation.simulated_quantize")
    quantized_exprs = []

    print('calibrate graph')
    print(graph)
    def visit_func(expr):
        """Internal visit function"""
        if isinstance(expr, _expr.Call) and expr.op == quantize_op and expr.attrs.kind not in [QAnnotateKind.WEIGHT, QAnnotateKind.BIAS]:
            quantized_exprs.append(expr.args[0])

    _ir_pass.post_order_visit(graph, visit_func)
    if len(quantized_exprs) == 0:
        return []
    graph = _expr.Function(graph.params, _expr.Tuple(quantized_exprs))

    graph_json, lib, params = _build.build(graph, 'cuda')

@vinx13
Copy link
Member

vinx13 commented Jul 13, 2019

@XiaotaoChen You can try fattening the nested tuple here:

graph = relay.Function(graph.params, relay.expr.Tuple([graph.body]+additional_outputs))

@vinx13
Copy link
Member

vinx13 commented Jul 13, 2019

@traveller59 let's move discussion to https://discuss.tvm.ai

t1 = time.time()
top1, outputs = _evaluate(tr_data, tr_batch_fn, graph, lib, params, ctx, free_vars, early_stopping=64)
for i, output in enumerate(outputs):
config.append(_mse_chooser(output, granularity, metadata[i][-1]))
Copy link

@XiaotaoChen XiaotaoChen Jul 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metadata include the input meta, but outputs ignore the input data. The output may mismatch its' metadata. Maybe we should replace _mse_chooser(output, granularity, metadata[i][-1]) with _mse_chooser(output, granularity, metadata[i+1][-1])
Another question, Should we ignore the input data for quantization ? INPUT type is included in quantization kind.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We really only care that the length of profile_data matches the length of metadata
Type also does not specify parameters such as scale for calibration

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for your reply. Maybe i haven't fully understand the quantization pass. that caused my doubt.

additional_outputs.append(data)
metadata.append((hint, granularity, layout))
graph = relay.Function(graph.params,
relay.expr.Tuple([graph.body]+additional_outputs))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should we add the final ouptut(graph.body) to graph extra ? I think graph.body is included in additional_outputs, which contains all quantization op and quantize_kind != WEIGHT. Am i misunderstand ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used for the final result, just to profile intermediate activations during calibration.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean additional_outputs includes all internal outputs. because the profile_data collects all internal act outputs by visit_func. graph.body only contains the final output. And additional_outputs also contains the final output. Is there a redundancy about the final output ?

@eqy
Copy link
Contributor Author

eqy commented Jul 16, 2019

I started taking a look at this PR and it looks like some of the pass infra has changed it was opened, so I'm going to incrementally see what breaks as I try to backport the new changes. The first seems to annotate being removed in the new infra.
@jroesch @ZihengJiang @vinx13 @zhiics

@XiaotaoChen
Copy link

XiaotaoChen commented Jul 18, 2019

Hi @eqy @vinx13 , I try this PR on my detection model with weight/act granularity=layer. and remove the graph.body on calibration stage. in the realize pass. there is a bug in UnifyDTypeScale called by ConcatenateRealize. in this line: Expr dom_scale = ChooseDomScale(nptrs, min); the scalar of dom_scale is 0, that cause the later error. Then i check the inputs->dom_scale, which used by Concatenate op. to find the lhs_expr->dom_scale=0, and the rhs_expr->dom_scale is normal. So i go back to check all dom_scale of quantize_op in calibration stage. But all act scale >0, such as [4], [8]. i don't know why the dom scale of quantize op can be changed. hope you give some clues. Thanks.

Feedback
Hi @eqy @vinx13 , I found that upsampling is ahead of concat op. And upsampling haven't defined in annotate and realize. my model can work normally when i add identity_rewrite for upsampling in annotate and realize pass. But i don't know why my model can run on master tvm in quantize int8 mode, even though it haven't support usampling quantization.


def visit_func(expr):
if isinstance(expr, _expr.Call):
if expr.op == conv2d_op:
Copy link

@XiaotaoChen XiaotaoChen Jul 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in my test case(resnet50_v2 or mobilenet1.0), when skip_k_conv > 0, here will break down. maybe we should skip first k conv op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will remove skip_k_conv as it was superseded by a more general skip_conv_layers: see #3173

@masahi
Copy link
Member

masahi commented Nov 28, 2019

@vinx13 what is the relation of this PR and your #3538? @tqchen said the issue #2651 was closed by #3538.

@vinx13
Copy link
Member

vinx13 commented Nov 28, 2019

@masahi #3538 computes scales directly to minimize kl divergence, this PR select scales from predefine candidates by mean square error

@masahi
Copy link
Member

masahi commented Nov 28, 2019

@vinx13 thanks, good to know we have multiple methods available.

@tqchen tqchen closed this Dec 22, 2019
@tqchen
Copy link
Member

tqchen commented Dec 22, 2019

close for now due to inactive status, would be great if we can follow up @vinx13

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants