Refactor model runners (#28598)
* Started work on model runner refactor
* Fixed some compile errors
* everything compiles
* Fixed bug in SNPEModel
* updateInput -> setInputBuffer
* I understand nothing
* whoops lol
* use std::string instead of char*
* Move common logic into RunModel
* formatting fix
old-commit-hash: c9f00678af
beeps
parent
fdc8876745
commit
95051090a1
10 changed files with 190 additions and 414 deletions
@ -1,18 +1,45 @@ |
|||||||
#pragma once |
#pragma once |
||||||
|
|
||||||
|
#include <string> |
||||||
|
#include <vector> |
||||||
|
#include <memory> |
||||||
|
#include <cassert> |
||||||
|
|
||||||
#include "common/clutil.h" |
#include "common/clutil.h" |
||||||
|
#include "common/swaglog.h" |
||||||
|
|
||||||
|
struct ModelInput { |
||||||
|
const std::string name; |
||||||
|
float *buffer; |
||||||
|
int size; |
||||||
|
|
||||||
|
ModelInput(const std::string _name, float *_buffer, int _size) : name(_name), buffer(_buffer), size(_size) {} |
||||||
|
virtual void setBuffer(float *_buffer, int _size) { |
||||||
|
assert(size == _size || size == 0); |
||||||
|
buffer = _buffer; |
||||||
|
size = _size; |
||||||
|
} |
||||||
|
}; |
||||||
|
|
||||||
class RunModel { |
class RunModel { |
||||||
public: |
public: |
||||||
|
std::vector<std::unique_ptr<ModelInput>> inputs; |
||||||
|
|
||||||
virtual ~RunModel() {} |
virtual ~RunModel() {} |
||||||
virtual void addRecurrent(float *state, int state_size) {} |
|
||||||
virtual void addDesire(float *state, int state_size) {} |
|
||||||
virtual void addNavFeatures(float *state, int state_size) {} |
|
||||||
virtual void addDrivingStyle(float *state, int state_size) {} |
|
||||||
virtual void addTrafficConvention(float *state, int state_size) {} |
|
||||||
virtual void addCalib(float *state, int state_size) {} |
|
||||||
virtual void addImage(float *image_buf, int buf_size) {} |
|
||||||
virtual void addExtra(float *image_buf, int buf_size) {} |
|
||||||
virtual void execute() {} |
virtual void execute() {} |
||||||
virtual void* getInputBuf() { return nullptr; } |
virtual void* getCLBuffer(const std::string name) { return nullptr; } |
||||||
virtual void* getExtraBuf() { return nullptr; } |
|
||||||
}; |
|
||||||
|
|
||||||
|
virtual void addInput(const std::string name, float *buffer, int size) { |
||||||
|
inputs.push_back(std::unique_ptr<ModelInput>(new ModelInput(name, buffer, size))); |
||||||
|
} |
||||||
|
virtual void setInputBuffer(const std::string name, float *buffer, int size) { |
||||||
|
for (auto &input : inputs) { |
||||||
|
if (name == input->name) { |
||||||
|
input->setBuffer(buffer, size); |
||||||
|
return; |
||||||
|
} |
||||||
|
} |
||||||
|
LOGE("Tried to update input `%s` but no input with this name exists", name.c_str()); |
||||||
|
assert(false); |
||||||
|
} |
||||||
|
}; |
||||||
|
@ -1,78 +1,56 @@ |
|||||||
#include "selfdrive/modeld/runners/thneedmodel.h" |
#include "selfdrive/modeld/runners/thneedmodel.h" |
||||||
|
|
||||||
#include <cassert> |
#include "common/swaglog.h" |
||||||
|
|
||||||
ThneedModel::ThneedModel(const char *path, float *loutput, size_t loutput_size, int runtime, bool luse_extra, bool luse_tf8, cl_context context) { |
ThneedModel::ThneedModel(const std::string path, float *_output, size_t _output_size, int runtime, bool luse_tf8, cl_context context) { |
||||||
thneed = new Thneed(true, context); |
thneed = new Thneed(true, context); |
||||||
thneed->load(path); |
thneed->load(path.c_str()); |
||||||
thneed->clexec(); |
thneed->clexec(); |
||||||
|
|
||||||
recorded = false; |
recorded = false; |
||||||
output = loutput; |
output = _output; |
||||||
use_extra = luse_extra; |
|
||||||
} |
} |
||||||
|
|
||||||
void ThneedModel::addRecurrent(float *state, int state_size) { |
void* ThneedModel::getCLBuffer(const std::string name) { |
||||||
recurrent = state; |
int index = -1; |
||||||
} |
for (int i = 0; i < inputs.size(); i++) { |
||||||
|
if (name == inputs[i]->name) { |
||||||
void ThneedModel::addTrafficConvention(float *state, int state_size) { |
index = i; |
||||||
trafficConvention = state; |
break; |
||||||
} |
} |
||||||
|
} |
||||||
void ThneedModel::addDesire(float *state, int state_size) { |
|
||||||
desire = state; |
|
||||||
} |
|
||||||
|
|
||||||
void ThneedModel::addDrivingStyle(float *state, int state_size) { |
|
||||||
drivingStyle = state; |
|
||||||
} |
|
||||||
|
|
||||||
void ThneedModel::addNavFeatures(float *state, int state_size) { |
|
||||||
navFeatures = state; |
|
||||||
} |
|
||||||
|
|
||||||
void ThneedModel::addImage(float *image_input_buf, int buf_size) { |
|
||||||
input = image_input_buf; |
|
||||||
} |
|
||||||
|
|
||||||
void ThneedModel::addExtra(float *extra_input_buf, int buf_size) { |
|
||||||
extra = extra_input_buf; |
|
||||||
} |
|
||||||
|
|
||||||
void* ThneedModel::getInputBuf() { |
if (index == -1) { |
||||||
if (use_extra && thneed->input_clmem.size() > 5) return &(thneed->input_clmem[5]); |
LOGE("Tried to get CL buffer for input `%s` but no input with this name exists", name.c_str()); |
||||||
else if (!use_extra && thneed->input_clmem.size() > 4) return &(thneed->input_clmem[4]); |
assert(false); |
||||||
else return nullptr; |
} |
||||||
} |
|
||||||
|
|
||||||
void* ThneedModel::getExtraBuf() { |
if (thneed->input_clmem.size() >= inputs.size()) { |
||||||
if (thneed->input_clmem.size() > 4) return &(thneed->input_clmem[4]); |
return &thneed->input_clmem[inputs.size() - index - 1]; |
||||||
else return nullptr; |
} else { |
||||||
|
return nullptr; |
||||||
|
} |
||||||
} |
} |
||||||
|
|
||||||
void ThneedModel::execute() { |
void ThneedModel::execute() { |
||||||
if (!recorded) { |
if (!recorded) { |
||||||
thneed->record = true; |
thneed->record = true; |
||||||
if (use_extra) { |
float *input_buffers[inputs.size()]; |
||||||
float *inputs[6] = {recurrent, navFeatures, trafficConvention, desire, extra, input}; |
for (int i = 0; i < inputs.size(); i++) { |
||||||
thneed->copy_inputs(inputs); |
input_buffers[inputs.size() - i - 1] = inputs[i]->buffer; |
||||||
} else { |
|
||||||
float *inputs[5] = {recurrent, navFeatures, trafficConvention, desire, input}; |
|
||||||
thneed->copy_inputs(inputs); |
|
||||||
} |
} |
||||||
|
|
||||||
|
thneed->copy_inputs(input_buffers); |
||||||
thneed->clexec(); |
thneed->clexec(); |
||||||
thneed->copy_output(output); |
thneed->copy_output(output); |
||||||
thneed->stop(); |
thneed->stop(); |
||||||
|
|
||||||
recorded = true; |
recorded = true; |
||||||
} else { |
} else { |
||||||
if (use_extra) { |
float *input_buffers[inputs.size()]; |
||||||
float *inputs[6] = {recurrent, navFeatures, trafficConvention, desire, extra, input}; |
for (int i = 0; i < inputs.size(); i++) { |
||||||
thneed->execute(inputs, output); |
input_buffers[inputs.size() - i - 1] = inputs[i]->buffer; |
||||||
} else { |
|
||||||
float *inputs[5] = {recurrent, navFeatures, trafficConvention, desire, input}; |
|
||||||
thneed->execute(inputs, output); |
|
||||||
} |
} |
||||||
|
thneed->execute(input_buffers, output); |
||||||
} |
} |
||||||
} |
} |
||||||
|
Loading…
Reference in new issue