Initial implementation.

This commit is contained in:
David Anderson 2014-06-04 22:45:26 -07:00
parent 88af898285
commit d7b49f7706
2 changed files with 397 additions and 26 deletions

View File

@ -1,8 +1,8 @@
/** /**
* vim: set ts=4 : * vim: set ts=4 sw=4 tw=99 noet :
* ============================================================================= * =============================================================================
* SourceMod * SourceMod
* Copyright (C) 2004-2008 AlliedModders LLC. All rights reserved. * Copyright (C) 2004-2014 AlliedModders LLC. All rights reserved.
* ============================================================================= * =============================================================================
* *
* This program is free software; you can redistribute it and/or modify it under * This program is free software; you can redistribute it and/or modify it under
@ -34,15 +34,36 @@
#include "ExtensionSys.h" #include "ExtensionSys.h"
#include "stringutil.h" #include "stringutil.h"
#include "ISourceMod.h" #include "ISourceMod.h"
#include "AutoHandleRooter.h"
#include "am-string.h"
#include "am-vector.h"
#include "am-refcounting.h"
HandleType_t hStmtType; HandleType_t hStmtType;
HandleType_t hCombinedQueryType; HandleType_t hCombinedQueryType;
typedef struct HandleType_t hTransactionType;
struct CombinedQuery
{ {
IQuery *query; IQuery *query;
IDatabase *db; IDatabase *db;
} CombinedQuery;
CombinedQuery(IQuery *query, IDatabase *db)
: query(query), db(db)
{
}
};
struct Transaction
{
struct Entry
{
ke::AString query;
cell_t data;
};
ke::Vector<Entry> entries;
};
class DatabaseHelpers : class DatabaseHelpers :
public SMGlobalClass, public SMGlobalClass,
@ -51,23 +72,23 @@ class DatabaseHelpers :
public: public:
virtual void OnSourceModAllInitialized() virtual void OnSourceModAllInitialized()
{ {
HandleAccess acc;
/* Disable cloning */ /* Disable cloning */
HandleAccess acc;
handlesys->InitAccessDefaults(NULL, &acc); handlesys->InitAccessDefaults(NULL, &acc);
acc.access[HandleAccess_Clone] = HANDLE_RESTRICT_OWNER|HANDLE_RESTRICT_IDENTITY; acc.access[HandleAccess_Clone] = HANDLE_RESTRICT_OWNER|HANDLE_RESTRICT_IDENTITY;
TypeAccess tacc; TypeAccess tacc;
handlesys->InitAccessDefaults(&tacc, NULL); handlesys->InitAccessDefaults(&tacc, NULL);
tacc.ident = g_pCoreIdent; tacc.ident = g_pCoreIdent;
hCombinedQueryType = handlesys->CreateType("IQuery", this, 0, &tacc, &acc, g_pCoreIdent, NULL); hCombinedQueryType = handlesys->CreateType("IQuery", this, 0, &tacc, &acc, g_pCoreIdent, NULL);
hStmtType = handlesys->CreateType("IPreparedQuery", this, 0, &tacc, &acc, g_pCoreIdent, NULL); hStmtType = handlesys->CreateType("IPreparedQuery", this, 0, &tacc, &acc, g_pCoreIdent, NULL);
hTransactionType = handlesys->CreateType("Transaction", this, 0, &tacc, &acc, g_pCoreIdent, NULL);
} }
virtual void OnSourceModShutdown() virtual void OnSourceModShutdown()
{ {
handlesys->RemoveType(hTransactionType, g_pCoreIdent);
handlesys->RemoveType(hStmtType, g_pCoreIdent); handlesys->RemoveType(hStmtType, g_pCoreIdent);
handlesys->RemoveType(hCombinedQueryType, g_pCoreIdent); handlesys->RemoveType(hCombinedQueryType, g_pCoreIdent);
} }
@ -82,10 +103,25 @@ public:
} else if (type == hStmtType) { } else if (type == hStmtType) {
IPreparedQuery *query = (IPreparedQuery *)object; IPreparedQuery *query = (IPreparedQuery *)object;
query->Destroy(); query->Destroy();
} else if (type == hTransactionType) {
delete (Transaction *)object;
} }
} }
} s_DatabaseNativeHelpers; } s_DatabaseNativeHelpers;
// Create a handle that can only be closed locally. That's the intent, at
// least. Since its callers pass the plugin's identity, the plugin can just
// close it anyway.
static inline Handle_t CreateLocalHandle(HandleType_t type, void *object, const HandleSecurity *sec)
{
HandleAccess access;
handlesys->InitAccessDefaults(NULL, &access);
access.access[HandleAccess_Delete] = HANDLE_RESTRICT_IDENTITY|HANDLE_RESTRICT_OWNER;
return handlesys->CreateHandleEx(type, object, sec, &access, NULL);
}
//is this safe for stmt handles? i think since it's single inheritance, it always will be. //is this safe for stmt handles? i think since it's single inheritance, it always will be.
inline HandleError ReadQueryHndl(Handle_t hndl, IPluginContext *pContext, IQuery **query) inline HandleError ReadQueryHndl(Handle_t hndl, IPluginContext *pContext, IQuery **query)
{ {
@ -157,18 +193,8 @@ public:
*/ */
m_pDatabase->IncReferenceCount(); m_pDatabase->IncReferenceCount();
/* Now create our own Handle such that it can only be closed by us.
* We allow cloning just in case someone wants to hold onto it.
*/
HandleSecurity sec(me->GetIdentity(), g_pCoreIdent); HandleSecurity sec(me->GetIdentity(), g_pCoreIdent);
HandleAccess access; m_MyHandle = CreateLocalHandle(g_DBMan.GetDatabaseType(), m_pDatabase, &sec);
handlesys->InitAccessDefaults(NULL, &access);
access.access[HandleAccess_Delete] = HANDLE_RESTRICT_IDENTITY|HANDLE_RESTRICT_OWNER;
m_MyHandle = handlesys->CreateHandleEx(g_DBMan.GetDatabaseType(),
db,
&sec,
&access,
NULL);
} }
~TQueryOp() ~TQueryOp()
{ {
@ -225,9 +251,7 @@ public:
if (m_pQuery) if (m_pQuery)
{ {
CombinedQuery *c = new CombinedQuery; CombinedQuery *c = new CombinedQuery(m_pQuery, m_pDatabase);
c->query = m_pQuery;
c->db = m_pDatabase;
qh = handlesys->CreateHandle(hCombinedQueryType, c, me->GetIdentity(), g_pCoreIdent, NULL); qh = handlesys->CreateHandle(hCombinedQueryType, c, me->GetIdentity(), g_pCoreIdent, NULL);
if (qh != BAD_HANDLE) if (qh != BAD_HANDLE)
@ -737,9 +761,7 @@ static cell_t SQL_Query(IPluginContext *pContext, const cell_t *params)
return BAD_HANDLE; return BAD_HANDLE;
} }
CombinedQuery *c = new CombinedQuery; CombinedQuery *c = new CombinedQuery(qr, db);
c->query = qr;
c->db = db;
Handle_t hndl = handlesys->CreateHandle(hCombinedQueryType, c, pContext->GetIdentity(), g_pCoreIdent, NULL); Handle_t hndl = handlesys->CreateHandle(hCombinedQueryType, c, pContext->GetIdentity(), g_pCoreIdent, NULL);
if (hndl == BAD_HANDLE) if (hndl == BAD_HANDLE)
{ {
@ -1416,6 +1438,286 @@ static cell_t SQL_SetCharset(IPluginContext *pContext, const cell_t *params)
return db->SetCharacterSet(characterset); return db->SetCharacterSet(characterset);
} }
static cell_t SQL_CreateTransaction(IPluginContext *pContext, const cell_t *params)
{
Transaction *txn = new Transaction();
Handle_t handle = handlesys->CreateHandle(hTransactionType, txn, pContext->GetIdentity(), g_pCoreIdent, NULL);
if (!handle)
{
delete txn;
return BAD_HANDLE;
}
return handle;
}
static cell_t SQL_AddQuery(IPluginContext *pContext, const cell_t *params)
{
HandleSecurity sec(pContext->GetIdentity(), g_pCoreIdent);
Transaction *txn;
Handle_t handle = params[1];
HandleError err = handlesys->ReadHandle(handle, hTransactionType, &sec, (void **)&txn);
if (err != HandleError_None)
return pContext->ThrowNativeError("Invalid handle %x (error %d)", handle, err);
char *query;
pContext->LocalToString(params[2], &query);
Transaction::Entry entry;
entry.query = query;
entry.data = params[3];
txn->entries.append(ke::Move(entry));
return cell_t(txn->entries.length() - 1);
}
class TTransactOp : public IDBThreadOperation
{
public:
TTransactOp(
IDatabase *db,
Transaction *txn,
Handle_t txnHandle,
IdentityToken_t *ident,
IPluginFunction *onSuccess,
IPluginFunction *onError,
cell_t data)
:
db_(db),
txn_(txn),
ident_(ident),
success_(onSuccess),
failure_(onError),
data_(data),
autoHandle_(txnHandle),
failIndex_(-1)
{
}
IdentityToken_t *GetOwner()
{
return ident_;
}
IDBDriver *GetDriver()
{
return db_->GetDriver();
}
void Destroy()
{
delete this;
}
private:
bool Succeeded() const
{
return error_.length() > 0;
}
void SetDbError()
{
const char *error = db_->GetError();
if (!error || strlen(error) == 0)
error_ = "unknown error";
else
error_ = error;
}
IQuery *Exec(const char *query)
{
IQuery *result = db_->DoQuery(query);
if (!result)
{
SetDbError();
db_->DoSimpleQuery("ROLLBACK");
return NULL;
}
return result;
}
void ExecuteTransaction()
{
if (!db_->DoSimpleQuery("BEGIN"))
{
SetDbError();
return;
}
for (size_t i = 0; i < txn_->entries.length(); i++)
{
Transaction::Entry &entry = txn_->entries[i];
ke::AutoPtr<IQuery> result(db_->DoQuery(entry.query.chars()));
if (!result)
{
failIndex_ = (cell_t)i;
return;
}
results_.append(ke::Move(result));
}
if (!db_->DoSimpleQuery("COMMIT"))
{
SetDbError();
db_->DoSimpleQuery("ROLLBACK");
}
}
public:
void RunThreadPart()
{
db_->LockForFullAtomicOperation();
ExecuteTransaction();
db_->UnlockFromFullAtomicOperation();
}
void CancelThinkPart()
{
if (Succeeded())
error_ = "Driver is unloading";
RunThinkPart();
}
private:
bool CallSuccess()
{
HandleSecurity sec(ident_, g_pCoreIdent);
// Allocate all the handles for calling the success callback.
Handle_t dbh = CreateLocalHandle(g_DBMan.GetDatabaseType(), db_, &sec);
if (dbh == BAD_HANDLE)
{
error_ = "unable to allocate handle";
return false;
}
assert(results_.length() == txn_->entries.length());
ke::AutoArray<cell_t> data(new cell_t[results_.length()]);
ke::AutoArray<cell_t> handles(new cell_t[results_.length()]);
for (size_t i = 0; i < results_.length(); i++)
{
CombinedQuery *obj = new CombinedQuery(results_[i], db_);
Handle_t rh = CreateLocalHandle(hCombinedQueryType, obj, &sec);
if (rh == BAD_HANDLE)
{
delete obj;
for (size_t iter = 0; iter < i; iter++)
handlesys->FreeHandle(handles[iter], &sec);
handlesys->FreeHandle(dbh, &sec);
error_ = "unable to allocate handle";
return false;
}
handles[i] = rh;
data[i] = txn_->entries[i].data;
}
success_->PushCell(dbh);
success_->PushCell(data_);
success_->PushCell(results_.length());
success_->PushArray(handles, results_.length());
success_->PushArray(data, results_.length());
success_->Execute(NULL);
// Cleanup.
for (size_t i = 0; i < results_.length(); i++)
handlesys->FreeHandle(handles[i], &sec);
handlesys->FreeHandle(dbh, &sec);
return true;
}
public:
void RunThinkPart()
{
if (!success_ || !failure_)
return;
if (Succeeded() && success_)
{
if (CallSuccess())
return;
}
if (!Succeeded() && failure_)
{
HandleSecurity sec(ident_, g_pCoreIdent);
ke::AutoArray<cell_t> data(new cell_t[results_.length()]);
for (size_t i = 0; i < txn_->entries.length(); i++)
data[i] = txn_->entries[i].data;
Handle_t dbh = CreateLocalHandle(g_DBMan.GetDatabaseType(), db_, &sec);
failure_->PushCell(dbh);
failure_->PushCell(data_);
failure_->PushCell(results_.length());
failure_->PushString(error_.chars());
failure_->PushCell(failIndex_);
failure_->PushArray(data, txn_->entries.length());
failure_->Execute(NULL);
}
}
private:
ke::Ref<IDatabase> db_;
Transaction *txn_;
IdentityToken_t *ident_;
IPluginFunction *success_;
IPluginFunction *failure_;
cell_t data_;
AutoHandleRooter autoHandle_;
ke::AString error_;
ke::Vector<ke::AutoPtr<IQuery> > results_;
cell_t failIndex_;
};
static cell_t SQL_ExecuteTransaction(IPluginContext *pContext, const cell_t *params)
{
HandleSecurity sec(pContext->GetIdentity(), g_pCoreIdent);
IDatabase *db = NULL;
HandleError err = g_DBMan.ReadHandle(params[1], DBHandle_Database, (void **)&db);
if (err != HandleError_None)
return pContext->ThrowNativeError("Invalid database handle %x (error: %d)", params[1], err);
Transaction *txn;
if ((err = handlesys->ReadHandle(params[2], hTransactionType, &sec, (void **)&txn)) != HandleError_None)
return pContext->ThrowNativeError("Invalid transaction handle %x (error %d)", params[2], err);
if (!db->GetDriver()->IsThreadSafe())
return pContext->ThrowNativeError("Driver \"%s\" is not thread safe!", db->GetDriver()->GetIdentifier());
IPluginFunction *onSuccess = NULL;
IPluginFunction *onError = NULL;
if (params[3] != -1 && ((onSuccess = pContext->GetFunctionById(params[3])) == NULL))
return pContext->ThrowNativeError("Function id %x is invalid", params[3]);
if (params[4] != -1 && ((onError = pContext->GetFunctionById(params[4])) == NULL))
return pContext->ThrowNativeError("Function id %x is invalid", params[4]);
cell_t data = params[5];
PrioQueueLevel priority = PrioQueue_Normal;
if (params[6] == (cell_t)PrioQueue_High)
priority = PrioQueue_High;
else if (params[6] == (cell_t)PrioQueue_Low)
priority = PrioQueue_Low;
TTransactOp *op = new TTransactOp(db, txn, params[2], pContext->GetIdentity(), onSuccess, onError, data);
// The handle has been cloned in |op|. Close the original.
handlesys->FreeHandle(params[2], &sec);
IPlugin *pPlugin = scripts->FindPluginByContext(pContext->GetContext());
if (pPlugin->GetProperty("DisallowDBThreads", NULL) || !g_DBMan.AddToThreadQueue(op, priority))
{
// Do everything right now.
op->RunThreadPart();
op->RunThinkPart();
op->Destroy();
}
return 0;
}
REGISTER_NATIVES(dbNatives) REGISTER_NATIVES(dbNatives)
{ {
{"SQL_BindParamInt", SQL_BindParamInt}, {"SQL_BindParamInt", SQL_BindParamInt},
@ -1457,7 +1759,10 @@ REGISTER_NATIVES(dbNatives)
{"SQL_TQuery", SQL_TQuery}, {"SQL_TQuery", SQL_TQuery},
{"SQL_UnlockDatabase", SQL_UnlockDatabase}, {"SQL_UnlockDatabase", SQL_UnlockDatabase},
{"SQL_ConnectCustom", SQL_ConnectCustom}, {"SQL_ConnectCustom", SQL_ConnectCustom},
{"SQL_SetCharset", SQL_SetCharset}, {"SQL_SetCharset", SQL_SetCharset},
{"SQL_CreateTransaction", SQL_CreateTransaction},
{"SQL_AddQuery", SQL_AddQuery},
{"SQL_ExecuteTransaction", SQL_ExecuteTransaction},
{NULL, NULL}, {NULL, NULL},
}; };

View File

@ -694,3 +694,69 @@ native SQL_TConnect(SQLTCallback:callback, const String:name[]="default", any:da
* @error Invalid database Handle. * @error Invalid database Handle.
*/ */
native SQL_TQuery(Handle:database, SQLTCallback:callback, const String:query[], any:data=0, DBPriority:prio=DBPrio_Normal); native SQL_TQuery(Handle:database, SQLTCallback:callback, const String:query[], any:data=0, DBPriority:prio=DBPrio_Normal);
/**
* Creates a new transaction object. A transaction object is a list of queries
* that can be sent to the database thread and executed as a single transaction.
*
* @return A transaction handle.
*/
native Handle:SQL_CreateTransaction();
/**
* Adds a query to a transaction object.
*
* @param txn A transaction handle.
* @param query Query string.
* @param data Extra data value to pass to the final callback.
* @return The index of the query in the transaction's query list.
* @error Invalid transaction handle.
*/
native Handle:SQL_AddQuery(Handle:txn, const String:query[], any:data=0);
/**
* Callback for a successful transaction.
*
* @param db Database handle.
* @param data Data value passed to SQL_ExecuteTransaction().
* @param numQueries Number of queries executed in the transaction.
* @param results An array of Query handle results, one for each of numQueries. They are closed automatically.
* @param queryData An array of each data value passed to SQL_AddQuery().
* @noreturn
*/
functag public SQLTxnSuccess(Handle:db, any:data, numQueries, Handle:results[], any:queryData[]);
/**
* Callback for a failed transaction.
*
* @param db Database handle.
* @param data Data value passed to SQL_ExecuteTransaction().
* @param numQueries Number of queries executed in the transaction.
* @param error Error message.
* @param failIndex Index of the query that failed, or -1 if something else.
* @param queryData An array of each data value passed to SQL_AddQuery().
* @noreturn
*/
functag public SQLTxnFailure(Handle:db, any:data, numQueries, const String:error[], failIndex, any:queryData[]);
/**
* Sends a transaction to the database thread. The transaction handle is
* automatically closed. When the transaction completes, the optional
* callback is invoked.
*
* @param db A database handle.
* @param txn A transaction handle.
* @param onSuccess An optional callback to receive a successful transaction.
* @param onError An optional callback to receive an error message.
* @param data An optional value to pass to callbacks.
* @param prio Priority queue to use.
* @noreturn
* @error An invalid handle.
*/
native SQL_ExecuteTransaction(
Handle:db,
Handle:txn,
SQLTxnSuccess:onSuccess=SQLTxnSuccess:-1,
SQLTxnFailure:onError=SQLTxnFailure:-1,
any:data=0,
DBPriority:priority=DBPrio_Normal);