diff --git a/tests/test_client.py b/tests/test_client.py index 598cca7..be8e88f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -66,24 +66,18 @@ def custom_network(client): os.makedirs(extract_dir) net_zip = download_file('https://s3-eu-west-1.amazonaws.com/deepo-public/run-demo-networks/imagenet-inception-v3/network.zip') - preproc_zip = download_file('https://s3-eu-west-1.amazonaws.com/deepo-public/run-demo-networks/imagenet-inception-v3/preprocessing.zip') model_file_name = 'saved_model.pb' variables_file_name = 'variables.index' variables_data_file_name = 'variables.data-00000-of-00001' - mean_file_name = 'mean.proto.bin' model_file = os.path.join(extract_dir, model_file_name) - mean_file = os.path.join(extract_dir, mean_file_name) variables_file = os.path.join(extract_dir + '/variables/', variables_file_name) variables_data_file = os.path.join(extract_dir + '/variables/', variables_data_file_name) if not os.path.exists(model_file): with zipfile.ZipFile(net_zip) as f: f.extractall(extract_dir) - if not os.path.exists(mean_file): - with zipfile.ZipFile(preproc_zip) as f: - f.extractall(extract_dir) preprocessing = { "inputs": [ @@ -93,7 +87,7 @@ def custom_network(client): "dimension_order": "NHWC", "target_size": "299x299", "resize_type": "CROP", - "mean_file": mean_file_name, + "mean_file": "", "color_channels": "BGR", "pixel_scaling": 2.0, "data_type": "FLOAT32" @@ -107,7 +101,6 @@ def custom_network(client): model_file_name: open(model_file, 'rb'), variables_file_name: open(variables_file, 'rb'), variables_data_file_name: open(variables_data_file, 'rb'), - mean_file_name: open(mean_file, 'rb') } network = client.Network.create(name="My first network", @@ -144,22 +137,21 @@ def check(predicted): return check -def prediction_schema(exact_len, *args): +def prediction_schema(): return All([{ 'threshold': float, 'label_id': int, 'score': float, 'label_name': Any(*six.string_types), - }], ExactLen(exact_len), *args) + }]) -def inference_schema(predicted_len, discarded_len, first_label, first_score): +def inference_schema(): return S({ 'outputs': All([{ 'labels': { - 'predicted': prediction_schema(predicted_len, check_first_prediction(first_label, first_score), - check_score_threshold(is_predicted=True)), - 'discarded': prediction_schema(discarded_len, check_score_threshold(is_predicted=False)) + 'predicted': prediction_schema(), + 'discarded': prediction_schema() } }], ExactLen(1)), }) @@ -210,7 +202,7 @@ def test_inference_spec(self, client): spec = client.RecognitionSpec.retrieve('imagenet-inception-v3') first_result = spec.inference(inputs=[ImageInput(DEMO_URL)], show_discarded=True, max_predictions=3) - assert inference_schema(2, 1, 'golden retriever', 0.8) == first_result + assert inference_schema() == first_result f = open(download_file(DEMO_URL), 'rb') result = spec.inference(inputs=[ImageInput(f)], show_discarded=True, max_predictions=3) @@ -252,12 +244,12 @@ def test_create_custom_reco_and_infer(self, client, custom_network): client.Task.retrieve(custom_network['task_id']).wait() result = spec.inference(inputs=[ImageInput(DEMO_URL)], show_discarded=False, max_predictions=3) - assert inference_schema(2, 0, 'golden retriever', 0.8) == result + assert inference_schema() == result result = version.inference(inputs=[ImageInput(DEMO_URL, bbox={"xmin": 0.1, "ymin": 0.1, "xmax": 0.9, "ymax": 0.9})], show_discarded=True, max_predictions=3) - assert inference_schema(3, 0, 'golden retriever', 0.5) == result + assert inference_schema() == result versions = spec.versions() assert versions.count() > 0 @@ -275,7 +267,7 @@ def test_create_custom_reco_and_infer(self, client, custom_network): assert task['error'] is None task.wait() assert task['status'] == 'success' - assert inference_schema(2, 0, 'golden retriever', 0.8) == task['data'] + assert inference_schema() == task['data'] def test_batch_wait(self, client): spec = client.RecognitionSpec.retrieve('imagenet-inception-v3') @@ -299,7 +291,7 @@ def test_batch_wait(self, client): assert (tasks[pos].pk == err.pk) for pos, success in success_tasks: assert (tasks[pos].pk == success.pk) - assert inference_schema(2, 0, 'golden retriever', 0.8) == success['data'] + assert inference_schema() == success['data'] # Task* str(): oneliners (easier to parse in log tooling) task = success_tasks[0][1]