diff --git a/selfdrive/modeld/runners/snpemodel.cc b/selfdrive/modeld/runners/snpemodel.cc index 1861494d59..cdaf02b0ad 100644 --- a/selfdrive/modeld/runners/snpemodel.cc +++ b/selfdrive/modeld/runners/snpemodel.cc @@ -14,7 +14,7 @@ void PrintErrorStringAndExit() { std::exit(EXIT_FAILURE); } -SNPEModel::SNPEModel(const char *path, float *loutput, size_t loutput_size, int runtime, bool luse_extra) { +SNPEModel::SNPEModel(const char *path, float *loutput, size_t loutput_size, int runtime, bool luse_extra, bool use_tf8) { output = loutput; output_size = loutput_size; use_extra = luse_extra; @@ -70,14 +70,16 @@ SNPEModel::SNPEModel(const char *path, float *loutput, size_t loutput_size, int printf("model: %s -> %s\n", input_tensor_name, output_tensor_name); zdl::DlSystem::UserBufferEncodingFloat userBufferEncodingFloat; + zdl::DlSystem::UserBufferEncodingTf8 userBufferEncodingTf8(0, 1./255); // network takes 0-1 zdl::DlSystem::IUserBufferFactory& ubFactory = zdl::SNPE::SNPEFactory::getUserBufferFactory(); + size_t size_of_input = use_tf8 ? sizeof(uint8_t) : sizeof(float); // create input buffer { const auto &inputDims_opt = snpe->getInputDimensions(input_tensor_name); const zdl::DlSystem::TensorShape& bufferShape = *inputDims_opt; std::vector strides(bufferShape.rank()); - strides[strides.size() - 1] = sizeof(float); + strides[strides.size() - 1] = size_of_input size_t product = 1; for (size_t i = 0; i < bufferShape.rank(); i++) product *= bufferShape[i]; size_t stride = strides[strides.size() - 1]; @@ -86,7 +88,10 @@ SNPEModel::SNPEModel(const char *path, float *loutput, size_t loutput_size, int strides[i-1] = stride; } printf("input product is %lu\n", product); - inputBuffer = ubFactory.createUserBuffer(NULL, product*sizeof(float), strides, &userBufferEncodingFloat); + inputBuffer = ubFactory.createUserBuffer(NULL, + product*size_of_input, + strides, + use_tf8 ? &userBufferEncodingTf8 : &userBufferEncodingFloat); inputMap.add(input_tensor_name, inputBuffer.get()); } diff --git a/selfdrive/modeld/runners/snpemodel.h b/selfdrive/modeld/runners/snpemodel.h index ba51fdced0..b009cf415a 100644 --- a/selfdrive/modeld/runners/snpemodel.h +++ b/selfdrive/modeld/runners/snpemodel.h @@ -23,7 +23,7 @@ class SNPEModel : public RunModel { public: - SNPEModel(const char *path, float *loutput, size_t loutput_size, int runtime, bool luse_extra = false); + SNPEModel(const char *path, float *loutput, size_t loutput_size, int runtime, bool luse_extra = false, bool use_tf8 = false); void addRecurrent(float *state, int state_size); void addTrafficConvention(float *state, int state_size); void addCalib(float *state, int state_size);