From 74d091f95ec293c195bd0a5e4886cf5891df2489 Mon Sep 17 00:00:00 2001 From: Dan Allan Date: Thu, 11 Jul 2013 10:01:37 -0400 Subject: [PATCH] ENH #4163 Use SQLAlchemy for DB abstraction TST Import sqlalchemy on Travis. DOC add docstrings to read sql ENH read_sql connects via Connection, Engine, file path, or :memory: string CLN Separate legacy code into new file, and fallback so that all old tests pass. TST to use sqlachemy syntax in tests CLN sql into classes, legacy passes FIX few engine vs con calls CLN pep8 cleanup add postgres support for pandas.io.sql.get_schema WIP: cleaup of sql io module - imported correct SQLALCHEMY type, delete redundant PandasSQLWithCon TODO: renamed _engine_read_table, need to think of a better name. TODO: clean up get_conneciton function ENH: cleanup of SQL io TODO: check that legacy mode works TODO: run tests correctly enabled coerce_float option Cleanup and bug-fixing mainly on legacy mode sql. IMPORTANT - changed legacy to require connection rather than cursor. This is still not yet finalized. TODO: tests and doc Added Test coverage for basic functionality using in-memory SQLite database Simplified API by automatically distinguishing between engine and connection. Added warnings --- pandas/io/sql.py | 790 ++++++++++++++++++++--------- pandas/io/sql_legacy.py | 332 ++++++++++++ pandas/io/tests/test_sql.py | 757 ++++++++++++++------------- pandas/io/tests/test_sql_legacy.py | 497 ++++++++++++++++++ 4 files changed, 1756 insertions(+), 620 deletions(-) create mode 100644 pandas/io/sql_legacy.py create mode 100644 pandas/io/tests/test_sql_legacy.py diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 5e83a0921189b..25b04a34ba6e4 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -4,20 +4,40 @@ """ from __future__ import print_function from datetime import datetime, date - -from pandas.compat import range, lzip, map, zip +import warnings +from pandas.compat import range, lzip, map, zip, raise_with_traceback import pandas.compat as compat import numpy as np -import traceback -from pandas.core.datetools import format as date_format + from pandas.core.api import DataFrame +from pandas.core.base import PandasObject + + +class SQLAlchemyRequired(ImportError): + pass + + +class LegacyMySQLConnection(Exception): + pass + + +class DatabaseError(IOError): + pass + #------------------------------------------------------------------------------ -# Helper execution function +# Helper execution functions + +def _convert_params(sql, params): + """convert sql and params args to DBAPI2.0 compliant format""" + args = [sql] + if params is not None: + args += list(params) + return args -def execute(sql, con, retry=True, cur=None, params=None): +def execute(sql, con, cur=None, params=[], engine=None, flavor='sqlite'): """ Execute the given SQL query using the provided connection object. @@ -25,52 +45,25 @@ def execute(sql, con, retry=True, cur=None, params=None): ---------- sql: string Query to be executed - con: database connection instance - Database connection. Must implement PEP249 (Database API v2.0). - retry: bool - Not currently implemented - cur: database cursor, optional - Must implement PEP249 (Datbase API v2.0). If cursor is not provided, - one will be obtained from the database connection. + con: SQLAlchemy engine or DBAPI2 connection (legacy mode) + Using SQLAlchemy makes it possible to use any DB supported by that + library. + If a DBAPI2 object is given, a supported SQL flavor must also be provided + cur: depreciated, cursor is obtained from connection params: list or tuple, optional List of parameters to pass to execute method. - + flavor : string {sqlite, mysql} specifying the flavor of SQL to use. + Ignored when using SQLAlchemy engine. Required when using DBAPI2 connection. Returns ------- - Cursor object + Results Iterable """ - try: - if cur is None: - cur = con.cursor() - - if params is None: - cur.execute(sql) - else: - cur.execute(sql, params) - return cur - except Exception: - try: - con.rollback() - except Exception: # pragma: no cover - pass - - print('Error on sql %s' % sql) - raise + pandas_sql = pandasSQL_builder(con=con, flavor=flavor) + args = _convert_params(sql, params) + return pandas_sql.execute(*args) -def _safe_fetch(cur): - try: - result = cur.fetchall() - if not isinstance(result, list): - result = list(result) - return result - except Exception as e: # pragma: no cover - excName = e.__class__.__name__ - if excName == 'OperationalError': - return [] - - -def tquery(sql, con=None, cur=None, retry=True): +def tquery(sql, con, cur=None, params=[], engine=None, flavor='sqlite'): """ Returns list of tuples corresponding to each row in given sql query. @@ -81,62 +74,50 @@ def tquery(sql, con=None, cur=None, retry=True): ---------- sql: string SQL query to be executed - con: SQLConnection or DB API 2.0-compliant connection - cur: DB API 2.0 cursor - - Provide a specific connection or a specific cursor if you are executing a - lot of sequential statements and want to commit outside. + con: SQLAlchemy engine or DBAPI2 connection (legacy mode) + Using SQLAlchemy makes it possible to use any DB supported by that + library. + If a DBAPI2 object is given, a supported SQL flavor must also be provided + cur: depreciated, cursor is obtained from connection + params: list or tuple, optional + List of parameters to pass to execute method. + flavor : string {sqlite, mysql} specifying the flavor of SQL to use. + Ignored when using SQLAlchemy engine. Required when using DBAPI2 connection. """ - cur = execute(sql, con, cur=cur) - result = _safe_fetch(cur) + pandas_sql = pandasSQL_builder(con=con, flavor=flavor) + args = _convert_params(sql, params) + return pandas_sql.tquery(*args) - if con is not None: - try: - cur.close() - con.commit() - except Exception as e: - excName = e.__class__.__name__ - if excName == 'OperationalError': # pragma: no cover - print('Failed to commit, may need to restart interpreter') - else: - raise - - traceback.print_exc() - if retry: - return tquery(sql, con=con, retry=False) - - if result and len(result[0]) == 1: - # python 3 compat - result = list(lzip(*result)[0]) - elif result is None: # pragma: no cover - result = [] - - return result - -def uquery(sql, con=None, cur=None, retry=True, params=None): +def uquery(sql, con, cur=None, params=[], engine=None, flavor='sqlite'): """ Does the same thing as tquery, but instead of returning results, it returns the number of rows affected. Good for update queries. + + Parameters + ---------- + sql: string + SQL query to be executed + con: SQLAlchemy engine or DBAPI2 connection (legacy mode) + Using SQLAlchemy makes it possible to use any DB supported by that + library. + If a DBAPI2 object is given, a supported SQL flavor must also be provided + cur: depreciated, cursor is obtained from connection + params: list or tuple, optional + List of parameters to pass to execute method. + flavor : string {sqlite, mysql} specifying the flavor of SQL to use. + Ignored when using SQLAlchemy engine. Required when using DBAPI2 connection. """ - cur = execute(sql, con, cur=cur, retry=retry, params=params) + pandas_sql = pandasSQL_builder(con=con, flavor=flavor) + args = _convert_params(sql, params) + return pandas_sql.uquery(*args) - result = cur.rowcount - try: - con.commit() - except Exception as e: - excName = e.__class__.__name__ - if excName != 'OperationalError': - raise - traceback.print_exc() - if retry: - print('Looks like your connection failed, reconnecting...') - return uquery(sql, con, retry=False) - return result +#------------------------------------------------------------------------------ +# Read and write to DataFrames -def read_frame(sql, con, index_col=None, coerce_float=True, params=None): +def read_sql(sql, con, index_col=None, flavor='sqlite', coerce_float=True, params=[]): """ Returns a DataFrame corresponding to the result set of the query string. @@ -148,35 +129,30 @@ def read_frame(sql, con, index_col=None, coerce_float=True, params=None): ---------- sql: string SQL query to be executed - con: DB connection object, optional + con: SQLAlchemy engine or DBAPI2 connection (legacy mode) + Using SQLAlchemy makes it possible to use any DB supported by that + library. + If a DBAPI2 object is given, a supported SQL flavor must also be provided index_col: string, optional column name to use for the returned DataFrame object. + flavor : string specifying the flavor of SQL to use. Ignored when using + SQLAlchemy engine. Required when using DBAPI2 connection. coerce_float : boolean, default True Attempt to convert values to non-string, non-numeric objects (like decimal.Decimal) to floating point, useful for SQL result sets + cur: depreciated, cursor is obtained from connection params: list or tuple, optional List of parameters to pass to execute method. - """ - cur = execute(sql, con, params=params) - rows = _safe_fetch(cur) - columns = [col_desc[0] for col_desc in cur.description] + flavor : string {sqlite, mysql} specifying the flavor of SQL to use. + Ignored when using SQLAlchemy engine. Required when using DBAPI2 connection. - cur.close() - con.commit() - - result = DataFrame.from_records(rows, columns=columns, - coerce_float=coerce_float) - - if index_col is not None: - result = result.set_index(index_col) - - return result + """ -frame_query = read_frame -read_sql = read_frame + pandas_sql = pandasSQL_builder(con=con, flavor=flavor) + return pandas_sql.read_sql(sql, index_col=index_col, params=params, coerce_float=coerce_float) -def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs): +def to_sql(frame, name, con, flavor='sqlite', if_exists='fail'): """ Write records stored in a DataFrame to a SQL database. @@ -184,21 +160,58 @@ def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs): ---------- frame: DataFrame name: name of SQL table - con: an open SQL database connection object - flavor: {'sqlite', 'mysql', 'oracle'}, default 'sqlite' + con: SQLAlchemy engine or DBAPI2 connection (legacy mode) + Using SQLAlchemy makes it possible to use any DB supported by that + library. + If a DBAPI2 object is given, a supported SQL flavor must also be provided + flavor: {'sqlite', 'mysql', 'postgres'}, default 'sqlite', ignored when using engine if_exists: {'fail', 'replace', 'append'}, default 'fail' fail: If table exists, do nothing. replace: If table exists, drop it, recreate it, and insert data. append: If table exists, insert data. Create if does not exist. """ + pandas_sql = pandasSQL_builder(con=con, flavor=flavor) + pandas_sql.to_sql(frame, name, if_exists=if_exists) + + +# This is an awesome function +def read_table(table_name, con, meta=None, index_col=None, coerce_float=True): + """Given a table name and SQLAlchemy engine, return a DataFrame. + Type convertions will be done automatically + + Parameters + ---------- + table_name: name of SQL table in database + con: SQLAlchemy engine. Legacy mode not supported + meta: SQLAlchemy meta, optional. If omitted MetaData is reflected from engine + index_col: column to set as index, optional + coerce_float : boolean, default True + Attempt to convert values to non-string, non-numeric objects (like + decimal.Decimal) to floating point. Can result in loss of Precision. + + """ + pandas_sql = PandasSQLWithEngine(con, meta=meta) + table = pandas_sql.get_table(table_name) - if 'append' in kwargs: - import warnings - warnings.warn("append is deprecated, use if_exists instead", - FutureWarning) - if kwargs['append']: - if_exists = 'append' + if table is not None: + sql_select = table.select() + return pandas_sql.read_sql(sql_select, index_col=index_col, coerce_float=coerce_float) + else: + raise ValueError("Table %s not found with %s." % table_name, con) + + +def pandasSQL_builder(con, flavor=None, meta=None): + """ + Convenience function to return the correct PandasSQL subclass based on the + provided parameters + """ + try: + import sqlalchemy + + if isinstance(con, sqlalchemy.engine.Engine): + return PandasSQLWithEngine(con, meta=meta) else: +<<<<<<< HEAD if_exists = 'fail' if if_exists not in ('fail', 'replace', 'append'): @@ -224,122 +237,435 @@ def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs): if create is not None: cur = con.cursor() cur.execute(create) +======= + warnings.warn("Not a valid SQLAlchemy engine, attempting to use as legacy DBAPI connection") + if flavor is None: + raise ValueError("""PandasSQL must be created with an SQLAlchemy engine + or a DBAPI2 connection and SQL flavour""") + else: + return PandasSQLWithCon(con, flavor) + + except ImportError: + warnings.warn("SQLAlchemy not installed, using legacy mode") + if flavor is None: + raise SQLAlchemyRequired + else: + return PandasSQLWithCon(con, flavor) + + +class PandasSQL(PandasObject): + """ + Subclasses Should define read_sql and to_sql + """ + def read_sql(self, *args, **kwargs): + raise ValueError("PandasSQL must be created with an engine," + " connection or cursor.") + + def to_sql(self, *args, **kwargs): + raise ValueError("PandasSQL must be created with an engine," + " connection or cursor.") + + def _create_sql_schema(self, frame, name, keys): + raise ValueError("PandasSQL must be created with an engine," + " connection or cursor.") + + def _frame_from_data_and_columns(self, data, columns, index_col=None, coerce_float=True): + df = DataFrame.from_records(data, columns=columns, coerce_float=coerce_float) + if index_col is not None: + df.set_index(index_col, inplace=True) + return df + + def _safe_col_names(self, col_names): + return [s.replace(' ', '_').strip() for s in col_names] # may not be safe enough... + + +class PandasSQLWithEngine(PandasSQL): + """ + This class enables convertion between DataFrame and SQL databases + using SQLAlchemy to handle DataBase abstraction + """ + def __init__(self, engine, meta=None): + self.engine = engine + if not meta: + from sqlalchemy.schema import MetaData + meta = MetaData(self.engine) + meta.reflect(self.engine) + + self.meta = meta + + def execute(self, *args, **kwargs): + """Simple passthrough to SQLAlchemy engine""" + return self.engine.execute(*args, **kwargs) + + def tquery(self, *args, **kwargs): + """Accepts same args as execute""" + result = self.execute(*args, **kwargs) + return result.fetchall() + + def uquery(self, *args, **kwargs): + """Accepts same args as execute""" + result = self.execute(*args, **kwargs) + return result.rowcount + + def read_sql(self, sql, index_col=None, coerce_float=True, params=[]): + args = _convert_params(sql, params) + result = self.execute(*args) + data = result.fetchall() + columns = result.keys() + + return self._frame_from_data_and_columns(data, columns, + index_col=index_col, + coerce_float=coerce_float) + + def to_sql(self, frame, name, if_exists='fail'): + if self.engine.has_table(name): + if if_exists == 'fail': + raise ValueError("Table '%s' already exists." % name) + elif if_exists == 'replace': + #TODO: this triggers a full refresh of metadata, could probably avoid this. + self._drop_table(name) + self._create_table(frame, name) + elif if_exists == 'append': + pass # table exists and will automatically be appended to + else: + self._create_table(frame, name) + self._write(frame, name) + + def _write(self, frame, table_name): + table = self.get_table(table_name) + ins = table.insert() + # TODO: do this in one pass + # TODO this should be done globally first (or work out how to pass np + # dtypes to sql) + + def maybe_asscalar(i): + try: + return np.asscalar(i) + except AttributeError: + return i + + for t in frame.iterrows(): + self.engine.execute(ins, **dict((k, maybe_asscalar(v)) + for k, v in t[1].iteritems())) + # TODO more efficient, I'm *sure* this was just working with tuples + + def has_table(self, name): + return self.engine.has_table(name) + + def get_table(self, table_name): + if self.engine.has_table(table_name): + return self.meta.tables[table_name] + else: + return None + + def _drop_table(self, table_name): + if self.engine.has_table(table_name): + self.get_table(table_name).drop() + self.meta.clear() + self.meta.reflect() + #print(table.exists()) + + def _create_table(self, frame, table_name, keys=None): + table = self._create_sqlalchemy_table(frame, table_name, keys) + table.create() + + def _create_sql_schema(self, frame, table_name, keys=None): + table = self._create_sqlalchemy_table(frame, table_name, keys) + return str(table.compile()) + + def _create_sqlalchemy_table(self, frame, table_name, keys=None): + from sqlalchemy import Table, Column + if keys is None: + keys = [] + + safe_columns = self._safe_col_names(frame.dtypes.index) + column_types = map(self._lookup_type, frame.dtypes) + + columns = [(col_name, col_sqltype, col_name in keys) + for col_name, col_sqltype in zip(safe_columns, column_types)] + + columns = [Column(name, typ, primary_key=pk) for name, typ, pk in columns] + + return Table(table_name, self.meta, *columns) + + def _lookup_type(self, dtype): + from sqlalchemy.types import Integer, Float, Text, Boolean, DateTime, Date + + pytype = dtype.type + + if issubclass(pytype, np.floating): + return Float + if issubclass(pytype, np.integer): + # TODO: Refine integer size. + return Integer + if issubclass(pytype, np.datetime64) or pytype is datetime: + # Caution: np.datetime64 is also a subclass of np.number. + return DateTime + if pytype is date: + return Date + if issubclass(pytype, np.bool_): + return Boolean + return Text + + +# ---- SQL without SQLAlchemy --- +# Flavour specific sql strings and handler class for access to DBs without SQLAlchemy installed + +# SQL type convertions for each DB +_SQL_TYPES = { + 'text': { + 'mysql': 'VARCHAR (63)', + 'sqlite': 'TEXT', + 'postgres': 'text' + }, + 'float': { + 'mysql': 'FLOAT', + 'sqlite': 'REAL', + 'postgres': 'real' + }, + 'int': { + 'mysql': 'BIGINT', + 'sqlite': 'INTEGER', + 'postgres': 'integer' + }, + 'datetime': { + 'mysql': 'DATETIME', + 'sqlite': 'TIMESTAMP', + 'postgres': 'timestamp' + }, + 'date': { + 'mysql': 'DATE', + 'sqlite': 'TIMESTAMP', + 'postgres': 'date' + }, + 'bool': { + 'mysql': 'BOOLEAN', + 'sqlite': 'INTEGER', + 'postgres': 'boolean' + } +} + +# SQL enquote and wildcard symbols +_SQL_SYMB = { + 'mysql': { + 'br_l': '`', + 'br_r': '`', + 'wld': '%s' + }, + 'sqlite': { + 'br_l': '[', + 'br_r': ']', + 'wld': '?' + }, + 'postgres': { + 'br_l': '', + 'br_r': '', + 'wld': '?' + } +} + + +class PandasSQLWithCon(PandasSQL): + def __init__(self, con, flavor): + self.con = con + if flavor not in ['sqlite', 'mysql', 'postgres']: + raise NotImplementedError + else: + self.flavor = flavor + + def execute(self, *args, **kwargs): + try: + cur = self.con.cursor() + if kwargs: + cur.execute(*args, **kwargs) + else: + cur.execute(*args) + return cur + except Exception as e: + try: + self.con.rollback() + except Exception: # pragma: no cover + ex = DatabaseError( + "Execution failed on sql: %s\n%s\nunable to rollback" % (args[0], e)) + raise_with_traceback(ex) + + ex = DatabaseError("Execution failed on sql: %s" % args[0]) + raise_with_traceback(ex) + + def tquery(self, *args): + cur = self.execute(*args) + result = self._fetchall_as_list(cur) + + # This makes into tuples + if result and len(result[0]) == 1: + # python 3 compat + result = list(lzip(*result)[0]) + elif result is None: # pragma: no cover + result = [] + return result + + def uquery(self, *args): + """ + Does the same thing as tquery, but instead of returning results, it + returns the number of rows affected. Good for update queries. + """ + cur = self.execute(*args) + return cur.rowcount + + def read_sql(self, sql, index_col=None, coerce_float=True, params=[], flavor='sqlite'): + args = _convert_params(sql, params) + cursor = self.execute(*args) + columns = [col_desc[0] for col_desc in cursor.description] + data = self._fetchall_as_list(cursor) + cursor.close() + + return self._frame_from_data_and_columns(data, columns, + index_col=index_col, + coerce_float=coerce_float) + + def to_sql(self, frame, name, con=None, if_exists='fail'): + """ + Write records stored in a DataFrame to a SQL database. + + Parameters + ---------- + frame: DataFrame + name: name of SQL table + con: an open SQL database connection object + flavor: {'sqlite', 'mysql', 'postgres'}, default 'sqlite' + if_exists: {'fail', 'replace', 'append'}, default 'fail' + fail: If table exists, do nothing. + replace: If table exists, drop it, recreate it, and insert data. + append: If table exists, insert data. Create if does not exist. + """ + if self.has_table(name): + if if_exists == 'fail': + raise ValueError("Table '%s' already exists." % name) + elif if_exists == 'replace': + self._drop_table(name) + self._create_table(frame, name) + elif if_exists == "append": + pass # should just add... + else: + self._create_table(frame, name) + + self._write(frame, name) + + def _fetchall_as_list(self, cur): + '''ensures result of fetchall is a list''' + result = cur.fetchall() + if not isinstance(result, list): + result = list(result) + return result + + def _write(self, frame, table_name): + # Replace spaces in DataFrame column names with _. + safe_names = self._safe_col_names(frame.columns) + + br_l = _SQL_SYMB[self.flavor]['br_l'] # left val quote char + br_r = _SQL_SYMB[self.flavor]['br_r'] # right val quote char + wld = _SQL_SYMB[self.flavor]['wld'] # wildcard char + + bracketed_names = [br_l + column + br_r for column in safe_names] + col_names = ','.join(bracketed_names) + wildcards = ','.join([wld] * len(safe_names)) + insert_query = 'INSERT INTO %s (%s) VALUES (%s)' % ( + table_name, col_names, wildcards) + + # pandas types are badly handled if there is only 1 col (Issue #3628) + if len(frame.columns) != 1: + data = [tuple(x) for x in frame.values] + else: + data = [tuple(x) for x in frame.values.tolist()] + + cur = self.con.cursor() + cur.executemany(insert_query, data) +>>>>>>> 1259dca... ENH #4163 Use SQLAlchemy for DB abstraction cur.close() - cur = con.cursor() - # Replace spaces in DataFrame column names with _. - safe_names = [s.replace(' ', '_').strip() for s in frame.columns] - flavor_picker = {'sqlite': _write_sqlite, - 'mysql': _write_mysql} - - func = flavor_picker.get(flavor, None) - if func is None: - raise NotImplementedError - func(frame, name, safe_names, cur) - cur.close() - con.commit() - - -def _write_sqlite(frame, table, names, cur): - bracketed_names = ['[' + column + ']' for column in names] - col_names = ','.join(bracketed_names) - wildcards = ','.join(['?'] * len(names)) - insert_query = 'INSERT INTO %s (%s) VALUES (%s)' % ( - table, col_names, wildcards) - # pandas types are badly handled if there is only 1 column ( Issue #3628 ) - if not len(frame.columns) == 1: - data = [tuple(x) for x in frame.values] - else: - data = [tuple(x) for x in frame.values.tolist()] - cur.executemany(insert_query, data) - - -def _write_mysql(frame, table, names, cur): - bracketed_names = ['`' + column + '`' for column in names] - col_names = ','.join(bracketed_names) - wildcards = ','.join([r'%s'] * len(names)) - insert_query = "INSERT INTO %s (%s) VALUES (%s)" % ( - table, col_names, wildcards) - data = [tuple(x) for x in frame.values] - cur.executemany(insert_query, data) - - -def table_exists(name, con, flavor): - flavor_map = { - 'sqlite': ("SELECT name FROM sqlite_master " - "WHERE type='table' AND name='%s';") % name, - 'mysql': "SHOW TABLES LIKE '%s'" % name} - query = flavor_map.get(flavor, None) - if query is None: - raise NotImplementedError - return len(tquery(query, con)) > 0 - - -def get_sqltype(pytype, flavor): - sqltype = {'mysql': 'VARCHAR (63)', - 'sqlite': 'TEXT'} - - if issubclass(pytype, np.floating): - sqltype['mysql'] = 'FLOAT' - sqltype['sqlite'] = 'REAL' - - if issubclass(pytype, np.integer): - #TODO: Refine integer size. - sqltype['mysql'] = 'BIGINT' - sqltype['sqlite'] = 'INTEGER' - - if issubclass(pytype, np.datetime64) or pytype is datetime: - # Caution: np.datetime64 is also a subclass of np.number. - sqltype['mysql'] = 'DATETIME' - sqltype['sqlite'] = 'TIMESTAMP' - - if pytype is datetime.date: - sqltype['mysql'] = 'DATE' - sqltype['sqlite'] = 'TIMESTAMP' - - if issubclass(pytype, np.bool_): - sqltype['sqlite'] = 'INTEGER' - - return sqltype[flavor] - - -def get_schema(frame, name, flavor, keys=None): - "Return a CREATE TABLE statement to suit the contents of a DataFrame." - lookup_type = lambda dtype: get_sqltype(dtype.type, flavor) - # Replace spaces in DataFrame column names with _. - safe_columns = [s.replace(' ', '_').strip() for s in frame.dtypes.index] - column_types = lzip(safe_columns, map(lookup_type, frame.dtypes)) - if flavor == 'sqlite': - columns = ',\n '.join('[%s] %s' % x for x in column_types) - else: - columns = ',\n '.join('`%s` %s' % x for x in column_types) - - keystr = '' - if keys is not None: - if isinstance(keys, compat.string_types): - keys = (keys,) - keystr = ', PRIMARY KEY (%s)' % ','.join(keys) - template = """CREATE TABLE %(name)s ( - %(columns)s - %(keystr)s - );""" - create_statement = template % {'name': name, 'columns': columns, - 'keystr': keystr} - return create_statement - - -def sequence2dict(seq): - """Helper function for cx_Oracle. - - For each element in the sequence, creates a dictionary item equal - to the element and keyed by the position of the item in the list. - >>> sequence2dict(("Matt", 1)) - {'1': 'Matt', '2': 1} - - Source: - http://www.gingerandjohn.com/archives/2004/02/26/cx_oracle-executemany-example/ + def _create_table(self, frame, name, keys=None): + create_sql = self._create_sql_schema(frame, name, keys) + self.execute(create_sql) + + def has_table(self, name): + flavor_map = { + 'sqlite': ("SELECT name FROM sqlite_master " + "WHERE type='table' AND name='%s';") % name, + 'mysql': "SHOW TABLES LIKE '%s'" % name} + query = flavor_map.get(self.flavor) + if query is None: + raise NotImplementedError + return len(self.tquery(query)) > 0 + + def _drop_table(self, name): + # Previously this worried about connection tp cursor then closing... + drop_sql = "DROP TABLE %s" % name + self.execute(drop_sql) + + def _create_sql_schema(self, frame, table_name, keys=None): + "Return a CREATE TABLE statement to suit the contents of a DataFrame." + + lookup_type = lambda dtype: self._get_sqltype(dtype.type) + # Replace spaces in DataFrame column names with _. + safe_columns = self._safe_col_names(frame.dtypes.index) + + column_types = lzip(safe_columns, map(lookup_type, frame.dtypes)) + + br_l = _SQL_SYMB[self.flavor]['br_l'] # left val quote char + br_r = _SQL_SYMB[self.flavor]['br_r'] # right val quote char + col_template = br_l + '%s' + br_r + ' %s' + columns = ',\n '.join(col_template % x for x in column_types) + + keystr = '' + if keys is not None: + if isinstance(keys, compat.string_types): + keys = (keys,) + keystr = ', PRIMARY KEY (%s)' % ','.join(keys) + template = """CREATE TABLE %(name)s ( + %(columns)s + %(keystr)s + );""" + create_statement = template % {'name': table_name, 'columns': columns, + 'keystr': keystr} + return create_statement + + def _get_sqltype(self, pytype): + pytype_name = "text" + if issubclass(pytype, np.floating): + pytype_name = "float" + elif issubclass(pytype, np.integer): + pytype_name = "int" + elif issubclass(pytype, np.datetime64) or pytype is datetime: + # Caution: np.datetime64 is also a subclass of np.number. + pytype_name = "datetime" + elif pytype is datetime.date: + pytype_name = "date" + elif issubclass(pytype, np.bool_): + pytype_name = "bool" + + return _SQL_TYPES[pytype_name][self.flavor] + + +# legacy names +def get_schema(frame, name, con=None, flavor='sqlite', engine=None): + """ + Get the SQL db table schema for the given frame + + Parameters + ---------- + frame: DataFrame + name: name of SQL table + con: an open SQL database connection object + engine: an SQLAlchemy engine - replaces connection and flavor + flavor: {'sqlite', 'mysql', 'postgres'}, default 'sqlite' + """ - d = {} - for k, v in zip(range(1, 1 + len(seq)), seq): - d[str(k)] = v - return d + pandas_sql = pandasSQL_builder(con=con, flavor=flavor) + return pandas_sql._create_sql_schema() + + + +#TODO: add depreciation warnings +read_frame = read_sql +write_frame = to_sql + diff --git a/pandas/io/sql_legacy.py b/pandas/io/sql_legacy.py new file mode 100644 index 0000000000000..a8a5d968dd02d --- /dev/null +++ b/pandas/io/sql_legacy.py @@ -0,0 +1,332 @@ +""" +Collection of query wrappers / abstractions to both facilitate data +retrieval and to reduce dependency on DB-specific API. +""" +from datetime import datetime, date + +import numpy as np +import traceback + +from pandas.core.datetools import format as date_format +from pandas.core.api import DataFrame, isnull + +#------------------------------------------------------------------------------ +# Helper execution function + + +def execute(sql, con, retry=True, cur=None, params=None): + """ + Execute the given SQL query using the provided connection object. + + Parameters + ---------- + sql: string + Query to be executed + con: database connection instance + Database connection. Must implement PEP249 (Database API v2.0). + retry: bool + Not currently implemented + cur: database cursor, optional + Must implement PEP249 (Datbase API v2.0). If cursor is not provided, + one will be obtained from the database connection. + params: list or tuple, optional + List of parameters to pass to execute method. + + Returns + ------- + Cursor object + """ + try: + if cur is None: + cur = con.cursor() + + if params is None: + cur.execute(sql) + else: + cur.execute(sql, params) + return cur + except Exception: + try: + con.rollback() + except Exception: # pragma: no cover + pass + + print ('Error on sql %s' % sql) + raise + + +def _safe_fetch(cur): + try: + result = cur.fetchall() + if not isinstance(result, list): + result = list(result) + return result + except Exception, e: # pragma: no cover + excName = e.__class__.__name__ + if excName == 'OperationalError': + return [] + + +def tquery(sql, con=None, cur=None, retry=True): + """ + Returns list of tuples corresponding to each row in given sql + query. + + If only one column selected, then plain list is returned. + + Parameters + ---------- + sql: string + SQL query to be executed + con: SQLConnection or DB API 2.0-compliant connection + cur: DB API 2.0 cursor + + Provide a specific connection or a specific cursor if you are executing a + lot of sequential statements and want to commit outside. + """ + cur = execute(sql, con, cur=cur) + result = _safe_fetch(cur) + + if con is not None: + try: + cur.close() + con.commit() + except Exception as e: + excName = e.__class__.__name__ + if excName == 'OperationalError': # pragma: no cover + print ('Failed to commit, may need to restart interpreter') + else: + raise + + traceback.print_exc() + if retry: + return tquery(sql, con=con, retry=False) + + if result and len(result[0]) == 1: + # python 3 compat + result = list(list(zip(*result))[0]) + elif result is None: # pragma: no cover + result = [] + + return result + + +def uquery(sql, con=None, cur=None, retry=True, params=None): + """ + Does the same thing as tquery, but instead of returning results, it + returns the number of rows affected. Good for update queries. + """ + cur = execute(sql, con, cur=cur, retry=retry, params=params) + + result = cur.rowcount + try: + con.commit() + except Exception as e: + excName = e.__class__.__name__ + if excName != 'OperationalError': + raise + + traceback.print_exc() + if retry: + print ('Looks like your connection failed, reconnecting...') + return uquery(sql, con, retry=False) + return result + + +def read_frame(sql, con, index_col=None, coerce_float=True, params=None): + """ + Returns a DataFrame corresponding to the result set of the query + string. + + Optionally provide an index_col parameter to use one of the + columns as the index. Otherwise will be 0 to len(results) - 1. + + Parameters + ---------- + sql: string + SQL query to be executed + con: DB connection object, optional + index_col: string, optional + column name to use for the returned DataFrame object. + coerce_float : boolean, default True + Attempt to convert values to non-string, non-numeric objects (like + decimal.Decimal) to floating point, useful for SQL result sets + params: list or tuple, optional + List of parameters to pass to execute method. + """ + cur = execute(sql, con, params=params) + rows = _safe_fetch(cur) + columns = [col_desc[0] for col_desc in cur.description] + + cur.close() + con.commit() + + result = DataFrame.from_records(rows, columns=columns, + coerce_float=coerce_float) + + if index_col is not None: + result = result.set_index(index_col) + + return result + +frame_query = read_frame +read_sql = read_frame + + +def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs): + """ + Write records stored in a DataFrame to a SQL database. + + Parameters + ---------- + frame: DataFrame + name: name of SQL table + con: an open SQL database connection object + flavor: {'sqlite', 'mysql', 'oracle'}, default 'sqlite' + if_exists: {'fail', 'replace', 'append'}, default 'fail' + fail: If table exists, do nothing. + replace: If table exists, drop it, recreate it, and insert data. + append: If table exists, insert data. Create if does not exist. + """ + + if 'append' in kwargs: + import warnings + warnings.warn("append is deprecated, use if_exists instead", + FutureWarning) + if kwargs['append']: + if_exists='append' + else: + if_exists='fail' + exists = table_exists(name, con, flavor) + if if_exists == 'fail' and exists: + raise ValueError, "Table '%s' already exists." % name + + #create or drop-recreate if necessary + create = None + if exists and if_exists == 'replace': + create = "DROP TABLE %s" % name + elif not exists: + create = get_schema(frame, name, flavor) + + if create is not None: + cur = con.cursor() + cur.execute(create) + cur.close() + + cur = con.cursor() + # Replace spaces in DataFrame column names with _. + safe_names = [s.replace(' ', '_').strip() for s in frame.columns] + flavor_picker = {'sqlite' : _write_sqlite, + 'mysql' : _write_mysql} + + func = flavor_picker.get(flavor, None) + if func is None: + raise NotImplementedError + func(frame, name, safe_names, cur) + cur.close() + con.commit() + + +def _write_sqlite(frame, table, names, cur): + bracketed_names = ['[' + column + ']' for column in names] + col_names = ','.join(bracketed_names) + wildcards = ','.join(['?'] * len(names)) + insert_query = 'INSERT INTO %s (%s) VALUES (%s)' % ( + table, col_names, wildcards) + # pandas types are badly handled if there is only 1 column ( Issue #3628 ) + if not len(frame.columns )==1 : + data = [tuple(x) for x in frame.values] + else : + data = [tuple(x) for x in frame.values.tolist()] + cur.executemany(insert_query, data) + + +def _write_mysql(frame, table, names, cur): + bracketed_names = ['`' + column + '`' for column in names] + col_names = ','.join(bracketed_names) + wildcards = ','.join([r'%s'] * len(names)) + insert_query = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, col_names, wildcards) + data = [tuple(x) for x in frame.values] + cur.executemany(insert_query, data) + + +def table_exists(name, con, flavor): + flavor_map = { + 'sqlite': ("SELECT name FROM sqlite_master " + "WHERE type='table' AND name='%s';") % name, + 'mysql' : "SHOW TABLES LIKE '%s'" % name} + query = flavor_map.get(flavor, None) + if query is None: + raise NotImplementedError + return len(tquery(query, con)) > 0 + + +def get_sqltype(pytype, flavor): + sqltype = {'mysql': 'VARCHAR (63)', + 'sqlite': 'TEXT'} + + if issubclass(pytype, np.floating): + sqltype['mysql'] = 'FLOAT' + sqltype['sqlite'] = 'REAL' + + if issubclass(pytype, np.integer): + #TODO: Refine integer size. + sqltype['mysql'] = 'BIGINT' + sqltype['sqlite'] = 'INTEGER' + + if issubclass(pytype, np.datetime64) or pytype is datetime: + # Caution: np.datetime64 is also a subclass of np.number. + sqltype['mysql'] = 'DATETIME' + sqltype['sqlite'] = 'TIMESTAMP' + + if pytype is datetime.date: + sqltype['mysql'] = 'DATE' + sqltype['sqlite'] = 'TIMESTAMP' + + if issubclass(pytype, np.bool_): + sqltype['sqlite'] = 'INTEGER' + + return sqltype[flavor] + + +def get_schema(frame, name, flavor, keys=None): + "Return a CREATE TABLE statement to suit the contents of a DataFrame." + lookup_type = lambda dtype: get_sqltype(dtype.type, flavor) + # Replace spaces in DataFrame column names with _. + safe_columns = [s.replace(' ', '_').strip() for s in frame.dtypes.index] + column_types = zip(safe_columns, map(lookup_type, frame.dtypes)) + if flavor == 'sqlite': + columns = ',\n '.join('[%s] %s' % x for x in column_types) + else: + columns = ',\n '.join('`%s` %s' % x for x in column_types) + + keystr = '' + if keys is not None: + if isinstance(keys, basestring): + keys = (keys,) + keystr = ', PRIMARY KEY (%s)' % ','.join(keys) + template = """CREATE TABLE %(name)s ( + %(columns)s + %(keystr)s + );""" + create_statement = template % {'name': name, 'columns': columns, + 'keystr': keystr} + return create_statement + + +def sequence2dict(seq): + """Helper function for cx_Oracle. + + For each element in the sequence, creates a dictionary item equal + to the element and keyed by the position of the item in the list. + >>> sequence2dict(("Matt", 1)) + {'1': 'Matt', '2': 1} + + Source: + http://www.gingerandjohn.com/archives/2004/02/26/cx_oracle-executemany-example/ + """ + d = {} + for k,v in zip(range(1, 1 + len(seq)), seq): + d[str(k)] = v + return d diff --git a/pandas/io/tests/test_sql.py b/pandas/io/tests/test_sql.py index ef9917c9a02f7..ffde3110371c4 100644 --- a/pandas/io/tests/test_sql.py +++ b/pandas/io/tests/test_sql.py @@ -1,212 +1,190 @@ from __future__ import print_function +import unittest import sqlite3 -import sys - -import warnings - -import nose +import csv +import os import numpy as np -from pandas.core.datetools import format as date_format -from pandas.core.api import DataFrame, isnull -from pandas.compat import StringIO, range, lrange -import pandas.compat as compat +#from pandas.core.datetools import format as date_format +from pandas import DataFrame +from pandas.compat import range, lrange, iteritems + import pandas.io.sql as sql import pandas.util.testing as tm -from pandas import Series, Index, DataFrame -from datetime import datetime - -_formatters = { - datetime: lambda dt: "'%s'" % date_format(dt), - str: lambda x: "'%s'" % x, - np.str_: lambda x: "'%s'" % x, - compat.text_type: lambda x: "'%s'" % x, - compat.binary_type: lambda x: "'%s'" % x, - float: lambda x: "%.8f" % x, - int: lambda x: "%s" % x, - type(None): lambda x: "NULL", - np.float64: lambda x: "%.10f" % x, - bool: lambda x: "'%s'" % x, -} - -def format_query(sql, *args): - """ - - """ - processed_args = [] - for arg in args: - if isinstance(arg, float) and isnull(arg): - arg = None - - formatter = _formatters[type(arg)] - processed_args.append(formatter(arg)) - - return sql % tuple(processed_args) - -def _skip_if_no_MySQLdb(): - try: - import MySQLdb - except ImportError: - raise nose.SkipTest('MySQLdb not installed, skipping') - -class TestSQLite(tm.TestCase): + +import sqlalchemy + + +class TestSQLAlchemy(unittest.TestCase): + ''' + Test the sqlalchemy backend against an in-memory sqlite database. + Assume that sqlalchemy takes case of the DB specifics + ''' def setUp(self): - self.db = sqlite3.connect(':memory:') - - def test_basic(self): - frame = tm.makeTimeDataFrame() - self._check_roundtrip(frame) - - def test_write_row_by_row(self): - frame = tm.makeTimeDataFrame() - frame.ix[0, 0] = np.nan - create_sql = sql.get_schema(frame, 'test', 'sqlite') - cur = self.db.cursor() - cur.execute(create_sql) - - cur = self.db.cursor() - - ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" - for idx, row in frame.iterrows(): - fmt_sql = format_query(ins, *row) - sql.tquery(fmt_sql, cur=cur) - - self.db.commit() - - result = sql.read_frame("select * from test", con=self.db) - result.index = frame.index - tm.assert_frame_equal(result, frame) - - def test_execute(self): - frame = tm.makeTimeDataFrame() - create_sql = sql.get_schema(frame, 'test', 'sqlite') - cur = self.db.cursor() - cur.execute(create_sql) - ins = "INSERT INTO test VALUES (?, ?, ?, ?)" - - row = frame.ix[0] - sql.execute(ins, self.db, params=tuple(row)) - self.db.commit() - - result = sql.read_frame("select * from test", self.db) - result.index = frame.index[:1] - tm.assert_frame_equal(result, frame[:1]) - - def test_schema(self): - frame = tm.makeTimeDataFrame() - create_sql = sql.get_schema(frame, 'test', 'sqlite') - lines = create_sql.splitlines() - for l in lines: - tokens = l.split(' ') - if len(tokens) == 2 and tokens[0] == 'A': - self.assert_(tokens[1] == 'DATETIME') - - frame = tm.makeTimeDataFrame() - create_sql = sql.get_schema(frame, 'test', 'sqlite', keys=['A', 'B'],) - lines = create_sql.splitlines() - self.assert_('PRIMARY KEY (A,B)' in create_sql) - cur = self.db.cursor() - cur.execute(create_sql) - - def test_execute_fail(self): - create_sql = """ - CREATE TABLE test - ( - a TEXT, - b TEXT, - c REAL, - PRIMARY KEY (a, b) - ); - """ - cur = self.db.cursor() - cur.execute(create_sql) - - sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) - sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.db) + self.engine = sqlalchemy.create_engine('sqlite:///:memory:') + self._load_iris_data(self.engine) - try: - sys.stdout = StringIO() - self.assertRaises(Exception, sql.execute, - 'INSERT INTO test VALUES("foo", "bar", 7)', - self.db) - finally: - sys.stdout = sys.__stdout__ + self.test_frame_time = tm.makeTimeDataFrame() + self._load_test1_data() - def test_execute_closed_connection(self): - create_sql = """ - CREATE TABLE test - ( - a TEXT, - b TEXT, - c REAL, - PRIMARY KEY (a, b) - ); - """ - cur = self.db.cursor() - cur.execute(create_sql) - - sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) - self.db.close() - try: - sys.stdout = StringIO() - self.assertRaises(Exception, sql.tquery, "select * from test", - con=self.db) - finally: - sys.stdout = sys.__stdout__ + def _load_iris_data(self, engine): + self.dirpath = tm.get_data_path() + iris_csv_file = os.path.join(self.dirpath, 'iris.csv') + engine.execute("""CREATE TABLE iris ( + `SepalLength` REAL, + `SepalWidth` REAL, + `PetalLength` REAL, + `PetalWidth` REAL, + `Name` TEXT + )""") + + with open(iris_csv_file, 'rU') as iris_csv: + r = csv.reader(iris_csv) + next(r) # skip header row + ins = """ + INSERT INTO iris + VALUES(?, ?, ?, ?, ?) + """ + for row in r: + engine.execute(ins, *row) + + def _load_test1_data(self): + test1_csv_file = os.path.join(self.dirpath, 'test1.csv') + + with open(test1_csv_file, 'rU') as test1_csv: + dr = csv.DictReader(test1_csv) + self.test_frame1 = DataFrame(list(dr)) + + def _test_iris_loaded_frame(self, iris_frame): + pytype = iris_frame.dtypes[0].type + row = iris_frame.iloc[0] + + self.assertTrue(issubclass(pytype, np.floating), 'Loaded frame has incorrect type') + tm.equalContents(row.values, [5.1, 3.5, 1.4, 0.2, 'Iris-setosa']) + + def test_read_sql(self): + iris_frame = sql.read_sql("SELECT * FROM iris", con=self.engine) + self._test_iris_loaded_frame(iris_frame) + + def test_read_table(self): + iris_frame = sql.read_table("iris", con=self.engine) + self._test_iris_loaded_frame(iris_frame) + + def test_to_sql(self): + # Nuke table + self.engine.execute("DROP TABLE IF EXISTS test_frame1") + + sql.to_sql(self.test_frame1, 'test_frame1', con=self.engine) + self.assertTrue(self.engine.has_table('test_frame1'), 'Table not written to DB') + + # Nuke table + self.engine.execute("DROP TABLE IF EXISTS test_frame1") + + def test_to_sql_fail(self): + # Nuke table + self.engine.execute("DROP TABLE IF EXISTS test_frame1") + + sql.to_sql(self.test_frame1, 'test_frame1', con=self.engine, if_exists='fail') + self.assertTrue(self.engine.has_table('test_frame1'), 'Table not written to DB') + + self.assertRaises(ValueError, sql.to_sql, self.test_frame1, 'test_frame1', con=self.engine, if_exists='fail') + + # Nuke table + self.engine.execute("DROP TABLE IF EXISTS test_frame1") + + def test_to_sql_replace(self): + # Nuke table just in case + self.engine.execute("DROP TABLE IF EXISTS test_frame1") + sql.to_sql(self.test_frame1, 'test_frame1', con=self.engine, if_exists='fail') + # Add to table again + sql.to_sql(self.test_frame1, 'test_frame1', con=self.engine, if_exists='replace') + self.assertTrue(self.engine.has_table('test_frame1'), 'Table not written to DB') + + num_entries = len(self.test_frame1) + + result = self.engine.execute("SELECT count(*) AS count_1 FROM test_frame1").fetchone() + num_rows = result[0] - def test_na_roundtrip(self): - pass + self.assertEqual(num_rows, num_entries, "not the same number of rows as entries") - def _check_roundtrip(self, frame): - sql.write_frame(frame, name='test_table', con=self.db) - result = sql.read_frame("select * from test_table", self.db) + # Nuke table + self.engine.execute("DROP TABLE IF EXISTS test_frame1") - # HACK! Change this once indexes are handled properly. - result.index = frame.index + def test_to_sql_append(self): + # Nuke table just in case + self.engine.execute("DROP TABLE IF EXISTS test_frame1") + sql.to_sql(self.test_frame1, 'test_frame1', con=self.engine, if_exists='fail') + # Add to table again + sql.to_sql(self.test_frame1, 'test_frame1', con=self.engine, if_exists='append') + self.assertTrue(self.engine.has_table('test_frame1'), 'Table not written to DB') - expected = frame - tm.assert_frame_equal(result, expected) + num_entries = 2*len(self.test_frame1) + result = self.engine.execute("SELECT count(*) AS count_1 FROM test_frame1").fetchone() + num_rows = result[0] - frame['txt'] = ['a'] * len(frame) - frame2 = frame.copy() - frame2['Idx'] = Index(lrange(len(frame2))) + 10 - sql.write_frame(frame2, name='test_table2', con=self.db) - result = sql.read_frame("select * from test_table2", self.db, - index_col='Idx') - expected = frame.copy() - expected.index = Index(lrange(len(frame2))) + 10 - expected.index.name = 'Idx' - print(expected.index.names) - print(result.index.names) - tm.assert_frame_equal(expected, result) + self.assertEqual(num_rows, num_entries, "not the same number of rows as entries") + + # Nuke table + self.engine.execute("DROP TABLE IF EXISTS test_frame1") + + def test_create_table(self): + temp_engine = sqlalchemy.create_engine('sqlite:///:memory:') + temp_frame = DataFrame({'one': [1., 2., 3., 4.], 'two': [4., 3., 2., 1.]}) + + pandasSQL = sql.PandasSQLWithEngine(temp_engine) + pandasSQL._create_table(temp_frame, 'temp_frame') + + self.assertTrue(temp_engine.has_table('temp_frame'), 'Table not written to DB') + + def test_drop_table(self): + temp_engine = sqlalchemy.create_engine('sqlite:///:memory:') + + temp_frame = DataFrame({'one': [1., 2., 3., 4.], 'two': [4., 3., 2., 1.]}) + + pandasSQL = sql.PandasSQLWithEngine(temp_engine) + pandasSQL._create_table(temp_frame, 'temp_frame') + + self.assertTrue(temp_engine.has_table('temp_frame'), 'Table not written to DB') + + pandasSQL._drop_table('temp_frame') + + self.assertFalse(temp_engine.has_table('temp_frame'), 'Table not deleted from DB') + + def test_roundtrip(self): + #temp_engine = sqlalchemy.create_engine('sqlite:///:memory:') + + sql.to_sql(self.test_frame1, 'test_frame_roundtrip', con=self.engine) + result = sql.read_table('test_frame_roundtrip', con=self.engine) + + # HACK! + result.index = self.test_frame1.index + + tm.assert_frame_equal(result, self.test_frame1) + + def test_execute_sql(self): + # drop_sql = "DROP TABLE IF EXISTS test" # should already be done + iris_results = sql.execute("SELECT * FROM iris", con=self.engine) + row = iris_results.fetchone() + tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, 'Iris-setosa']) def test_tquery(self): - frame = tm.makeTimeDataFrame() - sql.write_frame(frame, name='test_table', con=self.db) - result = sql.tquery("select A from test_table", self.db) - expected = frame.A - result = Series(result, frame.index) - tm.assert_series_equal(result, expected) + iris_results = sql.tquery("SELECT * FROM iris", con=self.engine) + row = iris_results[0] + tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, 'Iris-setosa']) - try: - sys.stdout = StringIO() - self.assertRaises(sqlite3.OperationalError, sql.tquery, - 'select * from blah', con=self.db) +# --- Test SQLITE fallback - self.assertRaises(sqlite3.OperationalError, sql.tquery, - 'select * from blah', con=self.db, retry=True) - finally: - sys.stdout = sys.__stdout__ - def test_uquery(self): - frame = tm.makeTimeDataFrame() - sql.write_frame(frame, name='test_table', con=self.db) - stmt = 'INSERT INTO test_table VALUES(2.314, -123.1, 1.234, 2.3)' - self.assertEqual(sql.uquery(stmt, con=self.db), 1) +class TestSQLite(unittest.TestCase): + ''' + Test the sqlalchemy backend against an in-memory sqlite database. + Assume that sqlalchemy takes case of the DB specifics + ''' +<<<<<<< HEAD try: sys.stdout = StringIO() @@ -301,228 +279,216 @@ def clean_up(test_table_to_drop): class TestMySQL(tm.TestCase): - +======= def setUp(self): - _skip_if_no_MySQLdb() - import MySQLdb - try: - # Try Travis defaults. - # No real user should allow root access with a blank password. - self.db = MySQLdb.connect(host='localhost', user='root', passwd='', - db='pandas_nosetest') - except: - pass - else: - return - try: - self.db = MySQLdb.connect(read_default_group='pandas') - except MySQLdb.ProgrammingError as e: - raise nose.SkipTest( - "Create a group of connection parameters under the heading " - "[pandas] in your system's mysql default file, " - "typically located at ~/.my.cnf or /etc/.my.cnf. ") - except MySQLdb.Error as e: - raise nose.SkipTest( - "Cannot connect to database. " - "Create a group of connection parameters under the heading " - "[pandas] in your system's mysql default file, " - "typically located at ~/.my.cnf or /etc/.my.cnf. ") + self.conn = sqlite3.connect(':memory:') + self.pandasSQL = sql.PandasSQLWithCon(self.conn, 'sqlite') + + self._load_iris_data(self.conn) + + self.test_frame_time = tm.makeTimeDataFrame() + self._load_test1_data() + + def _load_iris_data(self, conn): + self.dirpath = tm.get_data_path() + iris_csv_file = os.path.join(self.dirpath, 'iris.csv') + cur = conn.cursor() + cur.execute("""CREATE TABLE iris ( + `SepalLength` REAL, + `SepalWidth` REAL, + `PetalLength` REAL, + `PetalWidth` REAL, + `Name` TEXT + )""") + + with open(iris_csv_file, 'rU') as iris_csv: + r = csv.reader(iris_csv) + next(r) # skip header row + ins = """ + INSERT INTO iris + VALUES(?, ?, ?, ?, ?) + """ + for row in r: + cur.execute(ins, row) + conn.commit() + + def _load_test1_data(self): + test1_csv_file = os.path.join(self.dirpath, 'test1.csv') + + with open(test1_csv_file, 'rU') as test1_csv: + dr = csv.DictReader(test1_csv) + self.test_frame1 = DataFrame(list(dr)) + + def test_read_sql(self): + iris_frame = sql.read_sql("SELECT * FROM iris", con=self.conn) + pytype = iris_frame.dtypes[0].type + row = iris_frame.iloc[0] + + self.assertTrue(issubclass(pytype, np.floating), 'Loaded frame has incorrect type') + tm.equalContents(row.values, [5.1, 3.5, 1.4, 0.2, 'Iris-setosa']) + + def test_to_sql(self): + # Nuke table + cur = self.conn.cursor() + cur.execute("DROP TABLE IF EXISTS test_frame1") + self.conn.commit() + + sql.to_sql(self.test_frame1, 'test_frame1', con=self.conn, flavor='sqlite') + self.assertTrue(self.pandasSQL.has_table('test_frame1'), 'Table not written to DB') - def test_basic(self): - _skip_if_no_MySQLdb() - frame = tm.makeTimeDataFrame() - self._check_roundtrip(frame) + # Nuke table + cur = self.conn.cursor() + cur.execute("DROP TABLE IF EXISTS test_frame1") + self.conn.commit() + + def test_to_sql_fail(self): + # Nuke table + cur = self.conn.cursor() + cur.execute("DROP TABLE IF EXISTS test_frame1") + self.conn.commit() + sql.to_sql(self.test_frame1, 'test_frame1', con=self.conn, if_exists='fail', flavor='sqlite') + self.assertTrue(self.pandasSQL.has_table('test_frame1'), 'Table not written to DB') + + self.assertRaises(ValueError, sql.to_sql, self.test_frame1, 'test_frame1', con=self.conn, if_exists='fail') + + # Nuke table + cur = self.conn.cursor() + cur.execute("DROP TABLE IF EXISTS test_frame1") + self.conn.commit() + + def test_to_sql_replace(self): + # Nuke table just in case + cur = self.conn.cursor() + cur.execute("DROP TABLE IF EXISTS test_frame1") + self.conn.commit() + sql.to_sql(self.test_frame1, 'test_frame1', con=self.conn, if_exists='fail', flavor='sqlite') + # Add to table again + sql.to_sql(self.test_frame1, 'test_frame1', con=self.conn, if_exists='replace') + self.assertTrue(self.pandasSQL.has_table('test_frame1'), 'Table not written to DB') + + num_entries = len(self.test_frame1) + + result = self.conn.execute("SELECT count(*) AS count_1 FROM test_frame1").fetchone() + num_rows = result[0] + + self.assertEqual(num_rows, num_entries, "not the same number of rows as entries") +>>>>>>> 1259dca... ENH #4163 Use SQLAlchemy for DB abstraction + + # Nuke table + cur = self.conn.cursor() + cur.execute("DROP TABLE IF EXISTS test_frame1") + self.conn.commit() + + def test_to_sql_append(self): + # Nuke table just in case + cur = self.conn.cursor() + cur.execute("DROP TABLE IF EXISTS test_frame1") + self.conn.commit() + + sql.to_sql(self.test_frame1, 'test_frame1', con=self.conn, if_exists='fail', flavor='sqlite') + + # Add to table again + sql.to_sql(self.test_frame1, 'test_frame1', con=self.conn, if_exists='append') + self.assertTrue(self.pandasSQL.has_table('test_frame1'), 'Table not written to DB') + + num_entries = 2*len(self.test_frame1) + result = self.conn.execute("SELECT count(*) AS count_1 FROM test_frame1").fetchone() + num_rows = result[0] - def test_write_row_by_row(self): - _skip_if_no_MySQLdb() - frame = tm.makeTimeDataFrame() - frame.ix[0, 0] = np.nan - drop_sql = "DROP TABLE IF EXISTS test" - create_sql = sql.get_schema(frame, 'test', 'mysql') - cur = self.db.cursor() - cur.execute(drop_sql) - cur.execute(create_sql) - ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" - for idx, row in frame.iterrows(): - fmt_sql = format_query(ins, *row) - sql.tquery(fmt_sql, cur=cur) - - self.db.commit() - - result = sql.read_frame("select * from test", con=self.db) - result.index = frame.index - tm.assert_frame_equal(result, frame) - - def test_execute(self): - _skip_if_no_MySQLdb() - frame = tm.makeTimeDataFrame() - drop_sql = "DROP TABLE IF EXISTS test" - create_sql = sql.get_schema(frame, 'test', 'mysql') - cur = self.db.cursor() - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "Unknown table.*") - cur.execute(drop_sql) - cur.execute(create_sql) - ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" - - row = frame.ix[0] - sql.execute(ins, self.db, params=tuple(row)) - self.db.commit() - - result = sql.read_frame("select * from test", self.db) - result.index = frame.index[:1] - tm.assert_frame_equal(result, frame[:1]) - - def test_schema(self): - _skip_if_no_MySQLdb() - frame = tm.makeTimeDataFrame() - create_sql = sql.get_schema(frame, 'test', 'mysql') - lines = create_sql.splitlines() - for l in lines: - tokens = l.split(' ') - if len(tokens) == 2 and tokens[0] == 'A': - self.assert_(tokens[1] == 'DATETIME') - - frame = tm.makeTimeDataFrame() - drop_sql = "DROP TABLE IF EXISTS test" - create_sql = sql.get_schema(frame, 'test', 'mysql', keys=['A', 'B'],) - lines = create_sql.splitlines() - self.assert_('PRIMARY KEY (A,B)' in create_sql) - cur = self.db.cursor() - cur.execute(drop_sql) - cur.execute(create_sql) - - def test_execute_fail(self): - _skip_if_no_MySQLdb() - drop_sql = "DROP TABLE IF EXISTS test" - create_sql = """ - CREATE TABLE test - ( - a TEXT, - b TEXT, - c REAL, - PRIMARY KEY (a(5), b(5)) - ); - """ - cur = self.db.cursor() - cur.execute(drop_sql) - cur.execute(create_sql) - - sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) - sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.db) + self.assertEqual(num_rows, num_entries, "not the same number of rows as entries") - try: - sys.stdout = StringIO() - self.assertRaises(Exception, sql.execute, - 'INSERT INTO test VALUES("foo", "bar", 7)', - self.db) - finally: - sys.stdout = sys.__stdout__ + # Nuke table + cur = self.conn.cursor() + cur.execute("DROP TABLE IF EXISTS test_frame1") + self.conn.commit() - def test_execute_closed_connection(self): - _skip_if_no_MySQLdb() - drop_sql = "DROP TABLE IF EXISTS test" - create_sql = """ - CREATE TABLE test - ( - a TEXT, - b TEXT, - c REAL, - PRIMARY KEY (a(5), b(5)) - ); - """ - cur = self.db.cursor() - cur.execute(drop_sql) - cur.execute(create_sql) - - sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) - self.db.close() - try: - sys.stdout = StringIO() - self.assertRaises(Exception, sql.tquery, "select * from test", - con=self.db) - finally: - sys.stdout = sys.__stdout__ + def test_create_table(self): + temp_conn = sqlite3.connect(':memory:') + temp_frame = DataFrame({'one': [1., 2., 3., 4.], 'two': [4., 3., 2., 1.]}) - def test_na_roundtrip(self): - _skip_if_no_MySQLdb() - pass + pandasSQL = sql.PandasSQLWithCon(temp_conn, 'sqlite') + pandasSQL._create_table(temp_frame, 'temp_frame') - def _check_roundtrip(self, frame): - _skip_if_no_MySQLdb() - drop_sql = "DROP TABLE IF EXISTS test_table" - cur = self.db.cursor() - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "Unknown table.*") - cur.execute(drop_sql) - sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') - result = sql.read_frame("select * from test_table", self.db) - - # HACK! Change this once indexes are handled properly. - result.index = frame.index - result.index.name = frame.index.name - - expected = frame - tm.assert_frame_equal(result, expected) - - frame['txt'] = ['a'] * len(frame) - frame2 = frame.copy() - index = Index(lrange(len(frame2))) + 10 - frame2['Idx'] = index - drop_sql = "DROP TABLE IF EXISTS test_table2" - cur = self.db.cursor() - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "Unknown table.*") - cur.execute(drop_sql) - sql.write_frame(frame2, name='test_table2', con=self.db, flavor='mysql') - result = sql.read_frame("select * from test_table2", self.db, - index_col='Idx') - expected = frame.copy() - - # HACK! Change this once indexes are handled properly. - expected.index = index - expected.index.names = result.index.names - tm.assert_frame_equal(expected, result) + self.assertTrue(pandasSQL.has_table('temp_frame'), 'Table not written to DB') + + def test_drop_table(self): + temp_conn = sqlite3.connect(':memory:') + + temp_frame = DataFrame({'one': [1., 2., 3., 4.], 'two': [4., 3., 2., 1.]}) + + pandasSQL = sql.PandasSQLWithCon(temp_conn, 'sqlite') + pandasSQL._create_table(temp_frame, 'temp_frame') + + self.assertTrue(pandasSQL.has_table('temp_frame'), 'Table not written to DB') + + pandasSQL._drop_table('temp_frame') + + self.assertFalse(pandasSQL.has_table('temp_frame'), 'Table not deleted from DB') + + def test_roundtrip(self): + + sql.to_sql(self.test_frame1, 'test_frame_roundtrip', con=self.conn, flavor='sqlite') + result = sql.read_sql('SELECT * FROM test_frame_roundtrip', con=self.conn, flavor='sqlite') + + # HACK! + result.index = self.test_frame1.index + + tm.assert_frame_equal(result, self.test_frame1) + + def test_execute_sql(self): + # drop_sql = "DROP TABLE IF EXISTS test" # should already be done + iris_results = sql.execute("SELECT * FROM iris", con=self.conn, flavor='sqlite') + row = iris_results.fetchone() + tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, 'Iris-setosa']) def test_tquery(self): + iris_results = sql.tquery("SELECT * FROM iris", con=self.conn, flavor='sqlite') + row = iris_results[0] + tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, 'Iris-setosa']) + + + + +""" +class TestSQLA_pymysql(TestSQLAlchemy): + def setUp(self): + raise nose.SkipTest("MySQLdb was not installed") + + def set_flavor_engine(self): + # if can't import should skip all tests try: - import MySQLdb + import pymysql except ImportError: - raise nose.SkipTest("no MySQLdb") - frame = tm.makeTimeDataFrame() - drop_sql = "DROP TABLE IF EXISTS test_table" - cur = self.db.cursor() - cur.execute(drop_sql) - sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') - result = sql.tquery("select A from test_table", self.db) - expected = frame.A - result = Series(result, frame.index) - tm.assert_series_equal(result, expected) + raise nose.SkipTest("pymysql was not installed") try: - sys.stdout = StringIO() - self.assertRaises(MySQLdb.ProgrammingError, sql.tquery, - 'select * from blah', con=self.db) + self.engine = sqlalchemy.create_engine("mysql+pymysql://root:@localhost/pandas_nosetest") + except pymysql.Error as e: + raise nose.SkipTest( + "Cannot connect to database. " + "Create a group of conn parameters under the heading " + "[pandas] in your system's mysql default file, " + "typically located at ~/.my.cnf or /etc/.my.cnf. ") + except pymysql.ProgrammingError as e: + raise nose.SkipTest( + "Create a group of connection parameters under the heading " + "[pandas] in your system's mysql default file, " + "typically located at ~/.my.cnf or /etc/.my.cnf. ") - self.assertRaises(MySQLdb.ProgrammingError, sql.tquery, - 'select * from blah', con=self.db, retry=True) - finally: - sys.stdout = sys.__stdout__ - def test_uquery(self): +class TestSQLA_MySQLdb(TestSQLAlchemy): + def setUp(self): + raise nose.SkipTest("MySQLdb was not installed") + + def set_flavor_engine(self): + # if can't import should skip all tests try: import MySQLdb except ImportError: - raise nose.SkipTest("no MySQLdb") - frame = tm.makeTimeDataFrame() - drop_sql = "DROP TABLE IF EXISTS test_table" - cur = self.db.cursor() - cur.execute(drop_sql) - sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') - stmt = 'INSERT INTO test_table VALUES(2.314, -123.1, 1.234, 2.3)' - self.assertEqual(sql.uquery(stmt, con=self.db), 1) + raise nose.SkipTest("MySQLdb was not installed") try: +<<<<<<< HEAD sys.stdout = StringIO() self.assertRaises(MySQLdb.ProgrammingError, sql.tquery, @@ -606,3 +572,18 @@ def clean_up(test_table_to_drop): if __name__ == '__main__': nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'], exit=False) +======= + self.engine = sqlalchemy.create_engine("mysql+mysqldb://root:@localhost/pandas_nosetest") + except MySQLdb.Error: + raise nose.SkipTest( + "Cannot connect to database. " + "Create a group of connection parameters under the heading " + "[pandas] in your system's mysql default file, " + "typically located at ~/.my.cnf or /etc/.my.cnf. ") + except MySQLdb.ProgrammingError: + raise nose.SkipTest( + "Create a group of connection parameters under the heading " + "[pandas] in your system's mysql default file, " + "typically located at ~/.my.cnf or /etc/.my.cnf. ") +""" +>>>>>>> 1259dca... ENH #4163 Use SQLAlchemy for DB abstraction diff --git a/pandas/io/tests/test_sql_legacy.py b/pandas/io/tests/test_sql_legacy.py new file mode 100644 index 0000000000000..3c6e992097d30 --- /dev/null +++ b/pandas/io/tests/test_sql_legacy.py @@ -0,0 +1,497 @@ +from __future__ import with_statement +from pandas.compat import StringIO +import unittest +import sqlite3 +import sys + +import warnings + +import nose + +import numpy as np + +from pandas.core.datetools import format as date_format +from pandas.core.api import DataFrame, isnull +from pandas.compat import StringIO, range, lrange +import pandas.compat as compat + +import pandas.io.sql as sql +from pandas.io.sql import DatabaseError +import pandas.util.testing as tm +from pandas import Series, Index, DataFrame +from datetime import datetime + +_formatters = { + datetime: lambda dt: "'%s'" % date_format(dt), + str: lambda x: "'%s'" % x, + np.str_: lambda x: "'%s'" % x, + compat.text_type: lambda x: "'%s'" % x, + compat.binary_type: lambda x: "'%s'" % x, + float: lambda x: "%.8f" % x, + int: lambda x: "%s" % x, + type(None): lambda x: "NULL", + np.float64: lambda x: "%.10f" % x, + bool: lambda x: "'%s'" % x, +} + +def format_query(sql, *args): + """ + + """ + processed_args = [] + for arg in args: + if isinstance(arg, float) and isnull(arg): + arg = None + + formatter = _formatters[type(arg)] + processed_args.append(formatter(arg)) + + return sql % tuple(processed_args) + +def _skip_if_no_MySQLdb(): + try: + import MySQLdb + except ImportError: + raise nose.SkipTest('MySQLdb not installed, skipping') + +class TestSQLite(unittest.TestCase): + + def setUp(self): + self.db = sqlite3.connect(':memory:') + + def test_basic(self): + frame = tm.makeTimeDataFrame() + self._check_roundtrip(frame) + + def test_write_row_by_row(self): + frame = tm.makeTimeDataFrame() + frame.ix[0, 0] = np.nan + create_sql = sql.get_schema(frame, 'test', 'sqlite') + cur = self.db.cursor() + cur.execute(create_sql) + + cur = self.db.cursor() + + ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" + for idx, row in frame.iterrows(): + fmt_sql = format_query(ins, *row) + sql.tquery(fmt_sql, cur=cur) + + self.db.commit() + + result = sql.read_frame("select * from test", con=self.db) + result.index = frame.index + tm.assert_frame_equal(result, frame) + + def test_execute(self): + frame = tm.makeTimeDataFrame() + create_sql = sql.get_schema(frame, 'test', 'sqlite') + cur = self.db.cursor() + cur.execute(create_sql) + ins = "INSERT INTO test VALUES (?, ?, ?, ?)" + + row = frame.ix[0] + sql.execute(ins, self.db, params=tuple(row)) + self.db.commit() + + result = sql.read_frame("select * from test", self.db) + result.index = frame.index[:1] + tm.assert_frame_equal(result, frame[:1]) + + def test_schema(self): + frame = tm.makeTimeDataFrame() + create_sql = sql.get_schema(frame, 'test', 'sqlite') + lines = create_sql.splitlines() + for l in lines: + tokens = l.split(' ') + if len(tokens) == 2 and tokens[0] == 'A': + self.assert_(tokens[1] == 'DATETIME') + + frame = tm.makeTimeDataFrame() + create_sql = sql.get_schema(frame, 'test', 'sqlite', keys=['A', 'B'],) + lines = create_sql.splitlines() + self.assert_('PRIMARY KEY (A,B)' in create_sql) + cur = self.db.cursor() + cur.execute(create_sql) + + def test_execute_fail(self): + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a, b) + ); + """ + cur = self.db.cursor() + cur.execute(create_sql) + + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) + sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.db) + + try: + sys.stdout = StringIO() + self.assertRaises(Exception, sql.execute, + 'INSERT INTO test VALUES("foo", "bar", 7)', + self.db) + finally: + sys.stdout = sys.__stdout__ + + def test_execute_closed_connection(self): + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a, b) + ); + """ + cur = self.db.cursor() + cur.execute(create_sql) + + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) + self.db.close() + try: + sys.stdout = StringIO() + self.assertRaises(Exception, sql.tquery, "select * from test", + con=self.db) + finally: + sys.stdout = sys.__stdout__ + + def test_na_roundtrip(self): + pass + + def _check_roundtrip(self, frame): + sql.write_frame(frame, name='test_table', con=self.db) + result = sql.read_frame("select * from test_table", self.db) + + # HACK! Change this once indexes are handled properly. + result.index = frame.index + + expected = frame + tm.assert_frame_equal(result, expected) + + frame['txt'] = ['a'] * len(frame) + frame2 = frame.copy() + frame2['Idx'] = Index(lrange(len(frame2))) + 10 + sql.write_frame(frame2, name='test_table2', con=self.db) + result = sql.read_frame("select * from test_table2", self.db, + index_col='Idx') + expected = frame.copy() + expected.index = Index(lrange(len(frame2))) + 10 + expected.index.name = 'Idx' + print(expected.index.names) + print(result.index.names) + tm.assert_frame_equal(expected, result) + + def test_tquery(self): + frame = tm.makeTimeDataFrame() + sql.write_frame(frame, name='test_table', con=self.db) + result = sql.tquery("select A from test_table", self.db) + expected = frame.A + result = Series(result, frame.index) + tm.assert_series_equal(result, expected) + + try: + sys.stdout = StringIO() + self.assertRaises(DatabaseError, sql.tquery, + 'select * from blah', con=self.db) + + self.assertRaises(DatabaseError, sql.tquery, + 'select * from blah', con=self.db, retry=True) + finally: + sys.stdout = sys.__stdout__ + + def test_uquery(self): + frame = tm.makeTimeDataFrame() + sql.write_frame(frame, name='test_table', con=self.db) + stmt = 'INSERT INTO test_table VALUES(2.314, -123.1, 1.234, 2.3)' + self.assertEqual(sql.uquery(stmt, con=self.db), 1) + + try: + sys.stdout = StringIO() + + self.assertRaises(DatabaseError, sql.tquery, + 'insert into blah values (1)', con=self.db) + + self.assertRaises(DatabaseError, sql.tquery, + 'insert into blah values (1)', con=self.db, + retry=True) + finally: + sys.stdout = sys.__stdout__ + + def test_keyword_as_column_names(self): + ''' + ''' + df = DataFrame({'From':np.ones(5)}) + sql.write_frame(df, con = self.db, name = 'testkeywords') + + def test_onecolumn_of_integer(self): + ''' + GH 3628 + a column_of_integers dataframe should transfer well to sql + ''' + mono_df=DataFrame([1 , 2], columns=['c0']) + sql.write_frame(mono_df, con = self.db, name = 'mono_df') + # computing the sum via sql + con_x=self.db + the_sum=sum([my_c0[0] for my_c0 in con_x.execute("select * from mono_df")]) + # it should not fail, and gives 3 ( Issue #3628 ) + self.assertEqual(the_sum , 3) + + result = sql.read_frame("select * from mono_df",con_x) + tm.assert_frame_equal(result,mono_df) + + +class TestMySQL(unittest.TestCase): + + def setUp(self): + _skip_if_no_MySQLdb() + import MySQLdb + try: + # Try Travis defaults. + # No real user should allow root access with a blank password. + self.db = MySQLdb.connect(host='localhost', user='root', passwd='', + db='pandas_nosetest') + except: + pass + else: + return + try: + self.db = MySQLdb.connect(read_default_group='pandas') + except MySQLdb.ProgrammingError as e: + raise nose.SkipTest( + "Create a group of connection parameters under the heading " + "[pandas] in your system's mysql default file, " + "typically located at ~/.my.cnf or /etc/.my.cnf. ") + except MySQLdb.Error as e: + raise nose.SkipTest( + "Cannot connect to database. " + "Create a group of connection parameters under the heading " + "[pandas] in your system's mysql default file, " + "typically located at ~/.my.cnf or /etc/.my.cnf. ") + + def test_basic(self): + _skip_if_no_MySQLdb() + frame = tm.makeTimeDataFrame() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "For more robust support.*") + self._check_roundtrip(frame) + + def test_write_row_by_row(self): + _skip_if_no_MySQLdb() + frame = tm.makeTimeDataFrame() + frame.ix[0, 0] = np.nan + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = sql.get_schema(frame, 'test', 'mysql') + cur = self.db.cursor() + cur.execute(drop_sql) + cur.execute(create_sql) + ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" + for idx, row in frame.iterrows(): + fmt_sql = format_query(ins, *row) + sql.tquery(fmt_sql, cur=cur) + + self.db.commit() + + result = sql.read_frame("select * from test", con=self.db) + result.index = frame.index + tm.assert_frame_equal(result, frame) + + def test_execute(self): + _skip_if_no_MySQLdb() + frame = tm.makeTimeDataFrame() + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = sql.get_schema(frame, 'test', 'mysql') + cur = self.db.cursor() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unknown table.*") + cur.execute(drop_sql) + cur.execute(create_sql) + ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" + + row = frame.ix[0] + sql.execute(ins, self.db, params=tuple(row)) + self.db.commit() + + result = sql.read_frame("select * from test", self.db) + result.index = frame.index[:1] + tm.assert_frame_equal(result, frame[:1]) + + def test_schema(self): + _skip_if_no_MySQLdb() + frame = tm.makeTimeDataFrame() + create_sql = sql.get_schema(frame, 'test', 'mysql') + lines = create_sql.splitlines() + for l in lines: + tokens = l.split(' ') + if len(tokens) == 2 and tokens[0] == 'A': + self.assert_(tokens[1] == 'DATETIME') + + frame = tm.makeTimeDataFrame() + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = sql.get_schema(frame, 'test', 'mysql', keys=['A', 'B'],) + lines = create_sql.splitlines() + self.assert_('PRIMARY KEY (A,B)' in create_sql) + cur = self.db.cursor() + cur.execute(drop_sql) + cur.execute(create_sql) + + def test_execute_fail(self): + _skip_if_no_MySQLdb() + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a(5), b(5)) + ); + """ + cur = self.db.cursor() + cur.execute(drop_sql) + cur.execute(create_sql) + + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) + sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.db) + + try: + sys.stdout = StringIO() + self.assertRaises(Exception, sql.execute, + 'INSERT INTO test VALUES("foo", "bar", 7)', + self.db) + finally: + sys.stdout = sys.__stdout__ + + def test_execute_closed_connection(self): + _skip_if_no_MySQLdb() + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a(5), b(5)) + ); + """ + cur = self.db.cursor() + cur.execute(drop_sql) + cur.execute(create_sql) + + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) + self.db.close() + try: + sys.stdout = StringIO() + self.assertRaises(Exception, sql.tquery, "select * from test", + con=self.db) + finally: + sys.stdout = sys.__stdout__ + + def test_na_roundtrip(self): + _skip_if_no_MySQLdb() + pass + + def _check_roundtrip(self, frame): + _skip_if_no_MySQLdb() + drop_sql = "DROP TABLE IF EXISTS test_table" + cur = self.db.cursor() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unknown table.*") + cur.execute(drop_sql) + sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') + result = sql.read_frame("select * from test_table", self.db) + + # HACK! Change this once indexes are handled properly. + result.index = frame.index + result.index.name = frame.index.name + + expected = frame + tm.assert_frame_equal(result, expected) + + frame['txt'] = ['a'] * len(frame) + frame2 = frame.copy() + index = Index(lrange(len(frame2))) + 10 + frame2['Idx'] = index + drop_sql = "DROP TABLE IF EXISTS test_table2" + cur = self.db.cursor() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unknown table.*") + cur.execute(drop_sql) + sql.write_frame(frame2, name='test_table2', con=self.db, flavor='mysql') + result = sql.read_frame("select * from test_table2", self.db, + index_col='Idx') + expected = frame.copy() + + # HACK! Change this once indexes are handled properly. + expected.index = index + expected.index.names = result.index.names + tm.assert_frame_equal(expected, result) + + def test_tquery(self): + try: + import MySQLdb + except ImportError: + raise nose.SkipTest + frame = tm.makeTimeDataFrame() + drop_sql = "DROP TABLE IF EXISTS test_table" + cur = self.db.cursor() + cur.execute(drop_sql) + sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') + result = sql.tquery("select A from test_table", self.db) + expected = frame.A + result = Series(result, frame.index) + tm.assert_series_equal(result, expected) + + try: + sys.stdout = StringIO() + self.assertRaises(DatabaseError, sql.tquery, + 'select * from blah', con=self.db) + + self.assertRaises(DatabaseError, sql.tquery, + 'select * from blah', con=self.db, retry=True) + finally: + sys.stdout = sys.__stdout__ + + def test_uquery(self): + try: + import MySQLdb + except ImportError: + raise nose.SkipTest + frame = tm.makeTimeDataFrame() + drop_sql = "DROP TABLE IF EXISTS test_table" + cur = self.db.cursor() + cur.execute(drop_sql) + sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') + stmt = 'INSERT INTO test_table VALUES(2.314, -123.1, 1.234, 2.3)' + self.assertEqual(sql.uquery(stmt, con=self.db), 1) + + try: + sys.stdout = StringIO() + + self.assertRaises(DatabaseError, sql.tquery, + 'insert into blah values (1)', con=self.db) + + self.assertRaises(DatabaseError, sql.tquery, + 'insert into blah values (1)', con=self.db, + retry=True) + finally: + sys.stdout = sys.__stdout__ + + def test_keyword_as_column_names(self): + ''' + ''' + _skip_if_no_MySQLdb() + df = DataFrame({'From':np.ones(5)}) + sql.write_frame(df, name='testkeywords', con=self.db, + if_exists='replace', flavor='mysql') + +if __name__ == '__main__': + # unittest.main() + # nose.runmodule(argv=[__file__,'-vvs','-x', '--pdb-failure'], + # exit=False) + nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'], + exit=False)