77import nose
88import numpy as np
99
10- from pandas import DataFrame , Series
10+ from pandas import DataFrame , Series , MultiIndex
1111from pandas .compat import range , lrange , iteritems
1212#from pandas.core.datetools import format as date_format
1313
@@ -266,7 +266,7 @@ def _roundtrip(self):
266266 self .pandasSQL .to_sql (self .test_frame1 , 'test_frame_roundtrip' )
267267 result = self .pandasSQL .read_sql ('SELECT * FROM test_frame_roundtrip' )
268268
269- result .set_index ('pandas_index ' , inplace = True )
269+ result .set_index ('level_0 ' , inplace = True )
270270 # result.index.astype(int)
271271
272272 result .index .name = None
@@ -391,7 +391,7 @@ def test_roundtrip(self):
391391
392392 # HACK!
393393 result .index = self .test_frame1 .index
394- result .set_index ('pandas_index ' , inplace = True )
394+ result .set_index ('level_0 ' , inplace = True )
395395 result .index .astype (int )
396396 result .index .name = None
397397 tm .assert_frame_equal (result , self .test_frame1 )
@@ -460,7 +460,9 @@ def test_date_and_index(self):
460460 issubclass (df .IntDateCol .dtype .type , np .datetime64 ),
461461 "IntDateCol loaded with incorrect type" )
462462
463+
463464class TestSQLApi (_TestSQLApi ):
465+
464466 """Test the public API as it would be used directly
465467 """
466468 flavor = 'sqlite'
@@ -474,10 +476,10 @@ def connect(self):
474476 def test_to_sql_index_label (self ):
475477 temp_frame = DataFrame ({'col1' : range (4 )})
476478
477- # no index name, defaults to 'pandas_index '
479+ # no index name, defaults to 'index '
478480 sql .to_sql (temp_frame , 'test_index_label' , self .conn )
479481 frame = sql .read_table ('test_index_label' , self .conn )
480- self .assertEqual (frame .columns [0 ], 'pandas_index ' )
482+ self .assertEqual (frame .columns [0 ], 'index ' )
481483
482484 # specifying index_label
483485 sql .to_sql (temp_frame , 'test_index_label' , self .conn ,
@@ -487,11 +489,11 @@ def test_to_sql_index_label(self):
487489 "Specified index_label not written to database" )
488490
489491 # using the index name
490- temp_frame .index .name = 'index '
492+ temp_frame .index .name = 'index_name '
491493 sql .to_sql (temp_frame , 'test_index_label' , self .conn ,
492494 if_exists = 'replace' )
493495 frame = sql .read_table ('test_index_label' , self .conn )
494- self .assertEqual (frame .columns [0 ], 'index ' ,
496+ self .assertEqual (frame .columns [0 ], 'index_name ' ,
495497 "Index name not written to database" )
496498
497499 # has index name, but specifying index_label
@@ -501,8 +503,74 @@ def test_to_sql_index_label(self):
501503 self .assertEqual (frame .columns [0 ], 'other_label' ,
502504 "Specified index_label not written to database" )
503505
506+ def test_to_sql_index_label_multiindex (self ):
507+ temp_frame = DataFrame ({'col1' : range (4 )},
508+ index = MultiIndex .from_product ([('A0' , 'A1' ), ('B0' , 'B1' )]))
509+
510+ # no index name, defaults to 'level_0' and 'level_1'
511+ sql .to_sql (temp_frame , 'test_index_label' , self .conn )
512+ frame = sql .read_table ('test_index_label' , self .conn )
513+ self .assertEqual (frame .columns [0 ], 'level_0' )
514+ self .assertEqual (frame .columns [1 ], 'level_1' )
515+
516+ # specifying index_label
517+ sql .to_sql (temp_frame , 'test_index_label' , self .conn ,
518+ if_exists = 'replace' , index_label = ['A' , 'B' ])
519+ frame = sql .read_table ('test_index_label' , self .conn )
520+ self .assertEqual (frame .columns [:2 ].tolist (), ['A' , 'B' ],
521+ "Specified index_labels not written to database" )
522+
523+ # using the index name
524+ temp_frame .index .names = ['A' , 'B' ]
525+ sql .to_sql (temp_frame , 'test_index_label' , self .conn ,
526+ if_exists = 'replace' )
527+ frame = sql .read_table ('test_index_label' , self .conn )
528+ self .assertEqual (frame .columns [:2 ].tolist (), ['A' , 'B' ],
529+ "Index names not written to database" )
530+
531+ # has index name, but specifying index_label
532+ sql .to_sql (temp_frame , 'test_index_label' , self .conn ,
533+ if_exists = 'replace' , index_label = ['C' , 'D' ])
534+ frame = sql .read_table ('test_index_label' , self .conn )
535+ self .assertEqual (frame .columns [:2 ].tolist (), ['C' , 'D' ],
536+ "Specified index_labels not written to database" )
537+
538+ # wrong length of index_label
539+ self .assertRaises (ValueError , sql .to_sql , temp_frame ,
540+ 'test_index_label' , self .conn , if_exists = 'replace' ,
541+ index_label = 'C' )
542+
543+ def test_read_table_columns (self ):
544+ # test columns argument in read_table
545+ sql .to_sql (self .test_frame1 , 'test_frame' , self .conn )
546+
547+ cols = ['A' , 'B' ]
548+ result = sql .read_table ('test_frame' , self .conn , columns = cols )
549+ self .assertEqual (result .columns .tolist (), cols ,
550+ "Columns not correctly selected" )
551+
552+ def test_read_table_index_col (self ):
553+ # test columns argument in read_table
554+ sql .to_sql (self .test_frame1 , 'test_frame' , self .conn )
555+
556+ result = sql .read_table ('test_frame' , self .conn , index_col = "index" )
557+ self .assertEqual (result .index .names , ["index" ],
558+ "index_col not correctly set" )
559+
560+ result = sql .read_table ('test_frame' , self .conn , index_col = ["A" , "B" ])
561+ self .assertEqual (result .index .names , ["A" , "B" ],
562+ "index_col not correctly set" )
563+
564+ result = sql .read_table ('test_frame' , self .conn , index_col = ["A" , "B" ],
565+ columns = ["C" , "D" ])
566+ self .assertEqual (result .index .names , ["A" , "B" ],
567+ "index_col not correctly set" )
568+ self .assertEqual (result .columns .tolist (), ["C" , "D" ],
569+ "columns not set correctly whith index_col" )
570+
504571
505572class TestSQLLegacyApi (_TestSQLApi ):
573+
506574 """Test the public legacy API
507575 """
508576 flavor = 'sqlite'
@@ -554,6 +622,23 @@ def test_sql_open_close(self):
554622
555623 tm .assert_frame_equal (self .test_frame2 , result )
556624
625+ def test_roundtrip (self ):
626+ # this test otherwise fails, Legacy mode still uses 'pandas_index'
627+ # as default index column label
628+ sql .to_sql (self .test_frame1 , 'test_frame_roundtrip' ,
629+ con = self .conn , flavor = 'sqlite' )
630+ result = sql .read_sql (
631+ 'SELECT * FROM test_frame_roundtrip' ,
632+ con = self .conn ,
633+ flavor = 'sqlite' )
634+
635+ # HACK!
636+ result .index = self .test_frame1 .index
637+ result .set_index ('pandas_index' , inplace = True )
638+ result .index .astype (int )
639+ result .index .name = None
640+ tm .assert_frame_equal (result , self .test_frame1 )
641+
557642
558643class _TestSQLAlchemy (PandasSQLTest ):
559644 """
@@ -776,6 +861,16 @@ def setUp(self):
776861
777862 self ._load_test1_data ()
778863
864+ def _roundtrip (self ):
865+ # overwrite parent function (level_0 -> pandas_index in legacy mode)
866+ self .drop_table ('test_frame_roundtrip' )
867+ self .pandasSQL .to_sql (self .test_frame1 , 'test_frame_roundtrip' )
868+ result = self .pandasSQL .read_sql ('SELECT * FROM test_frame_roundtrip' )
869+ result .set_index ('pandas_index' , inplace = True )
870+ result .index .name = None
871+
872+ tm .assert_frame_equal (result , self .test_frame1 )
873+
779874 def test_invalid_flavor (self ):
780875 self .assertRaises (
781876 NotImplementedError , sql .PandasSQLLegacy , self .conn , 'oracle' )
0 commit comments