@ -9,8 +9,9 @@ void PrintErrorStringAndExit() {
std : : exit ( EXIT_FAILURE ) ;
}
SNPEModel : : SNPEModel ( const char * path , float * loutput , size_t output_size , int runtime ) {
SNPEModel : : SNPEModel ( const char * path , float * loutput , size_t l output_size, int runtime ) {
output = loutput ;
output_size = loutput_size ;
# ifdef QCOM
if ( runtime = = USE_GPU_RUNTIME ) {
Runtime = zdl : : DlSystem : : Runtime_t : : GPU ;
@ -102,6 +103,7 @@ SNPEModel::SNPEModel(const char *path, float *loutput, size_t output_size, int r
void SNPEModel : : addRecurrent ( float * state , int state_size ) {
recurrent = state ;
recurrent_size = state_size ;
recurrentBuffer = this - > addExtra ( state , state_size , 3 ) ;
}
@ -134,21 +136,37 @@ std::unique_ptr<zdl::DlSystem::IUserBuffer> SNPEModel::addExtra(float *state, in
void SNPEModel : : execute ( float * net_input_buf , int buf_size ) {
# ifdef USE_THNEED
if ( Runtime = = zdl : : DlSystem : : Runtime_t : : GPU ) {
float * inputs [ 4 ] = { recurrent , trafficConvention , desire , net_input_buf } ;
if ( thneed = = NULL ) {
assert ( inputBuffer - > setBufferAddress ( net_input_buf ) ) ;
if ( ! snpe - > execute ( inputMap , outputMap ) ) {
PrintErrorStringAndExit ( ) ;
}
memset ( recurrent , 0 , recurrent_size * sizeof ( float ) ) ;
thneed = new Thneed ( ) ;
//thneed->record = 3;
if ( ! snpe - > execute ( inputMap , outputMap ) ) {
PrintErrorStringAndExit ( ) ;
}
thneed - > stop ( ) ;
//thneed->record = 2;
printf ( " thneed cached \n " ) ;
// doing self test
float * outputs_golden = ( float * ) malloc ( output_size * sizeof ( float ) ) ;
memcpy ( outputs_golden , output , output_size * sizeof ( float ) ) ;
memset ( output , 0 , output_size * sizeof ( float ) ) ;
memset ( recurrent , 0 , recurrent_size * sizeof ( float ) ) ;
thneed - > execute ( inputs , output ) ;
if ( memcmp ( output , outputs_golden , output_size * sizeof ( float ) ) = = 0 ) {
printf ( " thneed selftest passed \n " ) ;
} else {
for ( int i = 0 ; i < output_size ; i + + ) {
printf ( " mismatch %3d: %f %f \n " , i , output [ i ] , outputs_golden [ i ] ) ;
}
assert ( false ) ;
}
free ( outputs_golden ) ;
} else {
float * inputs [ 4 ] = { recurrent , trafficConvention , desire , net_input_buf } ;
thneed - > execute ( inputs , output ) ;
}
} else {