You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					118 lines
				
				2.9 KiB
			
		
		
			
		
	
	
					118 lines
				
				2.9 KiB
			| 
											5 years ago
										 | #include "onnxmodel.h"
 | ||
| 
											6 years ago
										 | #include <stdio.h>
 | ||
| 
											6 years ago
										 | #include <string>
 | ||
|  | #include <string.h>
 | ||
| 
											5 years ago
										 | #include <poll.h>
 | ||
| 
											6 years ago
										 | #include <signal.h>
 | ||
|  | #include <unistd.h>
 | ||
| 
											6 years ago
										 | #include <stdlib.h>
 | ||
|  | #include <stdexcept>
 | ||
|  | #include "common/util.h"
 | ||
| 
											6 years ago
										 | #include "common/utilpp.h"
 | ||
| 
											6 years ago
										 | #include "common/swaglog.h"
 | ||
|  | #include <cassert>
 | ||
|  | 
 | ||
|  | 
 | ||
| 
											5 years ago
										 | ONNXModel::ONNXModel(const char *path, float *_output, size_t _output_size, int runtime) {
 | ||
| 
											6 years ago
										 |   output = _output;
 | ||
|  |   output_size = _output_size;
 | ||
| 
											6 years ago
										 | 
 | ||
| 
											6 years ago
										 |   char tmp[1024];
 | ||
|  |   strncpy(tmp, path, sizeof(tmp));
 | ||
|  |   strstr(tmp, ".dlc")[0] = '\0';
 | ||
| 
											5 years ago
										 |   strcat(tmp, ".onnx");
 | ||
| 
											6 years ago
										 |   LOGD("loading model %s", tmp);
 | ||
|  | 
 | ||
|  |   assert(pipe(pipein) == 0);
 | ||
|  |   assert(pipe(pipeout) == 0);
 | ||
|  | 
 | ||
|  |   std::string exe_dir = util::dir_name(util::readlink("/proc/self/exe"));
 | ||
| 
											5 years ago
										 |   std::string onnx_runner = exe_dir + "/runners/onnx_runner.py";
 | ||
| 
											6 years ago
										 | 
 | ||
|  |   proc_pid = fork();
 | ||
|  |   if (proc_pid == 0) {
 | ||
| 
											5 years ago
										 |     LOGD("spawning onnx process %s", onnx_runner.c_str());
 | ||
|  |     char *argv[] = {(char*)onnx_runner.c_str(), tmp, NULL};
 | ||
| 
											6 years ago
										 |     dup2(pipein[0], 0);
 | ||
|  |     dup2(pipeout[1], 1);
 | ||
|  |     close(pipein[0]);
 | ||
|  |     close(pipein[1]);
 | ||
|  |     close(pipeout[0]);
 | ||
|  |     close(pipeout[1]);
 | ||
| 
											5 years ago
										 |     execvp(onnx_runner.c_str(), argv);
 | ||
| 
											6 years ago
										 |   }
 | ||
|  | 
 | ||
| 
											6 years ago
										 |   // parent
 | ||
|  |   close(pipein[0]);
 | ||
|  |   close(pipeout[1]);
 | ||
|  | }
 | ||
| 
											6 years ago
										 | 
 | ||
| 
											5 years ago
										 | ONNXModel::~ONNXModel() {
 | ||
| 
											6 years ago
										 |   close(pipein[1]);
 | ||
|  |   close(pipeout[0]);
 | ||
|  |   kill(proc_pid, SIGTERM);
 | ||
|  | }
 | ||
| 
											6 years ago
										 | 
 | ||
| 
											5 years ago
										 | void ONNXModel::pwrite(float *buf, int size) {
 | ||
| 
											6 years ago
										 |   char *cbuf = (char *)buf;
 | ||
|  |   int tw = size*sizeof(float);
 | ||
|  |   while (tw > 0) {
 | ||
|  |     int err = write(pipein[1], cbuf, tw);
 | ||
|  |     //printf("host write %d\n", err);
 | ||
|  |     assert(err >= 0);
 | ||
|  |     cbuf += err;
 | ||
|  |     tw -= err;
 | ||
| 
											6 years ago
										 |   }
 | ||
| 
											5 years ago
										 |   LOGD("host write of size %d done", size);
 | ||
| 
											6 years ago
										 | }
 | ||
|  | 
 | ||
| 
											5 years ago
										 | void ONNXModel::pread(float *buf, int size) {
 | ||
| 
											6 years ago
										 |   char *cbuf = (char *)buf;
 | ||
|  |   int tr = size*sizeof(float);
 | ||
| 
											5 years ago
										 |   struct pollfd fds[1];
 | ||
|  |   fds[0].fd = pipeout[0];
 | ||
|  |   fds[0].events = POLLIN;
 | ||
| 
											6 years ago
										 |   while (tr > 0) {
 | ||
| 
											5 years ago
										 |     int err;
 | ||
|  |     err = poll(fds, 1, 10000);  // 10 second timeout
 | ||
|  |     assert(err == 1 || (err == -1 && errno == EINTR));
 | ||
|  |     LOGD("host read remaining %d/%d poll %d", tr, size*sizeof(float), err);
 | ||
|  |     err = read(pipeout[0], cbuf, tr);
 | ||
|  |     assert(err > 0 || (err == 0 && errno == EINTR));
 | ||
| 
											6 years ago
										 |     cbuf += err;
 | ||
|  |     tr -= err;
 | ||
|  |   }
 | ||
| 
											5 years ago
										 |   LOGD("host read done");
 | ||
| 
											6 years ago
										 | }
 | ||
|  | 
 | ||
| 
											5 years ago
										 | void ONNXModel::addRecurrent(float *state, int state_size) {
 | ||
| 
											6 years ago
										 |   rnn_input_buf = state;
 | ||
| 
											6 years ago
										 |   rnn_state_size = state_size;
 | ||
| 
											6 years ago
										 | }
 | ||
|  | 
 | ||
| 
											5 years ago
										 | void ONNXModel::addDesire(float *state, int state_size) {
 | ||
| 
											6 years ago
										 |   desire_input_buf = state;
 | ||
| 
											6 years ago
										 |   desire_state_size = state_size;
 | ||
| 
											6 years ago
										 | }
 | ||
|  | 
 | ||
| 
											5 years ago
										 | void ONNXModel::addTrafficConvention(float *state, int state_size) {
 | ||
| 
											6 years ago
										 |   traffic_convention_input_buf = state;
 | ||
|  |   traffic_convention_size = state_size;
 | ||
|  | }
 | ||
|  | 
 | ||
| 
											5 years ago
										 | void ONNXModel::execute(float *net_input_buf, int buf_size) {
 | ||
| 
											6 years ago
										 |   // order must be this
 | ||
|  |   pwrite(net_input_buf, buf_size);
 | ||
| 
											6 years ago
										 |   if (desire_input_buf != NULL) {
 | ||
|  |     pwrite(desire_input_buf, desire_state_size);
 | ||
|  |   }
 | ||
| 
											6 years ago
										 |   if (traffic_convention_input_buf != NULL) {
 | ||
|  |     pwrite(traffic_convention_input_buf, traffic_convention_size);
 | ||
|  |   }
 | ||
| 
											5 years ago
										 |   if (rnn_input_buf != NULL) {
 | ||
|  |     pwrite(rnn_input_buf, rnn_state_size);
 | ||
|  |   }
 | ||
| 
											6 years ago
										 |   pread(output, output_size);
 | ||
| 
											6 years ago
										 | }
 | ||
|  | 
 |