Thneed load/save (#19700)
* start thneed load/save
* compiling
* fix loading
* build thneed model in scons
* don't hardcode /data/openpilot
* release files
* those too
* support for loading/saving binary kernels
* save binaries out of json band
* make binary a command line flag to the compiler
* need include assert
* fix shadowed common in SConscript
* cleanup run.h
* hmm, the recurrent buffer wasn't 0ed
* ugh, unique ptr
* remove power constraint, refactor record
* Revert "remove power constraint, refactor record"
This reverts commit bb6fa52db6df59cd9d6420a6f630430e35af8a5e.
* print on thneed stop
* fingers crossed for this one
* recorded
* just curious
* okay okay, pass tests?
* cleanups
* refactor wait
Co-authored-by: Comma Device <device@comma.ai>
Co-authored-by: Adeeb Shihadeh <adeebshihadeh@gmail.com>
old-commit-hash: 59fac9fdc6
commatwo_master
parent
5c2c6691b2
commit
3484683199
12 changed files with 561 additions and 97 deletions
@ -0,0 +1,41 @@ |
||||
#include "thneedmodel.h" |
||||
#include <assert.h> |
||||
|
||||
ThneedModel::ThneedModel(const char *path, float *loutput, size_t loutput_size, int runtime) { |
||||
thneed = new Thneed(true); |
||||
thneed->record = 0; |
||||
thneed->load(path); |
||||
thneed->clexec(); |
||||
thneed->find_inputs_outputs(); |
||||
|
||||
recorded = false; |
||||
output = loutput; |
||||
} |
||||
|
||||
void ThneedModel::addRecurrent(float *state, int state_size) { |
||||
recurrent = state; |
||||
} |
||||
|
||||
void ThneedModel::addTrafficConvention(float *state, int state_size) { |
||||
trafficConvention = state; |
||||
} |
||||
|
||||
void ThneedModel::addDesire(float *state, int state_size) { |
||||
desire = state; |
||||
} |
||||
|
||||
void ThneedModel::execute(float *net_input_buf, int buf_size) { |
||||
float *inputs[4] = {recurrent, trafficConvention, desire, net_input_buf}; |
||||
if (!recorded) { |
||||
thneed->record = THNEED_RECORD; |
||||
thneed->copy_inputs(inputs); |
||||
thneed->clexec(); |
||||
thneed->copy_output(output); |
||||
thneed->stop(); |
||||
|
||||
recorded = true; |
||||
} else { |
||||
thneed->execute(inputs, output); |
||||
} |
||||
} |
||||
|
@ -0,0 +1,24 @@ |
||||
#pragma once |
||||
|
||||
#include "runmodel.h" |
||||
#include "thneed/thneed.h" |
||||
|
||||
class ThneedModel : public RunModel { |
||||
public: |
||||
ThneedModel(const char *path, float *loutput, size_t loutput_size, int runtime); |
||||
void addRecurrent(float *state, int state_size); |
||||
void addTrafficConvention(float *state, int state_size); |
||||
void addDesire(float *state, int state_size); |
||||
void execute(float *net_input_buf, int buf_size); |
||||
private: |
||||
Thneed *thneed = NULL; |
||||
bool recorded; |
||||
|
||||
float *output; |
||||
|
||||
// recurrent and desire
|
||||
float *recurrent; |
||||
float *trafficConvention; |
||||
float *desire; |
||||
}; |
||||
|
@ -0,0 +1,34 @@ |
||||
#include <string.h> |
||||
#include "thneed.h" |
||||
#include "../runners/snpemodel.h" |
||||
|
||||
#define TEMPORAL_SIZE 512 |
||||
#define DESIRE_LEN 8 |
||||
#define TRAFFIC_CONVENTION_LEN 2 |
||||
|
||||
// TODO: This should probably use SNPE directly.
|
||||
int main(int argc, char* argv[]) { |
||||
#define OUTPUT_SIZE 0x10000 |
||||
float *output = (float*)calloc(OUTPUT_SIZE, sizeof(float)); |
||||
SNPEModel mdl(argv[1], output, 0, USE_GPU_RUNTIME); |
||||
|
||||
float state[TEMPORAL_SIZE] = {0}; |
||||
float desire[DESIRE_LEN] = {0}; |
||||
float traffic_convention[TRAFFIC_CONVENTION_LEN] = {0}; |
||||
float *input = (float*)calloc(0x1000000, sizeof(float));; |
||||
|
||||
mdl.addRecurrent(state, TEMPORAL_SIZE); |
||||
mdl.addDesire(desire, DESIRE_LEN); |
||||
mdl.addTrafficConvention(traffic_convention, TRAFFIC_CONVENTION_LEN); |
||||
|
||||
// first run
|
||||
printf("************** execute 1 **************\n"); |
||||
memset(output, 0, OUTPUT_SIZE * sizeof(float)); |
||||
mdl.execute(input, 0); |
||||
|
||||
// save model
|
||||
bool save_binaries = (argc > 3) && (strcmp(argv[3], "--binary") == 0); |
||||
mdl.thneed->save(argv[2], save_binaries); |
||||
return 0; |
||||
} |
||||
|
@ -0,0 +1,290 @@ |
||||
#include <set> |
||||
#include <assert.h> |
||||
#include "thneed.h" |
||||
#include "json11.hpp" |
||||
using namespace json11; |
||||
|
||||
extern map<cl_program, string> g_program_source; |
||||
|
||||
void Thneed::load(const char *filename) { |
||||
printf("Thneed::load: loading from %s\n", filename); |
||||
|
||||
FILE *f = fopen(filename, "rb"); |
||||
fseek(f, 0L, SEEK_END); |
||||
int sz = ftell(f); |
||||
fseek(f, 0L, SEEK_SET); |
||||
char *buf = (char*)malloc(sz); |
||||
fread(buf, 1, sz, f); |
||||
fclose(f); |
||||
|
||||
int jsz = *(int *)buf; |
||||
string jj(buf+4, jsz); |
||||
string err; |
||||
Json jdat = Json::parse(jj, err); |
||||
|
||||
map<cl_mem, cl_mem> real_mem; |
||||
real_mem[NULL] = NULL; |
||||
|
||||
int ptr = 4+jsz; |
||||
for (auto &obj : jdat["objects"].array_items()) { |
||||
auto mobj = obj.object_items(); |
||||
int sz = mobj["size"].int_value(); |
||||
cl_mem clbuf = NULL; |
||||
|
||||
if (mobj["buffer_id"].string_value().size() > 0) { |
||||
// image buffer must already be allocated
|
||||
clbuf = real_mem[*(cl_mem*)(mobj["buffer_id"].string_value().data())]; |
||||
assert(mobj["needs_load"].bool_value() == false); |
||||
} else { |
||||
if (mobj["needs_load"].bool_value()) { |
||||
//printf("loading %p %d @ 0x%X\n", clbuf, sz, ptr);
|
||||
clbuf = clCreateBuffer(context, CL_MEM_COPY_HOST_PTR | CL_MEM_READ_WRITE, sz, &buf[ptr], NULL); |
||||
ptr += sz; |
||||
} else { |
||||
clbuf = clCreateBuffer(context, CL_MEM_READ_WRITE, sz, NULL, NULL); |
||||
} |
||||
} |
||||
assert(clbuf != NULL); |
||||
|
||||
if (mobj["arg_type"] == "image2d_t" || mobj["arg_type"] == "image1d_t") { |
||||
cl_image_desc desc = {0}; |
||||
desc.image_type = (mobj["arg_type"] == "image2d_t") ? CL_MEM_OBJECT_IMAGE2D : CL_MEM_OBJECT_IMAGE1D_BUFFER; |
||||
desc.image_width = mobj["width"].int_value(); |
||||
desc.image_height = mobj["height"].int_value(); |
||||
desc.image_row_pitch = mobj["row_pitch"].int_value(); |
||||
desc.buffer = clbuf; |
||||
|
||||
cl_image_format format; |
||||
format.image_channel_order = CL_RGBA; |
||||
format.image_channel_data_type = CL_HALF_FLOAT; |
||||
|
||||
clbuf = clCreateImage(context, CL_MEM_READ_WRITE, &format, &desc, NULL, NULL); |
||||
assert(clbuf != NULL); |
||||
} |
||||
|
||||
real_mem[*(cl_mem*)(mobj["id"].string_value().data())] = clbuf; |
||||
} |
||||
|
||||
map<string, cl_program> g_programs; |
||||
for (auto &obj : jdat["programs"].object_items()) { |
||||
const char *srcs[1]; |
||||
srcs[0] = (const char *)obj.second.string_value().c_str(); |
||||
size_t length = obj.second.string_value().size(); |
||||
|
||||
if (record & THNEED_DEBUG) printf("building %s with size %zu\n", obj.first.c_str(), length); |
||||
|
||||
cl_program program = clCreateProgramWithSource(context, 1, srcs, &length, NULL); |
||||
int err = clBuildProgram(program, 1, &device_id, "", NULL, NULL); |
||||
if (err != 0) { |
||||
printf("got err %d\n", err); |
||||
size_t length; |
||||
char buffer[2048]; |
||||
clGetProgramBuildInfo(program, device_id, CL_PROGRAM_BUILD_LOG, sizeof(buffer), buffer, &length); |
||||
buffer[length] = '\0'; |
||||
printf("%s\n", buffer); |
||||
} |
||||
assert(err == 0); |
||||
|
||||
g_programs[obj.first] = program; |
||||
} |
||||
|
||||
for (auto &obj : jdat["binaries"].array_items()) { |
||||
string name = obj["name"].string_value(); |
||||
size_t length = obj["length"].int_value(); |
||||
const unsigned char *srcs[1]; |
||||
srcs[0] = (const unsigned char *)&buf[ptr]; |
||||
ptr += length; |
||||
|
||||
if (record & THNEED_DEBUG) printf("binary %s with size %zu\n", name.c_str(), length); |
||||
|
||||
cl_int err; |
||||
cl_program program = clCreateProgramWithBinary(context, 1, &device_id, &length, srcs, NULL, &err); |
||||
assert(program != NULL && err == CL_SUCCESS); |
||||
err = clBuildProgram(program, 1, &device_id, "", NULL, NULL); |
||||
assert(err == CL_SUCCESS); |
||||
|
||||
g_programs[name] = program; |
||||
} |
||||
|
||||
for (auto &obj : jdat["kernels"].array_items()) { |
||||
auto gws = obj["global_work_size"]; |
||||
auto lws = obj["local_work_size"]; |
||||
auto kk = shared_ptr<CLQueuedKernel>(new CLQueuedKernel(this)); |
||||
|
||||
kk->name = obj["name"].string_value(); |
||||
kk->program = g_programs[kk->name]; |
||||
kk->work_dim = obj["work_dim"].int_value(); |
||||
for (int i = 0; i < kk->work_dim; i++) { |
||||
kk->global_work_size[i] = gws[i].int_value(); |
||||
kk->local_work_size[i] = lws[i].int_value(); |
||||
} |
||||
kk->num_args = obj["num_args"].int_value(); |
||||
for (int i = 0; i < kk->num_args; i++) { |
||||
string arg = obj["args"].array_items()[i].string_value(); |
||||
int arg_size = obj["args_size"].array_items()[i].int_value(); |
||||
kk->args_size.push_back(arg_size); |
||||
if (arg_size == 8) { |
||||
cl_mem val = *(cl_mem*)(arg.data()); |
||||
val = real_mem[val]; |
||||
kk->args.push_back(string((char*)&val, sizeof(val))); |
||||
} else { |
||||
kk->args.push_back(arg); |
||||
} |
||||
} |
||||
kq.push_back(kk); |
||||
} |
||||
|
||||
free(buf); |
||||
clFinish(command_queue); |
||||
} |
||||
|
||||
void Thneed::save(const char *filename, bool save_binaries) { |
||||
printf("Thneed::save: saving to %s\n", filename); |
||||
|
||||
// get kernels
|
||||
std::vector<Json> kernels; |
||||
std::set<string> saved_objects; |
||||
std::vector<Json> objects; |
||||
std::map<string, string> programs; |
||||
std::map<string, string> binaries; |
||||
|
||||
for (auto &k : kq) { |
||||
kernels.push_back(k->to_json()); |
||||
|
||||
// check args for objects
|
||||
int i = 0; |
||||
for (auto &a : k->args) { |
||||
if (a.size() == 8) { |
||||
if (saved_objects.find(a) == saved_objects.end()) { |
||||
saved_objects.insert(a); |
||||
cl_mem val = *(cl_mem*)(a.data()); |
||||
if (val != NULL) { |
||||
bool needs_load = k->arg_names[i] == "weights" || k->arg_names[i] == "biases"; |
||||
|
||||
auto jj = Json::object({ |
||||
{"id", a}, |
||||
{"arg_type", k->arg_types[i]}, |
||||
}); |
||||
|
||||
if (k->arg_types[i] == "image2d_t" || k->arg_types[i] == "image1d_t") { |
||||
cl_mem buf; |
||||
clGetImageInfo(val, CL_IMAGE_BUFFER, sizeof(buf), &buf, NULL); |
||||
string aa = string((char *)&buf, sizeof(buf)); |
||||
jj["buffer_id"] = aa; |
||||
|
||||
size_t width, height, row_pitch; |
||||
clGetImageInfo(val, CL_IMAGE_WIDTH, sizeof(width), &width, NULL); |
||||
clGetImageInfo(val, CL_IMAGE_HEIGHT, sizeof(height), &height, NULL); |
||||
clGetImageInfo(val, CL_IMAGE_ROW_PITCH, sizeof(row_pitch), &row_pitch, NULL); |
||||
jj["width"] = (int)width; |
||||
jj["height"] = (int)height; |
||||
jj["row_pitch"] = (int)row_pitch; |
||||
jj["size"] = (int)(height * row_pitch); |
||||
jj["needs_load"] = false; |
||||
|
||||
if (saved_objects.find(aa) == saved_objects.end()) { |
||||
saved_objects.insert(aa); |
||||
size_t sz; |
||||
clGetMemObjectInfo(buf, CL_MEM_SIZE, sizeof(sz), &sz, NULL); |
||||
// save the buffer
|
||||
objects.push_back(Json::object({ |
||||
{"id", aa}, |
||||
{"arg_type", "<image buffer>"}, |
||||
{"needs_load", needs_load}, |
||||
{"size", (int)sz} |
||||
})); |
||||
if (needs_load) assert(sz == height * row_pitch); |
||||
} |
||||
} else { |
||||
size_t sz = 0; |
||||
clGetMemObjectInfo(val, CL_MEM_SIZE, sizeof(sz), &sz, NULL); |
||||
jj["size"] = (int)sz; |
||||
jj["needs_load"] = needs_load; |
||||
} |
||||
|
||||
objects.push_back(jj); |
||||
} |
||||
} |
||||
} |
||||
i++; |
||||
} |
||||
|
||||
if (save_binaries) { |
||||
int err; |
||||
size_t binary_size = 0; |
||||
err = clGetProgramInfo(k->program, CL_PROGRAM_BINARY_SIZES, sizeof(binary_size), &binary_size, NULL); |
||||
assert(err == 0); |
||||
assert(binary_size > 0); |
||||
string sv(binary_size, '\x00'); |
||||
|
||||
uint8_t* bufs[1] = { (uint8_t*)sv.data(), }; |
||||
err = clGetProgramInfo(k->program, CL_PROGRAM_BINARIES, sizeof(bufs), &bufs, NULL); |
||||
assert(err == 0); |
||||
|
||||
binaries[k->name] = sv; |
||||
} else { |
||||
programs[k->name] = g_program_source[k->program]; |
||||
} |
||||
} |
||||
|
||||
vector<string> saved_buffers; |
||||
for (auto &obj : objects) { |
||||
auto mobj = obj.object_items(); |
||||
cl_mem val = *(cl_mem*)(mobj["id"].string_value().data()); |
||||
int sz = mobj["size"].int_value(); |
||||
if (mobj["needs_load"].bool_value()) { |
||||
char *buf = (char *)malloc(sz); |
||||
if (mobj["arg_type"] == "image2d_t" || mobj["arg_type"] == "image1d_t") { |
||||
assert(false); |
||||
} else { |
||||
// buffers alloced with CL_MEM_HOST_WRITE_ONLY, hence this hack
|
||||
//hexdump((uint32_t*)val, 0x100);
|
||||
|
||||
// the worst hack in thneed, the flags are at 0x14
|
||||
((uint32_t*)val)[0x14] &= ~CL_MEM_HOST_WRITE_ONLY; |
||||
cl_int ret = clEnqueueReadBuffer(command_queue, val, CL_TRUE, 0, sz, buf, 0, NULL, NULL); |
||||
assert(ret == CL_SUCCESS); |
||||
} |
||||
//printf("saving buffer: %d %p %s\n", sz, buf, mobj["arg_type"].string_value().c_str());
|
||||
saved_buffers.push_back(string(buf, sz)); |
||||
free(buf); |
||||
} |
||||
} |
||||
|
||||
std::vector<Json> jbinaries; |
||||
for (auto &obj : binaries) { |
||||
jbinaries.push_back(Json::object({{"name", obj.first}, {"length", (int)obj.second.size()}})); |
||||
saved_buffers.push_back(obj.second); |
||||
} |
||||
|
||||
Json jdat = Json::object({ |
||||
{"kernels", kernels}, |
||||
{"objects", objects}, |
||||
{"programs", programs}, |
||||
{"binaries", jbinaries}, |
||||
}); |
||||
|
||||
string str = jdat.dump(); |
||||
int jsz = str.length(); |
||||
|
||||
FILE *f = fopen(filename, "wb"); |
||||
fwrite(&jsz, 1, sizeof(jsz), f); |
||||
fwrite(str.data(), 1, jsz, f); |
||||
for (auto &s : saved_buffers) { |
||||
fwrite(s.data(), 1, s.length(), f); |
||||
} |
||||
fclose(f); |
||||
} |
||||
|
||||
Json CLQueuedKernel::to_json() const { |
||||
return Json::object { |
||||
{ "name", name }, |
||||
{ "work_dim", (int)work_dim }, |
||||
{ "global_work_size", Json::array { (int)global_work_size[0], (int)global_work_size[1], (int)global_work_size[2] } }, |
||||
{ "local_work_size", Json::array { (int)local_work_size[0], (int)local_work_size[1], (int)local_work_size[2] } }, |
||||
{ "num_args", (int)num_args }, |
||||
{ "args", args }, |
||||
{ "args_size", args_size }, |
||||
}; |
||||
} |
||||
|
Loading…
Reference in new issue