-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed
Description
I'm having difficulty getting floating point calibration working for quantization. The test test_pass_auto_quantize.py works fine (conv2d quantization), but my own test for dense layer quantization calibration is not working. Inside _calibrate.py I print graph after the line graph, lib, params = _build_module.build(func, target=target) then I just see the following output regardless of the input.
"nodes": [],
"arg_nodes": [],
"heads": [],
"attrs": {
"dltype": [
"list_str",
[]
],
"shape": [
"list_shape",
[]
],
"storage_id": [
"list_int",
[]
]
},
"node_row_ptr": [0]
}
func looks like it still contains the correct main function for my network, but I see a second function added to the module that contains nothing. This looks like the following
v0.0.4
fn () {
()
}
v0.0.4
def @main(...) {
...
... # my main function
...
}
This happens after the line func = _quantize.CreateStatsCollector(func).
def collect_stats(mod, dataset):
"""Given an annotated graph, create a profile graph to collect profile data from the
calibration dataset. This pass collects simulated_quantize op input into a tuple.
Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile
graph.
Parameters
----------
mod: Module
The simulation graph after annotation.
Returns
-------
ret: list of ndarray
List of output data of each layer
"""
logging.info("collecting statistics for calibration...")
func = mod['main']
func = _quantize.CreateStatsCollector(func)
if tvm.target.current_target():
target = tvm.target.current_target()
ctx = tvm.context(target.target_name)
else:
target = 'llvm'
ctx = tvm.context(target)
with _transform.build_config(opt_level=3):
graph, lib, params = _build_module.build(func, target=target)
outputs = []
runtime = graph_runtime.create(graph, lib, ctx)
runtime.set_input(**params)
num_outputs = runtime.get_num_outputs()
outputs = [[] for i in range(num_outputs)]
for batch in dataset:
runtime.set_input(**batch)
runtime.run()
for i in range(num_outputs):
output = runtime.get_output(i).asnumpy()
outputs[i].append(output)
for i in range(num_outputs):
outputs[i] = np.concatenate(outputs[i]).reshape(-1)
return outputsMetadata
Metadata
Assignees
Labels
No labels