diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 158ef7b7ed791..c18a4aef5355b 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -6,6 +6,7 @@ from datetime import datetime, date, timedelta import warnings +import traceback import itertools import re import numpy as np @@ -97,80 +98,130 @@ def execute(sql, con, cur=None, params=None, flavor='sqlite'): ------- Results Iterable """ - pandas_sql = pandasSQL_builder(con, flavor=flavor) + if cur is None: + pandas_sql = pandasSQL_builder(con, flavor=flavor) + else: + pandas_sql = pandasSQL_builder(cur, flavor=flavor, is_cursor=True) args = _convert_params(sql, params) return pandas_sql.execute(*args) -def tquery(sql, con, cur=None, params=None, flavor='sqlite'): +#------------------------------------------------------------------------------ +#--- Deprecated tquery and uquery + +def _safe_fetch(cur): + try: + result = cur.fetchall() + if not isinstance(result, list): + result = list(result) + return result + except Exception as e: # pragma: no cover + excName = e.__class__.__name__ + if excName == 'OperationalError': + return [] + +def tquery(sql, con=None, cur=None, retry=True): """ - Returns list of tuples corresponding to each row in given sql + DEPRECATED. Returns list of tuples corresponding to each row in given sql query. If only one column selected, then plain list is returned. + To obtain the same result in the future, you can use the following: + + >>> execute(sql, con, params).fetchall() + Parameters ---------- sql: string SQL query to be executed - con: SQLAlchemy engine or DBAPI2 connection (legacy mode) - Using SQLAlchemy makes it possible to use any DB supported by that - library. - If a DBAPI2 object is given, a supported SQL flavor must also be provided + con: DBAPI2 connection cur: depreciated, cursor is obtained from connection - params: list or tuple, optional - List of parameters to pass to execute method. - flavor : string "sqlite", "mysql" - Specifies the flavor of SQL to use. - Ignored when using SQLAlchemy engine. Required when using DBAPI2 - connection. + Returns ------- Results Iterable + """ warnings.warn( - "tquery is depreciated, and will be removed in future versions", - DeprecationWarning) + "tquery is depreciated, and will be removed in future versions. " + "You can use ``execute(...).fetchall()`` instead.", + FutureWarning) - pandas_sql = pandasSQL_builder(con, flavor=flavor) - args = _convert_params(sql, params) - return pandas_sql.tquery(*args) + cur = execute(sql, con, cur=cur) + result = _safe_fetch(cur) + if con is not None: + try: + cur.close() + con.commit() + except Exception as e: + excName = e.__class__.__name__ + if excName == 'OperationalError': # pragma: no cover + print('Failed to commit, may need to restart interpreter') + else: + raise -def uquery(sql, con, cur=None, params=None, engine=None, flavor='sqlite'): + traceback.print_exc() + if retry: + return tquery(sql, con=con, retry=False) + + if result and len(result[0]) == 1: + # python 3 compat + result = list(lzip(*result)[0]) + elif result is None: # pragma: no cover + result = [] + + return result + + +def uquery(sql, con=None, cur=None, retry=True, params=None): """ - Does the same thing as tquery, but instead of returning results, it + DEPRECATED. Does the same thing as tquery, but instead of returning results, it returns the number of rows affected. Good for update queries. + To obtain the same result in the future, you can use the following: + + >>> execute(sql, con).rowcount + Parameters ---------- sql: string SQL query to be executed - con: SQLAlchemy engine or DBAPI2 connection (legacy mode) - Using SQLAlchemy makes it possible to use any DB supported by that - library. - If a DBAPI2 object is given, a supported SQL flavor must also be provided + con: DBAPI2 connection cur: depreciated, cursor is obtained from connection params: list or tuple, optional List of parameters to pass to execute method. - flavor : string "sqlite", "mysql" - Specifies the flavor of SQL to use. - Ignored when using SQLAlchemy engine. Required when using DBAPI2 - connection. + Returns ------- Number of affected rows + """ warnings.warn( - "uquery is depreciated, and will be removed in future versions", - DeprecationWarning) - pandas_sql = pandasSQL_builder(con, flavor=flavor) - args = _convert_params(sql, params) - return pandas_sql.uquery(*args) + "uquery is depreciated, and will be removed in future versions. " + "You can use ``execute(...).rowcount`` instead.", + FutureWarning) + + cur = execute(sql, con, cur=cur, params=params) + + result = cur.rowcount + try: + con.commit() + except Exception as e: + excName = e.__class__.__name__ + if excName != 'OperationalError': + raise + + traceback.print_exc() + if retry: + print('Looks like your connection failed, reconnecting...') + return uquery(sql, con, retry=False) + return result #------------------------------------------------------------------------------ -# Read and write to DataFrames +#--- Read and write to DataFrames def read_sql_table(table_name, con, meta=None, index_col=None, coerce_float=True, parse_dates=None, columns=None): @@ -212,7 +263,7 @@ def read_sql_table(table_name, con, meta=None, index_col=None, -------- read_sql_query : Read SQL query into a DataFrame. read_sql - + """ pandas_sql = PandasSQLAlchemy(con, meta=meta) @@ -322,8 +373,8 @@ def read_sql(sql, con, index_col=None, flavor='sqlite', coerce_float=True, Notes ----- This function is a convenience wrapper around ``read_sql_table`` and - ``read_sql_query`` (and for backward compatibility) and will delegate - to the specific function depending on the provided input (database + ``read_sql_query`` (and for backward compatibility) and will delegate + to the specific function depending on the provided input (database table name or sql query). See also @@ -334,7 +385,19 @@ def read_sql(sql, con, index_col=None, flavor='sqlite', coerce_float=True, """ pandas_sql = pandasSQL_builder(con, flavor=flavor) - if pandas_sql.has_table(sql): + if 'select' in sql.lower(): + try: + if pandas_sql.has_table(sql): + return pandas_sql.read_table( + sql, index_col=index_col, coerce_float=coerce_float, + parse_dates=parse_dates, columns=columns) + except: + pass + + return pandas_sql.read_sql( + sql, index_col=index_col, params=params, + coerce_float=coerce_float, parse_dates=parse_dates) + else: if isinstance(pandas_sql, PandasSQLLegacy): raise ValueError("Reading a table with read_sql is not supported " "for a DBAPI2 connection. Use an SQLAlchemy " @@ -342,10 +405,6 @@ def read_sql(sql, con, index_col=None, flavor='sqlite', coerce_float=True, return pandas_sql.read_table( sql, index_col=index_col, coerce_float=coerce_float, parse_dates=parse_dates, columns=columns) - else: - return pandas_sql.read_sql( - sql, index_col=index_col, params=params, coerce_float=coerce_float, - parse_dates=parse_dates) def to_sql(frame, name, con, flavor='sqlite', if_exists='fail', index=True, @@ -377,6 +436,9 @@ def to_sql(frame, name, con, flavor='sqlite', if_exists='fail', index=True, A sequence should be given if the DataFrame uses MultiIndex. """ + if if_exists not in ('fail', 'replace', 'append'): + raise ValueError("'{0}' is not valid for if_exists".format(if_exists)) + pandas_sql = pandasSQL_builder(con, flavor=flavor) if isinstance(frame, Series): @@ -388,7 +450,7 @@ def to_sql(frame, name, con, flavor='sqlite', if_exists='fail', index=True, index_label=index_label) -def has_table(table_name, con, meta=None, flavor='sqlite'): +def has_table(table_name, con, flavor='sqlite'): """ Check if DataBase has named table. @@ -411,34 +473,37 @@ def has_table(table_name, con, meta=None, flavor='sqlite'): pandas_sql = pandasSQL_builder(con, flavor=flavor) return pandas_sql.has_table(table_name) +table_exists = has_table -def pandasSQL_builder(con, flavor=None, meta=None): + +def pandasSQL_builder(con, flavor=None, meta=None, is_cursor=False): """ Convenience function to return the correct PandasSQL subclass based on the provided parameters """ + # When support for DBAPI connections is removed, + # is_cursor should not be necessary. try: import sqlalchemy if isinstance(con, sqlalchemy.engine.Engine): return PandasSQLAlchemy(con, meta=meta) else: - warnings.warn( - """Not an SQLAlchemy engine, - attempting to use as legacy DBAPI connection""") + warnings.warn("Not an SQLAlchemy engine, " + "attempting to use as legacy DBAPI connection") if flavor is None: raise ValueError( - """PandasSQL must be created with an SQLAlchemy engine - or a DBAPI2 connection and SQL flavour""") + "PandasSQL must be created with an SQLAlchemy engine " + "or a DBAPI2 connection and SQL flavor") else: - return PandasSQLLegacy(con, flavor) + return PandasSQLLegacy(con, flavor, is_cursor=is_cursor) except ImportError: warnings.warn("SQLAlchemy not installed, using legacy mode") if flavor is None: raise SQLAlchemyRequired else: - return PandasSQLLegacy(con, flavor) + return PandasSQLLegacy(con, flavor, is_cursor=is_cursor) class PandasSQLTable(PandasObject): @@ -471,6 +536,9 @@ def __init__(self, name, pandas_sql_engine, frame=None, index=True, self.table = self.pd_sql.get_table(self.name) if self.table is None: self.table = self._create_table_statement() + else: + raise ValueError( + "'{0}' is not valid for if_exists".format(if_exists)) else: self.table = self._create_table_statement() self.create() @@ -485,7 +553,8 @@ def exists(self): return self.pd_sql.has_table(self.name) def sql_schema(self): - return str(self.table.compile()) + from sqlalchemy.schema import CreateTable + return str(CreateTable(self.table)) def create(self): self.table.create() @@ -722,14 +791,6 @@ def execute(self, *args, **kwargs): """Simple passthrough to SQLAlchemy engine""" return self.engine.execute(*args, **kwargs) - def tquery(self, *args, **kwargs): - result = self.execute(*args, **kwargs) - return result.fetchall() - - def uquery(self, *args, **kwargs): - result = self.execute(*args, **kwargs) - return result.rowcount - def read_table(self, table_name, index_col=None, coerce_float=True, parse_dates=None, columns=None): @@ -783,7 +844,7 @@ def drop_table(self, table_name): def _create_sql_schema(self, frame, table_name): table = PandasSQLTable(table_name, self, frame=frame) - return str(table.compile()) + return str(table.sql_schema()) # ---- SQL without SQLAlchemy --- @@ -927,7 +988,8 @@ def _sql_type_name(self, dtype): class PandasSQLLegacy(PandasSQL): - def __init__(self, con, flavor): + def __init__(self, con, flavor, is_cursor=False): + self.is_cursor = is_cursor self.con = con if flavor not in ['sqlite', 'mysql']: raise NotImplementedError @@ -935,8 +997,11 @@ def __init__(self, con, flavor): self.flavor = flavor def execute(self, *args, **kwargs): - try: + if self.is_cursor: + cur = self.con + else: cur = self.con.cursor() + try: if kwargs: cur.execute(*args, **kwargs) else: @@ -953,22 +1018,6 @@ def execute(self, *args, **kwargs): ex = DatabaseError("Execution failed on sql: %s" % args[0]) raise_with_traceback(ex) - def tquery(self, *args): - cur = self.execute(*args) - result = self._fetchall_as_list(cur) - - # This makes into tuples - if result and len(result[0]) == 1: - # python 3 compat - result = list(lzip(*result)[0]) - elif result is None: # pragma: no cover - result = [] - return result - - def uquery(self, *args): - cur = self.execute(*args) - return cur.rowcount - def read_sql(self, sql, index_col=None, coerce_float=True, params=None, parse_dates=None): args = _convert_params(sql, params) @@ -1006,7 +1055,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True, fail: If table exists, do nothing. replace: If table exists, drop it, recreate it, and insert data. append: If table exists, insert data. Create if does not exist. - + """ table = PandasSQLTableLegacy( name, self, frame=frame, index=index, if_exists=if_exists, @@ -1020,7 +1069,7 @@ def has_table(self, name): 'mysql': "SHOW TABLES LIKE '%s'" % name} query = flavor_map.get(self.flavor) - return len(self.tquery(query)) > 0 + return len(self.execute(query).fetchall()) > 0 def get_table(self, table_name): return None # not supported in Legacy mode @@ -1029,32 +1078,90 @@ def drop_table(self, name): drop_sql = "DROP TABLE %s" % name self.execute(drop_sql) + def _create_sql_schema(self, frame, table_name): + table = PandasSQLTableLegacy(table_name, self, frame=frame) + return str(table.sql_schema()) -# legacy names, with depreciation warnings and copied docs -def get_schema(frame, name, con, flavor='sqlite'): + +def get_schema(frame, name, flavor='sqlite', keys=None, con=None): """ Get the SQL db table schema for the given frame Parameters ---------- - frame: DataFrame - name: name of SQL table - con: an open SQL database connection object - engine: an SQLAlchemy engine - replaces connection and flavor - flavor: {'sqlite', 'mysql', 'postgres'}, default 'sqlite' + frame : DataFrame + name : name of SQL table + flavor : {'sqlite', 'mysql'}, default 'sqlite' + keys : columns to use a primary key + con: an open SQL database connection object or an SQLAlchemy engine """ - warnings.warn( - "get_schema is depreciated", DeprecationWarning) + + if con is None: + return _get_schema_legacy(frame, name, flavor, keys) + pandas_sql = pandasSQL_builder(con=con, flavor=flavor) return pandas_sql._create_sql_schema(frame, name) +def _get_schema_legacy(frame, name, flavor, keys=None): + """Old function from 0.13.1. To keep backwards compatibility. + When mysql legacy support is dropped, it should be possible to + remove this code + """ + + def get_sqltype(dtype, flavor): + pytype = dtype.type + pytype_name = "text" + if issubclass(pytype, np.floating): + pytype_name = "float" + elif issubclass(pytype, np.integer): + pytype_name = "int" + elif issubclass(pytype, np.datetime64) or pytype is datetime: + # Caution: np.datetime64 is also a subclass of np.number. + pytype_name = "datetime" + elif pytype is datetime.date: + pytype_name = "date" + elif issubclass(pytype, np.bool_): + pytype_name = "bool" + + return _SQL_TYPES[pytype_name][flavor] + + lookup_type = lambda dtype: get_sqltype(dtype, flavor) + + column_types = lzip(frame.dtypes.index, map(lookup_type, frame.dtypes)) + if flavor == 'sqlite': + columns = ',\n '.join('[%s] %s' % x for x in column_types) + else: + columns = ',\n '.join('`%s` %s' % x for x in column_types) + + keystr = '' + if keys is not None: + if isinstance(keys, string_types): + keys = (keys,) + keystr = ', PRIMARY KEY (%s)' % ','.join(keys) + template = """CREATE TABLE %(name)s ( + %(columns)s + %(keystr)s + );""" + create_statement = template % {'name': name, 'columns': columns, + 'keystr': keystr} + return create_statement + + +# legacy names, with depreciation warnings and copied docs + def read_frame(*args, **kwargs): """DEPRECIATED - use read_sql """ - warnings.warn( - "read_frame is depreciated, use read_sql", DeprecationWarning) + warnings.warn("read_frame is depreciated, use read_sql", FutureWarning) + return read_sql(*args, **kwargs) + + +def frame_query(*args, **kwargs): + """DEPRECIATED - use read_sql + """ + warnings.warn("frame_query is depreciated, use read_sql", FutureWarning) return read_sql(*args, **kwargs) @@ -1092,7 +1199,7 @@ def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs): pandas.DataFrame.to_sql """ - warnings.warn("write_frame is depreciated, use to_sql", DeprecationWarning) + warnings.warn("write_frame is depreciated, use to_sql", FutureWarning) # for backwards compatibility, set index=False when not specified index = kwargs.pop('index', False) @@ -1102,3 +1209,4 @@ def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs): # Append wrapped function docstrings read_frame.__doc__ += read_sql.__doc__ +frame_query.__doc__ += read_sql.__doc__ diff --git a/pandas/io/tests/test_sql.py b/pandas/io/tests/test_sql.py index ad3fa57ab48a7..9a34e84c153a0 100644 --- a/pandas/io/tests/test_sql.py +++ b/pandas/io/tests/test_sql.py @@ -20,13 +20,18 @@ import sqlite3 import csv import os +import sys import nose +import warnings import numpy as np -from pandas import DataFrame, Series, MultiIndex -from pandas.compat import range -#from pandas.core.datetools import format as date_format +from datetime import datetime + +from pandas import DataFrame, Series, Index, MultiIndex, isnull +import pandas.compat as compat +from pandas.compat import StringIO, range, lrange +from pandas.core.datetools import format as date_format import pandas.io.sql as sql import pandas.util.testing as tm @@ -296,11 +301,6 @@ def _execute_sql(self): row = iris_results.fetchone() tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, 'Iris-setosa']) - def _tquery(self): - iris_results = self.pandasSQL.tquery("SELECT * FROM iris") - row = iris_results[0] - tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, 'Iris-setosa']) - #------------------------------------------------------------------------------ #--- Testing the public API @@ -336,8 +336,9 @@ def test_read_sql_iris(self): self._check_iris_loaded_frame(iris_frame) def test_legacy_read_frame(self): - iris_frame = sql.read_frame( - "SELECT * FROM iris", self.conn, flavor='sqlite') + with tm.assert_produces_warning(FutureWarning): + iris_frame = sql.read_frame( + "SELECT * FROM iris", self.conn, flavor='sqlite') self._check_iris_loaded_frame(iris_frame) def test_to_sql(self): @@ -402,8 +403,10 @@ def test_to_sql_panel(self): def test_legacy_write_frame(self): # Assume that functionality is already tested above so just do # quick check that it basically works - sql.write_frame(self.test_frame1, 'test_frame_legacy', self.conn, - flavor='sqlite') + with tm.assert_produces_warning(FutureWarning): + sql.write_frame(self.test_frame1, 'test_frame_legacy', self.conn, + flavor='sqlite') + self.assertTrue( sql.has_table('test_frame_legacy', self.conn, flavor='sqlite'), 'Table not written to DB') @@ -430,12 +433,6 @@ def test_execute_sql(self): row = iris_results.fetchone() tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, 'Iris-setosa']) - def test_tquery(self): - iris_results = sql.tquery( - "SELECT * FROM iris", con=self.conn, flavor='sqlite') - row = iris_results[0] - tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, 'Iris-setosa']) - def test_date_parsing(self): # Test date parsing in read_sq # No Parsing @@ -555,6 +552,11 @@ def test_integer_col_names(self): sql.to_sql(df, "test_frame_integer_col_names", self.conn, if_exists='replace') + def test_get_schema(self): + create_sql = sql.get_schema(self.test_frame1, 'test', 'sqlite', + con=self.conn) + self.assert_('CREATE' in create_sql) + class TestSQLApi(_TestSQLApi): """ @@ -674,6 +676,22 @@ def test_safe_names_warning(self): sql.to_sql(df, "test_frame3_legacy", self.conn, flavor="sqlite", index=False) + def test_get_schema2(self): + # without providing a connection object (available for backwards comp) + create_sql = sql.get_schema(self.test_frame1, 'test', 'sqlite') + self.assert_('CREATE' in create_sql) + + def test_tquery(self): + with tm.assert_produces_warning(FutureWarning): + iris_results = sql.tquery("SELECT * FROM iris", con=self.conn) + row = iris_results[0] + tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, 'Iris-setosa']) + + def test_uquery(self): + with tm.assert_produces_warning(FutureWarning): + rows = sql.uquery("SELECT * FROM iris LIMIT 1", con=self.conn) + self.assertEqual(rows, -1) + #------------------------------------------------------------------------------ #--- Database flavor specific tests @@ -1043,9 +1061,6 @@ def test_roundtrip(self): def test_execute_sql(self): self._execute_sql() - def test_tquery(self): - self._tquery() - class TestMySQLLegacy(TestSQLiteLegacy): """ @@ -1095,6 +1110,598 @@ def tearDown(self): self.conn.close() +#------------------------------------------------------------------------------ +#--- Old tests from 0.13.1 (before refactor using sqlalchemy) + + +_formatters = { + datetime: lambda dt: "'%s'" % date_format(dt), + str: lambda x: "'%s'" % x, + np.str_: lambda x: "'%s'" % x, + compat.text_type: lambda x: "'%s'" % x, + compat.binary_type: lambda x: "'%s'" % x, + float: lambda x: "%.8f" % x, + int: lambda x: "%s" % x, + type(None): lambda x: "NULL", + np.float64: lambda x: "%.10f" % x, + bool: lambda x: "'%s'" % x, +} + +def format_query(sql, *args): + """ + + """ + processed_args = [] + for arg in args: + if isinstance(arg, float) and isnull(arg): + arg = None + + formatter = _formatters[type(arg)] + processed_args.append(formatter(arg)) + + return sql % tuple(processed_args) + +def _skip_if_no_pymysql(): + try: + import pymysql + except ImportError: + raise nose.SkipTest('pymysql not installed, skipping') + + +class TestXSQLite(tm.TestCase): + + def setUp(self): + self.db = sqlite3.connect(':memory:') + + def test_basic(self): + frame = tm.makeTimeDataFrame() + self._check_roundtrip(frame) + + def test_write_row_by_row(self): + + frame = tm.makeTimeDataFrame() + frame.ix[0, 0] = np.nan + create_sql = sql.get_schema(frame, 'test', 'sqlite') + cur = self.db.cursor() + cur.execute(create_sql) + + cur = self.db.cursor() + + ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" + for idx, row in frame.iterrows(): + fmt_sql = format_query(ins, *row) + sql.tquery(fmt_sql, cur=cur) + + self.db.commit() + + result = sql.read_frame("select * from test", con=self.db) + result.index = frame.index + tm.assert_frame_equal(result, frame) + + def test_execute(self): + frame = tm.makeTimeDataFrame() + create_sql = sql.get_schema(frame, 'test', 'sqlite') + cur = self.db.cursor() + cur.execute(create_sql) + ins = "INSERT INTO test VALUES (?, ?, ?, ?)" + + row = frame.ix[0] + sql.execute(ins, self.db, params=tuple(row)) + self.db.commit() + + result = sql.read_frame("select * from test", self.db) + result.index = frame.index[:1] + tm.assert_frame_equal(result, frame[:1]) + + def test_schema(self): + frame = tm.makeTimeDataFrame() + create_sql = sql.get_schema(frame, 'test', 'sqlite') + lines = create_sql.splitlines() + for l in lines: + tokens = l.split(' ') + if len(tokens) == 2 and tokens[0] == 'A': + self.assert_(tokens[1] == 'DATETIME') + + frame = tm.makeTimeDataFrame() + create_sql = sql.get_schema(frame, 'test', 'sqlite', keys=['A', 'B'],) + lines = create_sql.splitlines() + self.assert_('PRIMARY KEY (A,B)' in create_sql) + cur = self.db.cursor() + cur.execute(create_sql) + + def test_execute_fail(self): + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a, b) + ); + """ + cur = self.db.cursor() + cur.execute(create_sql) + + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) + sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.db) + + try: + sys.stdout = StringIO() + self.assertRaises(Exception, sql.execute, + 'INSERT INTO test VALUES("foo", "bar", 7)', + self.db) + finally: + sys.stdout = sys.__stdout__ + + def test_execute_closed_connection(self): + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a, b) + ); + """ + cur = self.db.cursor() + cur.execute(create_sql) + + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) + self.db.close() + try: + sys.stdout = StringIO() + self.assertRaises(Exception, sql.tquery, "select * from test", + con=self.db) + finally: + sys.stdout = sys.__stdout__ + + def test_na_roundtrip(self): + pass + + def _check_roundtrip(self, frame): + sql.write_frame(frame, name='test_table', con=self.db) + result = sql.read_frame("select * from test_table", self.db) + + # HACK! Change this once indexes are handled properly. + result.index = frame.index + + expected = frame + tm.assert_frame_equal(result, expected) + + frame['txt'] = ['a'] * len(frame) + frame2 = frame.copy() + frame2['Idx'] = Index(lrange(len(frame2))) + 10 + sql.write_frame(frame2, name='test_table2', con=self.db) + result = sql.read_frame("select * from test_table2", self.db, + index_col='Idx') + expected = frame.copy() + expected.index = Index(lrange(len(frame2))) + 10 + expected.index.name = 'Idx' + print(expected.index.names) + print(result.index.names) + tm.assert_frame_equal(expected, result) + + def test_tquery(self): + frame = tm.makeTimeDataFrame() + sql.write_frame(frame, name='test_table', con=self.db) + result = sql.tquery("select A from test_table", self.db) + expected = frame.A + result = Series(result, frame.index) + tm.assert_series_equal(result, expected) + + try: + sys.stdout = StringIO() + self.assertRaises(sql.DatabaseError, sql.tquery, + 'select * from blah', con=self.db) + + self.assertRaises(sql.DatabaseError, sql.tquery, + 'select * from blah', con=self.db, retry=True) + finally: + sys.stdout = sys.__stdout__ + + def test_uquery(self): + frame = tm.makeTimeDataFrame() + sql.write_frame(frame, name='test_table', con=self.db) + stmt = 'INSERT INTO test_table VALUES(2.314, -123.1, 1.234, 2.3)' + self.assertEqual(sql.uquery(stmt, con=self.db), 1) + + try: + sys.stdout = StringIO() + + self.assertRaises(sql.DatabaseError, sql.tquery, + 'insert into blah values (1)', con=self.db) + + self.assertRaises(sql.DatabaseError, sql.tquery, + 'insert into blah values (1)', con=self.db, + retry=True) + finally: + sys.stdout = sys.__stdout__ + + def test_keyword_as_column_names(self): + ''' + ''' + df = DataFrame({'From':np.ones(5)}) + sql.write_frame(df, con = self.db, name = 'testkeywords') + + def test_onecolumn_of_integer(self): + # GH 3628 + # a column_of_integers dataframe should transfer well to sql + + mono_df=DataFrame([1 , 2], columns=['c0']) + sql.write_frame(mono_df, con = self.db, name = 'mono_df') + # computing the sum via sql + con_x=self.db + the_sum=sum([my_c0[0] for my_c0 in con_x.execute("select * from mono_df")]) + # it should not fail, and gives 3 ( Issue #3628 ) + self.assertEqual(the_sum , 3) + + result = sql.read_frame("select * from mono_df",con_x) + tm.assert_frame_equal(result,mono_df) + + def test_if_exists(self): + df_if_exists_1 = DataFrame({'col1': [1, 2], 'col2': ['A', 'B']}) + df_if_exists_2 = DataFrame({'col1': [3, 4, 5], 'col2': ['C', 'D', 'E']}) + table_name = 'table_if_exists' + sql_select = "SELECT * FROM %s" % table_name + + def clean_up(test_table_to_drop): + """ + Drops tables created from individual tests + so no dependencies arise from sequential tests + """ + if sql.table_exists(test_table_to_drop, self.db, flavor='sqlite'): + cur = self.db.cursor() + cur.execute("DROP TABLE %s" % test_table_to_drop) + cur.close() + + # test if invalid value for if_exists raises appropriate error + self.assertRaises(ValueError, + sql.write_frame, + frame=df_if_exists_1, + con=self.db, + name=table_name, + flavor='sqlite', + if_exists='notvalidvalue') + clean_up(table_name) + + # test if_exists='fail' + sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + flavor='sqlite', if_exists='fail') + self.assertRaises(ValueError, + sql.write_frame, + frame=df_if_exists_1, + con=self.db, + name=table_name, + flavor='sqlite', + if_exists='fail') + + # test if_exists='replace' + sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + flavor='sqlite', if_exists='replace') + self.assertEqual(sql.tquery(sql_select, con=self.db), + [(1, 'A'), (2, 'B')]) + sql.write_frame(frame=df_if_exists_2, con=self.db, name=table_name, + flavor='sqlite', if_exists='replace') + self.assertEqual(sql.tquery(sql_select, con=self.db), + [(3, 'C'), (4, 'D'), (5, 'E')]) + clean_up(table_name) + + # test if_exists='append' + sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + flavor='sqlite', if_exists='fail') + self.assertEqual(sql.tquery(sql_select, con=self.db), + [(1, 'A'), (2, 'B')]) + sql.write_frame(frame=df_if_exists_2, con=self.db, name=table_name, + flavor='sqlite', if_exists='append') + self.assertEqual(sql.tquery(sql_select, con=self.db), + [(1, 'A'), (2, 'B'), (3, 'C'), (4, 'D'), (5, 'E')]) + clean_up(table_name) + + +class TestXMySQL(tm.TestCase): + + def setUp(self): + _skip_if_no_pymysql() + import pymysql + try: + # Try Travis defaults. + # No real user should allow root access with a blank password. + self.db = pymysql.connect(host='localhost', user='root', passwd='', + db='pandas_nosetest') + except: + pass + else: + return + try: + self.db = pymysql.connect(read_default_group='pandas') + except pymysql.ProgrammingError as e: + raise nose.SkipTest( + "Create a group of connection parameters under the heading " + "[pandas] in your system's mysql default file, " + "typically located at ~/.my.cnf or /etc/.my.cnf. ") + except pymysql.Error as e: + raise nose.SkipTest( + "Cannot connect to database. " + "Create a group of connection parameters under the heading " + "[pandas] in your system's mysql default file, " + "typically located at ~/.my.cnf or /etc/.my.cnf. ") + + def test_basic(self): + _skip_if_no_pymysql() + frame = tm.makeTimeDataFrame() + self._check_roundtrip(frame) + + def test_write_row_by_row(self): + + _skip_if_no_pymysql() + frame = tm.makeTimeDataFrame() + frame.ix[0, 0] = np.nan + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = sql.get_schema(frame, 'test', 'mysql') + cur = self.db.cursor() + cur.execute(drop_sql) + cur.execute(create_sql) + ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" + for idx, row in frame.iterrows(): + fmt_sql = format_query(ins, *row) + sql.tquery(fmt_sql, cur=cur) + + self.db.commit() + + result = sql.read_frame("select * from test", con=self.db) + result.index = frame.index + tm.assert_frame_equal(result, frame) + + def test_execute(self): + _skip_if_no_pymysql() + frame = tm.makeTimeDataFrame() + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = sql.get_schema(frame, 'test', 'mysql') + cur = self.db.cursor() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unknown table.*") + cur.execute(drop_sql) + cur.execute(create_sql) + ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" + + row = frame.ix[0].values.tolist() + sql.execute(ins, self.db, params=tuple(row)) + self.db.commit() + + result = sql.read_frame("select * from test", self.db) + result.index = frame.index[:1] + tm.assert_frame_equal(result, frame[:1]) + + def test_schema(self): + _skip_if_no_pymysql() + frame = tm.makeTimeDataFrame() + create_sql = sql.get_schema(frame, 'test', 'mysql') + lines = create_sql.splitlines() + for l in lines: + tokens = l.split(' ') + if len(tokens) == 2 and tokens[0] == 'A': + self.assert_(tokens[1] == 'DATETIME') + + frame = tm.makeTimeDataFrame() + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = sql.get_schema(frame, 'test', 'mysql', keys=['A', 'B'],) + lines = create_sql.splitlines() + self.assert_('PRIMARY KEY (A,B)' in create_sql) + cur = self.db.cursor() + cur.execute(drop_sql) + cur.execute(create_sql) + + def test_execute_fail(self): + _skip_if_no_pymysql() + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a(5), b(5)) + ); + """ + cur = self.db.cursor() + cur.execute(drop_sql) + cur.execute(create_sql) + + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) + sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.db) + + try: + sys.stdout = StringIO() + self.assertRaises(Exception, sql.execute, + 'INSERT INTO test VALUES("foo", "bar", 7)', + self.db) + finally: + sys.stdout = sys.__stdout__ + + def test_execute_closed_connection(self): + _skip_if_no_pymysql() + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a(5), b(5)) + ); + """ + cur = self.db.cursor() + cur.execute(drop_sql) + cur.execute(create_sql) + + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) + self.db.close() + try: + sys.stdout = StringIO() + self.assertRaises(Exception, sql.tquery, "select * from test", + con=self.db) + finally: + sys.stdout = sys.__stdout__ + + def test_na_roundtrip(self): + _skip_if_no_pymysql() + pass + + def _check_roundtrip(self, frame): + _skip_if_no_pymysql() + drop_sql = "DROP TABLE IF EXISTS test_table" + cur = self.db.cursor() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unknown table.*") + cur.execute(drop_sql) + sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') + result = sql.read_frame("select * from test_table", self.db) + + # HACK! Change this once indexes are handled properly. + result.index = frame.index + result.index.name = frame.index.name + + expected = frame + tm.assert_frame_equal(result, expected) + + frame['txt'] = ['a'] * len(frame) + frame2 = frame.copy() + index = Index(lrange(len(frame2))) + 10 + frame2['Idx'] = index + drop_sql = "DROP TABLE IF EXISTS test_table2" + cur = self.db.cursor() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unknown table.*") + cur.execute(drop_sql) + sql.write_frame(frame2, name='test_table2', con=self.db, flavor='mysql') + result = sql.read_frame("select * from test_table2", self.db, + index_col='Idx') + expected = frame.copy() + + # HACK! Change this once indexes are handled properly. + expected.index = index + expected.index.names = result.index.names + tm.assert_frame_equal(expected, result) + + def test_tquery(self): + try: + import pymysql + except ImportError: + raise nose.SkipTest("no pymysql") + frame = tm.makeTimeDataFrame() + drop_sql = "DROP TABLE IF EXISTS test_table" + cur = self.db.cursor() + cur.execute(drop_sql) + sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') + result = sql.tquery("select A from test_table", self.db) + expected = frame.A + result = Series(result, frame.index) + tm.assert_series_equal(result, expected) + + try: + sys.stdout = StringIO() + self.assertRaises(sql.DatabaseError, sql.tquery, + 'select * from blah', con=self.db) + + self.assertRaises(sql.DatabaseError, sql.tquery, + 'select * from blah', con=self.db, retry=True) + finally: + sys.stdout = sys.__stdout__ + + def test_uquery(self): + try: + import pymysql + except ImportError: + raise nose.SkipTest("no pymysql") + frame = tm.makeTimeDataFrame() + drop_sql = "DROP TABLE IF EXISTS test_table" + cur = self.db.cursor() + cur.execute(drop_sql) + sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') + stmt = 'INSERT INTO test_table VALUES(2.314, -123.1, 1.234, 2.3)' + self.assertEqual(sql.uquery(stmt, con=self.db), 1) + + try: + sys.stdout = StringIO() + + self.assertRaises(sql.DatabaseError, sql.tquery, + 'insert into blah values (1)', con=self.db) + + self.assertRaises(sql.DatabaseError, sql.tquery, + 'insert into blah values (1)', con=self.db, + retry=True) + finally: + sys.stdout = sys.__stdout__ + + def test_keyword_as_column_names(self): + ''' + ''' + _skip_if_no_pymysql() + df = DataFrame({'From':np.ones(5)}) + sql.write_frame(df, con = self.db, name = 'testkeywords', + if_exists='replace', flavor='mysql') + + def test_if_exists(self): + _skip_if_no_pymysql() + df_if_exists_1 = DataFrame({'col1': [1, 2], 'col2': ['A', 'B']}) + df_if_exists_2 = DataFrame({'col1': [3, 4, 5], 'col2': ['C', 'D', 'E']}) + table_name = 'table_if_exists' + sql_select = "SELECT * FROM %s" % table_name + + def clean_up(test_table_to_drop): + """ + Drops tables created from individual tests + so no dependencies arise from sequential tests + """ + if sql.table_exists(test_table_to_drop, self.db, flavor='mysql'): + cur = self.db.cursor() + cur.execute("DROP TABLE %s" % test_table_to_drop) + cur.close() + + # test if invalid value for if_exists raises appropriate error + self.assertRaises(ValueError, + sql.write_frame, + frame=df_if_exists_1, + con=self.db, + name=table_name, + flavor='mysql', + if_exists='notvalidvalue') + clean_up(table_name) + + # test if_exists='fail' + sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + flavor='mysql', if_exists='fail') + self.assertRaises(ValueError, + sql.write_frame, + frame=df_if_exists_1, + con=self.db, + name=table_name, + flavor='mysql', + if_exists='fail') + + # test if_exists='replace' + sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + flavor='mysql', if_exists='replace') + self.assertEqual(sql.tquery(sql_select, con=self.db), + [(1, 'A'), (2, 'B')]) + sql.write_frame(frame=df_if_exists_2, con=self.db, name=table_name, + flavor='mysql', if_exists='replace') + self.assertEqual(sql.tquery(sql_select, con=self.db), + [(3, 'C'), (4, 'D'), (5, 'E')]) + clean_up(table_name) + + # test if_exists='append' + sql.write_frame(frame=df_if_exists_1, con=self.db, name=table_name, + flavor='mysql', if_exists='fail') + self.assertEqual(sql.tquery(sql_select, con=self.db), + [(1, 'A'), (2, 'B')]) + sql.write_frame(frame=df_if_exists_2, con=self.db, name=table_name, + flavor='mysql', if_exists='append') + self.assertEqual(sql.tquery(sql_select, con=self.db), + [(1, 'A'), (2, 'B'), (3, 'C'), (4, 'D'), (5, 'E')]) + clean_up(table_name) + + if __name__ == '__main__': nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'], exit=False)