Refactor model runners (#28598)
	
		
	
				
					
				
			* Started work on model runner refactor
* Fixed some compile errors
* everything compiles
* Fixed bug in SNPEModel
* updateInput -> setInputBuffer
* I understand nothing
* whoops lol
* use std::string instead of char*
* Move common logic into RunModel
* formatting fix
old-commit-hash: c9f00678af
			
			
				beeps
			
			
		
							parent
							
								
									fdc8876745
								
							
						
					
					
						commit
						95051090a1
					
				
				 10 changed files with 190 additions and 414 deletions
			
			
		@ -1,18 +1,45 @@ | 
				
			||||
#pragma once | 
				
			||||
 | 
				
			||||
#include <string> | 
				
			||||
#include <vector> | 
				
			||||
#include <memory> | 
				
			||||
#include <cassert> | 
				
			||||
 | 
				
			||||
#include "common/clutil.h" | 
				
			||||
#include "common/swaglog.h" | 
				
			||||
 | 
				
			||||
struct ModelInput { | 
				
			||||
  const std::string name; | 
				
			||||
  float *buffer; | 
				
			||||
  int size; | 
				
			||||
 | 
				
			||||
  ModelInput(const std::string _name, float *_buffer, int _size) : name(_name), buffer(_buffer), size(_size) {} | 
				
			||||
  virtual void setBuffer(float *_buffer, int _size) { | 
				
			||||
    assert(size == _size || size == 0); | 
				
			||||
    buffer = _buffer; | 
				
			||||
    size = _size; | 
				
			||||
  } | 
				
			||||
}; | 
				
			||||
 | 
				
			||||
class RunModel { | 
				
			||||
public: | 
				
			||||
  std::vector<std::unique_ptr<ModelInput>> inputs; | 
				
			||||
 | 
				
			||||
  virtual ~RunModel() {} | 
				
			||||
  virtual void addRecurrent(float *state, int state_size) {} | 
				
			||||
  virtual void addDesire(float *state, int state_size) {} | 
				
			||||
  virtual void addNavFeatures(float *state, int state_size) {} | 
				
			||||
  virtual void addDrivingStyle(float *state, int state_size) {} | 
				
			||||
  virtual void addTrafficConvention(float *state, int state_size) {} | 
				
			||||
  virtual void addCalib(float *state, int state_size) {} | 
				
			||||
  virtual void addImage(float *image_buf, int buf_size) {} | 
				
			||||
  virtual void addExtra(float *image_buf, int buf_size) {} | 
				
			||||
  virtual void execute() {} | 
				
			||||
  virtual void* getInputBuf() { return nullptr; } | 
				
			||||
  virtual void* getExtraBuf() { return nullptr; } | 
				
			||||
}; | 
				
			||||
  virtual void* getCLBuffer(const std::string name) { return nullptr; } | 
				
			||||
 | 
				
			||||
  virtual void addInput(const std::string name, float *buffer, int size) { | 
				
			||||
    inputs.push_back(std::unique_ptr<ModelInput>(new ModelInput(name, buffer, size))); | 
				
			||||
  } | 
				
			||||
  virtual void setInputBuffer(const std::string name, float *buffer, int size) { | 
				
			||||
    for (auto &input : inputs) { | 
				
			||||
      if (name == input->name) { | 
				
			||||
        input->setBuffer(buffer, size); | 
				
			||||
        return; | 
				
			||||
      } | 
				
			||||
    } | 
				
			||||
    LOGE("Tried to update input `%s` but no input with this name exists", name.c_str()); | 
				
			||||
    assert(false); | 
				
			||||
  } | 
				
			||||
}; | 
				
			||||
 | 
				
			||||
@ -1,78 +1,56 @@ | 
				
			||||
#include "selfdrive/modeld/runners/thneedmodel.h" | 
				
			||||
 | 
				
			||||
#include <cassert> | 
				
			||||
#include "common/swaglog.h" | 
				
			||||
 | 
				
			||||
ThneedModel::ThneedModel(const char *path, float *loutput, size_t loutput_size, int runtime, bool luse_extra, bool luse_tf8, cl_context context) { | 
				
			||||
ThneedModel::ThneedModel(const std::string path, float *_output, size_t _output_size, int runtime, bool luse_tf8, cl_context context) { | 
				
			||||
  thneed = new Thneed(true, context); | 
				
			||||
  thneed->load(path); | 
				
			||||
  thneed->load(path.c_str()); | 
				
			||||
  thneed->clexec(); | 
				
			||||
 | 
				
			||||
  recorded = false; | 
				
			||||
  output = loutput; | 
				
			||||
  use_extra = luse_extra; | 
				
			||||
  output = _output; | 
				
			||||
} | 
				
			||||
 | 
				
			||||
void ThneedModel::addRecurrent(float *state, int state_size) { | 
				
			||||
  recurrent = state; | 
				
			||||
} | 
				
			||||
 | 
				
			||||
void ThneedModel::addTrafficConvention(float *state, int state_size) { | 
				
			||||
  trafficConvention = state; | 
				
			||||
} | 
				
			||||
 | 
				
			||||
void ThneedModel::addDesire(float *state, int state_size) { | 
				
			||||
  desire = state; | 
				
			||||
} | 
				
			||||
 | 
				
			||||
void ThneedModel::addDrivingStyle(float *state, int state_size) { | 
				
			||||
    drivingStyle = state; | 
				
			||||
} | 
				
			||||
 | 
				
			||||
void ThneedModel::addNavFeatures(float *state, int state_size) { | 
				
			||||
  navFeatures = state; | 
				
			||||
} | 
				
			||||
 | 
				
			||||
void ThneedModel::addImage(float *image_input_buf, int buf_size) { | 
				
			||||
  input = image_input_buf; | 
				
			||||
} | 
				
			||||
 | 
				
			||||
void ThneedModel::addExtra(float *extra_input_buf, int buf_size) { | 
				
			||||
  extra = extra_input_buf; | 
				
			||||
} | 
				
			||||
void* ThneedModel::getCLBuffer(const std::string name) { | 
				
			||||
  int index = -1; | 
				
			||||
  for (int i = 0; i < inputs.size(); i++) { | 
				
			||||
    if (name == inputs[i]->name) { | 
				
			||||
      index = i; | 
				
			||||
      break; | 
				
			||||
    } | 
				
			||||
  } | 
				
			||||
 | 
				
			||||
void* ThneedModel::getInputBuf() { | 
				
			||||
  if (use_extra && thneed->input_clmem.size() > 5) return &(thneed->input_clmem[5]); | 
				
			||||
  else if (!use_extra && thneed->input_clmem.size() > 4) return &(thneed->input_clmem[4]); | 
				
			||||
  else return nullptr; | 
				
			||||
} | 
				
			||||
  if (index == -1) { | 
				
			||||
    LOGE("Tried to get CL buffer for input `%s` but no input with this name exists", name.c_str()); | 
				
			||||
    assert(false); | 
				
			||||
  } | 
				
			||||
 | 
				
			||||
void* ThneedModel::getExtraBuf() { | 
				
			||||
  if (thneed->input_clmem.size() > 4) return &(thneed->input_clmem[4]); | 
				
			||||
  else return nullptr; | 
				
			||||
  if (thneed->input_clmem.size() >= inputs.size()) { | 
				
			||||
    return &thneed->input_clmem[inputs.size() - index - 1]; | 
				
			||||
  } else { | 
				
			||||
    return nullptr; | 
				
			||||
  } | 
				
			||||
} | 
				
			||||
 | 
				
			||||
void ThneedModel::execute() { | 
				
			||||
  if (!recorded) { | 
				
			||||
    thneed->record = true; | 
				
			||||
    if (use_extra) { | 
				
			||||
      float *inputs[6] = {recurrent, navFeatures, trafficConvention, desire, extra, input}; | 
				
			||||
      thneed->copy_inputs(inputs); | 
				
			||||
    } else { | 
				
			||||
      float *inputs[5] = {recurrent, navFeatures, trafficConvention, desire, input}; | 
				
			||||
      thneed->copy_inputs(inputs); | 
				
			||||
    float *input_buffers[inputs.size()]; | 
				
			||||
    for (int i = 0; i < inputs.size(); i++) { | 
				
			||||
      input_buffers[inputs.size() - i - 1] = inputs[i]->buffer; | 
				
			||||
    } | 
				
			||||
 | 
				
			||||
    thneed->copy_inputs(input_buffers); | 
				
			||||
    thneed->clexec(); | 
				
			||||
    thneed->copy_output(output); | 
				
			||||
    thneed->stop(); | 
				
			||||
 | 
				
			||||
    recorded = true; | 
				
			||||
  } else { | 
				
			||||
    if (use_extra) { | 
				
			||||
      float *inputs[6] = {recurrent, navFeatures, trafficConvention, desire, extra, input}; | 
				
			||||
      thneed->execute(inputs, output); | 
				
			||||
    } else { | 
				
			||||
      float *inputs[5] = {recurrent, navFeatures, trafficConvention, desire, input}; | 
				
			||||
      thneed->execute(inputs, output); | 
				
			||||
    float *input_buffers[inputs.size()]; | 
				
			||||
    for (int i = 0; i < inputs.size(); i++) { | 
				
			||||
      input_buffers[inputs.size() - i - 1] = inputs[i]->buffer; | 
				
			||||
    } | 
				
			||||
    thneed->execute(input_buffers, output); | 
				
			||||
  } | 
				
			||||
} | 
				
			||||
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue