Make SQL_LastInsertId and SQL_GetAffectedRows work on query handles, allowing their use with threaded queries

This commit is contained in:
John Schoenick 2011-05-14 20:21:37 -07:00
parent 5b45e29533
commit e8c141d775
9 changed files with 153 additions and 26 deletions

View File

@ -36,9 +36,15 @@
#include "PluginSys.h" #include "PluginSys.h"
#include "sm_stringutil.h" #include "sm_stringutil.h"
HandleType_t hQueryType;
HandleType_t hStmtType; HandleType_t hStmtType;
HandleType_t hCombinedQueryType;
typedef struct
{
IQuery *query;
IDatabase *db;
} CombinedQuery;
class DatabaseHelpers : class DatabaseHelpers :
public SMGlobalClass, public SMGlobalClass,
public IHandleTypeDispatch public IHandleTypeDispatch
@ -57,22 +63,23 @@ public:
g_HandleSys.InitAccessDefaults(&tacc, NULL); g_HandleSys.InitAccessDefaults(&tacc, NULL);
tacc.ident = g_pCoreIdent; tacc.ident = g_pCoreIdent;
hQueryType = g_HandleSys.CreateType("IQuery", this, 0, &tacc, &acc, g_pCoreIdent, NULL); hCombinedQueryType = g_HandleSys.CreateType("IQuery", this, 0, &tacc, &acc, g_pCoreIdent, NULL);
hStmtType = g_HandleSys.CreateType("IPreparedQuery", this, hQueryType, &tacc, &acc, g_pCoreIdent, NULL); hStmtType = g_HandleSys.CreateType("IPreparedQuery", this, 0, &tacc, &acc, g_pCoreIdent, NULL);
} }
virtual void OnSourceModShutdown() virtual void OnSourceModShutdown()
{ {
g_HandleSys.RemoveType(hStmtType, g_pCoreIdent); g_HandleSys.RemoveType(hStmtType, g_pCoreIdent);
g_HandleSys.RemoveType(hQueryType, g_pCoreIdent); g_HandleSys.RemoveType(hCombinedQueryType, g_pCoreIdent);
} }
virtual void OnHandleDestroy(HandleType_t type, void *object) virtual void OnHandleDestroy(HandleType_t type, void *object)
{ {
if (type == hQueryType) if (type == hCombinedQueryType)
{ {
IQuery *query = (IQuery *)object; CombinedQuery *combined = (CombinedQuery *)object;
query->Destroy(); combined->query->Destroy();
delete combined;
} else if (type == hStmtType) { } else if (type == hStmtType) {
IPreparedQuery *query = (IPreparedQuery *)object; IPreparedQuery *query = (IPreparedQuery *)object;
query->Destroy(); query->Destroy();
@ -84,10 +91,37 @@ public:
inline HandleError ReadQueryHndl(Handle_t hndl, IPluginContext *pContext, IQuery **query) inline HandleError ReadQueryHndl(Handle_t hndl, IPluginContext *pContext, IQuery **query)
{ {
HandleSecurity sec; HandleSecurity sec;
CombinedQuery *c;
sec.pOwner = pContext->GetIdentity(); sec.pOwner = pContext->GetIdentity();
sec.pIdentity = g_pCoreIdent; sec.pIdentity = g_pCoreIdent;
return g_HandleSys.ReadHandle(hndl, hQueryType, &sec, (void **)query); HandleError ret;
if ((ret = g_HandleSys.ReadHandle(hndl, hStmtType, &sec, (void **)query)) != HandleError_None)
{
ret = g_HandleSys.ReadHandle(hndl, hCombinedQueryType, &sec, (void **)&c);
if (ret == HandleError_None)
{
*query = c->query;
}
}
return ret;
}
inline HandleError ReadQueryAndDbHndl(Handle_t hndl, IPluginContext *pContext, IQuery **query, IDatabase **db)
{
HandleSecurity sec;
CombinedQuery *c;
sec.pOwner = pContext->GetIdentity();
sec.pIdentity = g_pCoreIdent;
HandleError ret = g_HandleSys.ReadHandle(hndl, hCombinedQueryType, &sec, (void **)&c);
if (ret == HandleError_None)
{
*query = c->query;
*db = c->db;
}
return ret;
} }
inline HandleError ReadStmtHndl(Handle_t hndl, IPluginContext *pContext, IPreparedQuery **query) inline HandleError ReadStmtHndl(Handle_t hndl, IPluginContext *pContext, IPreparedQuery **query)
@ -192,12 +226,17 @@ public:
if (m_pQuery) if (m_pQuery)
{ {
qh = g_HandleSys.CreateHandle(hQueryType, m_pQuery, me->GetIdentity(), g_pCoreIdent, NULL); CombinedQuery *c = new CombinedQuery;
c->query = m_pQuery;
c->db = m_pDatabase;
qh = g_HandleSys.CreateHandle(hCombinedQueryType, c, me->GetIdentity(), g_pCoreIdent, NULL);
if (qh != BAD_HANDLE) if (qh != BAD_HANDLE)
{ {
m_pQuery = NULL; m_pQuery = NULL;
} else { } else {
UTIL_Format(error, sizeof(error), "Could not alloc handle"); UTIL_Format(error, sizeof(error), "Could not alloc handle");
delete c;
} }
} }
@ -537,42 +576,59 @@ static cell_t SQL_GetAffectedRows(IPluginContext *pContext, const cell_t *params
{ {
IDatabase *db = NULL; IDatabase *db = NULL;
IPreparedQuery *stmt = NULL; IPreparedQuery *stmt = NULL;
IQuery *query = NULL;
HandleError err; HandleError err;
if ((err = ReadDbOrStmtHndl(params[1], pContext, &db, &stmt)) != HandleError_None) if (((err = ReadDbOrStmtHndl(params[1], pContext, &db, &stmt)) != HandleError_None)
&& ((err = ReadQueryAndDbHndl(params[1], pContext, &query, &db)) != HandleError_None))
{ {
return pContext->ThrowNativeError("Invalid statement or db Handle %x (error: %d)", params[1], err); return pContext->ThrowNativeError("Invalid statement, db, or query Handle %x (error: %d)", params[1], err);
} }
if (db)
if (stmt)
{ {
return db->GetAffectedRows();
} else if (stmt) {
return stmt->GetAffectedRows(); return stmt->GetAffectedRows();
} }
else if (query)
{
return db->GetAffectedRowsForQuery(query);
}
else if (db)
{
return db->GetAffectedRows();
}
return pContext->ThrowNativeError("Unknown error reading db/stmt handles"); return pContext->ThrowNativeError("Unknown error reading db/stmt/query handles");
} }
static cell_t SQL_GetInsertId(IPluginContext *pContext, const cell_t *params) static cell_t SQL_GetInsertId(IPluginContext *pContext, const cell_t *params)
{ {
IDatabase *db = NULL; IDatabase *db = NULL;
IQuery *query = NULL;
IPreparedQuery *stmt = NULL; IPreparedQuery *stmt = NULL;
HandleError err; HandleError err;
if ((err = ReadDbOrStmtHndl(params[1], pContext, &db, &stmt)) != HandleError_None) if (((err = ReadDbOrStmtHndl(params[1], pContext, &db, &stmt)) != HandleError_None)
&& ((err = ReadQueryAndDbHndl(params[1], pContext, &query, &db)) != HandleError_None))
{ {
return pContext->ThrowNativeError("Invalid statement or db Handle %x (error: %d)", params[1], err); return pContext->ThrowNativeError("Invalid statement, db, or query Handle %x (error: %d)", params[1], err);
} }
if (db) if (query)
{
return db->GetInsertIDForQuery(query);
}
else if (db)
{ {
return db->GetInsertID(); return db->GetInsertID();
} else if (stmt) { }
else if (stmt)
{
return stmt->GetInsertID(); return stmt->GetInsertID();
} }
return pContext->ThrowNativeError("Unknown error reading db/stmt handles"); return pContext->ThrowNativeError("Unknown error reading db/stmt/query handles");
} }
static cell_t SQL_GetError(IPluginContext *pContext, const cell_t *params) static cell_t SQL_GetError(IPluginContext *pContext, const cell_t *params)
@ -682,10 +738,14 @@ static cell_t SQL_Query(IPluginContext *pContext, const cell_t *params)
return BAD_HANDLE; return BAD_HANDLE;
} }
Handle_t hndl = g_HandleSys.CreateHandle(hQueryType, qr, pContext->GetIdentity(), g_pCoreIdent, NULL); CombinedQuery *c = new CombinedQuery;
c->query = qr;
c->db = db;
Handle_t hndl = g_HandleSys.CreateHandle(hCombinedQueryType, c, pContext->GetIdentity(), g_pCoreIdent, NULL);
if (hndl == BAD_HANDLE) if (hndl == BAD_HANDLE)
{ {
qr->Destroy(); qr->Destroy();
delete c;
return BAD_HANDLE; return BAD_HANDLE;
} }

View File

@ -307,6 +307,8 @@ MyQuery::MyQuery(MyDatabase *db, MYSQL_RES *res)
: m_pParent(db), m_rs(res) : m_pParent(db), m_rs(res)
{ {
m_pParent->IncReferenceCount(); m_pParent->IncReferenceCount();
m_InsertID = m_pParent->GetInsertID();
m_AffectedRows = m_pParent->GetAffectedRows();
} }
IResultSet *MyQuery::GetResultSet() IResultSet *MyQuery::GetResultSet()
@ -319,6 +321,16 @@ IResultSet *MyQuery::GetResultSet()
return &m_rs; return &m_rs;
} }
unsigned int MyQuery::GetInsertID()
{
return m_InsertID;
}
unsigned int MyQuery::GetAffectedRows()
{
return m_AffectedRows;
}
bool MyQuery::FetchMoreResults() bool MyQuery::FetchMoreResults()
{ {
if (m_rs.m_pRes == NULL) if (m_rs.m_pRes == NULL)

View File

@ -87,9 +87,14 @@ public:
IResultSet *GetResultSet(); IResultSet *GetResultSet();
bool FetchMoreResults(); bool FetchMoreResults();
void Destroy(); void Destroy();
public: // Used by the driver to implement GetInsertIDForQuery()/GetAffectedRowsForQuery()
unsigned int GetInsertID();
unsigned int GetAffectedRows();
private: private:
MyDatabase *m_pParent; MyDatabase *m_pParent;
MyBasicResults m_rs; MyBasicResults m_rs;
unsigned int m_InsertID;
unsigned int m_AffectedRows;
}; };
#endif //_INCLUDE_SM_MYSQL_BASIC_RESULTS_H_ #endif //_INCLUDE_SM_MYSQL_BASIC_RESULTS_H_

View File

@ -257,6 +257,16 @@ IQuery *MyDatabase::DoQueryEx(const char *query, size_t len)
return new MyQuery(this, res); return new MyQuery(this, res);
} }
unsigned int MyDatabase::GetAffectedRowsForQuery(IQuery *query)
{
return static_cast<MyQuery*>(query)->GetAffectedRows();
}
unsigned int MyDatabase::GetInsertIDForQuery(IQuery *query)
{
return static_cast<MyQuery*>(query)->GetInsertID();
}
IPreparedQuery *MyDatabase::PrepareQuery(const char *query, char *error, size_t maxlength, int *errCode) IPreparedQuery *MyDatabase::PrepareQuery(const char *query, char *error, size_t maxlength, int *errCode)
{ {
MYSQL_STMT *stmt = mysql_stmt_init(m_mysql); MYSQL_STMT *stmt = mysql_stmt_init(m_mysql);

View File

@ -60,6 +60,8 @@ public: //IDatabase
IDBDriver *GetDriver(); IDBDriver *GetDriver();
bool DoSimpleQueryEx(const char *query, size_t len); bool DoSimpleQueryEx(const char *query, size_t len);
IQuery *DoQueryEx(const char *query, size_t len); IQuery *DoQueryEx(const char *query, size_t len);
unsigned int GetAffectedRowsForQuery(IQuery *query);
unsigned int GetInsertIDForQuery(IQuery *query);
public: public:
const DatabaseInfo &GetInfo(); const DatabaseInfo &GetInfo();
private: private:

View File

@ -186,6 +186,15 @@ IQuery *SqDatabase::DoQueryEx(const char *query, size_t len)
return pQuery; return pQuery;
} }
unsigned int SqDatabase::GetAffectedRowsForQuery(IQuery *query)
{
return static_cast<SqQuery*>(query)->GetAffectedRows();
}
unsigned int SqDatabase::GetInsertIDForQuery(IQuery *query)
{
return static_cast<SqQuery*>(query)->GetInsertID();
}
IPreparedQuery *SqDatabase::PrepareQuery(const char *query, IPreparedQuery *SqDatabase::PrepareQuery(const char *query,
char *error, char *error,
size_t maxlength, size_t maxlength,

View File

@ -56,6 +56,8 @@ public:
IDBDriver *GetDriver(); IDBDriver *GetDriver();
bool DoSimpleQueryEx(const char *query, size_t len); bool DoSimpleQueryEx(const char *query, size_t len);
IQuery *DoQueryEx(const char *query, size_t len); IQuery *DoQueryEx(const char *query, size_t len);
unsigned int GetAffectedRowsForQuery(IQuery *query);
unsigned int GetInsertIDForQuery(IQuery *query);
public: public:
sqlite3 *GetDb(); sqlite3 *GetDb();
private: private:

View File

@ -261,20 +261,20 @@ native SQL_GetAffectedRows(Handle:hndl);
/** /**
* Returns the last query's insertion id. * Returns the last query's insertion id.
* *
* @param hndl A database OR statement Handle. * @param hndl A database, query, OR statement Handle.
* @return Last query's insertion id. * @return Last query's insertion id.
* @error Invalid database or statement Handle. * @error Invalid database, query, or statement Handle.
*/ */
native SQL_GetInsertId(Handle:hndl); native SQL_GetInsertId(Handle:hndl);
/** /**
* Returns the error reported by the last query. * Returns the error reported by the last query.
* *
* @param hndl A database OR statement Handle. * @param hndl A database, query, OR statement Handle.
* @param error Error buffer. * @param error Error buffer.
* @param maxlength Maximum length of the buffer. * @param maxlength Maximum length of the buffer.
* @return True if there was an error, false otherwise. * @return True if there was an error, false otherwise.
* @error Invalid database or statement Handle. * @error Invalid database, query, or statement Handle.
*/ */
native bool:SQL_GetError(Handle:hndl, String:error[], maxlength); native bool:SQL_GetError(Handle:hndl, String:error[], maxlength);

View File

@ -42,7 +42,7 @@
*/ */
#define SMINTERFACE_DBI_NAME "IDBI" #define SMINTERFACE_DBI_NAME "IDBI"
#define SMINTERFACE_DBI_VERSION 7 #define SMINTERFACE_DBI_VERSION 8
namespace SourceMod namespace SourceMod
{ {
@ -573,6 +573,33 @@ namespace SourceMod
* @return IQuery pointer on success, NULL otherwise. * @return IQuery pointer on success, NULL otherwise.
*/ */
virtual IQuery *DoQueryEx(const char *query, size_t len) =0; virtual IQuery *DoQueryEx(const char *query, size_t len) =0;
/**
* @brief Retrieves the number of affected rows from the last execute of
* the given query
*
* Note: This can only accept queries from this driver.
*
* This function is not thread safe and must be included in any locks.
*
* @param query IQuery object from this driver
* @return Rows affected from last execution of this query,
* if applicable.
*/
virtual unsigned int GetAffectedRowsForQuery(IQuery *query) =0;
/**
* @brief Retrieves the last insert id of the given query
*
* Note: This can only accept queries from this driver.
*
* This function is not thread safe and must be included in any locks.
*
* @param query IQuery object from this driver
* @return Insert Id from the last execution of this query,
* if applicable.
*/
virtual unsigned int GetInsertIDForQuery(IQuery *query) =0;
}; };
/** /**