diff --git a/core/logic/smn_database.cpp b/core/logic/smn_database.cpp index e7865e94..056e7035 100644 --- a/core/logic/smn_database.cpp +++ b/core/logic/smn_database.cpp @@ -32,6 +32,7 @@ #include "common_logic.h" #include "Database.h" #include "ExtensionSys.h" +#include "sprintf.h" #include "stringutil.h" #include "ISourceMod.h" #include "AutoHandleRooter.h" @@ -732,6 +733,24 @@ static cell_t SQL_QuoteString(IPluginContext *pContext, const cell_t *params) return s ? 1 : 0; } +static cell_t SQL_FormatQuery(IPluginContext *pContext, const cell_t *params) +{ + IDatabase *db = NULL; + HandleError err; + + if ((err = g_DBMan.ReadHandle(params[1], DBHandle_Database, (void **)&db)) + != HandleError_None) + { + return pContext->ThrowNativeError("Invalid database Handle %x (error: %d)", params[1], err); + } + + g_FormatEscapeDatabase = db; + cell_t result = InternalFormat(pContext, params, 1); + g_FormatEscapeDatabase = NULL; + + return result; +} + static cell_t SQL_FastQuery(IPluginContext *pContext, const cell_t *params) { IDatabase *db = NULL; @@ -1814,6 +1833,7 @@ REGISTER_NATIVES(dbNatives) {"Database.Driver.get", Database_Driver_get}, {"Database.SetCharset", SQL_SetCharset}, {"Database.Escape", SQL_QuoteString}, + {"Database.Format", SQL_FormatQuery}, {"Database.IsSameConnection", SQL_IsSameConnection}, {"Database.Execute", SQL_ExecuteTransaction}, @@ -1827,6 +1847,7 @@ REGISTER_NATIVES(dbNatives) {"SQL_Connect", SQL_Connect}, {"SQL_ConnectEx", SQL_ConnectEx}, {"SQL_EscapeString", SQL_QuoteString}, + {"SQL_FormatQuery", SQL_FormatQuery}, {"SQL_Execute", SQL_Execute}, {"SQL_FastQuery", SQL_FastQuery}, {"SQL_FetchFloat", SQL_FetchFloat}, diff --git a/core/logic/smn_string.cpp b/core/logic/smn_string.cpp index 2a9402ad..76ff2e05 100644 --- a/core/logic/smn_string.cpp +++ b/core/logic/smn_string.cpp @@ -194,80 +194,12 @@ static cell_t sm_formatex(IPluginContext *pCtx, const cell_t *params) return static_cast(res); } -class StaticCharBuf -{ - char *buffer; - size_t max_size; -public: - StaticCharBuf() : buffer(NULL), max_size(0) - { - } - ~StaticCharBuf() - { - free(buffer); - } - char* GetWithSize(size_t len) - { - if (len > max_size) - { - buffer = (char *)realloc(buffer, len); - max_size = len; - } - return buffer; - } -}; - -static char g_formatbuf[2048]; -static StaticCharBuf g_extrabuf; static cell_t sm_format(IPluginContext *pCtx, const cell_t *params) { - char *buf, *fmt, *destbuf; - cell_t start_addr, end_addr, maxparam; - size_t res, maxlen; - int arg = 4; - bool copy = false; - char *__copy_buf; - - pCtx->LocalToString(params[1], &destbuf); - pCtx->LocalToString(params[3], &fmt); - - maxlen = static_cast(params[2]); - start_addr = params[1]; - end_addr = params[1] + maxlen; - maxparam = params[0]; - - for (cell_t i=3; i<=maxparam; i++) - { - if ((params[i] >= start_addr) && (params[i] <= end_addr)) - { - copy = true; - break; - } - } - - if (copy) - { - if (maxlen > sizeof(g_formatbuf)) - { - __copy_buf = g_extrabuf.GetWithSize(maxlen); - } - else - { - __copy_buf = g_formatbuf; - } - } - - buf = (copy) ? __copy_buf : destbuf; - res = atcprintf(buf, maxlen, fmt, pCtx, params, &arg); - - if (copy) - { - memcpy(destbuf, __copy_buf, res+1); - } - - return static_cast(res); + return InternalFormat(pCtx, params, 0); } +static char g_vformatbuf[2048]; static cell_t sm_vformat(IPluginContext *pContext, const cell_t *params) { int vargPos = static_cast(params[4]); @@ -301,7 +233,7 @@ static cell_t sm_vformat(IPluginContext *pContext, const cell_t *params) if (copy) { - destination = g_formatbuf; + destination = g_vformatbuf; } else { pContext->LocalToString(params[1], &destination); } @@ -313,7 +245,7 @@ static cell_t sm_vformat(IPluginContext *pContext, const cell_t *params) /* Perform copy-on-write if we need to */ if (copy) { - pContext->StringToLocal(params[1], maxlen, g_formatbuf); + pContext->StringToLocal(params[1], maxlen, g_vformatbuf); } return total; diff --git a/core/logic/sprintf.cpp b/core/logic/sprintf.cpp index e00381f8..49a5b190 100644 --- a/core/logic/sprintf.cpp +++ b/core/logic/sprintf.cpp @@ -30,15 +30,19 @@ #include "sprintf.h" #include #include +#include #include #include #include using namespace SourceMod; -#define LADJUST 0x00000004 /* left adjustment */ -#define ZEROPAD 0x00000080 /* zero (as opposed to blank) pad */ -#define UPPERDIGITS 0x00000200 /* make alpha digits uppercase */ +IDatabase *g_FormatEscapeDatabase = NULL; + +#define LADJUST 0x00000001 /* left adjustment */ +#define ZEROPAD 0x00000002 /* zero (as opposed to blank) pad */ +#define UPPERDIGITS 0x00000004 /* make alpha digits uppercase */ +#define NOESCAPE 0x00000008 /* do not escape strings (they are only escaped if a database connection is provided) */ #define to_digit(c) ((c) - '0') #define is_digit(c) ((unsigned)to_digit(c) <= 9) @@ -87,7 +91,7 @@ try_serverlang: } else { - pCtx->ThrowNativeErrorEx(SP_ERROR_PARAM, "Translation failed: invalid client index %d", target); + pCtx->ThrowNativeErrorEx(SP_ERROR_PARAM, "Translation failed: invalid client index %d (arg %d)", target, *arg); goto error_out; } @@ -102,13 +106,13 @@ try_serverlang: { if (pPhrases->FindTranslation(key, SOURCEMOD_LANGUAGE_ENGLISH, &pTrans) != Trans_Okay) { - pCtx->ThrowNativeErrorEx(SP_ERROR_PARAM, "Language phrase \"%s\" not found", key); + pCtx->ThrowNativeErrorEx(SP_ERROR_PARAM, "Language phrase \"%s\" not found (arg %d)", key, *arg); goto error_out; } } else { - pCtx->ThrowNativeErrorEx(SP_ERROR_PARAM, "Language phrase \"%s\" not found", key); + pCtx->ThrowNativeErrorEx(SP_ERROR_PARAM, "Language phrase \"%s\" not found (arg %d)", key, *arg); goto error_out; } } @@ -123,9 +127,8 @@ try_serverlang: if ((*arg) + (max_params - 1) > (size_t)params[0]) { pCtx->ThrowNativeErrorEx(SP_ERROR_PARAMS_MAX, - "Translation string formatted incorrectly - missing at least %d parameters", - ((*arg + (max_params - 1)) - params[0]) - ); + "Translation string formatted incorrectly - missing at least %d parameters (arg %d)", + ((*arg + (max_params - 1)) - params[0]), *arg); goto error_out; } @@ -147,7 +150,7 @@ error_out: return 0; } -void AddString(char **buf_p, size_t &maxlen, const char *string, int width, int prec) +bool AddString(char **buf_p, size_t &maxlen, const char *string, int width, int prec, int flags) { int size = 0; char *buf; @@ -159,6 +162,7 @@ void AddString(char **buf_p, size_t &maxlen, const char *string, int width, int { string = nlstr; prec = -1; + flags |= NOESCAPE; } if (prec >= 0) @@ -182,12 +186,44 @@ void AddString(char **buf_p, size_t &maxlen, const char *string, int width, int size = maxlen; } - maxlen -= size; width -= size; - while (size--) + if (g_FormatEscapeDatabase && (flags & NOESCAPE) == 0) { - *buf++ = *string++; + char *tempBuffer = NULL; + if (prec != -1) + { + // I doubt anyone will ever do this, so just allocate. + tempBuffer = new char[maxlen + 1]; + memcpy(tempBuffer, string, size); + tempBuffer[size] = '\0'; + } + + size_t newSize; + bool ret = g_FormatEscapeDatabase->QuoteString(tempBuffer ? tempBuffer : string, buf, maxlen + 1, &newSize); + + if (tempBuffer) + { + delete[] tempBuffer; + } + + if (!ret) + { + return false; + } + + maxlen -= newSize; + buf += newSize; + size = 0; // Consistency. + } + else + { + maxlen -= size; + + while (size--) + { + *buf++ = *string++; + } } while ((width-- > 0) && maxlen) @@ -197,6 +233,8 @@ void AddString(char **buf_p, size_t &maxlen, const char *string, int width, int } *buf_p = buf; + + return true; } void AddFloat(char **buf_p, size_t &maxlen, double fval, int width, int prec, int flags) @@ -212,7 +250,7 @@ void AddFloat(char **buf_p, size_t &maxlen, double fval, int width, int prec, in if (ke::IsNaN(fval)) { - AddString(buf_p, maxlen, "NaN", width, prec); + AddString(buf_p, maxlen, "NaN", width, prec, flags | NOESCAPE); return; } @@ -750,7 +788,7 @@ reswitch: } const char *str = (const char *)params[curparam]; curparam++; - AddString(&buf_p, llen, str, width, prec); + AddString(&buf_p, llen, str, width, prec, flags); arg++; break; } @@ -1030,6 +1068,11 @@ reswitch: flags |= LADJUST; goto rflag; } + case '!': + { + flags |= NOESCAPE; + goto rflag; + } case '.': { n = 0; @@ -1127,7 +1170,7 @@ reswitch: const char *auth; int userid; if (!bridge->DescribePlayer(*value, &name, &auth, &userid)) - return pCtx->ThrowNativeError("Client index %d is invalid", *value); + return pCtx->ThrowNativeError("Client index %d is invalid (arg %d)", *value, arg); ke::SafeSprintf(buffer, sizeof(buffer), "%s<%d><%s><>", @@ -1141,7 +1184,8 @@ reswitch: sizeof(buffer), "Console<0>"); } - AddString(&buf_p, llen, buffer, width, prec); + if (!AddString(&buf_p, llen, buffer, width, prec, flags)) + return pCtx->ThrowNativeError("Escaped string would be truncated (arg %d)", arg); arg++; break; } @@ -1154,9 +1198,10 @@ reswitch: const char *name = "Console"; if (*value) { if (!bridge->DescribePlayer(*value, &name, nullptr, nullptr)) - return pCtx->ThrowNativeError("Client index %d is invalid", *value); + return pCtx->ThrowNativeError("Client index %d is invalid (arg %d)", *value, arg); } - AddString(&buf_p, llen, name, width, prec); + if (!AddString(&buf_p, llen, name, width, prec, flags)) + return pCtx->ThrowNativeError("Escaped string would be truncated (arg %d)", arg); arg++; break; } @@ -1165,7 +1210,8 @@ reswitch: CHECK_ARGS(0); char *str; pCtx->LocalToString(params[arg], &str); - AddString(&buf_p, llen, str, width, prec); + if (!AddString(&buf_p, llen, str, width, prec, flags)) + return pCtx->ThrowNativeError("Escaped string would be truncated (arg %d)", arg); arg++; break; } diff --git a/core/logic/sprintf.h b/core/logic/sprintf.h index 756c7db4..f586bb72 100644 --- a/core/logic/sprintf.h +++ b/core/logic/sprintf.h @@ -30,6 +30,7 @@ #include namespace SourceMod { +class IDatabase; class IPhraseCollection; } @@ -57,4 +58,6 @@ bool gnprintf(char *buffer, size_t *pOutLength, const char **pFailPhrase); +extern SourceMod::IDatabase *g_FormatEscapeDatabase; + #endif // _include_sourcemod_core_logic_sprintf_h_ diff --git a/core/logic/stringutil.cpp b/core/logic/stringutil.cpp index 3c966cb5..d5be7697 100644 --- a/core/logic/stringutil.cpp +++ b/core/logic/stringutil.cpp @@ -29,11 +29,13 @@ * Version: $Id$ */ +#include "common_logic.h" #include #include #include #include #include "stringutil.h" +#include "sprintf.h" #include #include "TextParsers.h" @@ -363,3 +365,76 @@ char *UTIL_TrimWhitespace(char *str, size_t &len) return str; } +class StaticCharBuf +{ + char *buffer; + size_t max_size; +public: + StaticCharBuf() : buffer(NULL), max_size(0) + { + } + ~StaticCharBuf() + { + free(buffer); + } + char* GetWithSize(size_t len) + { + if (len > max_size) + { + buffer = (char *)realloc(buffer, len); + max_size = len; + } + return buffer; + } +}; + +static char g_formatbuf[2048]; +static StaticCharBuf g_extrabuf; +cell_t InternalFormat(IPluginContext *pCtx, const cell_t *params, int start) +{ + char *buf, *fmt, *destbuf; + cell_t start_addr, end_addr, maxparam; + size_t res, maxlen; + int arg = start + 4; + bool copy = false; + char *__copy_buf; + + pCtx->LocalToString(params[start + 1], &destbuf); + pCtx->LocalToString(params[start + 3], &fmt); + + maxlen = static_cast(params[start + 2]); + start_addr = params[start + 1]; + end_addr = params[start + 1] + maxlen; + maxparam = params[0]; + + for (cell_t i = (start + 3); i <= maxparam; i++) + { + if ((params[i] >= start_addr) && (params[i] <= end_addr)) + { + copy = true; + break; + } + } + + if (copy) + { + if (maxlen > sizeof(g_formatbuf)) + { + __copy_buf = g_extrabuf.GetWithSize(maxlen); + } + else + { + __copy_buf = g_formatbuf; + } + } + + buf = (copy) ? __copy_buf : destbuf; + res = atcprintf(buf, maxlen, fmt, pCtx, params, &arg); + + if (copy) + { + memcpy(destbuf, __copy_buf, res+1); + } + + return static_cast(res); +} diff --git a/core/logic/stringutil.h b/core/logic/stringutil.h index 558ab37c..94b3a688 100644 --- a/core/logic/stringutil.h +++ b/core/logic/stringutil.h @@ -43,5 +43,9 @@ size_t UTIL_DecodeHexString(unsigned char *buffer, size_t maxlength, const char void UTIL_StripExtension(const char *in, char *out, int outSize); char *UTIL_TrimWhitespace(char *str, size_t &len); +// Internal copying Format helper, expects (char[] buffer, int maxlength, const char[] format, any ...) starting at |start| +// i.e. you can stuff your own params before |buffer|. +cell_t InternalFormat(IPluginContext *pCtx, const cell_t *params, int start); + #endif /* _INCLUDE_SOURCEMOD_COMMON_STRINGUTIL_H_ */ diff --git a/extensions/sqlite/driver/SqDatabase.cpp b/extensions/sqlite/driver/SqDatabase.cpp index e0f13e80..a4f03a89 100644 --- a/extensions/sqlite/driver/SqDatabase.cpp +++ b/extensions/sqlite/driver/SqDatabase.cpp @@ -84,6 +84,18 @@ IDBDriver *SqDatabase::GetDriver() bool SqDatabase::QuoteString(const char *str, char buffer[], size_t maxlen, size_t *newSize) { + unsigned long size = static_cast(strlen(str)); + unsigned long needed = size * 2 + 1; + + if (maxlen < needed) + { + if (newSize != NULL) + { + *newSize = (size_t)needed; + } + return false; + } + char *res = sqlite3_snprintf(static_cast(maxlen), buffer, "%q", str); if (res != NULL && newSize != NULL) diff --git a/plugins/include/dbi.inc b/plugins/include/dbi.inc index cd3a8851..55bc6953 100644 --- a/plugins/include/dbi.inc +++ b/plugins/include/dbi.inc @@ -380,6 +380,16 @@ methodmap Database < Handle // The buffer must be at least 2*strlen(string)+1. public native bool Escape(const char[] string, char[] buffer, int maxlength, int &written=0); + // Formats a string according to the SourceMod format rules (see documentation). + // All format specifiers are escaped (see SQL_EscapeString) unless the '!' flag is used. + // + // @param buffer Destination string buffer. + // @param maxlength Maximum length of output string buffer. + // @param format Formatting rules. + // @param ... Variable number of format parameters. + // @return Number of cells written. + public native int Format(const char[] buffer, int maxlength, const char[] format, any ...); + // Returns whether a database is the same connection as another database. public native bool IsSameConnection(Database other); @@ -639,6 +649,19 @@ native bool SQL_EscapeString(Handle database, int maxlength, int &written=0); +/** + * Formats a string according to the SourceMod format rules (see documentation). + * All format specifiers are escaped (see SQL_EscapeString) unless the '!' flag is used. + * + * @param database A database Handle. + * @param buffer Destination string buffer. + * @param maxlength Maximum length of output string buffer. + * @param format Formatting rules. + * @param ... Variable number of format parameters. + * @return Number of cells written. + */ +native int SQL_FormatQuery(Handle database, const char[] buffer, int maxlength, const char[] format, any ...); + /** * This is a backwards compatibility stock. You should use SQL_EscapeString() * instead, as this function will probably be deprecated in SourceMod 1.1.