You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
94 lines
3.9 KiB
94 lines
3.9 KiB
#include "snpemodel.h"
|
|
|
|
void PrintErrorStringAndExit() {
|
|
const char* const errStr = zdl::DlSystem::getLastErrorString();
|
|
std::cerr << zdl::DlSystem::getLastErrorString() << std::endl;
|
|
std::exit(EXIT_FAILURE);
|
|
}
|
|
|
|
SNPEModel::SNPEModel(const uint8_t *model_data, const size_t model_size, float *output, size_t output_size) {
|
|
assert(zdl::SNPE::SNPEFactory::isRuntimeAvailable(zdl::DlSystem::Runtime_t::GPU));
|
|
|
|
// load model
|
|
std::unique_ptr<zdl::DlContainer::IDlContainer> container = zdl::DlContainer::IDlContainer::open(model_data, model_size);
|
|
if (!container) { PrintErrorStringAndExit(); }
|
|
printf("loaded model with size: %d\n", model_size);
|
|
|
|
// create model runner
|
|
zdl::SNPE::SNPEBuilder snpeBuilder(container.get());
|
|
while (!snpe) {
|
|
snpe = snpeBuilder.setOutputLayers({})
|
|
.setRuntimeProcessor(zdl::DlSystem::Runtime_t::GPU)
|
|
.setUseUserSuppliedBuffers(true)
|
|
.setPerformanceProfile(zdl::DlSystem::PerformanceProfile_t::HIGH_PERFORMANCE)
|
|
.build();
|
|
if (!snpe) std::cerr << zdl::DlSystem::getLastErrorString() << std::endl;
|
|
}
|
|
|
|
// get input and output names
|
|
const auto &strListi_opt = snpe->getInputTensorNames();
|
|
if (!strListi_opt) throw std::runtime_error("Error obtaining Input tensor names");
|
|
const auto &strListi = *strListi_opt;
|
|
//assert(strListi.size() == 1);
|
|
const char *input_tensor_name = strListi.at(0);
|
|
|
|
const auto &strListo_opt = snpe->getOutputTensorNames();
|
|
if (!strListo_opt) throw std::runtime_error("Error obtaining Output tensor names");
|
|
const auto &strListo = *strListo_opt;
|
|
assert(strListo.size() == 1);
|
|
const char *output_tensor_name = strListo.at(0);
|
|
|
|
printf("model: %s -> %s\n", input_tensor_name, output_tensor_name);
|
|
|
|
zdl::DlSystem::UserBufferEncodingFloat userBufferEncodingFloat;
|
|
zdl::DlSystem::IUserBufferFactory& ubFactory = zdl::SNPE::SNPEFactory::getUserBufferFactory();
|
|
|
|
// create input buffer
|
|
{
|
|
const auto &inputDims_opt = snpe->getInputDimensions(input_tensor_name);
|
|
const zdl::DlSystem::TensorShape& bufferShape = *inputDims_opt;
|
|
std::vector<size_t> strides(bufferShape.rank());
|
|
strides[strides.size() - 1] = sizeof(float);
|
|
size_t product = 1;
|
|
for (size_t i = 0; i < bufferShape.rank(); i++) product *= bufferShape[i];
|
|
size_t stride = strides[strides.size() - 1];
|
|
for (size_t i = bufferShape.rank() - 1; i > 0; i--) {
|
|
stride *= bufferShape[i];
|
|
strides[i-1] = stride;
|
|
}
|
|
printf("input product is %u\n", product);
|
|
inputBuffer = ubFactory.createUserBuffer(NULL, product*sizeof(float), strides, &userBufferEncodingFloat);
|
|
|
|
inputMap.add(input_tensor_name, inputBuffer.get());
|
|
}
|
|
|
|
// create output buffer
|
|
{
|
|
std::vector<size_t> outputStrides = {output_size * sizeof(float), sizeof(float)};
|
|
outputBuffer = ubFactory.createUserBuffer(output, output_size * sizeof(float), outputStrides, &userBufferEncodingFloat);
|
|
outputMap.add(output_tensor_name, outputBuffer.get());
|
|
}
|
|
}
|
|
|
|
void SNPEModel::addRecurrent(float *state, int state_size) {
|
|
// get input and output names
|
|
const auto &strListi_opt = snpe->getInputTensorNames();
|
|
if (!strListi_opt) throw std::runtime_error("Error obtaining Input tensor names");
|
|
const auto &strListi = *strListi_opt;
|
|
const char *input_tensor_name = strListi.at(1);
|
|
printf("adding recurrent: %s\n", input_tensor_name);
|
|
|
|
zdl::DlSystem::UserBufferEncodingFloat userBufferEncodingFloat;
|
|
zdl::DlSystem::IUserBufferFactory& ubFactory = zdl::SNPE::SNPEFactory::getUserBufferFactory();
|
|
std::vector<size_t> recurrentStrides = {state_size * sizeof(float), sizeof(float)};
|
|
recurrentBuffer = ubFactory.createUserBuffer(state, state_size * sizeof(float), recurrentStrides, &userBufferEncodingFloat);
|
|
inputMap.add(input_tensor_name, recurrentBuffer.get());
|
|
}
|
|
|
|
void SNPEModel::execute(float *net_input_buf) {
|
|
assert(inputBuffer->setBufferAddress(net_input_buf));
|
|
if (!snpe->execute(inputMap, outputMap)) {
|
|
PrintErrorStringAndExit();
|
|
}
|
|
}
|
|
|
|
|