Created
December 16, 2018 14:44
-
-
Save YashasSamaga/1725320041b5579fdf531c8ac39a2dc8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "main.h" | |
#include <dlib/dnn.h> | |
#include <sampml/svm_classifier.hpp> | |
#include "tools/fixed_thread_pool.hpp" | |
#include "iscript.hpp" | |
#include "classifier.hpp" | |
#include "transform.hpp" | |
#include "dnn.hpp" | |
namespace classifier { | |
using sample_type = output_vector; | |
double test_vector_svm(const sample_type& sample) { | |
static thread_local sampml::trainer::svm_classifier<sample_type> svm; | |
static thread_local bool loaded = false; | |
if (loaded == false) { | |
svm.deserialize("models/svm_classifier.dat"); | |
loaded = true; | |
} | |
return svm.test(sample); | |
} | |
double test_vector_dnn(const sample_type& sample) { | |
static thread_local aa_network_type net; | |
static thread_local bool loaded = false; | |
if (loaded == false) { | |
dlib::deserialize("models/dnn_classifier.dat") >> net; | |
loaded = true; | |
} | |
return net(sample); | |
} | |
struct queue_item_tag_t { | |
int playerid; | |
AMX *amx; | |
}; | |
struct input_queue_item_t { | |
sample_type sample; | |
queue_item_tag_t tag; | |
}; | |
struct output_queue_item_t { | |
float probabilities[2]; | |
queue_item_tag_t tag; | |
}; | |
struct process_functor_t { | |
void operator()(input_queue_item_t& input, output_queue_item_t& output) { | |
sample_type& sample = input.sample; | |
output.probabilities[0] = test_vector_svm(sample); | |
output.probabilities[1] = test_vector_dnn(sample); | |
output.tag = input.tag; | |
} | |
}; | |
fixed_thread_pool<input_queue_item_t, output_queue_item_t, process_functor_t> pool(2); | |
void ProcessTick() { | |
std::vector<output_queue_item_t> results; | |
pool.deqeue_all(results); | |
for (auto&& item : results) { | |
AMX* amx = item.tag.amx; | |
if (iscript::IsValidAmx(amx)) { | |
int cb_idx = -1; | |
if (amx_FindPublic(amx, "OnPlayerSuspectedForAimbot", &cb_idx) != AMX_ERR_NONE || cb_idx < 0) { | |
// OnPlayerSuspectedForAimbot(playerid, Float:probablities[2]) | |
cell probablities[2] = { amx_ftoc(item.probabilities[0]), amx_ftoc(item.probabilities[1]) }; | |
cell amx_addr, *phys_addr; | |
amx_Allot(amx, sizeof(probablities) / sizeof(cell), &amx_addr, &phys_addr); | |
memcpy(phys_addr, probablities, sizeof(probablities)); | |
amx_Push(amx, amx_addr); | |
amx_Push(amx, item.tag.playerid); | |
amx_Exec(amx, NULL, cb_idx); | |
} | |
} | |
} | |
} | |
namespace natives { | |
static input_vector pawn_array_to_vector(cell data[]) { | |
input_vector vector; | |
for (int i = 0; i < input_vector::NR; i++) { | |
switch (i) { | |
case bHit: | |
case iShooterCameraMode: | |
case iShooterState: | |
case iShooterSpecialAction: | |
case bShooterInVehicle: | |
case bShooterSurfingVehicle: | |
case bShooterSurfingObject: | |
case iShooterWeaponID: | |
case iShooterSkinID: | |
case iShooterID: | |
case iVictimState: | |
case iVictimSpecialAction: | |
case bVictimInVehicle: | |
case bVictimSurfingVehicle: | |
case bVictimSurfingObject: | |
case iVictimWeaponID: | |
case iVictimSkinID: | |
case iVictimID: | |
case iHitType: | |
case iShooterPing: | |
case iVictimPing: | |
case iSecond: | |
case iTick: | |
vector(i) = (data[i]); | |
break; | |
default: | |
vector(i) = amx_ctof(data[i]); | |
} | |
} | |
return vector; | |
} | |
cell AMX_NATIVE_CALL submit_vector(AMX * amx, cell* params) | |
{ | |
cell playerid = params[1]; | |
cell *data; | |
amx_GetAddr(amx, params[2], &data); | |
auto vector = pawn_array_to_vector(data); | |
static std::array<transformer, MAX_PLAYERS> transformers; | |
transformers[playerid].submit(vector); | |
if (transformers[playerid].pool.size()) { | |
input_queue_item_t item; | |
item.sample = transformers[playerid].pool.back(); | |
item.tag.amx = amx; | |
item.tag.playerid = playerid; | |
transformers[playerid].pool.pop_back(); | |
return true; | |
} | |
return false; | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template <class InputTaskItem, class OutputTaskItem, class TaskFunction> | |
class fixed_thread_pool { | |
public: | |
fixed_thread_pool(int num) { start(num); } | |
~fixed_thread_pool() { stop(); } | |
bool empty() { | |
return; | |
std::lock_guard<std::mutex> lock(queue_lock); | |
return output_queue.empty(); | |
} | |
void enqueue(InputTaskItem&& task) { | |
return; | |
std::unique_lock<std::mutex> lock(queue_lock); | |
queue.push(task); | |
lock.unlock(); | |
queue_not_empty.notify_one(); | |
} | |
OutputTaskItem dequeue() { | |
return; | |
std::unique_lock<std::mutex> lock(queue_lock); | |
assert(queue.size() > 0); | |
OutputTaskItem item = queue.back(); | |
queue.pop(); | |
lock.unlock(); | |
return item; | |
} | |
void deqeue_all(std::vector<OutputTaskItem>& results) { | |
return; | |
std::unique_lock<std::mutex> lock(queue_lock); | |
while (!output_queue.empty()) { | |
results.push_back(output_queue.back()); | |
output_queue.pop(); | |
} | |
} | |
private: | |
std::vector<std::thread> threads; | |
std::queue<InputTaskItem> input_queue; | |
std::queue<OutputTaskItem> output_queue; | |
std::condition_variable queue_not_empty; | |
std::mutex queue_lock; | |
bool stop_threads; | |
void start(int num = 2) { | |
stop_threads = false; | |
std::cout << "INPUT\n" << std::endl; | |
for (int i = 0; i < num; i++) { | |
std::cout << "CREATE\n" << std::endl; | |
threads.push_back(std::thread([](){ | |
std::cout << "Hello from thread " << std::this_thread::get_id() << std::endl; | |
})); | |
} | |
std::cout << "OUTPUT\n" << std::endl; | |
} | |
void stop() { | |
std::unique_lock<std::mutex> lock(queue_lock); | |
stop_threads = true; | |
queue_not_empty.notify_all(); | |
for(auto& thread : threads ) | |
thread.join(); | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment