Merge pull request #476 from alliedmodders/query-format

Implement an auto-escaping Format native for SQL query construction
This commit is contained in:
Asher Baker 2016-10-03 16:20:01 +01:00 committed by GitHub
commit 47dd2870d9
8 changed files with 208 additions and 92 deletions

View File

@ -32,6 +32,7 @@
#include "common_logic.h" #include "common_logic.h"
#include "Database.h" #include "Database.h"
#include "ExtensionSys.h" #include "ExtensionSys.h"
#include "sprintf.h"
#include "stringutil.h" #include "stringutil.h"
#include "ISourceMod.h" #include "ISourceMod.h"
#include "AutoHandleRooter.h" #include "AutoHandleRooter.h"
@ -732,6 +733,24 @@ static cell_t SQL_QuoteString(IPluginContext *pContext, const cell_t *params)
return s ? 1 : 0; return s ? 1 : 0;
} }
static cell_t SQL_FormatQuery(IPluginContext *pContext, const cell_t *params)
{
IDatabase *db = NULL;
HandleError err;
if ((err = g_DBMan.ReadHandle(params[1], DBHandle_Database, (void **)&db))
!= HandleError_None)
{
return pContext->ThrowNativeError("Invalid database Handle %x (error: %d)", params[1], err);
}
g_FormatEscapeDatabase = db;
cell_t result = InternalFormat(pContext, params, 1);
g_FormatEscapeDatabase = NULL;
return result;
}
static cell_t SQL_FastQuery(IPluginContext *pContext, const cell_t *params) static cell_t SQL_FastQuery(IPluginContext *pContext, const cell_t *params)
{ {
IDatabase *db = NULL; IDatabase *db = NULL;
@ -1814,6 +1833,7 @@ REGISTER_NATIVES(dbNatives)
{"Database.Driver.get", Database_Driver_get}, {"Database.Driver.get", Database_Driver_get},
{"Database.SetCharset", SQL_SetCharset}, {"Database.SetCharset", SQL_SetCharset},
{"Database.Escape", SQL_QuoteString}, {"Database.Escape", SQL_QuoteString},
{"Database.Format", SQL_FormatQuery},
{"Database.IsSameConnection", SQL_IsSameConnection}, {"Database.IsSameConnection", SQL_IsSameConnection},
{"Database.Execute", SQL_ExecuteTransaction}, {"Database.Execute", SQL_ExecuteTransaction},
@ -1827,6 +1847,7 @@ REGISTER_NATIVES(dbNatives)
{"SQL_Connect", SQL_Connect}, {"SQL_Connect", SQL_Connect},
{"SQL_ConnectEx", SQL_ConnectEx}, {"SQL_ConnectEx", SQL_ConnectEx},
{"SQL_EscapeString", SQL_QuoteString}, {"SQL_EscapeString", SQL_QuoteString},
{"SQL_FormatQuery", SQL_FormatQuery},
{"SQL_Execute", SQL_Execute}, {"SQL_Execute", SQL_Execute},
{"SQL_FastQuery", SQL_FastQuery}, {"SQL_FastQuery", SQL_FastQuery},
{"SQL_FetchFloat", SQL_FetchFloat}, {"SQL_FetchFloat", SQL_FetchFloat},

View File

@ -194,80 +194,12 @@ static cell_t sm_formatex(IPluginContext *pCtx, const cell_t *params)
return static_cast<cell_t>(res); return static_cast<cell_t>(res);
} }
class StaticCharBuf
{
char *buffer;
size_t max_size;
public:
StaticCharBuf() : buffer(NULL), max_size(0)
{
}
~StaticCharBuf()
{
free(buffer);
}
char* GetWithSize(size_t len)
{
if (len > max_size)
{
buffer = (char *)realloc(buffer, len);
max_size = len;
}
return buffer;
}
};
static char g_formatbuf[2048];
static StaticCharBuf g_extrabuf;
static cell_t sm_format(IPluginContext *pCtx, const cell_t *params) static cell_t sm_format(IPluginContext *pCtx, const cell_t *params)
{ {
char *buf, *fmt, *destbuf; return InternalFormat(pCtx, params, 0);
cell_t start_addr, end_addr, maxparam;
size_t res, maxlen;
int arg = 4;
bool copy = false;
char *__copy_buf;
pCtx->LocalToString(params[1], &destbuf);
pCtx->LocalToString(params[3], &fmt);
maxlen = static_cast<size_t>(params[2]);
start_addr = params[1];
end_addr = params[1] + maxlen;
maxparam = params[0];
for (cell_t i=3; i<=maxparam; i++)
{
if ((params[i] >= start_addr) && (params[i] <= end_addr))
{
copy = true;
break;
}
}
if (copy)
{
if (maxlen > sizeof(g_formatbuf))
{
__copy_buf = g_extrabuf.GetWithSize(maxlen);
}
else
{
__copy_buf = g_formatbuf;
}
}
buf = (copy) ? __copy_buf : destbuf;
res = atcprintf(buf, maxlen, fmt, pCtx, params, &arg);
if (copy)
{
memcpy(destbuf, __copy_buf, res+1);
}
return static_cast<cell_t>(res);
} }
static char g_vformatbuf[2048];
static cell_t sm_vformat(IPluginContext *pContext, const cell_t *params) static cell_t sm_vformat(IPluginContext *pContext, const cell_t *params)
{ {
int vargPos = static_cast<int>(params[4]); int vargPos = static_cast<int>(params[4]);
@ -301,7 +233,7 @@ static cell_t sm_vformat(IPluginContext *pContext, const cell_t *params)
if (copy) if (copy)
{ {
destination = g_formatbuf; destination = g_vformatbuf;
} else { } else {
pContext->LocalToString(params[1], &destination); pContext->LocalToString(params[1], &destination);
} }
@ -313,7 +245,7 @@ static cell_t sm_vformat(IPluginContext *pContext, const cell_t *params)
/* Perform copy-on-write if we need to */ /* Perform copy-on-write if we need to */
if (copy) if (copy)
{ {
pContext->StringToLocal(params[1], maxlen, g_formatbuf); pContext->StringToLocal(params[1], maxlen, g_vformatbuf);
} }
return total; return total;

View File

@ -30,15 +30,19 @@
#include "sprintf.h" #include "sprintf.h"
#include <am-float.h> #include <am-float.h>
#include <am-string.h> #include <am-string.h>
#include <IDBDriver.h>
#include <ITranslator.h> #include <ITranslator.h>
#include <bridge/include/IScriptManager.h> #include <bridge/include/IScriptManager.h>
#include <bridge/include/CoreProvider.h> #include <bridge/include/CoreProvider.h>
using namespace SourceMod; using namespace SourceMod;
#define LADJUST 0x00000004 /* left adjustment */ IDatabase *g_FormatEscapeDatabase = NULL;
#define ZEROPAD 0x00000080 /* zero (as opposed to blank) pad */
#define UPPERDIGITS 0x00000200 /* make alpha digits uppercase */ #define LADJUST 0x00000001 /* left adjustment */
#define ZEROPAD 0x00000002 /* zero (as opposed to blank) pad */
#define UPPERDIGITS 0x00000004 /* make alpha digits uppercase */
#define NOESCAPE 0x00000008 /* do not escape strings (they are only escaped if a database connection is provided) */
#define to_digit(c) ((c) - '0') #define to_digit(c) ((c) - '0')
#define is_digit(c) ((unsigned)to_digit(c) <= 9) #define is_digit(c) ((unsigned)to_digit(c) <= 9)
@ -87,7 +91,7 @@ try_serverlang:
} }
else else
{ {
pCtx->ThrowNativeErrorEx(SP_ERROR_PARAM, "Translation failed: invalid client index %d", target); pCtx->ThrowNativeErrorEx(SP_ERROR_PARAM, "Translation failed: invalid client index %d (arg %d)", target, *arg);
goto error_out; goto error_out;
} }
@ -102,13 +106,13 @@ try_serverlang:
{ {
if (pPhrases->FindTranslation(key, SOURCEMOD_LANGUAGE_ENGLISH, &pTrans) != Trans_Okay) if (pPhrases->FindTranslation(key, SOURCEMOD_LANGUAGE_ENGLISH, &pTrans) != Trans_Okay)
{ {
pCtx->ThrowNativeErrorEx(SP_ERROR_PARAM, "Language phrase \"%s\" not found", key); pCtx->ThrowNativeErrorEx(SP_ERROR_PARAM, "Language phrase \"%s\" not found (arg %d)", key, *arg);
goto error_out; goto error_out;
} }
} }
else else
{ {
pCtx->ThrowNativeErrorEx(SP_ERROR_PARAM, "Language phrase \"%s\" not found", key); pCtx->ThrowNativeErrorEx(SP_ERROR_PARAM, "Language phrase \"%s\" not found (arg %d)", key, *arg);
goto error_out; goto error_out;
} }
} }
@ -123,9 +127,8 @@ try_serverlang:
if ((*arg) + (max_params - 1) > (size_t)params[0]) if ((*arg) + (max_params - 1) > (size_t)params[0])
{ {
pCtx->ThrowNativeErrorEx(SP_ERROR_PARAMS_MAX, pCtx->ThrowNativeErrorEx(SP_ERROR_PARAMS_MAX,
"Translation string formatted incorrectly - missing at least %d parameters", "Translation string formatted incorrectly - missing at least %d parameters (arg %d)",
((*arg + (max_params - 1)) - params[0]) ((*arg + (max_params - 1)) - params[0]), *arg);
);
goto error_out; goto error_out;
} }
@ -147,7 +150,7 @@ error_out:
return 0; return 0;
} }
void AddString(char **buf_p, size_t &maxlen, const char *string, int width, int prec) bool AddString(char **buf_p, size_t &maxlen, const char *string, int width, int prec, int flags)
{ {
int size = 0; int size = 0;
char *buf; char *buf;
@ -159,6 +162,7 @@ void AddString(char **buf_p, size_t &maxlen, const char *string, int width, int
{ {
string = nlstr; string = nlstr;
prec = -1; prec = -1;
flags |= NOESCAPE;
} }
if (prec >= 0) if (prec >= 0)
@ -182,12 +186,44 @@ void AddString(char **buf_p, size_t &maxlen, const char *string, int width, int
size = maxlen; size = maxlen;
} }
maxlen -= size;
width -= size; width -= size;
while (size--) if (g_FormatEscapeDatabase && (flags & NOESCAPE) == 0)
{ {
*buf++ = *string++; char *tempBuffer = NULL;
if (prec != -1)
{
// I doubt anyone will ever do this, so just allocate.
tempBuffer = new char[maxlen + 1];
memcpy(tempBuffer, string, size);
tempBuffer[size] = '\0';
}
size_t newSize;
bool ret = g_FormatEscapeDatabase->QuoteString(tempBuffer ? tempBuffer : string, buf, maxlen + 1, &newSize);
if (tempBuffer)
{
delete[] tempBuffer;
}
if (!ret)
{
return false;
}
maxlen -= newSize;
buf += newSize;
size = 0; // Consistency.
}
else
{
maxlen -= size;
while (size--)
{
*buf++ = *string++;
}
} }
while ((width-- > 0) && maxlen) while ((width-- > 0) && maxlen)
@ -197,6 +233,8 @@ void AddString(char **buf_p, size_t &maxlen, const char *string, int width, int
} }
*buf_p = buf; *buf_p = buf;
return true;
} }
void AddFloat(char **buf_p, size_t &maxlen, double fval, int width, int prec, int flags) void AddFloat(char **buf_p, size_t &maxlen, double fval, int width, int prec, int flags)
@ -212,7 +250,7 @@ void AddFloat(char **buf_p, size_t &maxlen, double fval, int width, int prec, in
if (ke::IsNaN(fval)) if (ke::IsNaN(fval))
{ {
AddString(buf_p, maxlen, "NaN", width, prec); AddString(buf_p, maxlen, "NaN", width, prec, flags | NOESCAPE);
return; return;
} }
@ -750,7 +788,7 @@ reswitch:
} }
const char *str = (const char *)params[curparam]; const char *str = (const char *)params[curparam];
curparam++; curparam++;
AddString(&buf_p, llen, str, width, prec); AddString(&buf_p, llen, str, width, prec, flags);
arg++; arg++;
break; break;
} }
@ -1030,6 +1068,11 @@ reswitch:
flags |= LADJUST; flags |= LADJUST;
goto rflag; goto rflag;
} }
case '!':
{
flags |= NOESCAPE;
goto rflag;
}
case '.': case '.':
{ {
n = 0; n = 0;
@ -1127,7 +1170,7 @@ reswitch:
const char *auth; const char *auth;
int userid; int userid;
if (!bridge->DescribePlayer(*value, &name, &auth, &userid)) if (!bridge->DescribePlayer(*value, &name, &auth, &userid))
return pCtx->ThrowNativeError("Client index %d is invalid", *value); return pCtx->ThrowNativeError("Client index %d is invalid (arg %d)", *value, arg);
ke::SafeSprintf(buffer, ke::SafeSprintf(buffer,
sizeof(buffer), sizeof(buffer),
"%s<%d><%s><>", "%s<%d><%s><>",
@ -1141,7 +1184,8 @@ reswitch:
sizeof(buffer), sizeof(buffer),
"Console<0><Console><Console>"); "Console<0><Console><Console>");
} }
AddString(&buf_p, llen, buffer, width, prec); if (!AddString(&buf_p, llen, buffer, width, prec, flags))
return pCtx->ThrowNativeError("Escaped string would be truncated (arg %d)", arg);
arg++; arg++;
break; break;
} }
@ -1154,9 +1198,10 @@ reswitch:
const char *name = "Console"; const char *name = "Console";
if (*value) { if (*value) {
if (!bridge->DescribePlayer(*value, &name, nullptr, nullptr)) if (!bridge->DescribePlayer(*value, &name, nullptr, nullptr))
return pCtx->ThrowNativeError("Client index %d is invalid", *value); return pCtx->ThrowNativeError("Client index %d is invalid (arg %d)", *value, arg);
} }
AddString(&buf_p, llen, name, width, prec); if (!AddString(&buf_p, llen, name, width, prec, flags))
return pCtx->ThrowNativeError("Escaped string would be truncated (arg %d)", arg);
arg++; arg++;
break; break;
} }
@ -1165,7 +1210,8 @@ reswitch:
CHECK_ARGS(0); CHECK_ARGS(0);
char *str; char *str;
pCtx->LocalToString(params[arg], &str); pCtx->LocalToString(params[arg], &str);
AddString(&buf_p, llen, str, width, prec); if (!AddString(&buf_p, llen, str, width, prec, flags))
return pCtx->ThrowNativeError("Escaped string would be truncated (arg %d)", arg);
arg++; arg++;
break; break;
} }

View File

@ -30,6 +30,7 @@
#include <sp_vm_api.h> #include <sp_vm_api.h>
namespace SourceMod { namespace SourceMod {
class IDatabase;
class IPhraseCollection; class IPhraseCollection;
} }
@ -57,4 +58,6 @@ bool gnprintf(char *buffer,
size_t *pOutLength, size_t *pOutLength,
const char **pFailPhrase); const char **pFailPhrase);
extern SourceMod::IDatabase *g_FormatEscapeDatabase;
#endif // _include_sourcemod_core_logic_sprintf_h_ #endif // _include_sourcemod_core_logic_sprintf_h_

View File

@ -29,11 +29,13 @@
* Version: $Id$ * Version: $Id$
*/ */
#include "common_logic.h"
#include <string.h> #include <string.h>
#include <stdio.h> #include <stdio.h>
#include <ctype.h> #include <ctype.h>
#include <sm_platform.h> #include <sm_platform.h>
#include "stringutil.h" #include "stringutil.h"
#include "sprintf.h"
#include <am-string.h> #include <am-string.h>
#include "TextParsers.h" #include "TextParsers.h"
@ -363,3 +365,76 @@ char *UTIL_TrimWhitespace(char *str, size_t &len)
return str; return str;
} }
class StaticCharBuf
{
char *buffer;
size_t max_size;
public:
StaticCharBuf() : buffer(NULL), max_size(0)
{
}
~StaticCharBuf()
{
free(buffer);
}
char* GetWithSize(size_t len)
{
if (len > max_size)
{
buffer = (char *)realloc(buffer, len);
max_size = len;
}
return buffer;
}
};
static char g_formatbuf[2048];
static StaticCharBuf g_extrabuf;
cell_t InternalFormat(IPluginContext *pCtx, const cell_t *params, int start)
{
char *buf, *fmt, *destbuf;
cell_t start_addr, end_addr, maxparam;
size_t res, maxlen;
int arg = start + 4;
bool copy = false;
char *__copy_buf;
pCtx->LocalToString(params[start + 1], &destbuf);
pCtx->LocalToString(params[start + 3], &fmt);
maxlen = static_cast<size_t>(params[start + 2]);
start_addr = params[start + 1];
end_addr = params[start + 1] + maxlen;
maxparam = params[0];
for (cell_t i = (start + 3); i <= maxparam; i++)
{
if ((params[i] >= start_addr) && (params[i] <= end_addr))
{
copy = true;
break;
}
}
if (copy)
{
if (maxlen > sizeof(g_formatbuf))
{
__copy_buf = g_extrabuf.GetWithSize(maxlen);
}
else
{
__copy_buf = g_formatbuf;
}
}
buf = (copy) ? __copy_buf : destbuf;
res = atcprintf(buf, maxlen, fmt, pCtx, params, &arg);
if (copy)
{
memcpy(destbuf, __copy_buf, res+1);
}
return static_cast<cell_t>(res);
}

View File

@ -43,5 +43,9 @@ size_t UTIL_DecodeHexString(unsigned char *buffer, size_t maxlength, const char
void UTIL_StripExtension(const char *in, char *out, int outSize); void UTIL_StripExtension(const char *in, char *out, int outSize);
char *UTIL_TrimWhitespace(char *str, size_t &len); char *UTIL_TrimWhitespace(char *str, size_t &len);
// Internal copying Format helper, expects (char[] buffer, int maxlength, const char[] format, any ...) starting at |start|
// i.e. you can stuff your own params before |buffer|.
cell_t InternalFormat(IPluginContext *pCtx, const cell_t *params, int start);
#endif /* _INCLUDE_SOURCEMOD_COMMON_STRINGUTIL_H_ */ #endif /* _INCLUDE_SOURCEMOD_COMMON_STRINGUTIL_H_ */

View File

@ -84,6 +84,18 @@ IDBDriver *SqDatabase::GetDriver()
bool SqDatabase::QuoteString(const char *str, char buffer[], size_t maxlen, size_t *newSize) bool SqDatabase::QuoteString(const char *str, char buffer[], size_t maxlen, size_t *newSize)
{ {
unsigned long size = static_cast<unsigned long>(strlen(str));
unsigned long needed = size * 2 + 1;
if (maxlen < needed)
{
if (newSize != NULL)
{
*newSize = (size_t)needed;
}
return false;
}
char *res = sqlite3_snprintf(static_cast<int>(maxlen), buffer, "%q", str); char *res = sqlite3_snprintf(static_cast<int>(maxlen), buffer, "%q", str);
if (res != NULL && newSize != NULL) if (res != NULL && newSize != NULL)

View File

@ -380,6 +380,16 @@ methodmap Database < Handle
// The buffer must be at least 2*strlen(string)+1. // The buffer must be at least 2*strlen(string)+1.
public native bool Escape(const char[] string, char[] buffer, int maxlength, int &written=0); public native bool Escape(const char[] string, char[] buffer, int maxlength, int &written=0);
// Formats a string according to the SourceMod format rules (see documentation).
// All format specifiers are escaped (see SQL_EscapeString) unless the '!' flag is used.
//
// @param buffer Destination string buffer.
// @param maxlength Maximum length of output string buffer.
// @param format Formatting rules.
// @param ... Variable number of format parameters.
// @return Number of cells written.
public native int Format(const char[] buffer, int maxlength, const char[] format, any ...);
// Returns whether a database is the same connection as another database. // Returns whether a database is the same connection as another database.
public native bool IsSameConnection(Database other); public native bool IsSameConnection(Database other);
@ -639,6 +649,19 @@ native bool SQL_EscapeString(Handle database,
int maxlength, int maxlength,
int &written=0); int &written=0);
/**
* Formats a string according to the SourceMod format rules (see documentation).
* All format specifiers are escaped (see SQL_EscapeString) unless the '!' flag is used.
*
* @param database A database Handle.
* @param buffer Destination string buffer.
* @param maxlength Maximum length of output string buffer.
* @param format Formatting rules.
* @param ... Variable number of format parameters.
* @return Number of cells written.
*/
native int SQL_FormatQuery(Handle database, const char[] buffer, int maxlength, const char[] format, any ...);
/** /**
* This is a backwards compatibility stock. You should use SQL_EscapeString() * This is a backwards compatibility stock. You should use SQL_EscapeString()
* instead, as this function will probably be deprecated in SourceMod 1.1. * instead, as this function will probably be deprecated in SourceMod 1.1.