diff --git a/DynamicHooks/hook.cpp b/DynamicHooks/hook.cpp index 0564bad..26f7757 100644 --- a/DynamicHooks/hook.cpp +++ b/DynamicHooks/hook.cpp @@ -53,6 +53,12 @@ CHook::CHook(void* pFunc, ICallingConvention* pConvention) m_pRegisters = new CRegisters(pConvention->GetRegisters()); m_pCallingConvention = pConvention; + if (!m_hookHandler.init()) + return; + + if (!m_RetAddr.init()) + return; + unsigned char* pTarget = (unsigned char *) pFunc; // Determine the number of bytes we need to copy @@ -103,37 +109,56 @@ void CHook::AddCallback(HookType_t eHookType, HookHandlerFn* pCallback) if (!pCallback) return; - if (!IsCallbackRegistered(eHookType, pCallback)) - m_hookHandler[eHookType].push_back(pCallback); + HookTypeMap::Insert i = m_hookHandler.findForAdd(eHookType); + if (!i.found()) { + HookHandlerSet set; + set.init(); + m_hookHandler.add(i, eHookType, ke::Move(set)); + } + + i->value.add(pCallback); } void CHook::RemoveCallback(HookType_t eHookType, HookHandlerFn* pCallback) { - if (IsCallbackRegistered(eHookType, pCallback)) - m_hookHandler[eHookType].remove(pCallback); + HookTypeMap::Result r = m_hookHandler.find(eHookType); + if (!r.found()) + return; + + r->value.removeIfExists(pCallback); } bool CHook::IsCallbackRegistered(HookType_t eHookType, HookHandlerFn* pCallback) { - std::list callbacks = m_hookHandler[eHookType]; - for(std::list::iterator it=callbacks.begin(); it != callbacks.end(); it++) - { - if (*it == pCallback) - return true; - } - return false; + HookTypeMap::Result r = m_hookHandler.find(eHookType); + if (!r.found()) + return false; + + return r->value.has(pCallback); } bool CHook::AreCallbacksRegistered() { - return !m_hookHandler[HOOKTYPE_PRE].empty() || !m_hookHandler[HOOKTYPE_POST].empty(); + HookTypeMap::Result r = m_hookHandler.find(HOOKTYPE_PRE); + if (r.found() && r->value.elements() > 0) + return true; + + r = m_hookHandler.find(HOOKTYPE_POST); + if (r.found() && r->value.elements() > 0) + return true; + + return false; } bool CHook::HookHandler(HookType_t eHookType) { bool bOverride = false; - std::list callbacks = this->m_hookHandler[eHookType]; - for(std::list::iterator it=callbacks.begin(); it != callbacks.end(); it++) + HookTypeMap::Result r = m_hookHandler.find(eHookType); + if (!r.found()) + return bOverride; + + HookHandlerSet &callbacks = r->value; + for(HookHandlerSet::iterator it=callbacks.iter(); !it.empty(); it.next()) { bool result = ((HookHandlerFn) *it)(eHookType, this); if (result) @@ -144,15 +169,17 @@ bool CHook::HookHandler(HookType_t eHookType) void* __cdecl CHook::GetReturnAddress(void* pESP) { - if (m_RetAddr.count(pESP) == 0) + ReturnAddressMap::Result r = m_RetAddr.find(pESP); + if (!r.found()) puts("ESP not present."); - return m_RetAddr[pESP]; + return r->value; } void __cdecl CHook::SetReturnAddress(void* pRetAddr, void* pESP) { - m_RetAddr[pESP] = pRetAddr; + ReturnAddressMap::Insert i = m_RetAddr.findForAdd(pESP); + m_RetAddr.add(i, pESP, pRetAddr); } void* CHook::CreateBridge() diff --git a/DynamicHooks/hook.h b/DynamicHooks/hook.h index 5053074..fba0c5f 100644 --- a/DynamicHooks/hook.h +++ b/DynamicHooks/hook.h @@ -34,11 +34,10 @@ // ============================================================================ // >> INCLUDES // ============================================================================ -#include -#include - #include "registers.h" #include "convention.h" +#include +#include // ============================================================================ // >> HookType_t @@ -63,6 +62,20 @@ typedef bool (*HookHandlerFn)(HookType_t, CHook*); #define __cdecl #endif +struct IntegerPolicy +{ + static inline uint32_t hash(size_t i) { + return ke::HashInteger(i); + } + static inline bool matches(size_t i1, size_t i2) { + return i1 == i2; + } +}; + +typedef ke::HashSet> HookHandlerSet; +typedef ke::HashMap HookTypeMap; +typedef ke::HashMap> ReturnAddressMap; + namespace sp { class MacroAssembler; @@ -164,12 +177,12 @@ private: void __cdecl SetReturnAddress(void* pRetAddr, void* pESP); public: - std::map > m_hookHandler; + + HookTypeMap m_hookHandler; // Address of the original function void* m_pFunc; - ICallingConvention* m_pCallingConvention; // Address of the bridge @@ -184,7 +197,7 @@ public: // New return address void* m_pNewRetAddr; - std::map m_RetAddr; + ReturnAddressMap m_RetAddr; }; #endif // _HOOK_H \ No newline at end of file