@@ -1840,6 +1840,25 @@ def test_to_arrow(self):
18401840 self .assertIsInstance (tbl , pyarrow .Table )
18411841 self .assertEqual (tbl .num_rows , 0 )
18421842
1843+ @mock .patch ("google.cloud.bigquery.table.pyarrow" , new = None )
1844+ def test_to_arrow_iterable_error_if_pyarrow_is_none (self ):
1845+ row_iterator = self ._make_one ()
1846+ with self .assertRaises (ValueError ):
1847+ row_iterator .to_arrow_iterable ()
1848+
1849+ @unittest .skipIf (pyarrow is None , "Requires `pyarrow`" )
1850+ def test_to_arrow_iterable (self ):
1851+ row_iterator = self ._make_one ()
1852+ arrow_iter = row_iterator .to_arrow_iterable ()
1853+
1854+ result = list (arrow_iter )
1855+
1856+ self .assertEqual (len (result ), 1 )
1857+ record_batch = result [0 ]
1858+ self .assertIsInstance (record_batch , pyarrow .RecordBatch )
1859+ self .assertEqual (record_batch .num_rows , 0 )
1860+ self .assertEqual (record_batch .num_columns , 0 )
1861+
18431862 @mock .patch ("google.cloud.bigquery.table.pandas" , new = None )
18441863 def test_to_dataframe_error_if_pandas_is_none (self ):
18451864 row_iterator = self ._make_one ()
@@ -2151,6 +2170,206 @@ def test__validate_bqstorage_returns_false_w_warning_if_obsolete_version(self):
21512170 ]
21522171 assert matching_warnings , "Obsolete dependency warning not raised."
21532172
2173+ @unittest .skipIf (pyarrow is None , "Requires `pyarrow`" )
2174+ def test_to_arrow_iterable (self ):
2175+ from google .cloud .bigquery .schema import SchemaField
2176+
2177+ schema = [
2178+ SchemaField ("name" , "STRING" , mode = "REQUIRED" ),
2179+ SchemaField ("age" , "INTEGER" , mode = "REQUIRED" ),
2180+ SchemaField (
2181+ "child" ,
2182+ "RECORD" ,
2183+ mode = "REPEATED" ,
2184+ fields = [
2185+ SchemaField ("name" , "STRING" , mode = "REQUIRED" ),
2186+ SchemaField ("age" , "INTEGER" , mode = "REQUIRED" ),
2187+ ],
2188+ ),
2189+ ]
2190+ rows = [
2191+ {
2192+ "f" : [
2193+ {"v" : "Bharney Rhubble" },
2194+ {"v" : "33" },
2195+ {
2196+ "v" : [
2197+ {"v" : {"f" : [{"v" : "Whamm-Whamm Rhubble" }, {"v" : "3" }]}},
2198+ {"v" : {"f" : [{"v" : "Hoppy" }, {"v" : "1" }]}},
2199+ ]
2200+ },
2201+ ]
2202+ },
2203+ {
2204+ "f" : [
2205+ {"v" : "Wylma Phlyntstone" },
2206+ {"v" : "29" },
2207+ {
2208+ "v" : [
2209+ {"v" : {"f" : [{"v" : "Bepples Phlyntstone" }, {"v" : "0" }]}},
2210+ {"v" : {"f" : [{"v" : "Dino" }, {"v" : "4" }]}},
2211+ ]
2212+ },
2213+ ]
2214+ },
2215+ ]
2216+ path = "/foo"
2217+ api_request = mock .Mock (
2218+ side_effect = [
2219+ {"rows" : [rows [0 ]], "pageToken" : "NEXTPAGE" ,},
2220+ {"rows" : [rows [1 ]]},
2221+ ]
2222+ )
2223+ row_iterator = self ._make_one (
2224+ _mock_client (), api_request , path , schema , page_size = 1 , max_results = 5
2225+ )
2226+
2227+ record_batches = row_iterator .to_arrow_iterable ()
2228+ self .assertIsInstance (record_batches , types .GeneratorType )
2229+ record_batches = list (record_batches )
2230+ self .assertEqual (len (record_batches ), 2 )
2231+
2232+ # Check the schema.
2233+ for record_batch in record_batches :
2234+ self .assertIsInstance (record_batch , pyarrow .RecordBatch )
2235+ self .assertEqual (record_batch .schema [0 ].name , "name" )
2236+ self .assertTrue (pyarrow .types .is_string (record_batch .schema [0 ].type ))
2237+ self .assertEqual (record_batch .schema [1 ].name , "age" )
2238+ self .assertTrue (pyarrow .types .is_int64 (record_batch .schema [1 ].type ))
2239+ child_field = record_batch .schema [2 ]
2240+ self .assertEqual (child_field .name , "child" )
2241+ self .assertTrue (pyarrow .types .is_list (child_field .type ))
2242+ self .assertTrue (pyarrow .types .is_struct (child_field .type .value_type ))
2243+ self .assertEqual (child_field .type .value_type [0 ].name , "name" )
2244+ self .assertEqual (child_field .type .value_type [1 ].name , "age" )
2245+
2246+ # Check the data.
2247+ record_batch_1 = record_batches [0 ].to_pydict ()
2248+ names = record_batch_1 ["name" ]
2249+ ages = record_batch_1 ["age" ]
2250+ children = record_batch_1 ["child" ]
2251+ self .assertEqual (names , ["Bharney Rhubble" ])
2252+ self .assertEqual (ages , [33 ])
2253+ self .assertEqual (
2254+ children ,
2255+ [
2256+ [
2257+ {"name" : "Whamm-Whamm Rhubble" , "age" : 3 },
2258+ {"name" : "Hoppy" , "age" : 1 },
2259+ ],
2260+ ],
2261+ )
2262+
2263+ record_batch_2 = record_batches [1 ].to_pydict ()
2264+ names = record_batch_2 ["name" ]
2265+ ages = record_batch_2 ["age" ]
2266+ children = record_batch_2 ["child" ]
2267+ self .assertEqual (names , ["Wylma Phlyntstone" ])
2268+ self .assertEqual (ages , [29 ])
2269+ self .assertEqual (
2270+ children ,
2271+ [[{"name" : "Bepples Phlyntstone" , "age" : 0 }, {"name" : "Dino" , "age" : 4 }],],
2272+ )
2273+
2274+ @mock .patch ("google.cloud.bigquery.table.pyarrow" , new = None )
2275+ def test_to_arrow_iterable_error_if_pyarrow_is_none (self ):
2276+ from google .cloud .bigquery .schema import SchemaField
2277+
2278+ schema = [
2279+ SchemaField ("name" , "STRING" , mode = "REQUIRED" ),
2280+ SchemaField ("age" , "INTEGER" , mode = "REQUIRED" ),
2281+ ]
2282+ rows = [
2283+ {"f" : [{"v" : "Phred Phlyntstone" }, {"v" : "32" }]},
2284+ {"f" : [{"v" : "Bharney Rhubble" }, {"v" : "33" }]},
2285+ ]
2286+ path = "/foo"
2287+ api_request = mock .Mock (return_value = {"rows" : rows })
2288+ row_iterator = self ._make_one (_mock_client (), api_request , path , schema )
2289+
2290+ with pytest .raises (ValueError , match = "pyarrow" ):
2291+ row_iterator .to_arrow_iterable ()
2292+
2293+ @unittest .skipIf (pyarrow is None , "Requires `pyarrow`" )
2294+ @unittest .skipIf (
2295+ bigquery_storage is None , "Requires `google-cloud-bigquery-storage`"
2296+ )
2297+ def test_to_arrow_iterable_w_bqstorage (self ):
2298+ from google .cloud .bigquery import schema
2299+ from google .cloud .bigquery import table as mut
2300+ from google .cloud .bigquery_storage_v1 import reader
2301+
2302+ bqstorage_client = mock .create_autospec (bigquery_storage .BigQueryReadClient )
2303+ bqstorage_client ._transport = mock .create_autospec (
2304+ big_query_read_grpc_transport .BigQueryReadGrpcTransport
2305+ )
2306+ streams = [
2307+ # Use two streams we want to check frames are read from each stream.
2308+ {"name" : "/projects/proj/dataset/dset/tables/tbl/streams/1234" },
2309+ {"name" : "/projects/proj/dataset/dset/tables/tbl/streams/5678" },
2310+ ]
2311+ session = bigquery_storage .types .ReadSession (streams = streams )
2312+ arrow_schema = pyarrow .schema (
2313+ [
2314+ pyarrow .field ("colA" , pyarrow .int64 ()),
2315+ # Not alphabetical to test column order.
2316+ pyarrow .field ("colC" , pyarrow .float64 ()),
2317+ pyarrow .field ("colB" , pyarrow .string ()),
2318+ ]
2319+ )
2320+ session .arrow_schema .serialized_schema = arrow_schema .serialize ().to_pybytes ()
2321+ bqstorage_client .create_read_session .return_value = session
2322+
2323+ mock_rowstream = mock .create_autospec (reader .ReadRowsStream )
2324+ bqstorage_client .read_rows .return_value = mock_rowstream
2325+
2326+ mock_rows = mock .create_autospec (reader .ReadRowsIterable )
2327+ mock_rowstream .rows .return_value = mock_rows
2328+ expected_num_rows = 2
2329+ page_items = [
2330+ pyarrow .array ([1 , - 1 ]),
2331+ pyarrow .array ([2.0 , 4.0 ]),
2332+ pyarrow .array (["abc" , "def" ]),
2333+ ]
2334+
2335+ expected_record_batch = pyarrow .RecordBatch .from_arrays (
2336+ page_items , schema = arrow_schema
2337+ )
2338+ expected_num_record_batches = 3
2339+
2340+ mock_page = mock .create_autospec (reader .ReadRowsPage )
2341+ mock_page .to_arrow .return_value = expected_record_batch
2342+ mock_pages = (mock_page ,) * expected_num_record_batches
2343+ type(mock_rows ).pages = mock .PropertyMock (return_value = mock_pages )
2344+
2345+ schema = [
2346+ schema .SchemaField ("colA" , "INTEGER" ),
2347+ schema .SchemaField ("colC" , "FLOAT" ),
2348+ schema .SchemaField ("colB" , "STRING" ),
2349+ ]
2350+
2351+ row_iterator = mut .RowIterator (
2352+ _mock_client (),
2353+ None , # api_request: ignored
2354+ None , # path: ignored
2355+ schema ,
2356+ table = mut .TableReference .from_string ("proj.dset.tbl" ),
2357+ selected_fields = schema ,
2358+ )
2359+
2360+ record_batches = list (
2361+ row_iterator .to_arrow_iterable (bqstorage_client = bqstorage_client )
2362+ )
2363+ total_record_batches = len (streams ) * len (mock_pages )
2364+ self .assertEqual (len (record_batches ), total_record_batches )
2365+
2366+ for record_batch in record_batches :
2367+ # Are the record batches return as expected?
2368+ self .assertEqual (record_batch , expected_record_batch )
2369+
2370+ # Don't close the client if it was passed in.
2371+ bqstorage_client ._transport .grpc_channel .close .assert_not_called ()
2372+
21542373 @unittest .skipIf (pyarrow is None , "Requires `pyarrow`" )
21552374 def test_to_arrow (self ):
21562375 from google .cloud .bigquery .schema import SchemaField
0 commit comments