@@ -110,48 +110,79 @@ def df_full():
110110 pd .Timestamp ('20130103' )]})
111111
112112
113- def test_invalid_engine (df_compat ):
113+ def check_round_trip (df , engine = None , path = None ,
114+ write_kwargs = None , read_kwargs = None ,
115+ expected = None , check_names = True ,
116+ repeat = 2 ):
117+ """Verify parquet serializer and deserializer produce the same results.
118+
119+ Performs a pandas to disk and disk to pandas round trip,
120+ then compares the 2 resulting DataFrames to verify equality.
121+
122+ Parameters
123+ ----------
124+ df: Dataframe
125+ engine: str, optional
126+ 'pyarrow' or 'fastparquet'
127+ path: str, optional
128+ write_kwargs: dict of str:str, optional
129+ read_kwargs: dict of str:str, optional
130+ expected: DataFrame, optional
131+ Expected deserialization result, otherwise will be equal to `df`
132+ check_names: list of str, optional
133+ Closed set of column names to be compared
134+ repeat: int, optional
135+ How many times to repeat the test
136+ """
137+
138+ write_kwargs = write_kwargs or {'compression' : None }
139+ read_kwargs = read_kwargs or {}
140+
141+ if expected is None :
142+ expected = df
143+
144+ if engine :
145+ write_kwargs ['engine' ] = engine
146+ read_kwargs ['engine' ] = engine
147+
148+ def compare (repeat ):
149+ for _ in range (repeat ):
150+ df .to_parquet (path , ** write_kwargs )
151+ actual = read_parquet (path , ** read_kwargs )
152+ tm .assert_frame_equal (expected , actual ,
153+ check_names = check_names )
154+
155+ if path is None :
156+ with tm .ensure_clean () as path :
157+ compare (repeat )
158+ else :
159+ compare (repeat )
114160
161+
162+ def test_invalid_engine (df_compat ):
115163 with pytest .raises (ValueError ):
116- df_compat . to_parquet ( 'foo' , 'bar' )
164+ check_round_trip ( df_compat , 'foo' , 'bar' )
117165
118166
119167def test_options_py (df_compat , pa ):
120168 # use the set option
121169
122- df = df_compat
123- with tm .ensure_clean () as path :
124-
125- with pd .option_context ('io.parquet.engine' , 'pyarrow' ):
126- df .to_parquet (path )
127-
128- result = read_parquet (path )
129- tm .assert_frame_equal (result , df )
170+ with pd .option_context ('io.parquet.engine' , 'pyarrow' ):
171+ check_round_trip (df_compat )
130172
131173
132174def test_options_fp (df_compat , fp ):
133175 # use the set option
134176
135- df = df_compat
136- with tm .ensure_clean () as path :
137-
138- with pd .option_context ('io.parquet.engine' , 'fastparquet' ):
139- df .to_parquet (path , compression = None )
140-
141- result = read_parquet (path )
142- tm .assert_frame_equal (result , df )
177+ with pd .option_context ('io.parquet.engine' , 'fastparquet' ):
178+ check_round_trip (df_compat )
143179
144180
145181def test_options_auto (df_compat , fp , pa ):
182+ # use the set option
146183
147- df = df_compat
148- with tm .ensure_clean () as path :
149-
150- with pd .option_context ('io.parquet.engine' , 'auto' ):
151- df .to_parquet (path )
152-
153- result = read_parquet (path )
154- tm .assert_frame_equal (result , df )
184+ with pd .option_context ('io.parquet.engine' , 'auto' ):
185+ check_round_trip (df_compat )
155186
156187
157188def test_options_get_engine (fp , pa ):
@@ -228,53 +259,23 @@ def check_error_on_write(self, df, engine, exc):
228259 with tm .ensure_clean () as path :
229260 to_parquet (df , path , engine , compression = None )
230261
231- def check_round_trip (self , df , engine , expected = None , path = None ,
232- write_kwargs = None , read_kwargs = None ,
233- check_names = True ):
234-
235- if write_kwargs is None :
236- write_kwargs = {'compression' : None }
237-
238- if read_kwargs is None :
239- read_kwargs = {}
240-
241- if expected is None :
242- expected = df
243-
244- if path is None :
245- with tm .ensure_clean () as path :
246- check_round_trip_equals (df , path , engine ,
247- write_kwargs = write_kwargs ,
248- read_kwargs = read_kwargs ,
249- expected = expected ,
250- check_names = check_names )
251- else :
252- check_round_trip_equals (df , path , engine ,
253- write_kwargs = write_kwargs ,
254- read_kwargs = read_kwargs ,
255- expected = expected ,
256- check_names = check_names )
257-
258262
259263class TestBasic (Base ):
260264
261265 def test_error (self , engine ):
262-
263266 for obj in [pd .Series ([1 , 2 , 3 ]), 1 , 'foo' , pd .Timestamp ('20130101' ),
264267 np .array ([1 , 2 , 3 ])]:
265268 self .check_error_on_write (obj , engine , ValueError )
266269
267270 def test_columns_dtypes (self , engine ):
268-
269271 df = pd .DataFrame ({'string' : list ('abc' ),
270272 'int' : list (range (1 , 4 ))})
271273
272274 # unicode
273275 df .columns = [u'foo' , u'bar' ]
274- self . check_round_trip (df , engine )
276+ check_round_trip (df , engine )
275277
276278 def test_columns_dtypes_invalid (self , engine ):
277-
278279 df = pd .DataFrame ({'string' : list ('abc' ),
279280 'int' : list (range (1 , 4 ))})
280281
@@ -302,17 +303,16 @@ def test_compression(self, engine, compression):
302303 pytest .importorskip ('brotli' )
303304
304305 df = pd .DataFrame ({'A' : [1 , 2 , 3 ]})
305- self .check_round_trip (df , engine ,
306- write_kwargs = {'compression' : compression })
306+ check_round_trip (df , engine , write_kwargs = {'compression' : compression })
307307
308308 def test_read_columns (self , engine ):
309309 # GH18154
310310 df = pd .DataFrame ({'string' : list ('abc' ),
311311 'int' : list (range (1 , 4 ))})
312312
313313 expected = pd .DataFrame ({'string' : list ('abc' )})
314- self . check_round_trip (df , engine , expected = expected ,
315- read_kwargs = {'columns' : ['string' ]})
314+ check_round_trip (df , engine , expected = expected ,
315+ read_kwargs = {'columns' : ['string' ]})
316316
317317 def test_write_index (self , engine ):
318318 check_names = engine != 'fastparquet'
@@ -323,7 +323,7 @@ def test_write_index(self, engine):
323323 pytest .skip ("pyarrow is < 0.7.0" )
324324
325325 df = pd .DataFrame ({'A' : [1 , 2 , 3 ]})
326- self . check_round_trip (df , engine )
326+ check_round_trip (df , engine )
327327
328328 indexes = [
329329 [2 , 3 , 4 ],
@@ -334,12 +334,12 @@ def test_write_index(self, engine):
334334 # non-default index
335335 for index in indexes :
336336 df .index = index
337- self . check_round_trip (df , engine , check_names = check_names )
337+ check_round_trip (df , engine , check_names = check_names )
338338
339339 # index with meta-data
340340 df .index = [0 , 1 , 2 ]
341341 df .index .name = 'foo'
342- self . check_round_trip (df , engine )
342+ check_round_trip (df , engine )
343343
344344 def test_write_multiindex (self , pa_ge_070 ):
345345 # Not suppoprted in fastparquet as of 0.1.3 or older pyarrow version
@@ -348,7 +348,7 @@ def test_write_multiindex(self, pa_ge_070):
348348 df = pd .DataFrame ({'A' : [1 , 2 , 3 ]})
349349 index = pd .MultiIndex .from_tuples ([('a' , 1 ), ('a' , 2 ), ('b' , 1 )])
350350 df .index = index
351- self . check_round_trip (df , engine )
351+ check_round_trip (df , engine )
352352
353353 def test_write_column_multiindex (self , engine ):
354354 # column multi-index
@@ -357,7 +357,6 @@ def test_write_column_multiindex(self, engine):
357357 self .check_error_on_write (df , engine , ValueError )
358358
359359 def test_multiindex_with_columns (self , pa_ge_070 ):
360-
361360 engine = pa_ge_070
362361 dates = pd .date_range ('01-Jan-2018' , '01-Dec-2018' , freq = 'MS' )
363362 df = pd .DataFrame (np .random .randn (2 * len (dates ), 3 ),
@@ -368,14 +367,10 @@ def test_multiindex_with_columns(self, pa_ge_070):
368367 index2 = index1 .copy (names = None )
369368 for index in [index1 , index2 ]:
370369 df .index = index
371- with tm .ensure_clean () as path :
372- df .to_parquet (path , engine )
373- result = read_parquet (path , engine )
374- expected = df
375- tm .assert_frame_equal (result , expected )
376- result = read_parquet (path , engine , columns = ['A' , 'B' ])
377- expected = df [['A' , 'B' ]]
378- tm .assert_frame_equal (result , expected )
370+
371+ check_round_trip (df , engine )
372+ check_round_trip (df , engine , read_kwargs = {'columns' : ['A' , 'B' ]},
373+ expected = df [['A' , 'B' ]])
379374
380375
381376class TestParquetPyArrow (Base ):
@@ -391,7 +386,7 @@ def test_basic(self, pa, df_full):
391386 tz = 'Europe/Brussels' )
392387 df ['bool_with_none' ] = [True , None , True ]
393388
394- self . check_round_trip (df , pa )
389+ check_round_trip (df , pa )
395390
396391 @pytest .mark .xfail (reason = "pyarrow fails on this (ARROW-1883)" )
397392 def test_basic_subset_columns (self , pa , df_full ):
@@ -402,8 +397,8 @@ def test_basic_subset_columns(self, pa, df_full):
402397 df ['datetime_tz' ] = pd .date_range ('20130101' , periods = 3 ,
403398 tz = 'Europe/Brussels' )
404399
405- self . check_round_trip (df , pa , expected = df [['string' , 'int' ]],
406- read_kwargs = {'columns' : ['string' , 'int' ]})
400+ check_round_trip (df , pa , expected = df [['string' , 'int' ]],
401+ read_kwargs = {'columns' : ['string' , 'int' ]})
407402
408403 def test_duplicate_columns (self , pa ):
409404 # not currently able to handle duplicate columns
@@ -433,7 +428,7 @@ def test_categorical(self, pa_ge_070):
433428
434429 # de-serialized as object
435430 expected = df .assign (a = df .a .astype (object ))
436- self . check_round_trip (df , pa , expected )
431+ check_round_trip (df , pa , expected = expected )
437432
438433 def test_categorical_unsupported (self , pa_lt_070 ):
439434 pa = pa_lt_070
@@ -444,20 +439,19 @@ def test_categorical_unsupported(self, pa_lt_070):
444439
445440 def test_s3_roundtrip (self , df_compat , s3_resource , pa ):
446441 # GH #19134
447- self . check_round_trip (df_compat , pa ,
448- path = 's3://pandas-test/pyarrow.parquet' )
442+ check_round_trip (df_compat , pa ,
443+ path = 's3://pandas-test/pyarrow.parquet' )
449444
450445
451446class TestParquetFastParquet (Base ):
452447
453448 def test_basic (self , fp , df_full ):
454-
455449 df = df_full
456450
457451 # additional supported types for fastparquet
458452 df ['timedelta' ] = pd .timedelta_range ('1 day' , periods = 3 )
459453
460- self . check_round_trip (df , fp )
454+ check_round_trip (df , fp )
461455
462456 @pytest .mark .skip (reason = "not supported" )
463457 def test_duplicate_columns (self , fp ):
@@ -470,7 +464,7 @@ def test_duplicate_columns(self, fp):
470464 def test_bool_with_none (self , fp ):
471465 df = pd .DataFrame ({'a' : [True , None , False ]})
472466 expected = pd .DataFrame ({'a' : [1.0 , np .nan , 0.0 ]}, dtype = 'float16' )
473- self . check_round_trip (df , fp , expected = expected )
467+ check_round_trip (df , fp , expected = expected )
474468
475469 def test_unsupported (self , fp ):
476470
@@ -486,7 +480,7 @@ def test_categorical(self, fp):
486480 if LooseVersion (fastparquet .__version__ ) < LooseVersion ("0.1.3" ):
487481 pytest .skip ("CategoricalDtype not supported for older fp" )
488482 df = pd .DataFrame ({'a' : pd .Categorical (list ('abc' ))})
489- self . check_round_trip (df , fp )
483+ check_round_trip (df , fp )
490484
491485 def test_datetime_tz (self , fp ):
492486 # doesn't preserve tz
@@ -495,7 +489,7 @@ def test_datetime_tz(self, fp):
495489
496490 # warns on the coercion
497491 with catch_warnings (record = True ):
498- self . check_round_trip (df , fp , df .astype ('datetime64[ns]' ))
492+ check_round_trip (df , fp , expected = df .astype ('datetime64[ns]' ))
499493
500494 def test_filter_row_groups (self , fp ):
501495 d = {'a' : list (range (0 , 3 ))}
@@ -508,5 +502,5 @@ def test_filter_row_groups(self, fp):
508502
509503 def test_s3_roundtrip (self , df_compat , s3_resource , fp ):
510504 # GH #19134
511- self . check_round_trip (df_compat , fp ,
512- path = 's3://pandas-test/fastparquet.parquet' )
505+ check_round_trip (df_compat , fp ,
506+ path = 's3://pandas-test/fastparquet.parquet' )
0 commit comments