diff --git a/.gitignore b/.gitignore index 42bbfb5d..a5797f8e 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ build/ dist/ MySQLdb/release.py .coverage +.idea diff --git a/MySQLdb/_mysql.c b/MySQLdb/_mysql.c index 1556fda3..4f93329d 100644 --- a/MySQLdb/_mysql.c +++ b/MySQLdb/_mysql.c @@ -1841,6 +1841,338 @@ _mysql_ConnectionObject_read_query_result( Py_RETURN_NONE; } +#if MYSQL_VERSION_ID >= 50707 +static char _mysql_ConnectionObject_session_state_changed__doc__[] = +"If the `session_track_state_change` system variable (global or session) \n\ +is `ON`, you can call this method to determine whether the session state \n\ +has changed. To get details of those changes, see the `get_session_*` \n\ +methods. Returns a Boolean.\n\ +"; + +static PyObject * +_mysql_ConnectionObject_session_state_changed( + _mysql_ConnectionObject *self, + PyObject *noargs) +{ + int r; + const char *data; + size_t length; + + Py_BEGIN_ALLOW_THREADS + r = mysql_session_track_get_first(&(self->connection), SESSION_TRACK_STATE_CHANGE, &data, &length); + Py_END_ALLOW_THREADS + + if (r != 0) { + Py_RETURN_FALSE; + } + + if (length != 1) { + PyErr_SetString( + PyExc_ValueError, + "_mysql: SESSION_TRACK_STATE_CHANGE returned value of length different than 1." + ); + return NULL; + } + + if (data[0] == '1') { + Py_RETURN_TRUE; + } + + Py_RETURN_FALSE; +} + +static char _mysql_ConnectionObject_get_session_gtids__doc__[] = +"If the `session_track_gtids` system variable (global or session) is \n\ +set to something other than 'OFF', you can call this method to \n\ +retrieve all GTIDs created by the session. Returns a list of byte strings.\n\ +"; + +static PyObject * +_mysql_ConnectionObject_get_session_gtids( + _mysql_ConnectionObject *self, + PyObject *noargs) +{ + int r; + const char *data; + size_t length; + + Py_BEGIN_ALLOW_THREADS + r = mysql_session_track_get_first(&(self->connection), SESSION_TRACK_GTIDS, &data, &length); + Py_END_ALLOW_THREADS + + PyObject *gtids = PyList_New(0); + + while (r == 0) + { + PyObject *gtid = PyBytes_FromStringAndSize(data, length); + if (gtid == NULL) + { + Py_DECREF(gtids); + return NULL; + } + + PyList_Append(gtids, gtid); + + Py_BEGIN_ALLOW_THREADS + r = mysql_session_track_get_next(&(self->connection), SESSION_TRACK_GTIDS, &data, &length); + Py_END_ALLOW_THREADS + } + + return gtids; +} + +static char _mysql_ConnectionObject_get_session_new_schema__doc__[] = +"If the `session_track_schema` system variable (global or session) is \n\ +set to 'ON', you can call this method to determine if the selected \n\ +schema has changed. Returns the byte string new schema name or None if \n\ +the schema has not changed.\n\ +"; + +static PyObject * +_mysql_ConnectionObject_get_session_new_schema( + _mysql_ConnectionObject *self, + PyObject *noargs) +{ + + int r; + const char *data; + size_t length; + + Py_BEGIN_ALLOW_THREADS + r = mysql_session_track_get_first(&(self->connection), SESSION_TRACK_SCHEMA, &data, &length); + Py_END_ALLOW_THREADS + + if (r != 0) { + Py_RETURN_NONE; + } + + PyObject *schema = PyBytes_FromStringAndSize(data, length); + + return schema; +} + +static PyTypeObject SessionVariableChangeType = {0, 0, 0, 0, 0, 0}; +static PyStructSequence_Field _session_variable_change_type_fields[] = { + {"variable_name", "The name of the variable that changed"}, + {"value", "The new value of the variable"}, + {NULL} +}; +static PyStructSequence_Desc _session_variable_change_type_fields_desc = { + "SessionVariableChange", + "Holds information about a changed session variable discovered through session tracking.", + _session_variable_change_type_fields, + 2 +}; + +static PyObject* SessionVariableChange( + PyObject *variable_name, + PyObject *value) +{ + PyObject *change = PyStructSequence_New(&SessionVariableChangeType); + PyStructSequence_SetItem(change, 0, variable_name); + PyStructSequence_SetItem(change, 1, value); + return change; +} + +static char _mysql_ConnectionObject_get_session_changed_variables__doc__[] = +"If the `session_track_system_variables` system variable (global or session) \n\ +is set to a non-blank string, you can call this method to determine if the \n\ +system variables listed in `session_track_system_variables` have changed. \n\ +Returns a possibly-empty list of `SessionVariableChange` objects.\n\ +"; + +static PyObject * +_mysql_ConnectionObject_get_session_changed_variables( + _mysql_ConnectionObject *self, + PyObject *noargs) +{ + int r; + const char *data; + size_t length; + + PyObject *variables = PyList_New(0); + + Py_BEGIN_ALLOW_THREADS + r = mysql_session_track_get_first(&(self->connection), SESSION_TRACK_SYSTEM_VARIABLES, &data, &length); + Py_END_ALLOW_THREADS + + while (r == 0) + { + PyObject *variable_name = PyBytes_FromStringAndSize(data, length); + if (variable_name == NULL) + { + Py_DECREF(variables); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + r = mysql_session_track_get_next(&(self->connection), SESSION_TRACK_SYSTEM_VARIABLES, &data, &length); + Py_END_ALLOW_THREADS + + if (r != 0) + { + Py_DECREF(variables); + PyErr_SetString( + PyExc_ValueError, + "_mysql: SESSION_TRACK_SYSTEM_VARIABLES returned a variable name but not a value." + ); + return NULL; + } + + PyObject *variable_value = PyBytes_FromStringAndSize(data, length); + if (variable_value == NULL) + { + Py_DECREF(variables); + return NULL; + } + + PyList_Append(variables, SessionVariableChange(variable_name, variable_value)); + + Py_BEGIN_ALLOW_THREADS + r = mysql_session_track_get_next(&(self->connection), SESSION_TRACK_SYSTEM_VARIABLES, &data, &length); + Py_END_ALLOW_THREADS + } + + return variables; +} + +static PyTypeObject TransactionStateType = {0, 0, 0, 0, 0, 0}; +static PyStructSequence_Field _transaction_state_type_fields[] = { + {"transaction_active", "Whether a transaction is active"}, + {"transaction_explicit", "Whether the transaction is explicit"}, + {"nontransactional_tables_read", "Whether unsafe reads occurred"}, + {"transactional_tables_read", "Whether safe reads occurred"}, + {"unsafe_writes", "Whether unsafe writes occurred"}, + {"safe_writes", "Whether safe writes occurred"}, + {"unsafe_statements", "Whether non-deterministic statements were executed"}, + {"result_sent", "Whether a result set was sent to the client"}, + {"locked_tables", "Whether LOCK TABLES was explicitly used"}, + {NULL} +}; +static PyStructSequence_Desc _transaction_state_type_fields_desc = { + "TransactionState", + "Holds information (Booleans) about the state of the current transaction.", + _transaction_state_type_fields, + 9 +}; + +static PyObject* TransactionState( + PyObject *transaction_active, + PyObject *transaction_explicit, + PyObject *nontransactional_tables_read, + PyObject *transactional_tables_read, + PyObject *unsafe_writes, + PyObject *safe_writes, + PyObject *unsafe_statements, + PyObject *result_sent, + PyObject *locked_tables) +{ + PyObject* state = PyStructSequence_New(&TransactionStateType); + PyStructSequence_SetItem(state, 0, transaction_active); + PyStructSequence_SetItem(state, 1, transaction_explicit); + PyStructSequence_SetItem(state, 2, nontransactional_tables_read); + PyStructSequence_SetItem(state, 3, transactional_tables_read); + PyStructSequence_SetItem(state, 4, unsafe_writes); + PyStructSequence_SetItem(state, 5, safe_writes); + PyStructSequence_SetItem(state, 6, unsafe_statements); + PyStructSequence_SetItem(state, 7, result_sent); + PyStructSequence_SetItem(state, 8, locked_tables); + return state; +} + +static char _mysql_ConnectionObject_get_session_transaction_state__doc__[] = +"If the `session_track_transaction_info` system variable (global or session) \n\ +is set to either `STATE` or `CHARACTERISTICS`, you can call this method to \n\ +determine the current transaction state. Returns a `TransactionState` object \n\ +representing the current transaction state, or None if no transaction state \n\ +is available.\n\ +"; + +static PyObject * +_mysql_ConnectionObject_get_session_transaction_state( + _mysql_ConnectionObject *self, + PyObject *noargs) +{ + int r; + const char *data; + size_t length; + + Py_BEGIN_ALLOW_THREADS + r = mysql_session_track_get_first(&(self->connection), SESSION_TRACK_TRANSACTION_STATE, &data, &length); + Py_END_ALLOW_THREADS + + if (r != 0) + { + Py_RETURN_NONE; + } + + if (length != 8) + { + PyErr_SetString( + PyExc_ValueError, + "_mysql: SESSION_TRACK_TRANSACTION_STATE returned a length other than the expected 8." + ); + return NULL; + } + + PyObject *transaction_active = data[0] == '_' ? Py_False : Py_True; + PyObject *transaction_explicit = data[0] == 'T' ? Py_True : Py_False; + PyObject *nontransactional_tables_read = data[1] == 'r' ? Py_True : Py_False; + PyObject *transactional_tables_read = data[2] == 'R' ? Py_True : Py_False; + PyObject *unsafe_writes = data[3] == 'w' ? Py_True : Py_False; + PyObject *safe_writes = data[4] == 'W' ? Py_True : Py_False; + PyObject *unsafe_statements = data[5] == 's' ? Py_True : Py_False; + PyObject *result_sent = data[6] == 'S' ? Py_True : Py_False; + PyObject *locked_tables = data[7] == 'L' ? Py_True : Py_False; + + PyObject *state = TransactionState( + transaction_active, + transaction_explicit, + nontransactional_tables_read, + transactional_tables_read, + unsafe_writes, + safe_writes, + unsafe_statements, + result_sent, + locked_tables + ); + + return state; +} + +static char _mysql_ConnectionObject_get_session_transaction_characteristics__doc__[] = +"If the `session_track_transaction_info` system variable (global or session) \n\ +is set to `CHARACTERISTICS`, you can call this method to obtain the SQL \n\ +statements necessary to restart a transaction with the same characteristics. \n\ +Returns a byte string of semicolon-separated SQL statements, or None if no \n\ +transaction characteristics are available.\n\ +"; + +static PyObject * +_mysql_ConnectionObject_get_session_transaction_characteristics( + _mysql_ConnectionObject *self, + PyObject *noargs) +{ + int r; + const char *data; + size_t length; + + Py_BEGIN_ALLOW_THREADS + r = mysql_session_track_get_first(&(self->connection), SESSION_TRACK_TRANSACTION_CHARACTERISTICS, &data, &length); + Py_END_ALLOW_THREADS + + if (r != 0 || length == 0) + { + Py_RETURN_NONE; + } + + PyObject *characteristics = PyBytes_FromStringAndSize(data, length); + + return characteristics; +} + +#endif + static char _mysql_ConnectionObject_select_db__doc__[] = "Causes the database specified by db to become the default\n\ (current) database on the connection specified by mysql. In subsequent\n\ @@ -2261,6 +2593,44 @@ static PyMethodDef _mysql_ConnectionObject_methods[] = { METH_NOARGS, _mysql_ConnectionObject_read_query_result__doc__, }, +#if MYSQL_VERSION_ID >= 50707 + { + "session_state_changed", + (PyCFunction)_mysql_ConnectionObject_session_state_changed, + METH_NOARGS, + _mysql_ConnectionObject_session_state_changed__doc__, + }, + { + "get_session_gtids", + (PyCFunction)_mysql_ConnectionObject_get_session_gtids, + METH_NOARGS, + _mysql_ConnectionObject_get_session_gtids__doc__, + }, + { + "get_session_new_schema", + (PyCFunction)_mysql_ConnectionObject_get_session_new_schema, + METH_NOARGS, + _mysql_ConnectionObject_get_session_new_schema__doc__, + }, + { + "get_session_changed_variables", + (PyCFunction)_mysql_ConnectionObject_get_session_changed_variables, + METH_NOARGS, + _mysql_ConnectionObject_get_session_changed_variables__doc__, + }, + { + "get_session_transaction_state", + (PyCFunction)_mysql_ConnectionObject_get_session_transaction_state, + METH_NOARGS, + _mysql_ConnectionObject_get_session_transaction_state__doc__, + }, + { + "get_session_transaction_characteristics", + (PyCFunction)_mysql_ConnectionObject_get_session_transaction_characteristics, + METH_NOARGS, + _mysql_ConnectionObject_get_session_transaction_characteristics__doc__, + }, +#endif { "select_db", (PyCFunction)_mysql_ConnectionObject_select_db, @@ -2678,6 +3048,16 @@ PyInit__mysql(void) module = PyModule_Create(&_mysqlmodule); if (!module) return module; /* this really should never happen */ + if (SessionVariableChangeType.tp_name == 0) + PyStructSequence_InitType(&SessionVariableChangeType, &_session_variable_change_type_fields_desc); + Py_INCREF((PyObject *) &SessionVariableChangeType); + PyModule_AddObject(module, "session_variable_change_type", (PyObject *) &SessionVariableChangeType); + + if (TransactionStateType.tp_name == 0) + PyStructSequence_InitType(&TransactionStateType, &_transaction_state_type_fields_desc); + Py_INCREF((PyObject *) &TransactionStateType); + PyModule_AddObject(module, "transaction_state_type", (PyObject *) &TransactionStateType); + if (!(dict = PyModule_GetDict(module))) goto error; if (PyDict_SetItemString(dict, "version_info", PyRun_String(QUOTE(version_info), Py_eval_input, diff --git a/MySQLdb/connections.py b/MySQLdb/connections.py index 8e226ffe..a10c6a38 100644 --- a/MySQLdb/connections.py +++ b/MySQLdb/connections.py @@ -357,4 +357,7 @@ def show_warnings(self): NotSupportedError = NotSupportedError +SessionVariableChange = _mysql.session_variable_change_type +TransactionState = _mysql.transaction_state_type + # vim: colorcolumn=100 diff --git a/MySQLdb/constants/CLIENT.py b/MySQLdb/constants/CLIENT.py index 35f578cc..a2841e58 100644 --- a/MySQLdb/constants/CLIENT.py +++ b/MySQLdb/constants/CLIENT.py @@ -25,3 +25,4 @@ SECURE_CONNECTION = 32768 MULTI_STATEMENTS = 65536 MULTI_RESULTS = 131072 +SESSION_TRACK = 8388608 # 1 << 23