Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[submodule "third_party/models"]
[submodule "tftrt/examples/third_party/models"]
path = tftrt/examples/third_party/models
url = https://github.com/tensorflow/models
[submodule "third_party/cocoapi"]
url = https://github.com/tensorflow/models.git
[submodule "tftrt/examples/third_party/cocoapi"]
path = tftrt/examples/third_party/cocoapi
url = https://github.com/cocodataset/cocoapi
url = https://github.com/cocodataset/cocoapi.git
24 changes: 15 additions & 9 deletions tftrt/examples/image-classification/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def get_frozen_graph(
use_synthetic=False,
cache=False,
default_models_dir='./data',
max_workspace_size=(2<<32)-1000):
max_workspace_size=(1<<32)):
"""Retreives a frozen GraphDef from model definitions in classification.py and applies TF-TRT

model: str, the model name (see NETS table in classification.py)
Expand All @@ -442,6 +442,7 @@ def get_frozen_graph(
"""
num_nodes = {}
times = {}
graph_sizes = {}

# Load from pb file if frozen graph was already created and cached
if cache:
Expand All @@ -456,11 +457,13 @@ def get_frozen_graph(
times['loading_frozen_graph'] = time.time() - start_time
num_nodes['loaded_frozen_graph'] = len(frozen_graph.node)
num_nodes['trt_only'] = len([1 for n in frozen_graph.node if str(n.op)=='TRTEngineOp'])
return frozen_graph, num_nodes, times
graph_sizes['loaded_frozen_graph'] = len(frozen_graph.SerializeToString())
return frozen_graph, num_nodes, times, graph_sizes

# Build graph and load weights
frozen_graph = build_classification_graph(model, model_dir, default_models_dir)
num_nodes['native_tf'] = len(frozen_graph.node)
graph_sizes['native_tf'] = len(frozen_graph.SerializeToString())

# Convert to TensorRT graph
if use_trt:
Expand All @@ -477,9 +480,11 @@ def get_frozen_graph(
times['trt_conversion'] = time.time() - start_time
num_nodes['tftrt_total'] = len(frozen_graph.node)
num_nodes['trt_only'] = len([1 for n in frozen_graph.node if str(n.op)=='TRTEngineOp'])
graph_sizes['trt'] = len(frozen_graph.SerializeToString())

if precision == 'int8':
calib_graph = frozen_graph
graph_sizes['calib'] = len(calib_graph.SerializeToString())
# INT8 calibration step
print('Calibrating INT8...')
start_time = time.time()
Expand All @@ -490,6 +495,8 @@ def get_frozen_graph(
start_time = time.time()
frozen_graph = trt.calib_graph_to_infer_graph(calib_graph)
times['trt_int8_conversion'] = time.time() - start_time
# This is already set but overwriting it here to ensure the right size
graph_sizes['trt'] = len(frozen_graph.SerializeToString())

del calib_graph
print('INT8 graph created.')
Expand All @@ -506,7 +513,7 @@ def get_frozen_graph(
f.write(frozen_graph.SerializeToString())
times['saving_frozen_graph'] = time.time() - start_time

return frozen_graph, num_nodes, times
return frozen_graph, num_nodes, times, graph_sizes

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluate model')
Expand Down Expand Up @@ -547,7 +554,7 @@ def get_frozen_graph(
parser.add_argument('--num_calib_inputs', type=int, default=500,
help='Number of inputs (e.g. images) used for calibration '
'(last batch is skipped in case it is not full)')
parser.add_argument('--max_workspace_size', type=int, default=(2<<32)-1000,
parser.add_argument('--max_workspace_size', type=int, default=(1<<32),
help='workspace size in bytes')
parser.add_argument('--cache', action='store_true',
help='If set, graphs will be saved to disk after conversion. If a converted graph is present on disk, it will be loaded instead of building the graph again.')
Expand Down Expand Up @@ -577,7 +584,7 @@ def get_files(data_dir, filename_pattern):
calib_files = get_files(args.calib_data_dir, 'train*')

# Retreive graph using NETS table in graph.py
frozen_graph, num_nodes, times = get_frozen_graph(
frozen_graph, num_nodes, times, graph_sizes = get_frozen_graph(
model=args.model,
model_dir=args.model_dir,
use_trt=args.use_trt,
Expand All @@ -592,16 +599,15 @@ def get_files(data_dir, filename_pattern):
default_models_dir=args.default_models_dir,
max_workspace_size=args.max_workspace_size)

def print_dict(input_dict, str=''):
def print_dict(input_dict, str='', scale=None):
for k, v in sorted(input_dict.items()):
headline = '{}({}): '.format(str, k) if str else '{}: '.format(k)
v = v * scale if scale else v
print('{}{}'.format(headline, '%.1f'%v if type(v)==float else v))

serialized_graph = frozen_graph.SerializeToString()
print('frozen graph size: {}'.format(len(serialized_graph)))

print_dict(vars(args))
print_dict(num_nodes, str='num_nodes')
print_dict(graph_sizes, str='graph_size(MB)', scale=1./(1<<20))
print_dict(times, str='time(s)')

# Evaluate model
Expand Down