2323from google .api_core import exceptions
2424from google .auth .credentials import AnonymousCredentials
2525from google .cloud import automl_v1beta1
26- from google .cloud .automl_v1beta1 .proto import data_types_pb2
26+ from google .cloud .automl_v1beta1 .proto import data_types_pb2 , data_items_pb2
27+ from google .protobuf import struct_pb2
2728
2829PROJECT = "project"
2930REGION = "region"
@@ -1116,9 +1117,10 @@ def test_predict_from_array(self):
11161117 model .configure_mock (tables_model_metadata = model_metadata , name = "my_model" )
11171118 client = self .tables_client ({"get_model.return_value" : model }, {})
11181119 client .predict (["1" ], model_name = "my_model" )
1119- client . prediction_client . predict . assert_called_with (
1120- "my_model" , { " row" : { " values" : [{ " string_value" : "1" }]}}, None
1120+ payload = data_items_pb2 . ExamplePayload (
1121+ row = data_items_pb2 . Row ( values = [ struct_pb2 . Value ( string_value = "1" )])
11211122 )
1123+ client .prediction_client .predict .assert_called_with ("my_model" , payload , None )
11221124
11231125 def test_predict_from_dict (self ):
11241126 data_type = mock .Mock (type_code = data_types_pb2 .CATEGORY )
@@ -1131,11 +1133,15 @@ def test_predict_from_dict(self):
11311133 model .configure_mock (tables_model_metadata = model_metadata , name = "my_model" )
11321134 client = self .tables_client ({"get_model.return_value" : model }, {})
11331135 client .predict ({"a" : "1" , "b" : "2" }, model_name = "my_model" )
1134- client .prediction_client .predict .assert_called_with (
1135- "my_model" ,
1136- {"row" : {"values" : [{"string_value" : "1" }, {"string_value" : "2" }]}},
1137- None ,
1136+ payload = data_items_pb2 .ExamplePayload (
1137+ row = data_items_pb2 .Row (
1138+ values = [
1139+ struct_pb2 .Value (string_value = "1" ),
1140+ struct_pb2 .Value (string_value = "2" ),
1141+ ]
1142+ )
11381143 )
1144+ client .prediction_client .predict .assert_called_with ("my_model" , payload , None )
11391145
11401146 def test_predict_from_dict_with_feature_importance (self ):
11411147 data_type = mock .Mock (type_code = data_types_pb2 .CATEGORY )
@@ -1150,10 +1156,16 @@ def test_predict_from_dict_with_feature_importance(self):
11501156 client .predict (
11511157 {"a" : "1" , "b" : "2" }, model_name = "my_model" , feature_importance = True
11521158 )
1159+ payload = data_items_pb2 .ExamplePayload (
1160+ row = data_items_pb2 .Row (
1161+ values = [
1162+ struct_pb2 .Value (string_value = "1" ),
1163+ struct_pb2 .Value (string_value = "2" ),
1164+ ]
1165+ )
1166+ )
11531167 client .prediction_client .predict .assert_called_with (
1154- "my_model" ,
1155- {"row" : {"values" : [{"string_value" : "1" }, {"string_value" : "2" }]}},
1156- {"feature_importance" : "true" },
1168+ "my_model" , payload , {"feature_importance" : "true" }
11571169 )
11581170
11591171 def test_predict_from_dict_missing (self ):
@@ -1167,18 +1179,32 @@ def test_predict_from_dict_missing(self):
11671179 model .configure_mock (tables_model_metadata = model_metadata , name = "my_model" )
11681180 client = self .tables_client ({"get_model.return_value" : model }, {})
11691181 client .predict ({"a" : "1" }, model_name = "my_model" )
1170- client .prediction_client .predict .assert_called_with (
1171- "my_model" ,
1172- {"row" : {"values" : [{"string_value" : "1" }, {"null_value" : 0 }]}},
1173- None ,
1182+ payload = data_items_pb2 .ExamplePayload (
1183+ row = data_items_pb2 .Row (
1184+ values = [
1185+ struct_pb2 .Value (string_value = "1" ),
1186+ struct_pb2 .Value (null_value = struct_pb2 .NullValue .NULL_VALUE ),
1187+ ]
1188+ )
11741189 )
1190+ client .prediction_client .predict .assert_called_with ("my_model" , payload , None )
11751191
11761192 def test_predict_all_types (self ):
11771193 float_type = mock .Mock (type_code = data_types_pb2 .FLOAT64 )
11781194 timestamp_type = mock .Mock (type_code = data_types_pb2 .TIMESTAMP )
11791195 string_type = mock .Mock (type_code = data_types_pb2 .STRING )
1180- array_type = mock .Mock (type_code = data_types_pb2 .ARRAY )
1181- struct_type = mock .Mock (type_code = data_types_pb2 .STRUCT )
1196+ array_type = mock .Mock (
1197+ type_code = data_types_pb2 .ARRAY ,
1198+ list_element_type = mock .Mock (type_code = data_types_pb2 .FLOAT64 ),
1199+ )
1200+ struct = data_types_pb2 .StructType ()
1201+ struct .fields ["a" ].CopyFrom (
1202+ data_types_pb2 .DataType (type_code = data_types_pb2 .CATEGORY )
1203+ )
1204+ struct .fields ["b" ].CopyFrom (
1205+ data_types_pb2 .DataType (type_code = data_types_pb2 .CATEGORY )
1206+ )
1207+ struct_type = mock .Mock (type_code = data_types_pb2 .STRUCT , struct_type = struct )
11821208 category_type = mock .Mock (type_code = data_types_pb2 .CATEGORY )
11831209 column_spec_float = mock .Mock (display_name = "float" , data_type = float_type )
11841210 column_spec_timestamp = mock .Mock (
@@ -1211,29 +1237,33 @@ def test_predict_all_types(self):
12111237 "timestamp" : "EST" ,
12121238 "string" : "text" ,
12131239 "array" : [1 ],
1214- "struct" : {"a" : "b " },
1240+ "struct" : {"a" : "label_a" , "b" : "label_b " },
12151241 "category" : "a" ,
12161242 "null" : None ,
12171243 },
12181244 model_name = "my_model" ,
12191245 )
1220- client .prediction_client .predict .assert_called_with (
1221- "my_model" ,
1222- {
1223- "row" : {
1224- "values" : [
1225- {"number_value" : 1.0 },
1226- {"string_value" : "EST" },
1227- {"string_value" : "text" },
1228- {"list_value" : [1 ]},
1229- {"struct_value" : {"a" : "b" }},
1230- {"string_value" : "a" },
1231- {"null_value" : 0 },
1232- ]
1233- }
1234- },
1235- None ,
1246+ struct = struct_pb2 .Struct ()
1247+ struct .fields ["a" ].CopyFrom (struct_pb2 .Value (string_value = "label_a" ))
1248+ struct .fields ["b" ].CopyFrom (struct_pb2 .Value (string_value = "label_b" ))
1249+ payload = data_items_pb2 .ExamplePayload (
1250+ row = data_items_pb2 .Row (
1251+ values = [
1252+ struct_pb2 .Value (number_value = 1.0 ),
1253+ struct_pb2 .Value (string_value = "EST" ),
1254+ struct_pb2 .Value (string_value = "text" ),
1255+ struct_pb2 .Value (
1256+ list_value = struct_pb2 .ListValue (
1257+ values = [struct_pb2 .Value (number_value = 1.0 )]
1258+ )
1259+ ),
1260+ struct_pb2 .Value (struct_value = struct ),
1261+ struct_pb2 .Value (string_value = "a" ),
1262+ struct_pb2 .Value (null_value = struct_pb2 .NullValue .NULL_VALUE ),
1263+ ]
1264+ )
12361265 )
1266+ client .prediction_client .predict .assert_called_with ("my_model" , payload , None )
12371267
12381268 def test_predict_from_array_missing (self ):
12391269 data_type = mock .Mock (type_code = data_types_pb2 .CATEGORY )
0 commit comments