Refactor model runners (#28598)
* Started work on model runner refactor
* Fixed some compile errors
* everything compiles
* Fixed bug in SNPEModel
* updateInput -> setInputBuffer
* I understand nothing
* whoops lol
* use std::string instead of char*
* Move common logic into RunModel
* formatting fix
old-commit-hash: c9f00678af
beeps
parent
fdc8876745
commit
95051090a1
10 changed files with 190 additions and 414 deletions
@ -1,18 +1,45 @@ |
||||
#pragma once |
||||
|
||||
#include <string> |
||||
#include <vector> |
||||
#include <memory> |
||||
#include <cassert> |
||||
|
||||
#include "common/clutil.h" |
||||
#include "common/swaglog.h" |
||||
|
||||
struct ModelInput { |
||||
const std::string name; |
||||
float *buffer; |
||||
int size; |
||||
|
||||
ModelInput(const std::string _name, float *_buffer, int _size) : name(_name), buffer(_buffer), size(_size) {} |
||||
virtual void setBuffer(float *_buffer, int _size) { |
||||
assert(size == _size || size == 0); |
||||
buffer = _buffer; |
||||
size = _size; |
||||
} |
||||
}; |
||||
|
||||
class RunModel { |
||||
public: |
||||
std::vector<std::unique_ptr<ModelInput>> inputs; |
||||
|
||||
virtual ~RunModel() {} |
||||
virtual void addRecurrent(float *state, int state_size) {} |
||||
virtual void addDesire(float *state, int state_size) {} |
||||
virtual void addNavFeatures(float *state, int state_size) {} |
||||
virtual void addDrivingStyle(float *state, int state_size) {} |
||||
virtual void addTrafficConvention(float *state, int state_size) {} |
||||
virtual void addCalib(float *state, int state_size) {} |
||||
virtual void addImage(float *image_buf, int buf_size) {} |
||||
virtual void addExtra(float *image_buf, int buf_size) {} |
||||
virtual void execute() {} |
||||
virtual void* getInputBuf() { return nullptr; } |
||||
virtual void* getExtraBuf() { return nullptr; } |
||||
}; |
||||
virtual void* getCLBuffer(const std::string name) { return nullptr; } |
||||
|
||||
virtual void addInput(const std::string name, float *buffer, int size) { |
||||
inputs.push_back(std::unique_ptr<ModelInput>(new ModelInput(name, buffer, size))); |
||||
} |
||||
virtual void setInputBuffer(const std::string name, float *buffer, int size) { |
||||
for (auto &input : inputs) { |
||||
if (name == input->name) { |
||||
input->setBuffer(buffer, size); |
||||
return; |
||||
} |
||||
} |
||||
LOGE("Tried to update input `%s` but no input with this name exists", name.c_str()); |
||||
assert(false); |
||||
} |
||||
}; |
||||
|
@ -1,78 +1,56 @@ |
||||
#include "selfdrive/modeld/runners/thneedmodel.h" |
||||
|
||||
#include <cassert> |
||||
#include "common/swaglog.h" |
||||
|
||||
ThneedModel::ThneedModel(const char *path, float *loutput, size_t loutput_size, int runtime, bool luse_extra, bool luse_tf8, cl_context context) { |
||||
ThneedModel::ThneedModel(const std::string path, float *_output, size_t _output_size, int runtime, bool luse_tf8, cl_context context) { |
||||
thneed = new Thneed(true, context); |
||||
thneed->load(path); |
||||
thneed->load(path.c_str()); |
||||
thneed->clexec(); |
||||
|
||||
recorded = false; |
||||
output = loutput; |
||||
use_extra = luse_extra; |
||||
output = _output; |
||||
} |
||||
|
||||
void ThneedModel::addRecurrent(float *state, int state_size) { |
||||
recurrent = state; |
||||
} |
||||
|
||||
void ThneedModel::addTrafficConvention(float *state, int state_size) { |
||||
trafficConvention = state; |
||||
} |
||||
|
||||
void ThneedModel::addDesire(float *state, int state_size) { |
||||
desire = state; |
||||
} |
||||
|
||||
void ThneedModel::addDrivingStyle(float *state, int state_size) { |
||||
drivingStyle = state; |
||||
} |
||||
|
||||
void ThneedModel::addNavFeatures(float *state, int state_size) { |
||||
navFeatures = state; |
||||
} |
||||
|
||||
void ThneedModel::addImage(float *image_input_buf, int buf_size) { |
||||
input = image_input_buf; |
||||
} |
||||
|
||||
void ThneedModel::addExtra(float *extra_input_buf, int buf_size) { |
||||
extra = extra_input_buf; |
||||
} |
||||
void* ThneedModel::getCLBuffer(const std::string name) { |
||||
int index = -1; |
||||
for (int i = 0; i < inputs.size(); i++) { |
||||
if (name == inputs[i]->name) { |
||||
index = i; |
||||
break; |
||||
} |
||||
} |
||||
|
||||
void* ThneedModel::getInputBuf() { |
||||
if (use_extra && thneed->input_clmem.size() > 5) return &(thneed->input_clmem[5]); |
||||
else if (!use_extra && thneed->input_clmem.size() > 4) return &(thneed->input_clmem[4]); |
||||
else return nullptr; |
||||
} |
||||
if (index == -1) { |
||||
LOGE("Tried to get CL buffer for input `%s` but no input with this name exists", name.c_str()); |
||||
assert(false); |
||||
} |
||||
|
||||
void* ThneedModel::getExtraBuf() { |
||||
if (thneed->input_clmem.size() > 4) return &(thneed->input_clmem[4]); |
||||
else return nullptr; |
||||
if (thneed->input_clmem.size() >= inputs.size()) { |
||||
return &thneed->input_clmem[inputs.size() - index - 1]; |
||||
} else { |
||||
return nullptr; |
||||
} |
||||
} |
||||
|
||||
void ThneedModel::execute() { |
||||
if (!recorded) { |
||||
thneed->record = true; |
||||
if (use_extra) { |
||||
float *inputs[6] = {recurrent, navFeatures, trafficConvention, desire, extra, input}; |
||||
thneed->copy_inputs(inputs); |
||||
} else { |
||||
float *inputs[5] = {recurrent, navFeatures, trafficConvention, desire, input}; |
||||
thneed->copy_inputs(inputs); |
||||
float *input_buffers[inputs.size()]; |
||||
for (int i = 0; i < inputs.size(); i++) { |
||||
input_buffers[inputs.size() - i - 1] = inputs[i]->buffer; |
||||
} |
||||
|
||||
thneed->copy_inputs(input_buffers); |
||||
thneed->clexec(); |
||||
thneed->copy_output(output); |
||||
thneed->stop(); |
||||
|
||||
recorded = true; |
||||
} else { |
||||
if (use_extra) { |
||||
float *inputs[6] = {recurrent, navFeatures, trafficConvention, desire, extra, input}; |
||||
thneed->execute(inputs, output); |
||||
} else { |
||||
float *inputs[5] = {recurrent, navFeatures, trafficConvention, desire, input}; |
||||
thneed->execute(inputs, output); |
||||
float *input_buffers[inputs.size()]; |
||||
for (int i = 0; i < inputs.size(); i++) { |
||||
input_buffers[inputs.size() - i - 1] = inputs[i]->buffer; |
||||
} |
||||
thneed->execute(input_buffers, output); |
||||
} |
||||
} |
||||
|
Loading…
Reference in new issue