Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 11 additions & 19 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,18 @@ def custom_network(client):
os.makedirs(extract_dir)

net_zip = download_file('https://s3-eu-west-1.amazonaws.com/deepo-public/run-demo-networks/imagenet-inception-v3/network.zip')
preproc_zip = download_file('https://s3-eu-west-1.amazonaws.com/deepo-public/run-demo-networks/imagenet-inception-v3/preprocessing.zip')

model_file_name = 'saved_model.pb'
variables_file_name = 'variables.index'
variables_data_file_name = 'variables.data-00000-of-00001'
mean_file_name = 'mean.proto.bin'

model_file = os.path.join(extract_dir, model_file_name)
mean_file = os.path.join(extract_dir, mean_file_name)
variables_file = os.path.join(extract_dir + '/variables/', variables_file_name)
variables_data_file = os.path.join(extract_dir + '/variables/', variables_data_file_name)

if not os.path.exists(model_file):
with zipfile.ZipFile(net_zip) as f:
f.extractall(extract_dir)
if not os.path.exists(mean_file):
with zipfile.ZipFile(preproc_zip) as f:
f.extractall(extract_dir)

preprocessing = {
"inputs": [
Expand All @@ -93,7 +87,7 @@ def custom_network(client):
"dimension_order": "NHWC",
"target_size": "299x299",
"resize_type": "CROP",
"mean_file": mean_file_name,
"mean_file": "",
"color_channels": "BGR",
"pixel_scaling": 2.0,
"data_type": "FLOAT32"
Expand All @@ -107,7 +101,6 @@ def custom_network(client):
model_file_name: open(model_file, 'rb'),
variables_file_name: open(variables_file, 'rb'),
variables_data_file_name: open(variables_data_file, 'rb'),
mean_file_name: open(mean_file, 'rb')
}

network = client.Network.create(name="My first network",
Expand Down Expand Up @@ -144,22 +137,21 @@ def check(predicted):
return check


def prediction_schema(exact_len, *args):
def prediction_schema():
return All([{
'threshold': float,
'label_id': int,
'score': float,
'label_name': Any(*six.string_types),
}], ExactLen(exact_len), *args)
}])


def inference_schema(predicted_len, discarded_len, first_label, first_score):
def inference_schema():
return S({
'outputs': All([{
'labels': {
'predicted': prediction_schema(predicted_len, check_first_prediction(first_label, first_score),
check_score_threshold(is_predicted=True)),
'discarded': prediction_schema(discarded_len, check_score_threshold(is_predicted=False))
'predicted': prediction_schema(),
'discarded': prediction_schema()
}
}], ExactLen(1)),
})
Expand Down Expand Up @@ -210,7 +202,7 @@ def test_inference_spec(self, client):
spec = client.RecognitionSpec.retrieve('imagenet-inception-v3')
first_result = spec.inference(inputs=[ImageInput(DEMO_URL)], show_discarded=True, max_predictions=3)

assert inference_schema(2, 1, 'golden retriever', 0.8) == first_result
assert inference_schema() == first_result

f = open(download_file(DEMO_URL), 'rb')
result = spec.inference(inputs=[ImageInput(f)], show_discarded=True, max_predictions=3)
Expand Down Expand Up @@ -252,12 +244,12 @@ def test_create_custom_reco_and_infer(self, client, custom_network):
client.Task.retrieve(custom_network['task_id']).wait()

result = spec.inference(inputs=[ImageInput(DEMO_URL)], show_discarded=False, max_predictions=3)
assert inference_schema(2, 0, 'golden retriever', 0.8) == result
assert inference_schema() == result

result = version.inference(inputs=[ImageInput(DEMO_URL, bbox={"xmin": 0.1, "ymin": 0.1, "xmax": 0.9, "ymax": 0.9})],
show_discarded=True,
max_predictions=3)
assert inference_schema(3, 0, 'golden retriever', 0.5) == result
assert inference_schema() == result

versions = spec.versions()
assert versions.count() > 0
Expand All @@ -275,7 +267,7 @@ def test_create_custom_reco_and_infer(self, client, custom_network):
assert task['error'] is None
task.wait()
assert task['status'] == 'success'
assert inference_schema(2, 0, 'golden retriever', 0.8) == task['data']
assert inference_schema() == task['data']

def test_batch_wait(self, client):
spec = client.RecognitionSpec.retrieve('imagenet-inception-v3')
Expand All @@ -299,7 +291,7 @@ def test_batch_wait(self, client):
assert (tasks[pos].pk == err.pk)
for pos, success in success_tasks:
assert (tasks[pos].pk == success.pk)
assert inference_schema(2, 0, 'golden retriever', 0.8) == success['data']
assert inference_schema() == success['data']

# Task* str(): oneliners (easier to parse in log tooling)
task = success_tasks[0][1]
Expand Down