Make SQL_LastInsertId and SQL_GetAffectedRows work on query handles, allowing their use with threaded queries

This commit is contained in:
John Schoenick 2011-05-14 20:21:37 -07:00
parent 5b45e29533
commit e8c141d775
9 changed files with 153 additions and 26 deletions

View File

@ -36,9 +36,15 @@
#include "PluginSys.h"
#include "sm_stringutil.h"
HandleType_t hQueryType;
HandleType_t hStmtType;
HandleType_t hCombinedQueryType;
typedef struct
{
IQuery *query;
IDatabase *db;
} CombinedQuery;
class DatabaseHelpers :
public SMGlobalClass,
public IHandleTypeDispatch
@ -57,22 +63,23 @@ public:
g_HandleSys.InitAccessDefaults(&tacc, NULL);
tacc.ident = g_pCoreIdent;
hQueryType = g_HandleSys.CreateType("IQuery", this, 0, &tacc, &acc, g_pCoreIdent, NULL);
hStmtType = g_HandleSys.CreateType("IPreparedQuery", this, hQueryType, &tacc, &acc, g_pCoreIdent, NULL);
hCombinedQueryType = g_HandleSys.CreateType("IQuery", this, 0, &tacc, &acc, g_pCoreIdent, NULL);
hStmtType = g_HandleSys.CreateType("IPreparedQuery", this, 0, &tacc, &acc, g_pCoreIdent, NULL);
}
virtual void OnSourceModShutdown()
{
g_HandleSys.RemoveType(hStmtType, g_pCoreIdent);
g_HandleSys.RemoveType(hQueryType, g_pCoreIdent);
g_HandleSys.RemoveType(hCombinedQueryType, g_pCoreIdent);
}
virtual void OnHandleDestroy(HandleType_t type, void *object)
{
if (type == hQueryType)
if (type == hCombinedQueryType)
{
IQuery *query = (IQuery *)object;
query->Destroy();
CombinedQuery *combined = (CombinedQuery *)object;
combined->query->Destroy();
delete combined;
} else if (type == hStmtType) {
IPreparedQuery *query = (IPreparedQuery *)object;
query->Destroy();
@ -84,10 +91,37 @@ public:
inline HandleError ReadQueryHndl(Handle_t hndl, IPluginContext *pContext, IQuery **query)
{
HandleSecurity sec;
CombinedQuery *c;
sec.pOwner = pContext->GetIdentity();
sec.pIdentity = g_pCoreIdent;
return g_HandleSys.ReadHandle(hndl, hQueryType, &sec, (void **)query);
HandleError ret;
if ((ret = g_HandleSys.ReadHandle(hndl, hStmtType, &sec, (void **)query)) != HandleError_None)
{
ret = g_HandleSys.ReadHandle(hndl, hCombinedQueryType, &sec, (void **)&c);
if (ret == HandleError_None)
{
*query = c->query;
}
}
return ret;
}
inline HandleError ReadQueryAndDbHndl(Handle_t hndl, IPluginContext *pContext, IQuery **query, IDatabase **db)
{
HandleSecurity sec;
CombinedQuery *c;
sec.pOwner = pContext->GetIdentity();
sec.pIdentity = g_pCoreIdent;
HandleError ret = g_HandleSys.ReadHandle(hndl, hCombinedQueryType, &sec, (void **)&c);
if (ret == HandleError_None)
{
*query = c->query;
*db = c->db;
}
return ret;
}
inline HandleError ReadStmtHndl(Handle_t hndl, IPluginContext *pContext, IPreparedQuery **query)
@ -192,12 +226,17 @@ public:
if (m_pQuery)
{
qh = g_HandleSys.CreateHandle(hQueryType, m_pQuery, me->GetIdentity(), g_pCoreIdent, NULL);
CombinedQuery *c = new CombinedQuery;
c->query = m_pQuery;
c->db = m_pDatabase;
qh = g_HandleSys.CreateHandle(hCombinedQueryType, c, me->GetIdentity(), g_pCoreIdent, NULL);
if (qh != BAD_HANDLE)
{
m_pQuery = NULL;
} else {
UTIL_Format(error, sizeof(error), "Could not alloc handle");
delete c;
}
}
@ -537,42 +576,59 @@ static cell_t SQL_GetAffectedRows(IPluginContext *pContext, const cell_t *params
{
IDatabase *db = NULL;
IPreparedQuery *stmt = NULL;
IQuery *query = NULL;
HandleError err;
if ((err = ReadDbOrStmtHndl(params[1], pContext, &db, &stmt)) != HandleError_None)
if (((err = ReadDbOrStmtHndl(params[1], pContext, &db, &stmt)) != HandleError_None)
&& ((err = ReadQueryAndDbHndl(params[1], pContext, &query, &db)) != HandleError_None))
{
return pContext->ThrowNativeError("Invalid statement or db Handle %x (error: %d)", params[1], err);
return pContext->ThrowNativeError("Invalid statement, db, or query Handle %x (error: %d)", params[1], err);
}
if (db)
if (stmt)
{
return db->GetAffectedRows();
} else if (stmt) {
return stmt->GetAffectedRows();
}
else if (query)
{
return db->GetAffectedRowsForQuery(query);
}
else if (db)
{
return db->GetAffectedRows();
}
return pContext->ThrowNativeError("Unknown error reading db/stmt handles");
return pContext->ThrowNativeError("Unknown error reading db/stmt/query handles");
}
static cell_t SQL_GetInsertId(IPluginContext *pContext, const cell_t *params)
{
IDatabase *db = NULL;
IQuery *query = NULL;
IPreparedQuery *stmt = NULL;
HandleError err;
if ((err = ReadDbOrStmtHndl(params[1], pContext, &db, &stmt)) != HandleError_None)
if (((err = ReadDbOrStmtHndl(params[1], pContext, &db, &stmt)) != HandleError_None)
&& ((err = ReadQueryAndDbHndl(params[1], pContext, &query, &db)) != HandleError_None))
{
return pContext->ThrowNativeError("Invalid statement or db Handle %x (error: %d)", params[1], err);
return pContext->ThrowNativeError("Invalid statement, db, or query Handle %x (error: %d)", params[1], err);
}
if (db)
if (query)
{
return db->GetInsertIDForQuery(query);
}
else if (db)
{
return db->GetInsertID();
} else if (stmt) {
}
else if (stmt)
{
return stmt->GetInsertID();
}
return pContext->ThrowNativeError("Unknown error reading db/stmt handles");
return pContext->ThrowNativeError("Unknown error reading db/stmt/query handles");
}
static cell_t SQL_GetError(IPluginContext *pContext, const cell_t *params)
@ -682,10 +738,14 @@ static cell_t SQL_Query(IPluginContext *pContext, const cell_t *params)
return BAD_HANDLE;
}
Handle_t hndl = g_HandleSys.CreateHandle(hQueryType, qr, pContext->GetIdentity(), g_pCoreIdent, NULL);
CombinedQuery *c = new CombinedQuery;
c->query = qr;
c->db = db;
Handle_t hndl = g_HandleSys.CreateHandle(hCombinedQueryType, c, pContext->GetIdentity(), g_pCoreIdent, NULL);
if (hndl == BAD_HANDLE)
{
qr->Destroy();
delete c;
return BAD_HANDLE;
}

View File

@ -307,6 +307,8 @@ MyQuery::MyQuery(MyDatabase *db, MYSQL_RES *res)
: m_pParent(db), m_rs(res)
{
m_pParent->IncReferenceCount();
m_InsertID = m_pParent->GetInsertID();
m_AffectedRows = m_pParent->GetAffectedRows();
}
IResultSet *MyQuery::GetResultSet()
@ -319,6 +321,16 @@ IResultSet *MyQuery::GetResultSet()
return &m_rs;
}
unsigned int MyQuery::GetInsertID()
{
return m_InsertID;
}
unsigned int MyQuery::GetAffectedRows()
{
return m_AffectedRows;
}
bool MyQuery::FetchMoreResults()
{
if (m_rs.m_pRes == NULL)

View File

@ -87,9 +87,14 @@ public:
IResultSet *GetResultSet();
bool FetchMoreResults();
void Destroy();
public: // Used by the driver to implement GetInsertIDForQuery()/GetAffectedRowsForQuery()
unsigned int GetInsertID();
unsigned int GetAffectedRows();
private:
MyDatabase *m_pParent;
MyBasicResults m_rs;
unsigned int m_InsertID;
unsigned int m_AffectedRows;
};
#endif //_INCLUDE_SM_MYSQL_BASIC_RESULTS_H_

View File

@ -257,6 +257,16 @@ IQuery *MyDatabase::DoQueryEx(const char *query, size_t len)
return new MyQuery(this, res);
}
unsigned int MyDatabase::GetAffectedRowsForQuery(IQuery *query)
{
return static_cast<MyQuery*>(query)->GetAffectedRows();
}
unsigned int MyDatabase::GetInsertIDForQuery(IQuery *query)
{
return static_cast<MyQuery*>(query)->GetInsertID();
}
IPreparedQuery *MyDatabase::PrepareQuery(const char *query, char *error, size_t maxlength, int *errCode)
{
MYSQL_STMT *stmt = mysql_stmt_init(m_mysql);

View File

@ -60,6 +60,8 @@ public: //IDatabase
IDBDriver *GetDriver();
bool DoSimpleQueryEx(const char *query, size_t len);
IQuery *DoQueryEx(const char *query, size_t len);
unsigned int GetAffectedRowsForQuery(IQuery *query);
unsigned int GetInsertIDForQuery(IQuery *query);
public:
const DatabaseInfo &GetInfo();
private:

View File

@ -186,6 +186,15 @@ IQuery *SqDatabase::DoQueryEx(const char *query, size_t len)
return pQuery;
}
unsigned int SqDatabase::GetAffectedRowsForQuery(IQuery *query)
{
return static_cast<SqQuery*>(query)->GetAffectedRows();
}
unsigned int SqDatabase::GetInsertIDForQuery(IQuery *query)
{
return static_cast<SqQuery*>(query)->GetInsertID();
}
IPreparedQuery *SqDatabase::PrepareQuery(const char *query,
char *error,
size_t maxlength,

View File

@ -56,6 +56,8 @@ public:
IDBDriver *GetDriver();
bool DoSimpleQueryEx(const char *query, size_t len);
IQuery *DoQueryEx(const char *query, size_t len);
unsigned int GetAffectedRowsForQuery(IQuery *query);
unsigned int GetInsertIDForQuery(IQuery *query);
public:
sqlite3 *GetDb();
private:

View File

@ -261,20 +261,20 @@ native SQL_GetAffectedRows(Handle:hndl);
/**
* Returns the last query's insertion id.
*
* @param hndl A database OR statement Handle.
* @param hndl A database, query, OR statement Handle.
* @return Last query's insertion id.
* @error Invalid database or statement Handle.
* @error Invalid database, query, or statement Handle.
*/
native SQL_GetInsertId(Handle:hndl);
/**
* Returns the error reported by the last query.
*
* @param hndl A database OR statement Handle.
* @param hndl A database, query, OR statement Handle.
* @param error Error buffer.
* @param maxlength Maximum length of the buffer.
* @return True if there was an error, false otherwise.
* @error Invalid database or statement Handle.
* @error Invalid database, query, or statement Handle.
*/
native bool:SQL_GetError(Handle:hndl, String:error[], maxlength);

View File

@ -42,7 +42,7 @@
*/
#define SMINTERFACE_DBI_NAME "IDBI"
#define SMINTERFACE_DBI_VERSION 7
#define SMINTERFACE_DBI_VERSION 8
namespace SourceMod
{
@ -573,6 +573,33 @@ namespace SourceMod
* @return IQuery pointer on success, NULL otherwise.
*/
virtual IQuery *DoQueryEx(const char *query, size_t len) =0;
/**
* @brief Retrieves the number of affected rows from the last execute of
* the given query
*
* Note: This can only accept queries from this driver.
*
* This function is not thread safe and must be included in any locks.
*
* @param query IQuery object from this driver
* @return Rows affected from last execution of this query,
* if applicable.
*/
virtual unsigned int GetAffectedRowsForQuery(IQuery *query) =0;
/**
* @brief Retrieves the last insert id of the given query
*
* Note: This can only accept queries from this driver.
*
* This function is not thread safe and must be included in any locks.
*
* @param query IQuery object from this driver
* @return Insert Id from the last execution of this query,
* if applicable.
*/
virtual unsigned int GetInsertIDForQuery(IQuery *query) =0;
};
/**