diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 437e279e909790..cb9dcccea76bde 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -572,8 +572,11 @@ def create(self): else: self._execute_create() - def insert_statement(self): - return self.table.insert() + def insert_statement(self, data, conn): + dialect = getattr(conn, 'dialect', None) + if dialect and getattr(dialect, 'supports_multivalues_insert', False): + return (self.table.insert(data),) + return (self.table.insert(), data) def insert_data(self): if self.index is not None: @@ -613,7 +616,7 @@ def insert_data(self): def _execute_insert(self, conn, keys, data_iter): data = [{k: v for k, v in zip(keys, row)} for row in data_iter] - conn.execute(self.insert_statement(), data) + conn.execute(*self.insert_statement(data, conn)) def insert(self, chunksize=None): keys, data_list = self.insert_data() diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index f3ab74d37a2bc9..31833097bae7af 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -479,6 +479,24 @@ def _transaction_test(self): res2 = self.pandasSQL.read_query('SELECT * FROM test_trans') assert len(res2) == 1 + def _test_insert_multivalues(self): + db = sql.SQLDatabase(self.conn) + df = DataFrame({'A': [1, 0, 0], 'B': [1.1, 0.2, 4.3]}) + table = sql.SQLTable("test_table", db, frame=df) + data = [ + {'A': 1, 'B': 0.46}, + {'A': 0, 'B': -2.06} + ] + statement = table.insert_statement(data, conn=self.conn)[0] + dialect = getattr(self.conn, 'dialect', None) + if dialect and getattr(dialect, 'supports_multivalues_insert', False): + assert statement.parameters == data, ( + 'insert statement should be multivalues' + ) + else: + assert statement.parameters is None, ( + 'insert statement should not be multivalues' + ) # ----------------------------------------------------------------------------- # -- Testing the public API @@ -1665,6 +1683,9 @@ class Temporary(Base): tm.assert_frame_equal(df, expected) + def test_insert_multivalues(self): + self._test_insert_multivalues() + class _TestSQLAlchemyConn(_EngineToConnMixin, _TestSQLAlchemy):