diff --git a/core/smn_database.cpp b/core/smn_database.cpp index 3d884aa1..3c048238 100644 --- a/core/smn_database.cpp +++ b/core/smn_database.cpp @@ -36,9 +36,15 @@ #include "PluginSys.h" #include "sm_stringutil.h" -HandleType_t hQueryType; HandleType_t hStmtType; +HandleType_t hCombinedQueryType; +typedef struct +{ + IQuery *query; + IDatabase *db; +} CombinedQuery; + class DatabaseHelpers : public SMGlobalClass, public IHandleTypeDispatch @@ -57,22 +63,23 @@ public: g_HandleSys.InitAccessDefaults(&tacc, NULL); tacc.ident = g_pCoreIdent; - hQueryType = g_HandleSys.CreateType("IQuery", this, 0, &tacc, &acc, g_pCoreIdent, NULL); - hStmtType = g_HandleSys.CreateType("IPreparedQuery", this, hQueryType, &tacc, &acc, g_pCoreIdent, NULL); + hCombinedQueryType = g_HandleSys.CreateType("IQuery", this, 0, &tacc, &acc, g_pCoreIdent, NULL); + hStmtType = g_HandleSys.CreateType("IPreparedQuery", this, 0, &tacc, &acc, g_pCoreIdent, NULL); } virtual void OnSourceModShutdown() { g_HandleSys.RemoveType(hStmtType, g_pCoreIdent); - g_HandleSys.RemoveType(hQueryType, g_pCoreIdent); + g_HandleSys.RemoveType(hCombinedQueryType, g_pCoreIdent); } virtual void OnHandleDestroy(HandleType_t type, void *object) { - if (type == hQueryType) + if (type == hCombinedQueryType) { - IQuery *query = (IQuery *)object; - query->Destroy(); + CombinedQuery *combined = (CombinedQuery *)object; + combined->query->Destroy(); + delete combined; } else if (type == hStmtType) { IPreparedQuery *query = (IPreparedQuery *)object; query->Destroy(); @@ -84,10 +91,37 @@ public: inline HandleError ReadQueryHndl(Handle_t hndl, IPluginContext *pContext, IQuery **query) { HandleSecurity sec; + CombinedQuery *c; sec.pOwner = pContext->GetIdentity(); sec.pIdentity = g_pCoreIdent; - return g_HandleSys.ReadHandle(hndl, hQueryType, &sec, (void **)query); + HandleError ret; + + if ((ret = g_HandleSys.ReadHandle(hndl, hStmtType, &sec, (void **)query)) != HandleError_None) + { + ret = g_HandleSys.ReadHandle(hndl, hCombinedQueryType, &sec, (void **)&c); + if (ret == HandleError_None) + { + *query = c->query; + } + } + return ret; +} + +inline HandleError ReadQueryAndDbHndl(Handle_t hndl, IPluginContext *pContext, IQuery **query, IDatabase **db) +{ + HandleSecurity sec; + CombinedQuery *c; + sec.pOwner = pContext->GetIdentity(); + sec.pIdentity = g_pCoreIdent; + + HandleError ret = g_HandleSys.ReadHandle(hndl, hCombinedQueryType, &sec, (void **)&c); + if (ret == HandleError_None) + { + *query = c->query; + *db = c->db; + } + return ret; } inline HandleError ReadStmtHndl(Handle_t hndl, IPluginContext *pContext, IPreparedQuery **query) @@ -192,12 +226,17 @@ public: if (m_pQuery) { - qh = g_HandleSys.CreateHandle(hQueryType, m_pQuery, me->GetIdentity(), g_pCoreIdent, NULL); + CombinedQuery *c = new CombinedQuery; + c->query = m_pQuery; + c->db = m_pDatabase; + + qh = g_HandleSys.CreateHandle(hCombinedQueryType, c, me->GetIdentity(), g_pCoreIdent, NULL); if (qh != BAD_HANDLE) { m_pQuery = NULL; } else { UTIL_Format(error, sizeof(error), "Could not alloc handle"); + delete c; } } @@ -537,42 +576,59 @@ static cell_t SQL_GetAffectedRows(IPluginContext *pContext, const cell_t *params { IDatabase *db = NULL; IPreparedQuery *stmt = NULL; + IQuery *query = NULL; HandleError err; - if ((err = ReadDbOrStmtHndl(params[1], pContext, &db, &stmt)) != HandleError_None) + if (((err = ReadDbOrStmtHndl(params[1], pContext, &db, &stmt)) != HandleError_None) + && ((err = ReadQueryAndDbHndl(params[1], pContext, &query, &db)) != HandleError_None)) { - return pContext->ThrowNativeError("Invalid statement or db Handle %x (error: %d)", params[1], err); + return pContext->ThrowNativeError("Invalid statement, db, or query Handle %x (error: %d)", params[1], err); } - if (db) + + if (stmt) { - return db->GetAffectedRows(); - } else if (stmt) { return stmt->GetAffectedRows(); } + else if (query) + { + return db->GetAffectedRowsForQuery(query); + } + else if (db) + { + return db->GetAffectedRows(); + } - return pContext->ThrowNativeError("Unknown error reading db/stmt handles"); + return pContext->ThrowNativeError("Unknown error reading db/stmt/query handles"); } static cell_t SQL_GetInsertId(IPluginContext *pContext, const cell_t *params) { IDatabase *db = NULL; + IQuery *query = NULL; IPreparedQuery *stmt = NULL; HandleError err; - if ((err = ReadDbOrStmtHndl(params[1], pContext, &db, &stmt)) != HandleError_None) + if (((err = ReadDbOrStmtHndl(params[1], pContext, &db, &stmt)) != HandleError_None) + && ((err = ReadQueryAndDbHndl(params[1], pContext, &query, &db)) != HandleError_None)) { - return pContext->ThrowNativeError("Invalid statement or db Handle %x (error: %d)", params[1], err); + return pContext->ThrowNativeError("Invalid statement, db, or query Handle %x (error: %d)", params[1], err); } - if (db) + if (query) + { + return db->GetInsertIDForQuery(query); + } + else if (db) { return db->GetInsertID(); - } else if (stmt) { + } + else if (stmt) + { return stmt->GetInsertID(); } - return pContext->ThrowNativeError("Unknown error reading db/stmt handles"); + return pContext->ThrowNativeError("Unknown error reading db/stmt/query handles"); } static cell_t SQL_GetError(IPluginContext *pContext, const cell_t *params) @@ -682,10 +738,14 @@ static cell_t SQL_Query(IPluginContext *pContext, const cell_t *params) return BAD_HANDLE; } - Handle_t hndl = g_HandleSys.CreateHandle(hQueryType, qr, pContext->GetIdentity(), g_pCoreIdent, NULL); + CombinedQuery *c = new CombinedQuery; + c->query = qr; + c->db = db; + Handle_t hndl = g_HandleSys.CreateHandle(hCombinedQueryType, c, pContext->GetIdentity(), g_pCoreIdent, NULL); if (hndl == BAD_HANDLE) { qr->Destroy(); + delete c; return BAD_HANDLE; } diff --git a/extensions/mysql/mysql/MyBasicResults.cpp b/extensions/mysql/mysql/MyBasicResults.cpp index edd03f1b..41e20dc4 100644 --- a/extensions/mysql/mysql/MyBasicResults.cpp +++ b/extensions/mysql/mysql/MyBasicResults.cpp @@ -307,6 +307,8 @@ MyQuery::MyQuery(MyDatabase *db, MYSQL_RES *res) : m_pParent(db), m_rs(res) { m_pParent->IncReferenceCount(); + m_InsertID = m_pParent->GetInsertID(); + m_AffectedRows = m_pParent->GetAffectedRows(); } IResultSet *MyQuery::GetResultSet() @@ -319,6 +321,16 @@ IResultSet *MyQuery::GetResultSet() return &m_rs; } +unsigned int MyQuery::GetInsertID() +{ + return m_InsertID; +} + +unsigned int MyQuery::GetAffectedRows() +{ + return m_AffectedRows; +} + bool MyQuery::FetchMoreResults() { if (m_rs.m_pRes == NULL) diff --git a/extensions/mysql/mysql/MyBasicResults.h b/extensions/mysql/mysql/MyBasicResults.h index a5173ade..ac418095 100644 --- a/extensions/mysql/mysql/MyBasicResults.h +++ b/extensions/mysql/mysql/MyBasicResults.h @@ -87,9 +87,14 @@ public: IResultSet *GetResultSet(); bool FetchMoreResults(); void Destroy(); +public: // Used by the driver to implement GetInsertIDForQuery()/GetAffectedRowsForQuery() + unsigned int GetInsertID(); + unsigned int GetAffectedRows(); private: MyDatabase *m_pParent; MyBasicResults m_rs; + unsigned int m_InsertID; + unsigned int m_AffectedRows; }; #endif //_INCLUDE_SM_MYSQL_BASIC_RESULTS_H_ diff --git a/extensions/mysql/mysql/MyDatabase.cpp b/extensions/mysql/mysql/MyDatabase.cpp index 5b15d89d..e62ccbfd 100644 --- a/extensions/mysql/mysql/MyDatabase.cpp +++ b/extensions/mysql/mysql/MyDatabase.cpp @@ -257,6 +257,16 @@ IQuery *MyDatabase::DoQueryEx(const char *query, size_t len) return new MyQuery(this, res); } +unsigned int MyDatabase::GetAffectedRowsForQuery(IQuery *query) +{ + return static_cast(query)->GetAffectedRows(); +} + +unsigned int MyDatabase::GetInsertIDForQuery(IQuery *query) +{ + return static_cast(query)->GetInsertID(); +} + IPreparedQuery *MyDatabase::PrepareQuery(const char *query, char *error, size_t maxlength, int *errCode) { MYSQL_STMT *stmt = mysql_stmt_init(m_mysql); diff --git a/extensions/mysql/mysql/MyDatabase.h b/extensions/mysql/mysql/MyDatabase.h index 48b71647..12420ab2 100644 --- a/extensions/mysql/mysql/MyDatabase.h +++ b/extensions/mysql/mysql/MyDatabase.h @@ -60,6 +60,8 @@ public: //IDatabase IDBDriver *GetDriver(); bool DoSimpleQueryEx(const char *query, size_t len); IQuery *DoQueryEx(const char *query, size_t len); + unsigned int GetAffectedRowsForQuery(IQuery *query); + unsigned int GetInsertIDForQuery(IQuery *query); public: const DatabaseInfo &GetInfo(); private: diff --git a/extensions/sqlite/driver/SqDatabase.cpp b/extensions/sqlite/driver/SqDatabase.cpp index 07724238..dfe0325f 100644 --- a/extensions/sqlite/driver/SqDatabase.cpp +++ b/extensions/sqlite/driver/SqDatabase.cpp @@ -186,6 +186,15 @@ IQuery *SqDatabase::DoQueryEx(const char *query, size_t len) return pQuery; } +unsigned int SqDatabase::GetAffectedRowsForQuery(IQuery *query) +{ + return static_cast(query)->GetAffectedRows(); +} +unsigned int SqDatabase::GetInsertIDForQuery(IQuery *query) +{ + return static_cast(query)->GetInsertID(); +} + IPreparedQuery *SqDatabase::PrepareQuery(const char *query, char *error, size_t maxlength, diff --git a/extensions/sqlite/driver/SqDatabase.h b/extensions/sqlite/driver/SqDatabase.h index 591bdd9d..ec016077 100644 --- a/extensions/sqlite/driver/SqDatabase.h +++ b/extensions/sqlite/driver/SqDatabase.h @@ -56,6 +56,8 @@ public: IDBDriver *GetDriver(); bool DoSimpleQueryEx(const char *query, size_t len); IQuery *DoQueryEx(const char *query, size_t len); + unsigned int GetAffectedRowsForQuery(IQuery *query); + unsigned int GetInsertIDForQuery(IQuery *query); public: sqlite3 *GetDb(); private: diff --git a/plugins/include/dbi.inc b/plugins/include/dbi.inc index 34e619d5..71a784bd 100644 --- a/plugins/include/dbi.inc +++ b/plugins/include/dbi.inc @@ -261,20 +261,20 @@ native SQL_GetAffectedRows(Handle:hndl); /** * Returns the last query's insertion id. * - * @param hndl A database OR statement Handle. + * @param hndl A database, query, OR statement Handle. * @return Last query's insertion id. - * @error Invalid database or statement Handle. + * @error Invalid database, query, or statement Handle. */ native SQL_GetInsertId(Handle:hndl); /** * Returns the error reported by the last query. * - * @param hndl A database OR statement Handle. + * @param hndl A database, query, OR statement Handle. * @param error Error buffer. * @param maxlength Maximum length of the buffer. * @return True if there was an error, false otherwise. - * @error Invalid database or statement Handle. + * @error Invalid database, query, or statement Handle. */ native bool:SQL_GetError(Handle:hndl, String:error[], maxlength); diff --git a/public/IDBDriver.h b/public/IDBDriver.h index bf02b7f7..adad660a 100644 --- a/public/IDBDriver.h +++ b/public/IDBDriver.h @@ -42,7 +42,7 @@ */ #define SMINTERFACE_DBI_NAME "IDBI" -#define SMINTERFACE_DBI_VERSION 7 +#define SMINTERFACE_DBI_VERSION 8 namespace SourceMod { @@ -573,6 +573,33 @@ namespace SourceMod * @return IQuery pointer on success, NULL otherwise. */ virtual IQuery *DoQueryEx(const char *query, size_t len) =0; + + /** + * @brief Retrieves the number of affected rows from the last execute of + * the given query + * + * Note: This can only accept queries from this driver. + * + * This function is not thread safe and must be included in any locks. + * + * @param query IQuery object from this driver + * @return Rows affected from last execution of this query, + * if applicable. + */ + virtual unsigned int GetAffectedRowsForQuery(IQuery *query) =0; + + /** + * @brief Retrieves the last insert id of the given query + * + * Note: This can only accept queries from this driver. + * + * This function is not thread safe and must be included in any locks. + * + * @param query IQuery object from this driver + * @return Insert Id from the last execution of this query, + * if applicable. + */ + virtual unsigned int GetInsertIDForQuery(IQuery *query) =0; }; /**