|
|
|
@ -9,8 +9,9 @@ void PrintErrorStringAndExit() { |
|
|
|
|
std::exit(EXIT_FAILURE); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
SNPEModel::SNPEModel(const char *path, float *loutput, size_t output_size, int runtime) { |
|
|
|
|
SNPEModel::SNPEModel(const char *path, float *loutput, size_t loutput_size, int runtime) { |
|
|
|
|
output = loutput; |
|
|
|
|
output_size = loutput_size; |
|
|
|
|
#ifdef QCOM |
|
|
|
|
if (runtime==USE_GPU_RUNTIME) { |
|
|
|
|
Runtime = zdl::DlSystem::Runtime_t::GPU; |
|
|
|
@ -102,6 +103,7 @@ SNPEModel::SNPEModel(const char *path, float *loutput, size_t output_size, int r |
|
|
|
|
|
|
|
|
|
void SNPEModel::addRecurrent(float *state, int state_size) { |
|
|
|
|
recurrent = state; |
|
|
|
|
recurrent_size = state_size; |
|
|
|
|
recurrentBuffer = this->addExtra(state, state_size, 3); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
@ -134,21 +136,37 @@ std::unique_ptr<zdl::DlSystem::IUserBuffer> SNPEModel::addExtra(float *state, in |
|
|
|
|
void SNPEModel::execute(float *net_input_buf, int buf_size) { |
|
|
|
|
#ifdef USE_THNEED |
|
|
|
|
if (Runtime == zdl::DlSystem::Runtime_t::GPU) { |
|
|
|
|
float *inputs[4] = {recurrent, trafficConvention, desire, net_input_buf}; |
|
|
|
|
if (thneed == NULL) { |
|
|
|
|
assert(inputBuffer->setBufferAddress(net_input_buf)); |
|
|
|
|
if (!snpe->execute(inputMap, outputMap)) { |
|
|
|
|
PrintErrorStringAndExit(); |
|
|
|
|
} |
|
|
|
|
memset(recurrent, 0, recurrent_size*sizeof(float)); |
|
|
|
|
thneed = new Thneed(); |
|
|
|
|
//thneed->record = 3;
|
|
|
|
|
if (!snpe->execute(inputMap, outputMap)) { |
|
|
|
|
PrintErrorStringAndExit(); |
|
|
|
|
} |
|
|
|
|
thneed->stop(); |
|
|
|
|
//thneed->record = 2;
|
|
|
|
|
printf("thneed cached\n"); |
|
|
|
|
|
|
|
|
|
// doing self test
|
|
|
|
|
float *outputs_golden = (float *)malloc(output_size*sizeof(float)); |
|
|
|
|
memcpy(outputs_golden, output, output_size*sizeof(float)); |
|
|
|
|
memset(output, 0, output_size*sizeof(float)); |
|
|
|
|
memset(recurrent, 0, recurrent_size*sizeof(float)); |
|
|
|
|
thneed->execute(inputs, output); |
|
|
|
|
|
|
|
|
|
if (memcmp(output, outputs_golden, output_size*sizeof(float)) == 0) { |
|
|
|
|
printf("thneed selftest passed\n"); |
|
|
|
|
} else { |
|
|
|
|
for (int i = 0; i < output_size; i++) { |
|
|
|
|
printf("mismatch %3d: %f %f\n", i, output[i], outputs_golden[i]); |
|
|
|
|
} |
|
|
|
|
assert(false); |
|
|
|
|
} |
|
|
|
|
free(outputs_golden); |
|
|
|
|
} else { |
|
|
|
|
float *inputs[4] = {recurrent, trafficConvention, desire, net_input_buf}; |
|
|
|
|
thneed->execute(inputs, output); |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|