You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
50 lines
1.2 KiB
50 lines
1.2 KiB
4 months ago
|
#pragma once
|
||
|
|
||
|
#include <string>
|
||
|
#include <vector>
|
||
|
#include <memory>
|
||
|
#include <cassert>
|
||
|
|
||
|
#include "common/clutil.h"
|
||
|
#include "common/swaglog.h"
|
||
|
|
||
|
#define USE_CPU_RUNTIME 0
|
||
|
#define USE_GPU_RUNTIME 1
|
||
|
#define USE_DSP_RUNTIME 2
|
||
|
|
||
|
struct ModelInput {
|
||
|
const std::string name;
|
||
|
float *buffer;
|
||
|
int size;
|
||
|
|
||
|
ModelInput(const std::string _name, float *_buffer, int _size) : name(_name), buffer(_buffer), size(_size) {}
|
||
|
virtual void setBuffer(float *_buffer, int _size) {
|
||
|
assert(size == _size || size == 0);
|
||
|
buffer = _buffer;
|
||
|
size = _size;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
class RunModel {
|
||
|
public:
|
||
|
std::vector<std::unique_ptr<ModelInput>> inputs;
|
||
|
|
||
|
virtual ~RunModel() {}
|
||
|
virtual void execute() {}
|
||
|
virtual void* getCLBuffer(const std::string name) { return nullptr; }
|
||
|
|
||
|
virtual void addInput(const std::string name, float *buffer, int size) {
|
||
|
inputs.push_back(std::unique_ptr<ModelInput>(new ModelInput(name, buffer, size)));
|
||
|
}
|
||
|
virtual void setInputBuffer(const std::string name, float *buffer, int size) {
|
||
|
for (auto &input : inputs) {
|
||
|
if (name == input->name) {
|
||
|
input->setBuffer(buffer, size);
|
||
|
return;
|
||
|
}
|
||
|
}
|
||
|
LOGE("Tried to update input `%s` but no input with this name exists", name.c_str());
|
||
|
assert(false);
|
||
|
}
|
||
|
};
|