Make SQL_LastInsertId and SQL_GetAffectedRows work on query handles, allowing their use with threaded queries
This commit is contained in:
parent
5b45e29533
commit
e8c141d775
@ -36,9 +36,15 @@
|
||||
#include "PluginSys.h"
|
||||
#include "sm_stringutil.h"
|
||||
|
||||
HandleType_t hQueryType;
|
||||
HandleType_t hStmtType;
|
||||
|
||||
HandleType_t hCombinedQueryType;
|
||||
typedef struct
|
||||
{
|
||||
IQuery *query;
|
||||
IDatabase *db;
|
||||
} CombinedQuery;
|
||||
|
||||
class DatabaseHelpers :
|
||||
public SMGlobalClass,
|
||||
public IHandleTypeDispatch
|
||||
@ -57,22 +63,23 @@ public:
|
||||
g_HandleSys.InitAccessDefaults(&tacc, NULL);
|
||||
tacc.ident = g_pCoreIdent;
|
||||
|
||||
hQueryType = g_HandleSys.CreateType("IQuery", this, 0, &tacc, &acc, g_pCoreIdent, NULL);
|
||||
hStmtType = g_HandleSys.CreateType("IPreparedQuery", this, hQueryType, &tacc, &acc, g_pCoreIdent, NULL);
|
||||
hCombinedQueryType = g_HandleSys.CreateType("IQuery", this, 0, &tacc, &acc, g_pCoreIdent, NULL);
|
||||
hStmtType = g_HandleSys.CreateType("IPreparedQuery", this, 0, &tacc, &acc, g_pCoreIdent, NULL);
|
||||
}
|
||||
|
||||
virtual void OnSourceModShutdown()
|
||||
{
|
||||
g_HandleSys.RemoveType(hStmtType, g_pCoreIdent);
|
||||
g_HandleSys.RemoveType(hQueryType, g_pCoreIdent);
|
||||
g_HandleSys.RemoveType(hCombinedQueryType, g_pCoreIdent);
|
||||
}
|
||||
|
||||
virtual void OnHandleDestroy(HandleType_t type, void *object)
|
||||
{
|
||||
if (type == hQueryType)
|
||||
if (type == hCombinedQueryType)
|
||||
{
|
||||
IQuery *query = (IQuery *)object;
|
||||
query->Destroy();
|
||||
CombinedQuery *combined = (CombinedQuery *)object;
|
||||
combined->query->Destroy();
|
||||
delete combined;
|
||||
} else if (type == hStmtType) {
|
||||
IPreparedQuery *query = (IPreparedQuery *)object;
|
||||
query->Destroy();
|
||||
@ -84,10 +91,37 @@ public:
|
||||
inline HandleError ReadQueryHndl(Handle_t hndl, IPluginContext *pContext, IQuery **query)
|
||||
{
|
||||
HandleSecurity sec;
|
||||
CombinedQuery *c;
|
||||
sec.pOwner = pContext->GetIdentity();
|
||||
sec.pIdentity = g_pCoreIdent;
|
||||
|
||||
return g_HandleSys.ReadHandle(hndl, hQueryType, &sec, (void **)query);
|
||||
HandleError ret;
|
||||
|
||||
if ((ret = g_HandleSys.ReadHandle(hndl, hStmtType, &sec, (void **)query)) != HandleError_None)
|
||||
{
|
||||
ret = g_HandleSys.ReadHandle(hndl, hCombinedQueryType, &sec, (void **)&c);
|
||||
if (ret == HandleError_None)
|
||||
{
|
||||
*query = c->query;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline HandleError ReadQueryAndDbHndl(Handle_t hndl, IPluginContext *pContext, IQuery **query, IDatabase **db)
|
||||
{
|
||||
HandleSecurity sec;
|
||||
CombinedQuery *c;
|
||||
sec.pOwner = pContext->GetIdentity();
|
||||
sec.pIdentity = g_pCoreIdent;
|
||||
|
||||
HandleError ret = g_HandleSys.ReadHandle(hndl, hCombinedQueryType, &sec, (void **)&c);
|
||||
if (ret == HandleError_None)
|
||||
{
|
||||
*query = c->query;
|
||||
*db = c->db;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline HandleError ReadStmtHndl(Handle_t hndl, IPluginContext *pContext, IPreparedQuery **query)
|
||||
@ -192,12 +226,17 @@ public:
|
||||
|
||||
if (m_pQuery)
|
||||
{
|
||||
qh = g_HandleSys.CreateHandle(hQueryType, m_pQuery, me->GetIdentity(), g_pCoreIdent, NULL);
|
||||
CombinedQuery *c = new CombinedQuery;
|
||||
c->query = m_pQuery;
|
||||
c->db = m_pDatabase;
|
||||
|
||||
qh = g_HandleSys.CreateHandle(hCombinedQueryType, c, me->GetIdentity(), g_pCoreIdent, NULL);
|
||||
if (qh != BAD_HANDLE)
|
||||
{
|
||||
m_pQuery = NULL;
|
||||
} else {
|
||||
UTIL_Format(error, sizeof(error), "Could not alloc handle");
|
||||
delete c;
|
||||
}
|
||||
}
|
||||
|
||||
@ -537,42 +576,59 @@ static cell_t SQL_GetAffectedRows(IPluginContext *pContext, const cell_t *params
|
||||
{
|
||||
IDatabase *db = NULL;
|
||||
IPreparedQuery *stmt = NULL;
|
||||
IQuery *query = NULL;
|
||||
HandleError err;
|
||||
|
||||
if ((err = ReadDbOrStmtHndl(params[1], pContext, &db, &stmt)) != HandleError_None)
|
||||
if (((err = ReadDbOrStmtHndl(params[1], pContext, &db, &stmt)) != HandleError_None)
|
||||
&& ((err = ReadQueryAndDbHndl(params[1], pContext, &query, &db)) != HandleError_None))
|
||||
{
|
||||
return pContext->ThrowNativeError("Invalid statement or db Handle %x (error: %d)", params[1], err);
|
||||
return pContext->ThrowNativeError("Invalid statement, db, or query Handle %x (error: %d)", params[1], err);
|
||||
}
|
||||
|
||||
if (db)
|
||||
|
||||
if (stmt)
|
||||
{
|
||||
return db->GetAffectedRows();
|
||||
} else if (stmt) {
|
||||
return stmt->GetAffectedRows();
|
||||
}
|
||||
else if (query)
|
||||
{
|
||||
return db->GetAffectedRowsForQuery(query);
|
||||
}
|
||||
else if (db)
|
||||
{
|
||||
return db->GetAffectedRows();
|
||||
}
|
||||
|
||||
return pContext->ThrowNativeError("Unknown error reading db/stmt handles");
|
||||
return pContext->ThrowNativeError("Unknown error reading db/stmt/query handles");
|
||||
}
|
||||
|
||||
static cell_t SQL_GetInsertId(IPluginContext *pContext, const cell_t *params)
|
||||
{
|
||||
IDatabase *db = NULL;
|
||||
IQuery *query = NULL;
|
||||
IPreparedQuery *stmt = NULL;
|
||||
HandleError err;
|
||||
|
||||
if ((err = ReadDbOrStmtHndl(params[1], pContext, &db, &stmt)) != HandleError_None)
|
||||
if (((err = ReadDbOrStmtHndl(params[1], pContext, &db, &stmt)) != HandleError_None)
|
||||
&& ((err = ReadQueryAndDbHndl(params[1], pContext, &query, &db)) != HandleError_None))
|
||||
{
|
||||
return pContext->ThrowNativeError("Invalid statement or db Handle %x (error: %d)", params[1], err);
|
||||
return pContext->ThrowNativeError("Invalid statement, db, or query Handle %x (error: %d)", params[1], err);
|
||||
}
|
||||
|
||||
if (db)
|
||||
if (query)
|
||||
{
|
||||
return db->GetInsertIDForQuery(query);
|
||||
}
|
||||
else if (db)
|
||||
{
|
||||
return db->GetInsertID();
|
||||
} else if (stmt) {
|
||||
}
|
||||
else if (stmt)
|
||||
{
|
||||
return stmt->GetInsertID();
|
||||
}
|
||||
|
||||
return pContext->ThrowNativeError("Unknown error reading db/stmt handles");
|
||||
return pContext->ThrowNativeError("Unknown error reading db/stmt/query handles");
|
||||
}
|
||||
|
||||
static cell_t SQL_GetError(IPluginContext *pContext, const cell_t *params)
|
||||
@ -682,10 +738,14 @@ static cell_t SQL_Query(IPluginContext *pContext, const cell_t *params)
|
||||
return BAD_HANDLE;
|
||||
}
|
||||
|
||||
Handle_t hndl = g_HandleSys.CreateHandle(hQueryType, qr, pContext->GetIdentity(), g_pCoreIdent, NULL);
|
||||
CombinedQuery *c = new CombinedQuery;
|
||||
c->query = qr;
|
||||
c->db = db;
|
||||
Handle_t hndl = g_HandleSys.CreateHandle(hCombinedQueryType, c, pContext->GetIdentity(), g_pCoreIdent, NULL);
|
||||
if (hndl == BAD_HANDLE)
|
||||
{
|
||||
qr->Destroy();
|
||||
delete c;
|
||||
return BAD_HANDLE;
|
||||
}
|
||||
|
||||
|
@ -307,6 +307,8 @@ MyQuery::MyQuery(MyDatabase *db, MYSQL_RES *res)
|
||||
: m_pParent(db), m_rs(res)
|
||||
{
|
||||
m_pParent->IncReferenceCount();
|
||||
m_InsertID = m_pParent->GetInsertID();
|
||||
m_AffectedRows = m_pParent->GetAffectedRows();
|
||||
}
|
||||
|
||||
IResultSet *MyQuery::GetResultSet()
|
||||
@ -319,6 +321,16 @@ IResultSet *MyQuery::GetResultSet()
|
||||
return &m_rs;
|
||||
}
|
||||
|
||||
unsigned int MyQuery::GetInsertID()
|
||||
{
|
||||
return m_InsertID;
|
||||
}
|
||||
|
||||
unsigned int MyQuery::GetAffectedRows()
|
||||
{
|
||||
return m_AffectedRows;
|
||||
}
|
||||
|
||||
bool MyQuery::FetchMoreResults()
|
||||
{
|
||||
if (m_rs.m_pRes == NULL)
|
||||
|
@ -87,9 +87,14 @@ public:
|
||||
IResultSet *GetResultSet();
|
||||
bool FetchMoreResults();
|
||||
void Destroy();
|
||||
public: // Used by the driver to implement GetInsertIDForQuery()/GetAffectedRowsForQuery()
|
||||
unsigned int GetInsertID();
|
||||
unsigned int GetAffectedRows();
|
||||
private:
|
||||
MyDatabase *m_pParent;
|
||||
MyBasicResults m_rs;
|
||||
unsigned int m_InsertID;
|
||||
unsigned int m_AffectedRows;
|
||||
};
|
||||
|
||||
#endif //_INCLUDE_SM_MYSQL_BASIC_RESULTS_H_
|
||||
|
@ -257,6 +257,16 @@ IQuery *MyDatabase::DoQueryEx(const char *query, size_t len)
|
||||
return new MyQuery(this, res);
|
||||
}
|
||||
|
||||
unsigned int MyDatabase::GetAffectedRowsForQuery(IQuery *query)
|
||||
{
|
||||
return static_cast<MyQuery*>(query)->GetAffectedRows();
|
||||
}
|
||||
|
||||
unsigned int MyDatabase::GetInsertIDForQuery(IQuery *query)
|
||||
{
|
||||
return static_cast<MyQuery*>(query)->GetInsertID();
|
||||
}
|
||||
|
||||
IPreparedQuery *MyDatabase::PrepareQuery(const char *query, char *error, size_t maxlength, int *errCode)
|
||||
{
|
||||
MYSQL_STMT *stmt = mysql_stmt_init(m_mysql);
|
||||
|
@ -60,6 +60,8 @@ public: //IDatabase
|
||||
IDBDriver *GetDriver();
|
||||
bool DoSimpleQueryEx(const char *query, size_t len);
|
||||
IQuery *DoQueryEx(const char *query, size_t len);
|
||||
unsigned int GetAffectedRowsForQuery(IQuery *query);
|
||||
unsigned int GetInsertIDForQuery(IQuery *query);
|
||||
public:
|
||||
const DatabaseInfo &GetInfo();
|
||||
private:
|
||||
|
@ -186,6 +186,15 @@ IQuery *SqDatabase::DoQueryEx(const char *query, size_t len)
|
||||
return pQuery;
|
||||
}
|
||||
|
||||
unsigned int SqDatabase::GetAffectedRowsForQuery(IQuery *query)
|
||||
{
|
||||
return static_cast<SqQuery*>(query)->GetAffectedRows();
|
||||
}
|
||||
unsigned int SqDatabase::GetInsertIDForQuery(IQuery *query)
|
||||
{
|
||||
return static_cast<SqQuery*>(query)->GetInsertID();
|
||||
}
|
||||
|
||||
IPreparedQuery *SqDatabase::PrepareQuery(const char *query,
|
||||
char *error,
|
||||
size_t maxlength,
|
||||
|
@ -56,6 +56,8 @@ public:
|
||||
IDBDriver *GetDriver();
|
||||
bool DoSimpleQueryEx(const char *query, size_t len);
|
||||
IQuery *DoQueryEx(const char *query, size_t len);
|
||||
unsigned int GetAffectedRowsForQuery(IQuery *query);
|
||||
unsigned int GetInsertIDForQuery(IQuery *query);
|
||||
public:
|
||||
sqlite3 *GetDb();
|
||||
private:
|
||||
|
@ -261,20 +261,20 @@ native SQL_GetAffectedRows(Handle:hndl);
|
||||
/**
|
||||
* Returns the last query's insertion id.
|
||||
*
|
||||
* @param hndl A database OR statement Handle.
|
||||
* @param hndl A database, query, OR statement Handle.
|
||||
* @return Last query's insertion id.
|
||||
* @error Invalid database or statement Handle.
|
||||
* @error Invalid database, query, or statement Handle.
|
||||
*/
|
||||
native SQL_GetInsertId(Handle:hndl);
|
||||
|
||||
/**
|
||||
* Returns the error reported by the last query.
|
||||
*
|
||||
* @param hndl A database OR statement Handle.
|
||||
* @param hndl A database, query, OR statement Handle.
|
||||
* @param error Error buffer.
|
||||
* @param maxlength Maximum length of the buffer.
|
||||
* @return True if there was an error, false otherwise.
|
||||
* @error Invalid database or statement Handle.
|
||||
* @error Invalid database, query, or statement Handle.
|
||||
*/
|
||||
native bool:SQL_GetError(Handle:hndl, String:error[], maxlength);
|
||||
|
||||
|
@ -42,7 +42,7 @@
|
||||
*/
|
||||
|
||||
#define SMINTERFACE_DBI_NAME "IDBI"
|
||||
#define SMINTERFACE_DBI_VERSION 7
|
||||
#define SMINTERFACE_DBI_VERSION 8
|
||||
|
||||
namespace SourceMod
|
||||
{
|
||||
@ -573,6 +573,33 @@ namespace SourceMod
|
||||
* @return IQuery pointer on success, NULL otherwise.
|
||||
*/
|
||||
virtual IQuery *DoQueryEx(const char *query, size_t len) =0;
|
||||
|
||||
/**
|
||||
* @brief Retrieves the number of affected rows from the last execute of
|
||||
* the given query
|
||||
*
|
||||
* Note: This can only accept queries from this driver.
|
||||
*
|
||||
* This function is not thread safe and must be included in any locks.
|
||||
*
|
||||
* @param query IQuery object from this driver
|
||||
* @return Rows affected from last execution of this query,
|
||||
* if applicable.
|
||||
*/
|
||||
virtual unsigned int GetAffectedRowsForQuery(IQuery *query) =0;
|
||||
|
||||
/**
|
||||
* @brief Retrieves the last insert id of the given query
|
||||
*
|
||||
* Note: This can only accept queries from this driver.
|
||||
*
|
||||
* This function is not thread safe and must be included in any locks.
|
||||
*
|
||||
* @param query IQuery object from this driver
|
||||
* @return Insert Id from the last execution of this query,
|
||||
* if applicable.
|
||||
*/
|
||||
virtual unsigned int GetInsertIDForQuery(IQuery *query) =0;
|
||||
};
|
||||
|
||||
/**
|
||||
|
Loading…
Reference in New Issue
Block a user