openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

94 lines
3.9 KiB

#include "snpemodel.h"
void PrintErrorStringAndExit() {
const char* const errStr = zdl::DlSystem::getLastErrorString();
std::cerr << zdl::DlSystem::getLastErrorString() << std::endl;
std::exit(EXIT_FAILURE);
}
SNPEModel::SNPEModel(const uint8_t *model_data, const size_t model_size, float *output, size_t output_size) {
assert(zdl::SNPE::SNPEFactory::isRuntimeAvailable(zdl::DlSystem::Runtime_t::GPU));
// load model
std::unique_ptr<zdl::DlContainer::IDlContainer> container = zdl::DlContainer::IDlContainer::open(model_data, model_size);
if (!container) { PrintErrorStringAndExit(); }
printf("loaded model with size: %d\n", model_size);
// create model runner
zdl::SNPE::SNPEBuilder snpeBuilder(container.get());
while (!snpe) {
snpe = snpeBuilder.setOutputLayers({})
.setRuntimeProcessor(zdl::DlSystem::Runtime_t::GPU)
.setUseUserSuppliedBuffers(true)
.setPerformanceProfile(zdl::DlSystem::PerformanceProfile_t::HIGH_PERFORMANCE)
.build();
if (!snpe) std::cerr << zdl::DlSystem::getLastErrorString() << std::endl;
}
// get input and output names
const auto &strListi_opt = snpe->getInputTensorNames();
if (!strListi_opt) throw std::runtime_error("Error obtaining Input tensor names");
const auto &strListi = *strListi_opt;
//assert(strListi.size() == 1);
const char *input_tensor_name = strListi.at(0);
const auto &strListo_opt = snpe->getOutputTensorNames();
if (!strListo_opt) throw std::runtime_error("Error obtaining Output tensor names");
const auto &strListo = *strListo_opt;
assert(strListo.size() == 1);
const char *output_tensor_name = strListo.at(0);
printf("model: %s -> %s\n", input_tensor_name, output_tensor_name);
zdl::DlSystem::UserBufferEncodingFloat userBufferEncodingFloat;
zdl::DlSystem::IUserBufferFactory& ubFactory = zdl::SNPE::SNPEFactory::getUserBufferFactory();
// create input buffer
{
const auto &inputDims_opt = snpe->getInputDimensions(input_tensor_name);
const zdl::DlSystem::TensorShape& bufferShape = *inputDims_opt;
std::vector<size_t> strides(bufferShape.rank());
strides[strides.size() - 1] = sizeof(float);
size_t product = 1;
for (size_t i = 0; i < bufferShape.rank(); i++) product *= bufferShape[i];
size_t stride = strides[strides.size() - 1];
for (size_t i = bufferShape.rank() - 1; i > 0; i--) {
stride *= bufferShape[i];
strides[i-1] = stride;
}
printf("input product is %u\n", product);
inputBuffer = ubFactory.createUserBuffer(NULL, product*sizeof(float), strides, &userBufferEncodingFloat);
inputMap.add(input_tensor_name, inputBuffer.get());
}
// create output buffer
{
std::vector<size_t> outputStrides = {output_size * sizeof(float), sizeof(float)};
outputBuffer = ubFactory.createUserBuffer(output, output_size * sizeof(float), outputStrides, &userBufferEncodingFloat);
outputMap.add(output_tensor_name, outputBuffer.get());
}
}
void SNPEModel::addRecurrent(float *state, int state_size) {
// get input and output names
const auto &strListi_opt = snpe->getInputTensorNames();
if (!strListi_opt) throw std::runtime_error("Error obtaining Input tensor names");
const auto &strListi = *strListi_opt;
const char *input_tensor_name = strListi.at(1);
printf("adding recurrent: %s\n", input_tensor_name);
zdl::DlSystem::UserBufferEncodingFloat userBufferEncodingFloat;
zdl::DlSystem::IUserBufferFactory& ubFactory = zdl::SNPE::SNPEFactory::getUserBufferFactory();
std::vector<size_t> recurrentStrides = {state_size * sizeof(float), sizeof(float)};
recurrentBuffer = ubFactory.createUserBuffer(state, state_size * sizeof(float), recurrentStrides, &userBufferEncodingFloat);
inputMap.add(input_tensor_name, recurrentBuffer.get());
}
void SNPEModel::execute(float *net_input_buf) {
assert(inputBuffer->setBufferAddress(net_input_buf));
if (!snpe->execute(inputMap, outputMap)) {
PrintErrorStringAndExit();
}
}