diff --git a/tftrt/examples/image-classification/image_classification.py b/tftrt/examples/image-classification/image_classification.py index a02a0f059..6fdc6a2ad 100644 --- a/tftrt/examples/image-classification/image_classification.py +++ b/tftrt/examples/image-classification/image_classification.py @@ -202,16 +202,16 @@ def get_netdef(model): 'resnet_v1_50': NetDef( name='resnet_v1_50', - url='http://download.tensorflow.org/models/official/20180601_resnet_v1_imagenet_checkpoint.tar.gz', - model_dir_in_archive='20180601_resnet_v1_imagenet_checkpoint', + url='http://download.tensorflow.org/models/official/20181001_resnet/checkpoints/resnet_imagenet_v1_fp32_20181001.tar.gz', + model_dir_in_archive='resnet_imagenet_v1_fp32_20181001', slim=False, preprocess='vgg', model_fn=official.resnet.imagenet_main.ImagenetModel(resnet_size=50, resnet_version=1)), 'resnet_v2_50': NetDef( name='resnet_v2_50', - url='http://download.tensorflow.org/models/official/20180601_resnet_v2_imagenet_checkpoint.tar.gz', - model_dir_in_archive='20180601_resnet_v2_imagenet_checkpoint', + url='http://download.tensorflow.org/models/official/20181001_resnet/checkpoints/resnet_imagenet_v2_fp32_20181001.tar.gz', + model_dir_in_archive='resnet_imagenet_v2_fp32_20181001', slim=False, preprocess='vgg', model_fn=official.resnet.imagenet_main.ImagenetModel(resnet_size=50, resnet_version=2)), @@ -398,7 +398,17 @@ def find_checkpoint_in_dir(model_dir): checkpoint_path = '.'.join(parts[:ckpt_index+1]) return checkpoint_path + def download_checkpoint(model, destination_path): + #copy files from source to destination (without any directories) + def copy_files(source, destination): + try: + shutil.copy2(source, destination) + except OSError as e: + pass + except shutil.Error as e: + pass + # Make directories if they don't exist. if not os.path.exists(destination_path): os.makedirs(destination_path) @@ -416,7 +426,7 @@ def download_checkpoint(model, destination_path): get_netdef(model).model_dir_in_archive, '*') for f in glob.glob(source_files): - shutil.copy2(f, destination_path) + copy_files(f, destination_path) def get_frozen_graph( model,