Skip to content

Commit

Permalink
BUG: enable multivalues insert
Browse files Browse the repository at this point in the history
  • Loading branch information
danfrankj committed Feb 18, 2018
1 parent 569bc7a commit f7f1c3d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
9 changes: 6 additions & 3 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,11 @@ def create(self):
else:
self._execute_create()

def insert_statement(self):
return self.table.insert()
def insert_statement(self, data, conn):
dialect = getattr(conn, 'dialect', None)
if dialect and getattr(dialect, 'supports_multivalues_insert', False):
return (self.table.insert(data),)
return (self.table.insert(), data)

def insert_data(self):
if self.index is not None:
Expand Down Expand Up @@ -613,7 +616,7 @@ def insert_data(self):

def _execute_insert(self, conn, keys, data_iter):
data = [{k: v for k, v in zip(keys, row)} for row in data_iter]
conn.execute(self.insert_statement(), data)
conn.execute(*self.insert_statement(data, conn))

def insert(self, chunksize=None):
keys, data_list = self.insert_data()
Expand Down
21 changes: 21 additions & 0 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,24 @@ def _transaction_test(self):
res2 = self.pandasSQL.read_query('SELECT * FROM test_trans')
assert len(res2) == 1

def _test_insert_multivalues(self):
db = sql.SQLDatabase(self.conn)
df = DataFrame({'A': [1, 0, 0], 'B': [1.1, 0.2, 4.3]})
table = sql.SQLTable("test_table", db, frame=df)
data = [
{'A': 1, 'B': 0.46},
{'A': 0, 'B': -2.06}
]
statement = table.insert_statement(data, conn=self.conn)[0]
dialect = getattr(self.conn, 'dialect', None)
if dialect and getattr(dialect, 'supports_multivalues_insert', False):
assert statement.parameters == data, (
'insert statement should be multivalues'
)
else:
assert statement.parameters is None, (
'insert statement should not be multivalues'
)

# -----------------------------------------------------------------------------
# -- Testing the public API
Expand Down Expand Up @@ -1665,6 +1683,9 @@ class Temporary(Base):

tm.assert_frame_equal(df, expected)

def test_insert_multivalues(self):
self._test_insert_multivalues()


class _TestSQLAlchemyConn(_EngineToConnMixin, _TestSQLAlchemy):

Expand Down

0 comments on commit f7f1c3d

Please sign in to comment.