Skip to content

Commit f8ceb94

Browse files
authored
Vision support for batch processing part one. (#2967)
* Vision support for batch processing part one.
1 parent 50c8e88 commit f8ceb94

File tree

7 files changed

+88
-37
lines changed

7 files changed

+88
-37
lines changed

vision/google/cloud/vision/_gax.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,18 @@ def annotate(self, image, features):
3939
:type features: list
4040
:param features: List of :class:`~google.cloud.vision.feature.Feature`.
4141
42-
:rtype: :class:`~google.cloud.vision.annotations.Annotations`
43-
:returns: Instance of ``Annotations`` with results or ``None``.
42+
:rtype: list
43+
:returns: List of
44+
:class:`~google.cloud.vision.annotations.Annotations`.
4445
"""
4546
gapic_features = [_to_gapic_feature(feature) for feature in features]
4647
gapic_image = _to_gapic_image(image)
4748
request = image_annotator_pb2.AnnotateImageRequest(
4849
image=gapic_image, features=gapic_features)
4950
requests = [request]
5051
annotator_client = self._annotator_client
51-
images = annotator_client.batch_annotate_images(requests)
52-
if len(images.responses) == 1:
53-
return Annotations.from_pb(images.responses[0])
54-
elif len(images.responses) > 1:
55-
raise NotImplementedError(
56-
'Multiple image processing is not yet supported.')
52+
responses = annotator_client.batch_annotate_images(requests).responses
53+
return [Annotations.from_pb(response) for response in responses]
5754

5855

5956
def _to_gapic_feature(feature):

vision/google/cloud/vision/_http.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,16 @@ def annotate(self, image, features):
4141
based on the number of Feature Types.
4242
4343
See: https://cloud.google.com/vision/docs/pricing
44-
:rtype: dict
45-
:returns: List of annotations.
44+
:rtype: list
45+
:returns: List of :class:`~googe.cloud.vision.annotations.Annotations`.
4646
"""
4747
request = _make_request(image, features)
4848

4949
data = {'requests': [request]}
5050
api_response = self._connection.api_request(
5151
method='POST', path='/images:annotate', data=data)
52-
images = api_response.get('responses')
53-
if len(images) == 1:
54-
return Annotations.from_api_repr(images[0])
55-
elif len(images) > 1:
56-
raise NotImplementedError(
57-
'Multiple image processing is not yet supported.')
52+
responses = api_response.get('responses')
53+
return [Annotations.from_api_repr(response) for response in responses]
5854

5955

6056
def _make_request(image, features):

vision/google/cloud/vision/image.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def detect_faces(self, limit=10):
134134
"""
135135
features = [Feature(FeatureTypes.FACE_DETECTION, limit)]
136136
annotations = self._detect_annotation(features)
137-
return annotations.faces
137+
return annotations[0].faces
138138

139139
def detect_labels(self, limit=10):
140140
"""Detect labels that describe objects in an image.
@@ -147,7 +147,7 @@ def detect_labels(self, limit=10):
147147
"""
148148
features = [Feature(FeatureTypes.LABEL_DETECTION, limit)]
149149
annotations = self._detect_annotation(features)
150-
return annotations.labels
150+
return annotations[0].labels
151151

152152
def detect_landmarks(self, limit=10):
153153
"""Detect landmarks in an image.
@@ -161,7 +161,7 @@ def detect_landmarks(self, limit=10):
161161
"""
162162
features = [Feature(FeatureTypes.LANDMARK_DETECTION, limit)]
163163
annotations = self._detect_annotation(features)
164-
return annotations.landmarks
164+
return annotations[0].landmarks
165165

166166
def detect_logos(self, limit=10):
167167
"""Detect logos in an image.
@@ -175,7 +175,7 @@ def detect_logos(self, limit=10):
175175
"""
176176
features = [Feature(FeatureTypes.LOGO_DETECTION, limit)]
177177
annotations = self._detect_annotation(features)
178-
return annotations.logos
178+
return annotations[0].logos
179179

180180
def detect_properties(self, limit=10):
181181
"""Detect the color properties of an image.
@@ -189,7 +189,7 @@ def detect_properties(self, limit=10):
189189
"""
190190
features = [Feature(FeatureTypes.IMAGE_PROPERTIES, limit)]
191191
annotations = self._detect_annotation(features)
192-
return annotations.properties
192+
return annotations[0].properties
193193

194194
def detect_safe_search(self, limit=10):
195195
"""Retreive safe search properties from an image.
@@ -203,7 +203,7 @@ def detect_safe_search(self, limit=10):
203203
"""
204204
features = [Feature(FeatureTypes.SAFE_SEARCH_DETECTION, limit)]
205205
annotations = self._detect_annotation(features)
206-
return annotations.safe_searches
206+
return annotations[0].safe_searches
207207

208208
def detect_text(self, limit=10):
209209
"""Detect text in an image.
@@ -217,4 +217,4 @@ def detect_text(self, limit=10):
217217
"""
218218
features = [Feature(FeatureTypes.TEXT_DETECTION, limit)]
219219
annotations = self._detect_annotation(features)
220-
return annotations.texts
220+
return annotations[0].texts

vision/unit_tests/_fixtures.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,6 +1688,39 @@
16881688
}
16891689

16901690

1691+
MULTIPLE_RESPONSE = {
1692+
'responses': [
1693+
{
1694+
'labelAnnotations': [
1695+
{
1696+
'mid': '/m/0k4j',
1697+
'description': 'automobile',
1698+
'score': 0.9776855
1699+
},
1700+
{
1701+
'mid': '/m/07yv9',
1702+
'description': 'vehicle',
1703+
'score': 0.947987
1704+
},
1705+
{
1706+
'mid': '/m/07r04',
1707+
'description': 'truck',
1708+
'score': 0.88429511
1709+
},
1710+
],
1711+
},
1712+
{
1713+
'safeSearchAnnotation': {
1714+
'adult': 'VERY_UNLIKELY',
1715+
'spoof': 'UNLIKELY',
1716+
'medical': 'POSSIBLE',
1717+
'violence': 'VERY_UNLIKELY'
1718+
},
1719+
},
1720+
],
1721+
}
1722+
1723+
16911724
SAFE_SEARCH_DETECTION_RESPONSE = {
16921725
'responses': [
16931726
{

vision/unit_tests/test__gax.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,15 @@ def test_annotate_no_results(self):
7878
gax_api._annotator_client = mock.Mock(
7979
spec_set=['batch_annotate_images'], **mock_response)
8080
with mock.patch('google.cloud.vision._gax.Annotations'):
81-
self.assertIsNone(gax_api.annotate(image, [feature]))
81+
response = gax_api.annotate(image, [feature])
82+
self.assertEqual(len(response), 0)
83+
self.assertIsInstance(response, list)
8284

8385
gax_api._annotator_client.batch_annotate_images.assert_called()
8486

8587
def test_annotate_multiple_results(self):
88+
from google.cloud.grpc.vision.v1 import image_annotator_pb2
89+
from google.cloud.vision.annotations import Annotations
8690
from google.cloud.vision.feature import Feature
8791
from google.cloud.vision.feature import FeatureTypes
8892
from google.cloud.vision.image import Image
@@ -95,16 +99,21 @@ def test_annotate_multiple_results(self):
9599
'ImageAnnotatorClient'):
96100
gax_api = self._make_one(client)
97101

98-
mock_response = {
99-
'batch_annotate_images.return_value': mock.Mock(responses=[1, 2]),
100-
}
102+
responses = [
103+
image_annotator_pb2.AnnotateImageResponse(),
104+
image_annotator_pb2.AnnotateImageResponse(),
105+
]
106+
response = image_annotator_pb2.BatchAnnotateImagesResponse(
107+
responses=responses)
101108

102109
gax_api._annotator_client = mock.Mock(
103-
spec_set=['batch_annotate_images'], **mock_response)
104-
with mock.patch('google.cloud.vision._gax.Annotations'):
105-
with self.assertRaises(NotImplementedError):
106-
gax_api.annotate(image, [feature])
110+
spec_set=['batch_annotate_images'])
111+
gax_api._annotator_client.batch_annotate_images.return_value = response
112+
responses = gax_api.annotate(image, [feature])
107113

114+
self.assertEqual(len(responses), 2)
115+
self.assertIsInstance(responses[0], Annotations)
116+
self.assertIsInstance(responses[1], Annotations)
108117
gax_api._annotator_client.batch_annotate_images.assert_called()
109118

110119

vision/unit_tests/test__http.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,16 @@ def test_call_annotate_with_no_results(self):
4444
http_api = self._make_one(client)
4545
http_api._connection = mock.Mock(spec_set=['api_request'])
4646
http_api._connection.api_request.return_value = {'responses': []}
47-
self.assertIsNone(http_api.annotate(image, [feature]))
47+
response = http_api.annotate(image, [feature])
48+
self.assertEqual(len(response), 0)
49+
self.assertIsInstance(response, list)
4850

4951
def test_call_annotate_with_more_than_one_result(self):
5052
from google.cloud.vision.feature import Feature
5153
from google.cloud.vision.feature import FeatureTypes
5254
from google.cloud.vision.image import Image
55+
from google.cloud.vision.likelihood import Likelihood
56+
from unit_tests._fixtures import MULTIPLE_RESPONSE
5357

5458
client = mock.Mock(spec_set=['_connection'])
5559
feature = Feature(FeatureTypes.LABEL_DETECTION, 5)
@@ -58,9 +62,17 @@ def test_call_annotate_with_more_than_one_result(self):
5862

5963
http_api = self._make_one(client)
6064
http_api._connection = mock.Mock(spec_set=['api_request'])
61-
http_api._connection.api_request.return_value = {'responses': [1, 2]}
62-
with self.assertRaises(NotImplementedError):
63-
http_api.annotate(image, [feature])
65+
http_api._connection.api_request.return_value = MULTIPLE_RESPONSE
66+
responses = http_api.annotate(image, [feature])
67+
68+
self.assertEqual(len(responses), 2)
69+
image_one = responses[0]
70+
image_two = responses[1]
71+
self.assertEqual(len(image_one.labels), 3)
72+
self.assertIsInstance(image_one.safe_searches, tuple)
73+
self.assertEqual(image_two.safe_searches.adult,
74+
Likelihood.VERY_UNLIKELY)
75+
self.assertEqual(len(image_two.labels), 0)
6476

6577

6678
class TestVisionRequest(unittest.TestCase):

vision/unit_tests/test_client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,10 @@ def test_face_annotation(self):
104104
features = [Feature(feature_type=FeatureTypes.FACE_DETECTION,
105105
max_results=3)]
106106
image = client.image(content=IMAGE_CONTENT)
107-
response = client._vision_api.annotate(image, features)
107+
api_response = client._vision_api.annotate(image, features)
108108

109+
self.assertEqual(len(api_response), 1)
110+
response = api_response[0]
109111
self.assertEqual(REQUEST,
110112
client._connection._requested[0]['data'])
111113
self.assertIsInstance(response, Annotations)
@@ -166,8 +168,10 @@ def test_multiple_detection_from_content(self):
166168
logo_feature = Feature(FeatureTypes.LOGO_DETECTION, limit)
167169
features = [label_feature, logo_feature]
168170
image = client.image(content=IMAGE_CONTENT)
169-
items = image.detect(features)
171+
detected_items = image.detect(features)
170172

173+
self.assertEqual(len(detected_items), 1)
174+
items = detected_items[0]
171175
self.assertEqual(len(items.logos), 2)
172176
self.assertEqual(len(items.labels), 3)
173177
first_logo = items.logos[0]

0 commit comments

Comments
 (0)