From cdb3861fe7439b0ea294a67bd521c01136dba035 Mon Sep 17 00:00:00 2001 From: Marek Drozdowski Date: Tue, 12 Feb 2019 15:18:01 -0800 Subject: [PATCH 1/2] update resnet_v1 and resnet_v2 --- .../image_classification.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index a02a0f059..6080295d6 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -202,16 +202,16 @@ def get_netdef(model): 'resnet_v1_50': NetDef( name='resnet_v1_50', - url='http://download.tensorflow.org/models/official/20180601_resnet_v1_imagenet_checkpoint.tar.gz', - model_dir_in_archive='20180601_resnet_v1_imagenet_checkpoint', + url='http://download.tensorflow.org/models/official/20181001_resnet/checkpoints/resnet_imagenet_v1_fp32_20181001.tar.gz', + model_dir_in_archive='resnet_imagenet_v1_fp32_20181001', slim=False, preprocess='vgg', model_fn=official.resnet.imagenet_main.ImagenetModel(resnet_size=50, resnet_version=1)), 'resnet_v2_50': NetDef( name='resnet_v2_50', - url='http://download.tensorflow.org/models/official/20180601_resnet_v2_imagenet_checkpoint.tar.gz', - model_dir_in_archive='20180601_resnet_v2_imagenet_checkpoint', + url='http://download.tensorflow.org/models/official/20181001_resnet/checkpoints/resnet_imagenet_v2_fp32_20181001.tar.gz', + model_dir_in_archive='resnet_imagenet_v2_fp32_20181001', slim=False, preprocess='vgg', model_fn=official.resnet.imagenet_main.ImagenetModel(resnet_size=50, resnet_version=2)), @@ -398,6 +398,14 @@ def find_checkpoint_in_dir(model_dir): checkpoint_path = '.'.join(parts[:ckpt_index+1]) return checkpoint_path +def copy(source, destination): + try: + shutil.copy2(source, destination) + except OSError as e: + pass + except shutil.Error as e: + pass + def download_checkpoint(model, destination_path): # Make directories if they don't exist. if not os.path.exists(destination_path): @@ -416,7 +424,7 @@ def download_checkpoint(model, destination_path): get_netdef(model).model_dir_in_archive, '*') for f in glob.glob(source_files): - shutil.copy2(f, destination_path) + copy(f, destination_path) def get_frozen_graph( model, From b5c5a123a30bad0b1d8d5314ed7da6074b9c2cdc Mon Sep 17 00:00:00 2001 From: Marek Drozdowski Date: Wed, 27 Feb 2019 15:56:56 -0800 Subject: [PATCH 2/2] put copy into download_model function --- .../image_classification.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index 6080295d6..6fdc6a2ad 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -398,15 +398,17 @@ def find_checkpoint_in_dir(model_dir): checkpoint_path = '.'.join(parts[:ckpt_index+1]) return checkpoint_path -def copy(source, destination): - try: - shutil.copy2(source, destination) - except OSError as e: - pass - except shutil.Error as e: - pass def download_checkpoint(model, destination_path): + #copy files from source to destination (without any directories) + def copy_files(source, destination): + try: + shutil.copy2(source, destination) + except OSError as e: + pass + except shutil.Error as e: + pass + # Make directories if they don't exist. if not os.path.exists(destination_path): os.makedirs(destination_path) @@ -424,7 +426,7 @@ def download_checkpoint(model, destination_path): get_netdef(model).model_dir_in_archive, '*') for f in glob.glob(source_files): - copy(f, destination_path) + copy_files(f, destination_path) def get_frozen_graph( model,