diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 68e4bdfcc0a36..0ae82bec38e26 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -588,7 +588,7 @@ def pandasSQL_builder(con, flavor=None, schema=None, meta=None, class SQLTable(PandasObject): """ For mapping Pandas tables to SQL tables. - Uses fact that table is reflected by SQLAlchemy to !!!! + Uses fact that table is reflected by SQLAlchemy to do better type convertions. Also holds various flags needed to avoid having to pass them between functions all the time. @@ -608,7 +608,7 @@ def __init__(self, name, pandas_sql_engine, frame=None, index=True, if frame is not None: # We want to initialize based on a dataframe - self.table = self.pd_sql._create_table_setup(self) + self.table = self._create_table_setup() else: # no data provided, read-only mode self.table = self.pd_sql.get_table(self.name, self.schema) @@ -619,20 +619,32 @@ def __init__(self, name, pandas_sql_engine, frame=None, index=True, def exists(self): return self.pd_sql.has_table(self.name, self.schema) + def sql_schema(self): + from sqlalchemy.schema import CreateTable + return str(CreateTable(self.table)) + + def _execute_create(self): + # Inserting table into database, add to MetaData object + self.table = self.table.tometadata(self.pd_sql.meta) + self.table.create() + def create(self): if self.exists(): if self.if_exists == 'fail': raise ValueError("Table '%s' already exists." % self.name) elif self.if_exists == 'replace': self.pd_sql.drop_table(self.name, self.schema) - self.pd_sql._execute_create(self) + self._execute_create() elif self.if_exists == 'append': pass else: raise ValueError( "'{0}' is not valid for if_exists".format(self.if_exists)) else: - self.pd_sql._execute_create(self) + self._execute_create() + + def insert_statement(self): + return self.table.insert() def insert_data(self): if self.index is not None: @@ -670,6 +682,10 @@ def insert_data(self): return column_names, data_list + def _execute_insert(self, conn, keys, data_iter): + data = [dict((k, v) for k, v in zip(keys, row)) for row in data_iter] + conn.execute(self.insert_statement(), data) + def insert(self, chunksize=None): keys, data_list = self.insert_data() @@ -693,7 +709,7 @@ def insert(self, chunksize=None): break chunk_iter = zip(*[arr[start_i:end_i] for arr in data_list]) - self.pd_sql._execute_insert(conn, self, keys, chunk_iter) + self._execute_insert(conn, keys, chunk_iter) def _query_iterator(self, result, chunksize, columns, coerce_float=True, parse_dates=None): @@ -792,6 +808,28 @@ def _get_column_names_and_types(self, dtype_mapper): return column_names_and_types + def _create_table_setup(self): + from sqlalchemy import Table, Column, PrimaryKeyConstraint + + column_names_and_types = \ + self._get_column_names_and_types(self._sqlalchemy_type) + + columns = [Column(name, typ, index=is_index) + for name, typ, is_index in column_names_and_types] + + if self.keys is not None: + pkc = PrimaryKeyConstraint(self.keys, name=self.name + '_pk') + columns.append(pkc) + + schema = self.schema or self.pd_sql.meta.schema + + # At this point, attach to new metadata, only attach to self.meta + # once table is created. + from sqlalchemy.schema import MetaData + meta = MetaData(self.pd_sql, schema=schema) + + return Table(self.name, meta, *columns, schema=schema) + def _harmonize_columns(self, parse_dates=None): """ Make the DataFrame's column types align with the SQL table @@ -843,6 +881,35 @@ def _harmonize_columns(self, parse_dates=None): except KeyError: pass # this column not in results + def _sqlalchemy_type(self, col): + from sqlalchemy.types import (BigInteger, Float, Text, Boolean, + DateTime, Date, Time) + + if com.is_datetime64_dtype(col): + try: + tz = col.tzinfo + return DateTime(timezone=True) + except: + return DateTime + if com.is_timedelta64_dtype(col): + warnings.warn("the 'timedelta' type is not supported, and will be " + "written as integer values (ns frequency) to the " + "database.", UserWarning) + return BigInteger + elif com.is_float_dtype(col): + return Float + elif com.is_integer_dtype(col): + # TODO: Refine integer size. + return BigInteger + elif com.is_bool_dtype(col): + return Boolean + inferred = lib.infer_dtype(com._ensure_object(col)) + if inferred == 'date': + return Date + if inferred == 'time': + return Time + return Text + def _numpy_type(self, sqltype): from sqlalchemy.types import Integer, Float, Boolean, DateTime, Date @@ -898,12 +965,17 @@ class SQLDatabase(PandasSQL): def __init__(self, engine, schema=None, meta=None): self.engine = engine if not meta: - meta = self.get_meta(schema) + from sqlalchemy.schema import MetaData + meta = MetaData(self.engine, schema=schema) self.meta = meta - def get_meta_schema(self): - return self.meta.schema + def run_transaction(self): + return self.engine.begin() + + def execute(self, *args, **kwargs): + """Simple passthrough to SQLAlchemy engine""" + return self.engine.execute(*args, **kwargs) def read_table(self, table_name, index_col=None, coerce_float=True, parse_dates=None, columns=None, schema=None, @@ -954,29 +1026,20 @@ def read_table(self, table_name, index_col=None, coerce_float=True, parse_dates=parse_dates, columns=columns, chunksize=chunksize) - def _query_iterator(self, result, chunksize, columns, index_col=None, + @staticmethod + def _query_iterator(result, chunksize, columns, index_col=None, coerce_float=True, parse_dates=None): """Return generator through chunked result set""" while True: data = result.fetchmany(chunksize) if not data: - self._close_result(result) break else: yield _wrap_result(data, columns, index_col=index_col, coerce_float=coerce_float, parse_dates=parse_dates) - def _get_result_columns(self, result): - return result.keys() - - def _fetchall_as_list(self, result): - return result.fetchall() - - def _close_result(self, result): - pass - def read_query(self, sql, index_col=None, coerce_float=True, parse_dates=None, params=None, chunksize=None): """Read SQL query into a DataFrame. @@ -1019,7 +1082,7 @@ def read_query(self, sql, index_col=None, coerce_float=True, args = _convert_params(sql, params) result = self.execute(*args) - columns = self._get_result_columns(result) + columns = result.keys() if chunksize is not None: return self._query_iterator(result, chunksize, columns, @@ -1027,9 +1090,7 @@ def read_query(self, sql, index_col=None, coerce_float=True, coerce_float=coerce_float, parse_dates=parse_dates) else: - data = self._fetchall_as_list(result) - self._close_result(result) - + data = result.fetchall() frame = _wrap_result(data, columns, index_col=index_col, coerce_float=coerce_float, parse_dates=parse_dates) @@ -1060,7 +1121,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True, schema : string, default None Name of SQL schema in database to write to (if database flavor supports this). If specified, this overwrites the default - schema of the SQLDatabase object. !!! + schema of the SQLDatabase object. chunksize : int, default None If not None, then rows will be written in batches of this size at a time. If None, all rows will be written at once. @@ -1072,7 +1133,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True, table.create() table.insert(chunksize) # check for potentially case sensitivity issues (GH7815) - if not self.has_table(name, schema=schema or self.get_meta_schema()): + if name not in self.engine.table_names(schema=schema or self.meta.schema): warnings.warn("The provided table name '{0}' is not found exactly " "as such in the database after writing the table, " "possibly due to case sensitivity issues. Consider " @@ -1083,17 +1144,17 @@ def tables(self): return self.meta.tables def has_table(self, name, schema=None): - return self.engine.has_table(name, schema or self.get_meta_schema()) + return self.engine.has_table(name, schema or self.meta.schema) def get_table(self, table_name, schema=None): - schema = schema or self.get_meta_schema() + schema = schema or self.meta.schema if schema: return self.meta.tables.get('.'.join([schema, table_name])) else: return self.meta.tables.get(table_name) def drop_table(self, table_name, schema=None): - schema = schema or self.get_meta_schema() + schema = schema or self.meta.schema if self.engine.has_table(table_name, schema): self.meta.reflect(only=[table_name], schema=schema) self.get_table(table_name, schema).drop() @@ -1101,86 +1162,7 @@ def drop_table(self, table_name, schema=None): def _create_sql_schema(self, frame, table_name, keys=None): table = SQLTable(table_name, self, frame=frame, index=False, keys=keys) - return str(self.sql_schema(table)) - - def get_meta(self, schema): - from sqlalchemy.schema import MetaData - return MetaData(self.engine, schema=schema) - - def execute(self, *args, **kwargs): - """Simple passthrough to SQLAlchemy engine""" - return self.engine.execute(*args, **kwargs) - - def run_transaction(self): - return self.engine.begin() - - def _create_table_setup(self, table): - from sqlalchemy import Table, Column, PrimaryKeyConstraint - - column_names_and_types = \ - table._get_column_names_and_types(self._sqlalchemy_type) - - columns = [Column(name, typ, index=is_index) - for name, typ, is_index in column_names_and_types] - - if table.keys is not None: - pkc = PrimaryKeyConstraint(table.keys, name=self.name + '_pk') - columns.append(pkc) - - schema = table.schema or self.get_meta_schema() - - # At this point, attach to new metadata, only attach to self.meta - # once table is created. - from sqlalchemy.schema import MetaData - meta = MetaData(self.engine, schema=schema) - - return Table(table.name, meta, *columns, schema=schema) - - def sql_schema(self, table): - from sqlalchemy.schema import CreateTable - return str(CreateTable(table.table)) - - def insert_statement(self, table): - return table.table.insert() - - def _execute_insert(self, conn, table, keys, data_iter): - data = [dict((k, v) for k, v in zip(keys, row)) for row in data_iter] - conn.execute(self.insert_statement(table), data) - - def _execute_create(self, table): - # Inserting table into database, add to MetaData object - table.table = table.table.tometadata(self.meta) - table.table.create() - - def _sqlalchemy_type(self, col): - from sqlalchemy.types import (BigInteger, Float, Text, Boolean, - DateTime, Date, Time) - - if com.is_datetime64_dtype(col): - try: - tz = col.tzinfo - return DateTime(timezone=True) - except: - return DateTime - if com.is_timedelta64_dtype(col): - warnings.warn("the 'timedelta' type is not supported, and will be " - "written as integer values (ns frequency) to the " - "database.", UserWarning) - return BigInteger - elif com.is_float_dtype(col): - return Float - elif com.is_integer_dtype(col): - # TODO: Refine integer size. - return BigInteger - elif com.is_bool_dtype(col): - return Boolean - inferred = lib.infer_dtype(com._ensure_object(col)) - if inferred == 'date': - return Date - if inferred == 'time': - return Time - return Text - + return str(table.sql_schema()) # ---- SQL without SQLAlchemy --- @@ -1238,8 +1220,110 @@ def _sqlalchemy_type(self, col): "underscores.") +class SQLiteTable(SQLTable): + """ + Patch the SQLTable for fallback support. + Instead of a table variable just use the Create Table statement. + """ + + def sql_schema(self): + return str(";\n".join(self.table)) + + def _execute_create(self): + with self.pd_sql.run_transaction() as conn: + for stmt in self.table: + conn.execute(stmt) + + def insert_statement(self): + names = list(map(str, self.frame.columns)) + flv = self.pd_sql.flavor + br_l = _SQL_SYMB[flv]['br_l'] # left val quote char + br_r = _SQL_SYMB[flv]['br_r'] # right val quote char + wld = _SQL_SYMB[flv]['wld'] # wildcard char + + if self.index is not None: + [names.insert(0, idx) for idx in self.index[::-1]] + + bracketed_names = [br_l + column + br_r for column in names] + col_names = ','.join(bracketed_names) + wildcards = ','.join([wld] * len(names)) + insert_statement = 'INSERT INTO %s (%s) VALUES (%s)' % ( + self.name, col_names, wildcards) + return insert_statement + + def _execute_insert(self, conn, keys, data_iter): + data_list = list(data_iter) + conn.executemany(self.insert_statement(), data_list) + + def _create_table_setup(self): + """ + Return a list of SQL statement that create a table reflecting the + structure of a DataFrame. The first entry will be a CREATE TABLE + statement while the rest will be CREATE INDEX statements + """ + column_names_and_types = \ + self._get_column_names_and_types(self._sql_type_name) + + pat = re.compile('\s+') + column_names = [col_name for col_name, _, _ in column_names_and_types] + if any(map(pat.search, column_names)): + warnings.warn(_SAFE_NAMES_WARNING) + + flv = self.pd_sql.flavor + + br_l = _SQL_SYMB[flv]['br_l'] # left val quote char + br_r = _SQL_SYMB[flv]['br_r'] # right val quote char -class SQLiteDatabase(SQLDatabase): + create_tbl_stmts = [(br_l + '%s' + br_r + ' %s') % (cname, ctype) + for cname, ctype, _ in column_names_and_types] + if self.keys is not None and len(self.keys): + cnames_br = ",".join([br_l + c + br_r for c in self.keys]) + create_tbl_stmts.append( + "CONSTRAINT {tbl}_pk PRIMARY KEY ({cnames_br})".format( + tbl=self.name, cnames_br=cnames_br)) + + create_stmts = ["CREATE TABLE " + self.name + " (\n" + + ',\n '.join(create_tbl_stmts) + "\n)"] + + ix_cols = [cname for cname, _, is_index in column_names_and_types + if is_index] + if len(ix_cols): + cnames = "_".join(ix_cols) + cnames_br = ",".join([br_l + c + br_r for c in ix_cols]) + create_stmts.append( + "CREATE INDEX ix_{tbl}_{cnames} ON {tbl} ({cnames_br})".format( + tbl=self.name, cnames=cnames, cnames_br=cnames_br)) + + return create_stmts + + def _sql_type_name(self, col): + pytype = col.dtype.type + pytype_name = "text" + if issubclass(pytype, np.floating): + pytype_name = "float" + elif com.is_timedelta64_dtype(pytype): + warnings.warn("the 'timedelta' type is not supported, and will be " + "written as integer values (ns frequency) to the " + "database.", UserWarning) + pytype_name = "int" + elif issubclass(pytype, np.integer): + pytype_name = "int" + elif issubclass(pytype, np.datetime64) or pytype is datetime: + # Caution: np.datetime64 is also a subclass of np.number. + pytype_name = "datetime" + elif issubclass(pytype, np.bool_): + pytype_name = "bool" + elif issubclass(pytype, np.object): + pytype = lib.infer_dtype(com._ensure_object(col)) + if pytype == "date": + pytype_name = "date" + elif pytype == "time": + pytype_name = "time" + + return _SQL_TYPES[pytype_name][self.pd_sql.flavor] + + +class SQLiteDatabase(PandasSQL): """ Version of SQLDatabase to support sqlite connections (fallback without sqlalchemy). This should only be used internally. @@ -1263,10 +1347,6 @@ def __init__(self, con, flavor, is_cursor=False): else: self.flavor = flavor - def get_meta_schema(self): - # !!! - return None - @contextmanager def run_transaction(self): cur = self.con.cursor() @@ -1301,17 +1381,79 @@ def execute(self, *args, **kwargs): ex = DatabaseError("Execution failed on sql '%s': %s" % (args[0], exc)) raise_with_traceback(ex) - def _get_result_columns(self, result): - return [col_desc[0] for col_desc in result.description] - + @staticmethod + def _query_iterator(cursor, chunksize, columns, index_col=None, + coerce_float=True, parse_dates=None): + """Return generator through chunked result set""" + + while True: + data = cursor.fetchmany(chunksize) + if not data: + cursor.close() + break + else: + yield _wrap_result(data, columns, index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates) + + def read_query(self, sql, index_col=None, coerce_float=True, params=None, + parse_dates=None, chunksize=None): + + args = _convert_params(sql, params) + cursor = self.execute(*args) + columns = [col_desc[0] for col_desc in cursor.description] + + if chunksize is not None: + return self._query_iterator(cursor, chunksize, columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates) + else: + data = self._fetchall_as_list(cursor) + cursor.close() + + frame = _wrap_result(data, columns, index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates) + return frame + def _fetchall_as_list(self, cur): result = cur.fetchall() if not isinstance(result, list): result = list(result) return result - def _close_result(self, cur): - cur.close() + def to_sql(self, frame, name, if_exists='fail', index=True, + index_label=None, schema=None, chunksize=None): + """ + Write records stored in a DataFrame to a SQL database. + + Parameters + ---------- + frame: DataFrame + name: name of SQL table + if_exists: {'fail', 'replace', 'append'}, default 'fail' + fail: If table exists, do nothing. + replace: If table exists, drop it, recreate it, and insert data. + append: If table exists, insert data. Create if does not exist. + index : boolean, default True + Write DataFrame index as a column + index_label : string or sequence, default None + Column label for index column(s). If None is given (default) and + `index` is True, then the index names are used. + A sequence should be given if the DataFrame uses MultiIndex. + schema : string, default None + Ignored parameter included for compatability with SQLAlchemy + version of ``to_sql``. + chunksize : int, default None + If not None, then rows will be written in batches of this + size at a time. If None, all rows will be written at once. + + """ + table = SQLiteTable(name, self, frame=frame, index=index, + if_exists=if_exists, index_label=index_label) + table.create() + table.insert(chunksize) def has_table(self, name, schema=None): flavor_map = { @@ -1329,101 +1471,10 @@ def drop_table(self, name, schema=None): drop_sql = "DROP TABLE %s" % name self.execute(drop_sql) - def sql_schema(self, table): - return str(";\n".join(table.table)) - - def _execute_create(self, table): - with self.run_transaction() as conn: - for stmt in table.table: - conn.execute(stmt) - - def insert_statement(self, table): - names = list(map(str, table.frame.columns)) - flv = self.flavor - br_l = _SQL_SYMB[flv]['br_l'] # left val quote char - br_r = _SQL_SYMB[flv]['br_r'] # right val quote char - wld = _SQL_SYMB[flv]['wld'] # wildcard char - - if table.index is not None: - [names.insert(0, idx) for idx in table.index[::-1]] - - bracketed_names = [br_l + column + br_r for column in names] - col_names = ','.join(bracketed_names) - wildcards = ','.join([wld] * len(names)) - insert_statement = 'INSERT INTO %s (%s) VALUES (%s)' % ( - table.name, col_names, wildcards) - return insert_statement - - def _execute_insert(self, conn, table, keys, data_iter): - data_list = list(data_iter) - conn.executemany(self.insert_statement(table), data_list) - - def _create_table_setup(self, table): - """ - Return a list of SQL statement that create a table reflecting the - structure of a DataFrame. The first entry will be a CREATE TABLE - statement while the rest will be CREATE INDEX statements - """ - column_names_and_types = \ - table._get_column_names_and_types(self._sql_type_name) - - pat = re.compile('\s+') - column_names = [col_name for col_name, _, _ in column_names_and_types] - if any(map(pat.search, column_names)): - warnings.warn(_SAFE_NAMES_WARNING) - - flv = self.flavor - - br_l = _SQL_SYMB[flv]['br_l'] # left val quote char - br_r = _SQL_SYMB[flv]['br_r'] # right val quote char - - create_tbl_stmts = [(br_l + '%s' + br_r + ' %s') % (cname, ctype) - for cname, ctype, _ in column_names_and_types] - if table.keys is not None and len(table.keys): - cnames_br = ",".join([br_l + c + br_r for c in table.keys]) - create_tbl_stmts.append( - "CONSTRAINT {tbl}_pk PRIMARY KEY ({cnames_br})".format( - tbl=table.name, cnames_br=cnames_br)) - - create_stmts = ["CREATE TABLE " + table.name + " (\n" + - ',\n '.join(create_tbl_stmts) + "\n)"] - - ix_cols = [cname for cname, _, is_index in column_names_and_types - if is_index] - if len(ix_cols): - cnames = "_".join(ix_cols) - cnames_br = ",".join([br_l + c + br_r for c in ix_cols]) - create_stmts.append( - "CREATE INDEX ix_{tbl}_{cnames} ON {tbl} ({cnames_br})".format( - tbl=table.name, cnames=cnames, cnames_br=cnames_br)) - - return create_stmts - - def _sql_type_name(self, col): - pytype = col.dtype.type - pytype_name = "text" - if issubclass(pytype, np.floating): - pytype_name = "float" - elif com.is_timedelta64_dtype(pytype): - warnings.warn("the 'timedelta' type is not supported, and will be " - "written as integer values (ns frequency) to the " - "database.", UserWarning) - pytype_name = "int" - elif issubclass(pytype, np.integer): - pytype_name = "int" - elif issubclass(pytype, np.datetime64) or pytype is datetime: - # Caution: np.datetime64 is also a subclass of np.number. - pytype_name = "datetime" - elif issubclass(pytype, np.bool_): - pytype_name = "bool" - elif issubclass(pytype, np.object): - pytype = lib.infer_dtype(com._ensure_object(col)) - if pytype == "date": - pytype_name = "date" - elif pytype == "time": - pytype_name = "time" - - return _SQL_TYPES[pytype_name][self.flavor] + def _create_sql_schema(self, frame, table_name, keys=None): + table = SQLiteTable(table_name, self, frame=frame, index=False, + keys=keys) + return str(table.sql_schema()) def get_schema(frame, name, flavor='sqlite', keys=None, con=None):