1919from pandas .core .base import PandasObject
2020from pandas .tseries .tools import to_datetime
2121
22+ from contextlib import contextmanager
2223
2324class SQLAlchemyRequired (ImportError ):
2425 pass
@@ -645,13 +646,9 @@ def insert_data(self):
645646
646647 return column_names , data_list
647648
648- def get_session (self ):
649- con = self .pd_sql .engine .connect ()
650- return con .begin ()
651-
652- def _execute_insert (self , trans , keys , data_iter ):
649+ def _execute_insert (self , conn , keys , data_iter ):
653650 data = [dict ( (k , v ) for k , v in zip (keys , row ) ) for row in data_iter ]
654- trans . connection .execute (self .insert_statement (), data )
651+ conn .execute (self .insert_statement (), data )
655652
656653 def insert (self , chunksize = None ):
657654 keys , data_list = self .insert_data ()
@@ -661,15 +658,15 @@ def insert(self, chunksize=None):
661658 chunksize = nrows
662659 chunks = int (nrows / chunksize ) + 1
663660
664- with self .get_session () as trans :
661+ with self .pd_sql . run_transaction () as conn :
665662 for i in range (chunks ):
666663 start_i = i * chunksize
667664 end_i = min ((i + 1 ) * chunksize , nrows )
668665 if start_i >= end_i :
669666 break
670667
671668 chunk_iter = zip (* [arr [start_i :end_i ] for arr in data_list ])
672- self ._execute_insert (trans , keys , chunk_iter )
669+ self ._execute_insert (conn , keys , chunk_iter )
673670
674671 def read (self , coerce_float = True , parse_dates = None , columns = None ):
675672
@@ -892,6 +889,9 @@ def __init__(self, engine, schema=None, meta=None):
892889
893890 self .meta = meta
894891
892+ def run_transaction (self ):
893+ return self .engine .begin ()
894+
895895 def execute (self , * args , ** kwargs ):
896896 """Simple passthrough to SQLAlchemy engine"""
897897 return self .engine .execute (* args , ** kwargs )
@@ -1025,9 +1025,9 @@ def sql_schema(self):
10251025 return str (";\n " .join (self .table ))
10261026
10271027 def _execute_create (self ):
1028- with self .get_session () :
1028+ with self .pd_sql . run_transaction () as conn :
10291029 for stmt in self .table :
1030- self . pd_sql .execute (stmt )
1030+ conn .execute (stmt )
10311031
10321032 def insert_statement (self ):
10331033 names = list (map (str , self .frame .columns ))
@@ -1046,12 +1046,9 @@ def insert_statement(self):
10461046 self .name , col_names , wildcards )
10471047 return insert_statement
10481048
1049- def get_session (self ):
1050- return self .pd_sql .con
1051-
1052- def _execute_insert (self , trans , keys , data_iter ):
1049+ def _execute_insert (self , conn , keys , data_iter ):
10531050 data_list = list (data_iter )
1054- trans .executemany (self .insert_statement (), data_list )
1051+ conn .executemany (self .insert_statement (), data_list )
10551052
10561053 def _create_table_setup (self ):
10571054 """Return a list of SQL statement that create a table reflecting the
@@ -1133,6 +1130,17 @@ def __init__(self, con, flavor, is_cursor=False):
11331130 else :
11341131 self .flavor = flavor
11351132
1133+ @contextmanager
1134+ def run_transaction (self ):
1135+ cur = self .con .cursor ()
1136+ try :
1137+ yield cur
1138+ self .con .commit ()
1139+ except :
1140+ self .con .rollback ()
1141+ finally :
1142+ cur .close ()
1143+
11361144 def execute (self , * args , ** kwargs ):
11371145 if self .is_cursor :
11381146 cur = self .con
0 commit comments