diff --git a/.gitmodules b/.gitmodules index 2688d24bc..36fd3ef55 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "third_party/models"] - path = tftrt/examples/object_detection/third_party/models + path = tftrt/examples/third_party/models url = https://github.com/tensorflow/models [submodule "third_party/cocoapi"] - path = tftrt/examples/object_detection/third_party/cocoapi + path = tftrt/examples/third_party/cocoapi url = https://github.com/cocodataset/cocoapi diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 0090b802a..88893a28e 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -431,7 +431,8 @@ def get_frozen_graph( num_calib_inputs=None, use_synthetic=False, cache=False, - default_models_dir='./data'): + default_models_dir='./data', + max_workspace_size=(2<<32)-1000): """Retreives a frozen GraphDef from model definitions in classification.py and applies TF-TRT model: str, the model name (see NETS table in classification.py) @@ -469,7 +470,7 @@ def get_frozen_graph( input_graph_def=frozen_graph, outputs=['logits', 'classes'], max_batch_size=batch_size, - max_workspace_size_bytes=(4096<<20)-1000, + max_workspace_size_bytes=max_workspace_size, precision_mode=precision, minimum_segment_size=minimum_segment_size, is_dynamic_op=use_dynamic_op @@ -547,6 +548,8 @@ def get_frozen_graph( parser.add_argument('--num_calib_inputs', type=int, default=500, help='Number of inputs (e.g. images) used for calibration ' '(last batch is skipped in case it is not full)') + parser.add_argument('--max_workspace_size', type=int, default=(2<<32)-1000, + help='workspace size in bytes') parser.add_argument('--cache', action='store_true', help='If set, graphs will be saved to disk after conversion. If a converted graph is present on disk, it will be loaded instead of building the graph again.') args = parser.parse_args() @@ -575,12 +578,17 @@ def get_frozen_graph( num_calib_inputs=args.num_calib_inputs, use_synthetic=args.use_synthetic, cache=args.cache, - default_models_dir=args.default_models_dir) + default_models_dir=args.default_models_dir, + max_workspace_size=args.max_workspace_size) def print_dict(input_dict, str=''): for k, v in sorted(input_dict.items()): headline = '{}({}): '.format(str, k) if str else '{}: '.format(k) print('{}{}'.format(headline, '%.1f'%v if type(v)==float else v)) + + serialized_graph = frozen_graph.SerializeToString() + print('frozen graph size: {}'.format(len(serialized_graph))) + print_dict(vars(args)) print_dict(num_nodes, str='num_nodes') print_dict(times, str='time(s)') diff --git a/tftrt/examples/object_detection/install_dependencies.sh b/tftrt/examples/object_detection/install_dependencies.sh index 5d7a0e61a..a1041f450 100755 --- a/tftrt/examples/object_detection/install_dependencies.sh +++ b/tftrt/examples/object_detection/install_dependencies.sh @@ -17,10 +17,10 @@ # ============================================================================= echo Setup local variables... -TF_MODELS_DIR=third_party/models +TF_MODELS_DIR=../third_party/models RESEARCH_DIR=$TF_MODELS_DIR/research SLIM_DIR=$RESEARCH_DIR/slim -COCO_API_DIR=third_party/cocoapi +COCO_API_DIR=../third_party/cocoapi PYCOCO_DIR=$COCO_API_DIR/PythonAPI PROTO_BASE_URL="https://github.com/google/protobuf/releases/download/v3.5.1/" PROTOC_DIR=$PWD/protoc