update snpemodel

pull/24762/head
ZwX1616 3 years ago
parent db7e64fcef
commit 749e4593af
  1. 11
      selfdrive/modeld/runners/snpemodel.cc
  2. 2
      selfdrive/modeld/runners/snpemodel.h

@ -14,7 +14,7 @@ void PrintErrorStringAndExit() {
std::exit(EXIT_FAILURE);
}
SNPEModel::SNPEModel(const char *path, float *loutput, size_t loutput_size, int runtime, bool luse_extra) {
SNPEModel::SNPEModel(const char *path, float *loutput, size_t loutput_size, int runtime, bool luse_extra, bool use_tf8) {
output = loutput;
output_size = loutput_size;
use_extra = luse_extra;
@ -70,14 +70,16 @@ SNPEModel::SNPEModel(const char *path, float *loutput, size_t loutput_size, int
printf("model: %s -> %s\n", input_tensor_name, output_tensor_name);
zdl::DlSystem::UserBufferEncodingFloat userBufferEncodingFloat;
zdl::DlSystem::UserBufferEncodingTf8 userBufferEncodingTf8(0, 1./255); // network takes 0-1
zdl::DlSystem::IUserBufferFactory& ubFactory = zdl::SNPE::SNPEFactory::getUserBufferFactory();
size_t size_of_input = use_tf8 ? sizeof(uint8_t) : sizeof(float);
// create input buffer
{
const auto &inputDims_opt = snpe->getInputDimensions(input_tensor_name);
const zdl::DlSystem::TensorShape& bufferShape = *inputDims_opt;
std::vector<size_t> strides(bufferShape.rank());
strides[strides.size() - 1] = sizeof(float);
strides[strides.size() - 1] = size_of_input
size_t product = 1;
for (size_t i = 0; i < bufferShape.rank(); i++) product *= bufferShape[i];
size_t stride = strides[strides.size() - 1];
@ -86,7 +88,10 @@ SNPEModel::SNPEModel(const char *path, float *loutput, size_t loutput_size, int
strides[i-1] = stride;
}
printf("input product is %lu\n", product);
inputBuffer = ubFactory.createUserBuffer(NULL, product*sizeof(float), strides, &userBufferEncodingFloat);
inputBuffer = ubFactory.createUserBuffer(NULL,
product*size_of_input,
strides,
use_tf8 ? &userBufferEncodingTf8 : &userBufferEncodingFloat);
inputMap.add(input_tensor_name, inputBuffer.get());
}

@ -23,7 +23,7 @@
class SNPEModel : public RunModel {
public:
SNPEModel(const char *path, float *loutput, size_t loutput_size, int runtime, bool luse_extra = false);
SNPEModel(const char *path, float *loutput, size_t loutput_size, int runtime, bool luse_extra = false, bool use_tf8 = false);
void addRecurrent(float *state, int state_size);
void addTrafficConvention(float *state, int state_size);
void addCalib(float *state, int state_size);

Loading…
Cancel
Save