diff --git a/.gitmodules b/.gitmodules index 36fd3ef55..9aa63ca53 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ -[submodule "third_party/models"] +[submodule "tftrt/examples/third_party/models"] path = tftrt/examples/third_party/models - url = https://github.com/tensorflow/models -[submodule "third_party/cocoapi"] + url = https://github.com/tensorflow/models.git +[submodule "tftrt/examples/third_party/cocoapi"] path = tftrt/examples/third_party/cocoapi - url = https://github.com/cocodataset/cocoapi + url = https://github.com/cocodataset/cocoapi.git diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index d7891364f..a02a0f059 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -431,7 +431,7 @@ def get_frozen_graph( use_synthetic=False, cache=False, default_models_dir='./data', - max_workspace_size=(2<<32)-1000): + max_workspace_size=(1<<32)): """Retreives a frozen GraphDef from model definitions in classification.py and applies TF-TRT model: str, the model name (see NETS table in classification.py) @@ -442,6 +442,7 @@ def get_frozen_graph( """ num_nodes = {} times = {} + graph_sizes = {} # Load from pb file if frozen graph was already created and cached if cache: @@ -456,11 +457,13 @@ def get_frozen_graph( times['loading_frozen_graph'] = time.time() - start_time num_nodes['loaded_frozen_graph'] = len(frozen_graph.node) num_nodes['trt_only'] = len([1 for n in frozen_graph.node if str(n.op)=='TRTEngineOp']) - return frozen_graph, num_nodes, times + graph_sizes['loaded_frozen_graph'] = len(frozen_graph.SerializeToString()) + return frozen_graph, num_nodes, times, graph_sizes # Build graph and load weights frozen_graph = build_classification_graph(model, model_dir, default_models_dir) num_nodes['native_tf'] = len(frozen_graph.node) + graph_sizes['native_tf'] = len(frozen_graph.SerializeToString()) # Convert to TensorRT graph if use_trt: @@ -477,9 +480,11 @@ def get_frozen_graph( times['trt_conversion'] = time.time() - start_time num_nodes['tftrt_total'] = len(frozen_graph.node) num_nodes['trt_only'] = len([1 for n in frozen_graph.node if str(n.op)=='TRTEngineOp']) + graph_sizes['trt'] = len(frozen_graph.SerializeToString()) if precision == 'int8': calib_graph = frozen_graph + graph_sizes['calib'] = len(calib_graph.SerializeToString()) # INT8 calibration step print('Calibrating INT8...') start_time = time.time() @@ -490,6 +495,8 @@ def get_frozen_graph( start_time = time.time() frozen_graph = trt.calib_graph_to_infer_graph(calib_graph) times['trt_int8_conversion'] = time.time() - start_time + # This is already set but overwriting it here to ensure the right size + graph_sizes['trt'] = len(frozen_graph.SerializeToString()) del calib_graph print('INT8 graph created.') @@ -506,7 +513,7 @@ def get_frozen_graph( f.write(frozen_graph.SerializeToString()) times['saving_frozen_graph'] = time.time() - start_time - return frozen_graph, num_nodes, times + return frozen_graph, num_nodes, times, graph_sizes if __name__ == '__main__': parser = argparse.ArgumentParser(description='Evaluate model') @@ -547,7 +554,7 @@ def get_frozen_graph( parser.add_argument('--num_calib_inputs', type=int, default=500, help='Number of inputs (e.g. images) used for calibration ' '(last batch is skipped in case it is not full)') - parser.add_argument('--max_workspace_size', type=int, default=(2<<32)-1000, + parser.add_argument('--max_workspace_size', type=int, default=(1<<32), help='workspace size in bytes') parser.add_argument('--cache', action='store_true', help='If set, graphs will be saved to disk after conversion. If a converted graph is present on disk, it will be loaded instead of building the graph again.') @@ -577,7 +584,7 @@ def get_files(data_dir, filename_pattern): calib_files = get_files(args.calib_data_dir, 'train*') # Retreive graph using NETS table in graph.py - frozen_graph, num_nodes, times = get_frozen_graph( + frozen_graph, num_nodes, times, graph_sizes = get_frozen_graph( model=args.model, model_dir=args.model_dir, use_trt=args.use_trt, @@ -592,16 +599,15 @@ def get_files(data_dir, filename_pattern): default_models_dir=args.default_models_dir, max_workspace_size=args.max_workspace_size) - def print_dict(input_dict, str=''): + def print_dict(input_dict, str='', scale=None): for k, v in sorted(input_dict.items()): headline = '{}({}): '.format(str, k) if str else '{}: '.format(k) + v = v * scale if scale else v print('{}{}'.format(headline, '%.1f'%v if type(v)==float else v)) - serialized_graph = frozen_graph.SerializeToString() - print('frozen graph size: {}'.format(len(serialized_graph))) - print_dict(vars(args)) print_dict(num_nodes, str='num_nodes') + print_dict(graph_sizes, str='graph_size(MB)', scale=1./(1<<20)) print_dict(times, str='time(s)') # Evaluate model diff --git a/tftrt/examples/object_detection/third_party/cocoapi b/tftrt/examples/third_party/cocoapi similarity index 100% rename from tftrt/examples/object_detection/third_party/cocoapi rename to tftrt/examples/third_party/cocoapi diff --git a/tftrt/examples/object_detection/third_party/models b/tftrt/examples/third_party/models similarity index 100% rename from tftrt/examples/object_detection/third_party/models rename to tftrt/examples/third_party/models