diff --git a/pyextra/acados_template/acados_ocp_solver_pyx.pyx b/pyextra/acados_template/acados_ocp_solver_pyx.pyx index 0fd4ef7262..06669b2936 100644 --- a/pyextra/acados_template/acados_ocp_solver_pyx.pyx +++ b/pyextra/acados_template/acados_ocp_solver_pyx.pyx @@ -37,12 +37,7 @@ # cimport cython - from libc cimport string -from libc cimport bool - -cimport numpy as np -from cpython cimport array cimport acados_solver_common cimport acados_solver @@ -50,11 +45,8 @@ cimport acados_solver import os import json import numpy as np -import array from datetime import datetime -# from .utils import np_array_to_list - cdef class AcadosOcpSolverFast: """ @@ -102,9 +94,7 @@ cdef class AcadosOcpSolverFast: """ Solve the ocp with current input. """ - - status = acados_solver.acados_solve(self.capsule) - return status + return acados_solver.acados_solve(self.capsule) def get_slice(self, int start_stage_, int end_stage_, str field_): field = field_.encode('utf-8') @@ -113,7 +103,7 @@ cdef class AcadosOcpSolverFast: self.fill_in_slice(start_stage_, end_stage_, field_, out) return out - def fill_in_slice(self, int start_stage_, int end_stage_, str field_, np.ndarray[np.float64_t, ndim=2, mode='c'] arr_): + def fill_in_slice(self, int start_stage_, int end_stage_, str field_, double[:,:] arr): out_fields = ['x', 'u', 'z', 'pi', 'lam', 't'] mem_fields = ['sl', 'su'] @@ -129,17 +119,15 @@ cdef class AcadosOcpSolverFast: if start_stage_ < 0 or end_stage_ > self.N + 1: raise Exception('AcadosOcpSolver.get_slice(): stage index must be in [0, N], got: {}.'.format(self.N)) - cdef np.ndarray[np.float64_t, ndim=2, mode='c'] arr = np.ascontiguousarray(arr_, dtype=np.double) - if (field_ in out_fields): - acados_solver_common.ocp_nlp_out_get_slice(self.nlp_config, self.nlp_dims, self.nlp_out, start_stage_, end_stage_, - field, arr.data) + acados_solver_common.ocp_nlp_out_get_slice(self.nlp_config, self.nlp_dims, self.nlp_out, + start_stage_, end_stage_, field, &arr[0][0]) elif field_ in mem_fields: raise NotImplementedError() # acados_solver_common.ocp_nlp_get_at_stage(self.nlp_config, self.nlp_dims, self.nlp_solver, start_stage_, end_stage_, # field, arr.data) - def get(self, stage_, field_): + def get(self, int stage_, field_): return self.get_slice(stage_, stage_ + 1, field_)[0] @@ -305,14 +293,14 @@ cdef class AcadosOcpSolverFast: # compute cost internally acados_solver_common.ocp_nlp_eval_cost(self.nlp_solver, self.nlp_in, self.nlp_out) - # create output array - out = np.ascontiguousarray(np.zeros((1,)), dtype=np.float64) + # create output + cdef double out # call getter field = "cost_value".encode('utf-8') - acados_solver_common.ocp_nlp_get(self.nlp_config, self.nlp_solver, field, out.data) + acados_solver_common.ocp_nlp_get(self.nlp_config, self.nlp_solver, field, &out) - return out[0] + return out def get_residuals(self): @@ -387,9 +375,9 @@ cdef class AcadosOcpSolverFast: msg += 'with dimension {} (you have {})'.format(dims, value_.shape) raise Exception(msg) - cdef np.ndarray[np.float64_t, ndim=1, mode='c'] value = np.ascontiguousarray(value_, dtype=np.double) + cdef double[:] value = np.ascontiguousarray(value_, dtype=np.double) - value_data_p = value.data + value_data_p = &value[0] if field_ in constraints_fields: acados_solver_common.ocp_nlp_constraints_model_set(self.nlp_config, self.nlp_dims, self.nlp_in, stage, field, value_data_p) elif field_ in cost_fields: @@ -401,14 +389,13 @@ cdef class AcadosOcpSolverFast: return - def set_param(self, stage_, np.ndarray[np.float64_t, ndim=1] value_): - cdef np.ndarray[np.float64_t, ndim=1, mode='c'] value = np.ascontiguousarray(value_, dtype=np.double) - acados_solver.acados_update_params(self.capsule, stage_, value_.data, value_.shape[0]) + def set_param(self, int stage_, double[:] value): + acados_solver.acados_update_params(self.capsule, stage_, &value[0], value.shape[0]) - def cost_set(self, start_stage_, field_, value_, api='warn'): + def cost_set(self, int start_stage_, field_, value_, api='warn'): self.cost_set_slice(start_stage_, start_stage_+1, field_, value_[None], api='warn') - def cost_set_slice(self, start_stage_, end_stage_, field_, value_, api='warn'): + def cost_set_slice(self, int start_stage_, int end_stage_, field_, value_, api='warn'): """ Set numerical data in the cost module of the solver. @@ -424,17 +411,17 @@ cdef class AcadosOcpSolverFast: dim = value_.shape[1] value_ = value_[None,:,:] - cdef np.ndarray[np.float64_t, ndim=3, mode='c'] value = np.ascontiguousarray(value_, dtype=np.double) + cdef double[:,:,:] value = np.ascontiguousarray(value_, dtype=np.double) - acados_solver_common.ocp_nlp_cost_model_set_slice(self.nlp_config, self.nlp_dims, self.nlp_in, start_stage_, end_stage_, - field, value.data, dim) + acados_solver_common.ocp_nlp_cost_model_set_slice(self.nlp_config, self.nlp_dims, self.nlp_in, + start_stage_, end_stage_, field, &value[0][0][0], dim) - def constraints_set(self, start_stage_, field_, value_, api='warn'): + def constraints_set(self, int start_stage_, field_, value_, api='warn'): self.constraints_set_slice(start_stage_, start_stage_+1, field_, value_[None], api='warn') - def constraints_set_slice(self, start_stage_, end_stage_, field_, value_, api='warn'): + def constraints_set_slice(self, int start_stage_, int end_stage_, field_, value_, api='warn'): """ Set numerical data in the constraint module of the solver. @@ -450,10 +437,9 @@ cdef class AcadosOcpSolverFast: dim = value_.shape[1] value_ = value_[None,:,:] - cdef np.ndarray[np.float64_t, ndim=3, mode='c'] value = np.ascontiguousarray(value_, dtype=np.double) - - acados_solver_common.ocp_nlp_constraints_model_set_slice(self.nlp_config, self.nlp_dims, self.nlp_in, start_stage_, end_stage_, - field, value.data, dim) + cdef double[:,:,:] value = np.ascontiguousarray(value_, dtype=np.double) + acados_solver_common.ocp_nlp_constraints_model_set_slice(self.nlp_config, self.nlp_dims, self.nlp_in, + start_stage_, end_stage_, field, &value[0][0][0], dim) def dynamics_get(self, int stage, field_): @@ -538,8 +524,3 @@ cdef class AcadosOcpSolverFast: if self.solver_created: acados_solver.acados_free(self.capsule) acados_solver.acados_free_capsule(self.capsule) - - # try: - # self.dlclose(self.shared_lib._handle) - # except: - # pass