@@ -187,7 +187,39 @@ def _verify_entity(self, entity, name, entity_type, wiki_url, salience):
187187 self .assertEqual (entity .salience , salience )
188188 self .assertEqual (entity .mentions , [name ])
189189
190+ @staticmethod
191+ def _expected_data (content , encoding_type = None ,
192+ extract_sentiment = False ,
193+ extract_entities = False ,
194+ extract_syntax = False ):
195+ from google .cloud .language .document import DEFAULT_LANGUAGE
196+ from google .cloud .language .document import Document
197+
198+ expected = {
199+ 'document' : {
200+ 'language' : DEFAULT_LANGUAGE ,
201+ 'type' : Document .PLAIN_TEXT ,
202+ 'content' : content ,
203+ },
204+ }
205+ if encoding_type is not None :
206+ expected ['encodingType' ] = encoding_type
207+ if extract_sentiment :
208+ features = expected .setdefault ('features' , {})
209+ features ['extractDocumentSentiment' ] = True
210+ if extract_entities :
211+ features = expected .setdefault ('features' , {})
212+ features ['extractEntities' ] = True
213+ if extract_syntax :
214+ features = expected .setdefault ('features' , {})
215+ features ['extractSyntax' ] = True
216+ return expected
217+
190218 def test_analyze_entities (self ):
219+ import mock
220+ from google .cloud .language .connection import Connection
221+ from google .cloud .language .client import Client
222+ from google .cloud .language .document import Encoding
191223 from google .cloud .language .entity import EntityType
192224
193225 name1 = 'R-O-C-K'
@@ -229,8 +261,9 @@ def test_analyze_entities(self):
229261 ],
230262 'language' : 'en-US' ,
231263 }
232- connection = _Connection (response )
233- client = _Client (connection = connection )
264+ connection = mock .Mock (spec = Connection )
265+ connection .api_request .return_value = response
266+ client = mock .Mock (connection = connection , spec = Client )
234267 document = self ._make_one (client , content )
235268
236269 entities = document .analyze_entities ()
@@ -243,10 +276,10 @@ def test_analyze_entities(self):
243276 wiki2 , salience2 )
244277
245278 # Verify the request.
246- self .assertEqual ( len ( connection . _requested ), 1 )
247- req = connection . _requested [ 0 ]
248- self . assertEqual ( req [ 'path' ], 'analyzeEntities' )
249- self . assertEqual ( req [ 'method' ], 'POST' )
279+ expected = self ._expected_data (
280+ content , encoding_type = Encoding . UTF8 )
281+ connection . api_request . assert_called_once_with (
282+ path = 'analyzeEntities' , method = 'POST' , data = expected )
250283
251284 def _verify_sentiment (self , sentiment , polarity , magnitude ):
252285 from google .cloud .language .sentiment import Sentiment
@@ -256,6 +289,10 @@ def _verify_sentiment(self, sentiment, polarity, magnitude):
256289 self .assertEqual (sentiment .magnitude , magnitude )
257290
258291 def test_analyze_sentiment (self ):
292+ import mock
293+ from google .cloud .language .connection import Connection
294+ from google .cloud .language .client import Client
295+
259296 content = 'All the pretty horses.'
260297 polarity = 1
261298 magnitude = 0.6
@@ -266,18 +303,18 @@ def test_analyze_sentiment(self):
266303 },
267304 'language' : 'en-US' ,
268305 }
269- connection = _Connection (response )
270- client = _Client (connection = connection )
306+ connection = mock .Mock (spec = Connection )
307+ connection .api_request .return_value = response
308+ client = mock .Mock (connection = connection , spec = Client )
271309 document = self ._make_one (client , content )
272310
273311 sentiment = document .analyze_sentiment ()
274312 self ._verify_sentiment (sentiment , polarity , magnitude )
275313
276314 # Verify the request.
277- self .assertEqual (len (connection ._requested ), 1 )
278- req = connection ._requested [0 ]
279- self .assertEqual (req ['path' ], 'analyzeSentiment' )
280- self .assertEqual (req ['method' ], 'POST' )
315+ expected = self ._expected_data (content )
316+ connection .api_request .assert_called_once_with (
317+ path = 'analyzeSentiment' , method = 'POST' , data = expected )
281318
282319 def _verify_sentences (self , include_syntax , annotations ):
283320 from google .cloud .language .syntax import Sentence
@@ -306,7 +343,12 @@ def _verify_tokens(self, annotations, token_info):
306343
307344 def _annotate_text_helper (self , include_sentiment ,
308345 include_entities , include_syntax ):
346+ import mock
347+
348+ from google .cloud .language .connection import Connection
349+ from google .cloud .language .client import Client
309350 from google .cloud .language .document import Annotations
351+ from google .cloud .language .document import Encoding
310352 from google .cloud .language .entity import EntityType
311353
312354 token_info , sentences = _get_token_and_sentences (include_syntax )
@@ -324,8 +366,9 @@ def _annotate_text_helper(self, include_sentiment,
324366 'magnitude' : ANNOTATE_MAGNITUDE ,
325367 }
326368
327- connection = _Connection (response )
328- client = _Client (connection = connection )
369+ connection = mock .Mock (spec = Connection )
370+ connection .api_request .return_value = response
371+ client = mock .Mock (connection = connection , spec = Client )
329372 document = self ._make_one (client , ANNOTATE_CONTENT )
330373
331374 annotations = document .annotate_text (
@@ -352,16 +395,13 @@ def _annotate_text_helper(self, include_sentiment,
352395 self .assertEqual (annotations .entities , [])
353396
354397 # Verify the request.
355- self .assertEqual (len (connection ._requested ), 1 )
356- req = connection ._requested [0 ]
357- self .assertEqual (req ['path' ], 'annotateText' )
358- self .assertEqual (req ['method' ], 'POST' )
359- features = req ['data' ]['features' ]
360- self .assertEqual (features .get ('extractDocumentSentiment' , False ),
361- include_sentiment )
362- self .assertEqual (features .get ('extractEntities' , False ),
363- include_entities )
364- self .assertEqual (features .get ('extractSyntax' , False ), include_syntax )
398+ expected = self ._expected_data (
399+ ANNOTATE_CONTENT , encoding_type = Encoding .UTF8 ,
400+ extract_sentiment = include_sentiment ,
401+ extract_entities = include_entities ,
402+ extract_syntax = include_syntax )
403+ connection .api_request .assert_called_once_with (
404+ path = 'annotateText' , method = 'POST' , data = expected )
365405
366406 def test_annotate_text (self ):
367407 self ._annotate_text_helper (True , True , True )
@@ -374,20 +414,3 @@ def test_annotate_text_entities_only(self):
374414
375415 def test_annotate_text_syntax_only (self ):
376416 self ._annotate_text_helper (False , False , True )
377-
378-
379- class _Connection (object ):
380-
381- def __init__ (self , response ):
382- self ._response = response
383- self ._requested = []
384-
385- def api_request (self , ** kwargs ):
386- self ._requested .append (kwargs )
387- return self ._response
388-
389-
390- class _Client (object ):
391-
392- def __init__ (self , connection = None ):
393- self .connection = connection
0 commit comments