diff --git a/src/MySQLdb/_mysql.c b/src/MySQLdb/_mysql.c index 30b111e5..929bab08 100644 --- a/src/MySQLdb/_mysql.c +++ b/src/MySQLdb/_mysql.c @@ -84,6 +84,9 @@ typedef struct { bool open; bool reconnect; PyObject *converter; +#ifndef Py_GIL_DISABLED + PyThread_type_lock lock; +#endif } _mysql_ConnectionObject; #define check_connection(c) \ @@ -109,6 +112,57 @@ typedef struct { extern PyTypeObject _mysql_ResultObject_Type; +#ifdef Py_GIL_DISABLED +#define DECLARE_CONNECTION_GUARD PyCriticalSection _mysql_cs = {0} +#define BEGIN_CONNECTION_LOCK(c) \ + PyCriticalSection_Begin(&_mysql_cs, (PyObject *)(c)) +#define END_CONNECTION_LOCK(c) PyCriticalSection_End(&_mysql_cs) +#else +#define DECLARE_CONNECTION_GUARD +static int +_mysql_ConnectionObject_AllocateLock(_mysql_ConnectionObject *self) +{ + self->lock = PyThread_allocate_lock(); + if (self->lock == NULL) { + PyErr_NoMemory(); + return -1; + } + return 0; +} + +static void +_mysql_ConnectionObject_Lock(_mysql_ConnectionObject *self) +{ + Py_BEGIN_ALLOW_THREADS + PyThread_acquire_lock(self->lock, WAIT_LOCK); + Py_END_ALLOW_THREADS +} + +static void +_mysql_ConnectionObject_Unlock(_mysql_ConnectionObject *self) +{ + PyThread_release_lock(self->lock); +} + +#define BEGIN_CONNECTION_LOCK(c) _mysql_ConnectionObject_Lock(c) +#define END_CONNECTION_LOCK(c) _mysql_ConnectionObject_Unlock(c) +#endif + +#define BEGIN_RESULT_CONNECTION_LOCK(r) \ + BEGIN_CONNECTION_LOCK(result_connection(r)) +#define END_RESULT_CONNECTION_LOCK(r) \ + END_CONNECTION_LOCK(result_connection(r)) +#define BEGIN_CONNECTION_OPERATION(c, on_closed) \ + do { \ + BEGIN_CONNECTION_LOCK(c); \ + if (!(c)->open) { \ + END_CONNECTION_LOCK(c); \ + on_closed; \ + } \ + } while (0) +#define BEGIN_RESULT_OPERATION(r, on_closed) \ + BEGIN_CONNECTION_OPERATION(result_connection(r), on_closed) + PyObject * _mysql_Exception(_mysql_ConnectionObject *c) @@ -266,6 +320,7 @@ _mysql_ResultObject_Initialize( PyObject *conv=NULL; int n, i; MYSQL_FIELD *fields; + DECLARE_CONNECTION_GUARD; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|iO", kwlist, &_mysql_ConnectionObject_Type, &conn, &use, &conv)) @@ -274,6 +329,12 @@ _mysql_ResultObject_Initialize( self->conn = (PyObject *) conn; Py_INCREF(conn); self->use = use; + BEGIN_CONNECTION_LOCK(conn); + if (!conn->open) { + END_CONNECTION_LOCK(conn); + _mysql_Exception(conn); + return -1; + } Py_BEGIN_ALLOW_THREADS ; if (use) result = mysql_use_result(&(conn->connection)); @@ -284,6 +345,7 @@ _mysql_ResultObject_Initialize( Py_END_ALLOW_THREADS ; self->encoding = _get_encoding(&(conn->connection)); + END_CONNECTION_LOCK(conn); //fprintf(stderr, "encoding=%s\n", self->encoding); if (!result) { if (mysql_errno(&(conn->connection))) { @@ -458,6 +520,9 @@ _mysql_ConnectionObject_Initialize( self->converter = NULL; self->open = false; self->reconnect = false; +#ifndef Py_GIL_DISABLED + self->lock = NULL; +#endif if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ssssisOiiisssiOsiiissss:connect", @@ -479,6 +544,12 @@ _mysql_ConnectionObject_Initialize( )) return -1; +#ifndef Py_GIL_DISABLED + if (_mysql_ConnectionObject_AllocateLock(self)) { + return -1; + } +#endif + #ifndef HAVE_MYSQL_SERVER_PUBLIC_KEY if (server_public_key_path) { PyErr_SetString(_mysql_NotSupportedError, "server_public_key_path is not supported"); @@ -744,12 +815,14 @@ _mysql_ConnectionObject_close( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); + DECLARE_CONNECTION_GUARD; + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS mysql_close(&(self->connection)); Py_END_ALLOW_THREADS self->open = false; _mysql_ConnectionObject_clear(self); + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -764,8 +837,10 @@ _mysql_ConnectionObject_affected_rows( PyObject *noargs) { my_ulonglong ret; - check_connection(self); + DECLARE_CONNECTION_GUARD; + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); ret = mysql_affected_rows(&(self->connection)); + END_CONNECTION_LOCK(self); if (ret == (my_ulonglong)-1) return PyLong_FromLong(-1); return PyLong_FromUnsignedLongLong(ret); @@ -800,11 +875,17 @@ _mysql_ConnectionObject_dump_debug_info( PyObject *noargs) { int err; - check_connection(self); + DECLARE_CONNECTION_GUARD; + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS err = mysql_dump_debug_info(&(self->connection)); Py_END_ALLOW_THREADS - if (err) return _mysql_Exception(self); + if (err) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -817,12 +898,18 @@ _mysql_ConnectionObject_autocommit( PyObject *args) { int flag, err; + DECLARE_CONNECTION_GUARD; if (!PyArg_ParseTuple(args, "i", &flag)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS err = mysql_autocommit(&(self->connection), flag); Py_END_ALLOW_THREADS - if (err) return _mysql_Exception(self); + if (err) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -834,10 +921,13 @@ _mysql_ConnectionObject_get_autocommit( _mysql_ConnectionObject *self, PyObject *args) { - check_connection(self); + DECLARE_CONNECTION_GUARD; + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); if (self->connection.server_status & SERVER_STATUS_AUTOCOMMIT) { + END_CONNECTION_LOCK(self); Py_RETURN_TRUE; } + END_CONNECTION_LOCK(self); Py_RETURN_FALSE; } @@ -850,11 +940,17 @@ _mysql_ConnectionObject_commit( PyObject *noargs) { int err; - check_connection(self); + DECLARE_CONNECTION_GUARD; + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS err = mysql_commit(&(self->connection)); Py_END_ALLOW_THREADS - if (err) return _mysql_Exception(self); + if (err) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -867,11 +963,17 @@ _mysql_ConnectionObject_rollback( PyObject *noargs) { int err; - check_connection(self); + DECLARE_CONNECTION_GUARD; + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS err = mysql_rollback(&(self->connection)); Py_END_ALLOW_THREADS - if (err) return _mysql_Exception(self); + if (err) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -894,11 +996,17 @@ _mysql_ConnectionObject_next_result( PyObject *noargs) { int err; - check_connection(self); + DECLARE_CONNECTION_GUARD; + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS err = mysql_next_result(&(self->connection)); Py_END_ALLOW_THREADS - if (err > 0) return _mysql_Exception(self); + if (err > 0) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); return PyLong_FromLong(err); } @@ -915,13 +1023,19 @@ _mysql_ConnectionObject_set_server_option( PyObject *args) { int err, flags=0; + DECLARE_CONNECTION_GUARD; if (!PyArg_ParseTuple(args, "i", &flags)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS err = mysql_set_server_option(&(self->connection), flags); Py_END_ALLOW_THREADS - if (err) return _mysql_Exception(self); + if (err) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); return PyLong_FromLong(err); } @@ -942,8 +1056,12 @@ _mysql_ConnectionObject_sqlstate( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyUnicode_FromString(mysql_sqlstate(&(self->connection))); + PyObject *ret; + DECLARE_CONNECTION_GUARD; + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + ret = PyUnicode_FromString(mysql_sqlstate(&(self->connection))); + END_CONNECTION_LOCK(self); + return ret; } static char _mysql_ConnectionObject_warning_count__doc__[] = @@ -957,8 +1075,12 @@ _mysql_ConnectionObject_warning_count( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyLong_FromLong(mysql_warning_count(&(self->connection))); + unsigned int count; + DECLARE_CONNECTION_GUARD; + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + count = mysql_warning_count(&(self->connection)); + END_CONNECTION_LOCK(self); + return PyLong_FromLong(count); } static char _mysql_ConnectionObject_errno__doc__[] = @@ -972,8 +1094,12 @@ _mysql_ConnectionObject_errno( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyLong_FromLong((long)mysql_errno(&(self->connection))); + unsigned int err; + DECLARE_CONNECTION_GUARD; + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + err = mysql_errno(&(self->connection)); + END_CONNECTION_LOCK(self); + return PyLong_FromLong((long)err); } static char _mysql_ConnectionObject_error__doc__[] = @@ -987,8 +1113,12 @@ _mysql_ConnectionObject_error( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyUnicode_FromString(mysql_error(&(self->connection))); + PyObject *ret; + DECLARE_CONNECTION_GUARD; + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + ret = PyUnicode_FromString(mysql_error(&(self->connection))); + END_CONNECTION_LOCK(self); + return ret; } static char _mysql_escape_string__doc__[] = @@ -1008,6 +1138,8 @@ _mysql_escape_string( char *in, *out; unsigned long len; Py_ssize_t size; + int use_connection = 0; + DECLARE_CONNECTION_GUARD; if (!PyArg_ParseTuple(args, "s#:escape_string", &in, &size)) return NULL; str = PyBytes_FromStringAndSize((char *) NULL, size*2+1); if (!str) return PyErr_NoMemory(); @@ -1015,7 +1147,11 @@ _mysql_escape_string( if (self && PyModule_Check((PyObject*)self)) self = NULL; - if (self && self->open) { + if (self) { + BEGIN_CONNECTION_LOCK(self); + use_connection = self->open; + } + if (use_connection) { #if MYSQL_VERSION_ID >= 50707 && !defined(MARIADB_BASE_VERSION) && !defined(MARIADB_VERSION_ID) len = mysql_real_escape_string_quote(&(self->connection), out, in, size, '\''); #else @@ -1024,6 +1160,9 @@ _mysql_escape_string( } else { len = mysql_escape_string(out, in, size); } + if (self) { + END_CONNECTION_LOCK(self); + } if (_PyBytes_Resize(&str, len) < 0) return NULL; return (str); } @@ -1044,9 +1183,16 @@ _mysql_string_literal( PyObject *o) { PyObject *s; // input string or bytes. need to decref. + int use_connection = 0; + PyObject *str = NULL; + DECLARE_CONNECTION_GUARD; if (self && PyModule_Check((PyObject*)self)) self = NULL; + if (self) { + BEGIN_CONNECTION_LOCK(self); + use_connection = self->open; + } if (PyBytes_Check(o)) { s = o; @@ -1054,9 +1200,9 @@ _mysql_string_literal( } else { PyObject *t = PyObject_Str(o); - if (!t) return NULL; + if (!t) goto error; - const char *encoding = (self && self->open) ? + const char *encoding = use_connection ? _get_encoding(&self->connection) : utf8; if (encoding == utf8) { s = t; @@ -1064,7 +1210,7 @@ _mysql_string_literal( else { s = PyUnicode_AsEncodedString(t, encoding, "strict"); Py_DECREF(t); - if (!s) return NULL; + if (!s) goto error; } } @@ -1073,6 +1219,10 @@ _mysql_string_literal( Py_ssize_t size; if (PyUnicode_Check(s)) { in = PyUnicode_AsUTF8AndSize(s, &size); + if (!in) { + Py_DECREF(s); + goto error; + } } else { assert(PyBytes_Check(s)); in = PyBytes_AsString(s); @@ -1080,16 +1230,17 @@ _mysql_string_literal( } // Prepare output buffer (str, out) - PyObject *str = PyBytes_FromStringAndSize((char *) NULL, size*2+3); + str = PyBytes_FromStringAndSize((char *) NULL, size*2+3); if (!str) { Py_DECREF(s); - return PyErr_NoMemory(); + PyErr_NoMemory(); + goto error; } char *out = PyBytes_AS_STRING(str); // escape unsigned long len; - if (self && self->open) { + if (use_connection) { #if MYSQL_VERSION_ID >= 50707 && !defined(MARIADB_BASE_VERSION) && !defined(MARIADB_VERSION_ID) len = mysql_real_escape_string_quote(&(self->connection), out+1, in, size, '\''); #else @@ -1100,12 +1251,20 @@ _mysql_string_literal( } Py_DECREF(s); + if (self) { + END_CONNECTION_LOCK(self); + } *out = *(out+len+1) = '\''; if (_PyBytes_Resize(&str, len+2) < 0) { Py_DECREF(str); return NULL; } return str; +error: + if (self) { + END_CONNECTION_LOCK(self); + } + return NULL; } static PyObject * @@ -1142,6 +1301,8 @@ _mysql_escape( PyObject *args) { PyObject *o=NULL, *d=NULL; + PyObject *converter = NULL; + DECLARE_CONNECTION_GUARD; if (!PyArg_ParseTuple(args, "O|O:escape", &o, &d)) return NULL; if (d) { @@ -1152,13 +1313,21 @@ _mysql_escape( } return _escape_item(self, o, d); } else { - if (!self) { + if (!self || PyModule_Check(self)) { PyErr_SetString(PyExc_TypeError, "argument 2 must be a mapping"); return NULL; } - return _escape_item(self, o, - ((_mysql_ConnectionObject *) self)->converter); + BEGIN_CONNECTION_LOCK((_mysql_ConnectionObject *)self); + converter = ((_mysql_ConnectionObject *) self)->converter; + Py_XINCREF(converter); + END_CONNECTION_LOCK((_mysql_ConnectionObject *)self); + if (!converter) { + return _mysql_Exception((_mysql_ConnectionObject *)self); + } + PyObject *ret = _escape_item(self, o, converter); + Py_DECREF(converter); + return ret; } } @@ -1175,8 +1344,9 @@ _mysql_ResultObject_describe( PyObject *d; MYSQL_FIELD *fields; unsigned int i, n; + DECLARE_CONNECTION_GUARD; - check_result_connection(self); + BEGIN_RESULT_OPERATION(self, return _mysql_Exception(result_connection(self))); n = mysql_num_fields(self->result); fields = mysql_fetch_fields(self->result); @@ -1204,8 +1374,10 @@ _mysql_ResultObject_describe( if (!t) goto error; PyTuple_SET_ITEM(d, i, t); } + END_RESULT_CONNECTION_LOCK(self); return d; error: + END_RESULT_CONNECTION_LOCK(self); Py_XDECREF(d); return NULL; } @@ -1222,7 +1394,8 @@ _mysql_ResultObject_field_flags( PyObject *d; MYSQL_FIELD *fields; unsigned int i, n; - check_result_connection(self); + DECLARE_CONNECTION_GUARD; + BEGIN_RESULT_OPERATION(self, return _mysql_Exception(result_connection(self))); n = mysql_num_fields(self->result); fields = mysql_fetch_fields(self->result); if (!(d = PyTuple_New(n))) return NULL; @@ -1231,8 +1404,10 @@ _mysql_ResultObject_field_flags( if (!(f = PyLong_FromLong((long)fields[i].flags))) goto error; PyTuple_SET_ITEM(d, i, f); } + END_RESULT_CONNECTION_LOCK(self); return d; error: + END_RESULT_CONNECTION_LOCK(self); Py_XDECREF(d); return NULL; } @@ -1535,12 +1710,14 @@ _mysql_ResultObject_fetch_row( static char *kwlist[] = {"maxrows", "how", NULL }; int maxrows=1, how=0; PyObject *r=NULL; + DECLARE_CONNECTION_GUARD; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ii:fetch_row", kwlist, &maxrows, &how)) return NULL; - check_result_connection(self); + BEGIN_RESULT_OPERATION(self, return _mysql_Exception(result_connection(self))); if (how >= (int)(sizeof(row_converters) / sizeof(row_converters[0]))) { + END_RESULT_CONNECTION_LOCK(self); PyErr_SetString(PyExc_ValueError, "how out of range"); return NULL; } @@ -1561,8 +1738,10 @@ _mysql_ResultObject_fetch_row( */ PyObject *t = PyList_AsTuple(r); Py_DECREF(r); + END_RESULT_CONNECTION_LOCK(self); return t; error: + END_RESULT_CONNECTION_LOCK(self); Py_XDECREF(r); return NULL; } @@ -1575,7 +1754,8 @@ _mysql_ResultObject_discard( _mysql_ResultObject *self, PyObject *noargs) { - check_result_connection(self); + DECLARE_CONNECTION_GUARD; + BEGIN_RESULT_OPERATION(self, return _mysql_Exception(result_connection(self))); MYSQL_ROW row; Py_BEGIN_ALLOW_THREADS @@ -1585,8 +1765,11 @@ _mysql_ResultObject_discard( Py_END_ALLOW_THREADS _mysql_ConnectionObject *conn = (_mysql_ConnectionObject *)self->conn; if (mysql_errno(&conn->connection)) { - return _mysql_Exception(conn); + PyObject *ret = _mysql_Exception(conn); + END_RESULT_CONNECTION_LOCK(self); + return ret; } + END_RESULT_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -1613,6 +1796,7 @@ _mysql_ConnectionObject_change_user( PyObject *args, PyObject *kwargs) { + DECLARE_CONNECTION_GUARD; char *user, *pwd=NULL, *db=NULL; int r; static char *kwlist[] = { "user", "passwd", "db", NULL } ; @@ -1620,11 +1804,16 @@ _mysql_ConnectionObject_change_user( if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|ss:change_user", kwlist, &user, &pwd, &db)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS r = mysql_change_user(&(self->connection), user, pwd, db); Py_END_ALLOW_THREADS - if (r) return _mysql_Exception(self); + if (r) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -1638,10 +1827,15 @@ _mysql_ConnectionObject_character_set_name( _mysql_ConnectionObject *self, PyObject *noargs) { + DECLARE_CONNECTION_GUARD; const char *s; - check_connection(self); + PyObject *ret; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); s = mysql_character_set_name(&(self->connection)); - return PyUnicode_FromString(s); + ret = PyUnicode_FromString(s); + END_CONNECTION_LOCK(self); + return ret; } static char _mysql_ConnectionObject_set_character_set__doc__[] = @@ -1654,14 +1848,21 @@ _mysql_ConnectionObject_set_character_set( _mysql_ConnectionObject *self, PyObject *args) { + DECLARE_CONNECTION_GUARD; const char *s; int err; + if (!PyArg_ParseTuple(args, "s", &s)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS err = mysql_set_character_set(&(self->connection), s); Py_END_ALLOW_THREADS - if (err) return _mysql_Exception(self); + if (err) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -1692,12 +1893,17 @@ _mysql_ConnectionObject_get_character_set_info( _mysql_ConnectionObject *self, PyObject *noargs) { + DECLARE_CONNECTION_GUARD; PyObject *result; MY_CHARSET_INFO cs; - check_connection(self); + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); mysql_get_character_set_info(&(self->connection), &cs); - if (!(result = PyDict_New())) return NULL; + if (!(result = PyDict_New())) { + END_CONNECTION_LOCK(self); + return NULL; + } if (cs.csname) PyDict_SetItemString(result, "name", PyUnicode_FromString(cs.csname)); if (cs.name) @@ -1708,6 +1914,7 @@ _mysql_ConnectionObject_get_character_set_info( PyDict_SetItemString(result, "dir", PyUnicode_FromString(cs.dir)); PyDict_SetItemString(result, "mbminlen", PyLong_FromLong(cs.mbminlen)); PyDict_SetItemString(result, "mbmaxlen", PyLong_FromLong(cs.mbmaxlen)); + END_CONNECTION_LOCK(self); return result; } #endif @@ -1726,10 +1933,13 @@ _mysql_ConnectionObject_get_native_connection( _mysql_ConnectionObject *self, PyObject *noargs) { + DECLARE_CONNECTION_GUARD; PyObject *result; - check_connection(self); + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); result = PyCapsule_New(&(self->connection), "_mysql.connection.native_connection", NULL); + END_CONNECTION_LOCK(self); return result; } @@ -1755,8 +1965,13 @@ _mysql_ConnectionObject_get_host_info( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyUnicode_FromString(mysql_get_host_info(&(self->connection))); + DECLARE_CONNECTION_GUARD; + PyObject *ret; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + ret = PyUnicode_FromString(mysql_get_host_info(&(self->connection))); + END_CONNECTION_LOCK(self); + return ret; } static char _mysql_ConnectionObject_get_proto_info__doc__[] = @@ -1769,8 +1984,13 @@ _mysql_ConnectionObject_get_proto_info( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyLong_FromLong((long)mysql_get_proto_info(&(self->connection))); + DECLARE_CONNECTION_GUARD; + unsigned int proto; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + proto = mysql_get_proto_info(&(self->connection)); + END_CONNECTION_LOCK(self); + return PyLong_FromLong((long)proto); } static char _mysql_ConnectionObject_get_server_info__doc__[] = @@ -1783,8 +2003,13 @@ _mysql_ConnectionObject_get_server_info( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyUnicode_FromString(mysql_get_server_info(&(self->connection))); + DECLARE_CONNECTION_GUARD; + PyObject *ret; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + ret = PyUnicode_FromString(mysql_get_server_info(&(self->connection))); + END_CONNECTION_LOCK(self); + return ret; } static char _mysql_ConnectionObject_info__doc__[] = @@ -1798,10 +2023,18 @@ _mysql_ConnectionObject_info( _mysql_ConnectionObject *self, PyObject *noargs) { + DECLARE_CONNECTION_GUARD; const char *s; - check_connection(self); + PyObject *ret; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); s = mysql_info(&(self->connection)); - if (s) return PyUnicode_FromString(s); + if (s) { + ret = PyUnicode_FromString(s); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -1831,9 +2064,12 @@ _mysql_ConnectionObject_insert_id( _mysql_ConnectionObject *self, PyObject *noargs) { + DECLARE_CONNECTION_GUARD; my_ulonglong r; - check_connection(self); + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); r = mysql_insert_id(&(self->connection)); + END_CONNECTION_LOCK(self); return PyLong_FromUnsignedLongLong(r); } @@ -1846,16 +2082,23 @@ _mysql_ConnectionObject_kill( _mysql_ConnectionObject *self, PyObject *args) { + DECLARE_CONNECTION_GUARD; unsigned long pid; int r; char query[50]; + if (!PyArg_ParseTuple(args, "k:kill", &pid)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); snprintf(query, 50, "KILL %lu", pid); Py_BEGIN_ALLOW_THREADS r = mysql_query(&(self->connection), query); Py_END_ALLOW_THREADS - if (r) return _mysql_Exception(self); + if (r) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -1870,8 +2113,13 @@ _mysql_ConnectionObject_field_count( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyLong_FromLong((long)mysql_field_count(&(self->connection))); + DECLARE_CONNECTION_GUARD; + unsigned int count; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + count = mysql_field_count(&(self->connection)); + END_CONNECTION_LOCK(self); + return PyLong_FromLong((long)count); } static char _mysql_ConnectionObject_fileno__doc__[] = @@ -1884,8 +2132,13 @@ _mysql_ConnectionObject_fileno( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); - return PyLong_FromLong(self->connection.net.fd); + DECLARE_CONNECTION_GUARD; + int fd; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); + fd = self->connection.net.fd; + END_CONNECTION_LOCK(self); + return PyLong_FromLong(fd); } static char _mysql_ResultObject_num_fields__doc__[] = @@ -1896,8 +2149,13 @@ _mysql_ResultObject_num_fields( _mysql_ResultObject *self, PyObject *noargs) { - check_result_connection(self); - return PyLong_FromLong((long)mysql_num_fields(self->result)); + DECLARE_CONNECTION_GUARD; + unsigned int fields; + + BEGIN_RESULT_OPERATION(self, return _mysql_Exception(result_connection(self))); + fields = mysql_num_fields(self->result); + END_RESULT_CONNECTION_LOCK(self); + return PyLong_FromLong((long)fields); } static char _mysql_ResultObject_num_rows__doc__[] = @@ -1911,8 +2169,13 @@ _mysql_ResultObject_num_rows( _mysql_ResultObject *self, PyObject *noargs) { - check_result_connection(self); - return PyLong_FromUnsignedLongLong(mysql_num_rows(self->result)); + DECLARE_CONNECTION_GUARD; + my_ulonglong rows; + + BEGIN_RESULT_OPERATION(self, return _mysql_Exception(result_connection(self))); + rows = mysql_num_rows(self->result); + END_RESULT_CONNECTION_LOCK(self); + return PyLong_FromUnsignedLongLong(rows); } static char _mysql_ConnectionObject_ping__doc__[] = @@ -1939,9 +2202,11 @@ _mysql_ConnectionObject_ping( _mysql_ConnectionObject *self, PyObject *args) { + DECLARE_CONNECTION_GUARD; int reconnect = 0; + if (!PyArg_ParseTuple(args, "|p", &reconnect)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); if (reconnect != (self->reconnect == true)) { // libmysqlclient show warning to stderr when MYSQL_OPT_RECONNECT is used. // so we avoid using it as possible for now. @@ -1956,7 +2221,12 @@ _mysql_ConnectionObject_ping( Py_BEGIN_ALLOW_THREADS r = mysql_ping(&(self->connection)); Py_END_ALLOW_THREADS - if (r) return _mysql_Exception(self); + if (r) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -1971,16 +2241,23 @@ _mysql_ConnectionObject_query( _mysql_ConnectionObject *self, PyObject *args) { + DECLARE_CONNECTION_GUARD; char *query; Py_ssize_t len; int r; + if (!PyArg_ParseTuple(args, "s#:query", &query, &len)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS r = mysql_real_query(&(self->connection), query, len); Py_END_ALLOW_THREADS - if (r) return _mysql_Exception(self); + if (r) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -1994,17 +2271,24 @@ _mysql_ConnectionObject_send_query( _mysql_ConnectionObject *self, PyObject *args) { + DECLARE_CONNECTION_GUARD; char *query; Py_ssize_t len; int r; MYSQL *mysql = &(self->connection); + if (!PyArg_ParseTuple(args, "s#:query", &query, &len)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS r = mysql_send_query(mysql, query, len); Py_END_ALLOW_THREADS - if (r) return _mysql_Exception(self); + if (r) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -2017,14 +2301,21 @@ _mysql_ConnectionObject_read_query_result( _mysql_ConnectionObject *self, PyObject *noargs) { + DECLARE_CONNECTION_GUARD; int r; MYSQL *mysql = &(self->connection); - check_connection(self); + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS r = (int)mysql_read_query_result(mysql); Py_END_ALLOW_THREADS - if (r) return _mysql_Exception(self); + if (r) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -2045,14 +2336,21 @@ _mysql_ConnectionObject_select_db( _mysql_ConnectionObject *self, PyObject *args) { + DECLARE_CONNECTION_GUARD; char *db; int r; + if (!PyArg_ParseTuple(args, "s:select_db", &db)) return NULL; - check_connection(self); + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS r = mysql_select_db(&(self->connection), db); Py_END_ALLOW_THREADS - if (r) return _mysql_Exception(self); + if (r) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -2066,12 +2364,19 @@ _mysql_ConnectionObject_shutdown( _mysql_ConnectionObject *self, PyObject *noargs) { + DECLARE_CONNECTION_GUARD; int r; - check_connection(self); + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS r = mysql_query(&(self->connection), "SHUTDOWN"); Py_END_ALLOW_THREADS - if (r) return _mysql_Exception(self); + if (r) { + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -2087,13 +2392,22 @@ _mysql_ConnectionObject_stat( _mysql_ConnectionObject *self, PyObject *noargs) { + DECLARE_CONNECTION_GUARD; const char *s; - check_connection(self); + PyObject *ret; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); Py_BEGIN_ALLOW_THREADS s = mysql_stat(&(self->connection)); Py_END_ALLOW_THREADS - if (!s) return _mysql_Exception(self); - return PyUnicode_FromString(s); + if (!s) { + ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; + } + ret = PyUnicode_FromString(s); + END_CONNECTION_LOCK(self); + return ret; } static char _mysql_ConnectionObject_store_result__doc__[] = @@ -2107,11 +2421,14 @@ _mysql_ConnectionObject_store_result( _mysql_ConnectionObject *self, PyObject *noargs) { + DECLARE_CONNECTION_GUARD; PyObject *arglist=NULL, *kwarglist=NULL, *result=NULL; _mysql_ResultObject *r=NULL; - check_connection(self); + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); arglist = Py_BuildValue("(OiO)", self, 0, self->converter); + END_CONNECTION_LOCK(self); if (!arglist) goto error; kwarglist = PyDict_New(); if (!kwarglist) goto error; @@ -2149,9 +2466,12 @@ _mysql_ConnectionObject_thread_id( _mysql_ConnectionObject *self, PyObject *noargs) { + DECLARE_CONNECTION_GUARD; unsigned long pid; - check_connection(self); + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); pid = mysql_thread_id(&(self->connection)); + END_CONNECTION_LOCK(self); return PyLong_FromLong((long)pid); } @@ -2166,11 +2486,14 @@ _mysql_ConnectionObject_use_result( _mysql_ConnectionObject *self, PyObject *noargs) { + DECLARE_CONNECTION_GUARD; PyObject *arglist=NULL, *kwarglist=NULL, *result=NULL; _mysql_ResultObject *r=NULL; - check_connection(self); + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); arglist = Py_BuildValue("(OiO)", self, 1, self->converter); + END_CONNECTION_LOCK(self); if (!arglist) return NULL; kwarglist = PyDict_New(); if (!kwarglist) goto error; @@ -2201,31 +2524,32 @@ _mysql_ConnectionObject_discard_result( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); + DECLARE_CONNECTION_GUARD; + MYSQL_RES *res; + MYSQL_ROW row; + int err = 0; + + BEGIN_CONNECTION_OPERATION(self, return _mysql_Exception(self)); MYSQL *conn = &(self->connection); Py_BEGIN_ALLOW_THREADS; - MYSQL_RES *res = mysql_use_result(conn); - if (res == NULL) { - Py_BLOCK_THREADS; - if (mysql_errno(conn) != 0) { - // fprintf(stderr, "mysql_use_result failed: %s\n", mysql_error(conn)); - return _mysql_Exception(self); + res = mysql_use_result(conn); + if (res != NULL) { + while (NULL != (row = mysql_fetch_row(res))) { + // do nothing. } - Py_RETURN_NONE; - } - - MYSQL_ROW row; - while (NULL != (row = mysql_fetch_row(res))) { - // do nothing. + mysql_free_result(res); } - mysql_free_result(res); Py_END_ALLOW_THREADS; - if (mysql_errno(conn)) { + err = mysql_errno(conn); + if (err) { // fprintf(stderr, "mysql_free_result failed: %s\n", mysql_error(conn)); - return _mysql_Exception(self); + PyObject *ret = _mysql_Exception(self); + END_CONNECTION_LOCK(self); + return ret; } + END_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -2239,6 +2563,12 @@ _mysql_ConnectionObject_dealloc( self->open = false; } Py_CLEAR(self->converter); +#ifndef Py_GIL_DISABLED + if (self->lock != NULL) { + PyThread_free_lock(self->lock); + self->lock = NULL; + } +#endif MyFree(self); } @@ -2262,10 +2592,13 @@ _mysql_ResultObject_data_seek( _mysql_ResultObject *self, PyObject *args) { + DECLARE_CONNECTION_GUARD; unsigned int row; + if (!PyArg_ParseTuple(args, "i:data_seek", &row)) return NULL; - check_result_connection(self); + BEGIN_RESULT_OPERATION(self, return _mysql_Exception(result_connection(self))); mysql_data_seek(self->result, row); + END_RESULT_CONNECTION_LOCK(self); Py_RETURN_NONE; } @@ -2273,8 +2606,16 @@ static void _mysql_ResultObject_dealloc( _mysql_ResultObject *self) { + DECLARE_CONNECTION_GUARD; + PyObject *conn = self->conn; PyObject_GC_UnTrack((PyObject *)self); + if (conn != NULL) { + BEGIN_RESULT_CONNECTION_LOCK(self); + } mysql_free_result(self->result); + if (conn != NULL) { + END_RESULT_CONNECTION_LOCK(self); + } _mysql_ResultObject_clear(self); MyFree(self); } diff --git a/tests/test_connection.py b/tests/test_connection.py index 960de572..005c47f6 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,3 +1,6 @@ +import threading +import time + import pytest from MySQLdb._exceptions import ProgrammingError @@ -24,3 +27,70 @@ def test_multi_statements_false(): cursor.execute("select 17") rows = cursor.fetchall() assert rows == ((17,),) + + +def _assert_thread_id_blocked_during_operation(conn, func, min_wait=0.15): + error = None + done = threading.Event() + thread_id_done = threading.Event() + results = {} + + def run(): + nonlocal error + try: + func() + except Exception as exc: # pragma: no cover - error checked below + error = exc + finally: + done.set() + + def read_thread_id(): + try: + results["thread_id"] = conn.thread_id() + except Exception as exc: # pragma: no cover - assertion checked below + results["error"] = exc + finally: + thread_id_done.set() + + thread = threading.Thread(target=run) + thread.start() + time.sleep(0.05) + assert not done.is_set() + + blocker = threading.Thread(target=read_thread_id) + blocker.start() + + assert not thread_id_done.wait(min_wait) + thread.join() + blocker.join() + assert error is None + assert done.is_set() + assert "error" not in results + assert isinstance(results["thread_id"], int) + + +def test_connection_methods_are_serialized(): + conn = connection_factory() + try: + def run_query(): + conn.query("SELECT SLEEP(0.2)") + result = conn.store_result() + assert result.fetch_row() == ((0,),) + + _assert_thread_id_blocked_during_operation(conn, run_query) + finally: + conn.close() + + +def test_result_methods_share_connection_lock(): + conn = connection_factory() + try: + conn.query("SELECT 1 UNION ALL SELECT SLEEP(0.2)") + result = conn.use_result() + + def fetch_all_rows(): + assert result.fetch_row(maxrows=0) == ((1,), (0,)) + + _assert_thread_id_blocked_during_operation(conn, fetch_all_rows) + finally: + conn.close()