from tinygrad.runtime.ops_dsp import DSPDevice

kernel = """__attribute__((noinline)) void r_64_4_4_64_4_4_4(float* restrict __attribute__((align_value(128))) data0, const float* restrict __attribute__((align_value(128))) data1, const float* restrict __attribute__((align_value(128))) data2, const float* restrict __attribute__((align_value(128))) data3) {
  for (int ridx0 = 0; ridx0 < 64; ridx0++) {
    int alu0 = (ridx0*4096);
    for (int ridx1 = 0; ridx1 < 4; ridx1++) {
      int alu1 = (ridx1*64);
      for (int ridx2 = 0; ridx2 < 4; ridx2++) {
        int alu2 = (ridx2*4);
        int alu3 = ((ridx0*1024)+alu1+alu2);
        int alu4 = (alu1+alu2);
        float val0 = data3[alu4+1];
        float val1 = data3[alu4+2];
        float val2 = data3[alu4+3];
        float val3 = data3[alu4+16];
        float val4 = data3[alu4+17];
        float val5 = data3[alu4+18];
        float val6 = data3[alu4+19];
        float val7 = data3[alu4+32];
        float val8 = data3[alu4+33];
        float val9 = data3[alu4+34];
        float val10 = data3[alu4+35];
        float val11 = data3[alu4+48];
        float val12 = data3[alu4+49];
        float val13 = data3[alu4+50];
        float val14 = data3[alu4+51];
        float val15 = data3[alu4];
        float acc0 = 0.0f;
        float acc1 = 0.0f;
        float acc2 = 0.0f;
        float acc3 = 0.0f;
        float acc4 = 0.0f;
        float acc5 = 0.0f;
        float acc6 = 0.0f;
        float acc7 = 0.0f;
        float acc8 = 0.0f;
        float acc9 = 0.0f;
        float acc10 = 0.0f;
        float acc11 = 0.0f;
        float acc12 = 0.0f;
        float acc13 = 0.0f;
        float acc14 = 0.0f;
        float acc15 = 0.0f;
        float acc16 = 0.0f;
        float acc17 = 0.0f;
        float acc18 = 0.0f;
        float acc19 = 0.0f;
        float acc20 = 0.0f;
        float acc21 = 0.0f;
        float acc22 = 0.0f;
        float acc23 = 0.0f;
        float acc24 = 0.0f;
        float acc25 = 0.0f;
        float acc26 = 0.0f;
        float acc27 = 0.0f;
        float acc28 = 0.0f;
        float acc29 = 0.0f;
        float acc30 = 0.0f;
        float acc31 = 0.0f;
        float acc32 = 0.0f;
        float acc33 = 0.0f;
        float acc34 = 0.0f;
        float acc35 = 0.0f;
        float acc36 = 0.0f;
        float acc37 = 0.0f;
        float acc38 = 0.0f;
        float acc39 = 0.0f;
        float acc40 = 0.0f;
        float acc41 = 0.0f;
        float acc42 = 0.0f;
        float acc43 = 0.0f;
        float acc44 = 0.0f;
        float acc45 = 0.0f;
        float acc46 = 0.0f;
        float acc47 = 0.0f;
        float acc48 = 0.0f;
        float acc49 = 0.0f;
        float acc50 = 0.0f;
        float acc51 = 0.0f;
        float acc52 = 0.0f;
        float acc53 = 0.0f;
        float acc54 = 0.0f;
        float acc55 = 0.0f;
        float acc56 = 0.0f;
        float acc57 = 0.0f;
        float acc58 = 0.0f;
        float acc59 = 0.0f;
        float acc60 = 0.0f;
        float acc61 = 0.0f;
        float acc62 = 0.0f;
        float acc63 = 0.0f;
        for (int ridx3 = 0; ridx3 < 64; ridx3++) {
          int alu5 = (alu0+(ridx2*256)+ridx3);
          float val16 = data2[alu5+64];
          float val17 = data2[alu5+128];
          float val18 = data2[alu5+192];
          float val19 = data2[alu5+1024];
          float val20 = data2[alu5+1088];
          float val21 = data2[alu5+1152];
          float val22 = data2[alu5+1216];
          float val23 = data2[alu5+2048];
          float val24 = data2[alu5+2112];
          float val25 = data2[alu5+2176];
          float val26 = data2[alu5+2240];
          float val27 = data2[alu5+3072];
          float val28 = data2[alu5+3136];
          float val29 = data2[alu5+3200];
          float val30 = data2[alu5+3264];
          float val31 = data2[alu5];
          int alu6 = (alu0+(ridx1*256)+ridx3);
          float val32 = data1[alu6+64];
          float val33 = data1[alu6+128];
          float val34 = data1[alu6+192];
          float val35 = data1[alu6+1024];
          float val36 = data1[alu6+1088];
          float val37 = data1[alu6+1152];
          float val38 = data1[alu6+1216];
          float val39 = data1[alu6+2048];
          float val40 = data1[alu6+2112];
          float val41 = data1[alu6+2176];
          float val42 = data1[alu6+2240];
          float val43 = data1[alu6+3072];
          float val44 = data1[alu6+3136];
          float val45 = data1[alu6+3200];
          float val46 = data1[alu6+3264];
          float val47 = data1[alu6];
          acc0 = (acc0+(val47*val31));
          acc1 = (acc1+(val35*val19));
          acc2 = (acc2+(val39*val23));
          acc3 = (acc3+(val43*val27));
          acc4 = (acc4+(val32*val31));
          acc5 = (acc5+(val36*val19));
          acc6 = (acc6+(val40*val23));
          acc7 = (acc7+(val44*val27));
          acc8 = (acc8+(val33*val31));
          acc9 = (acc9+(val37*val19));
          acc10 = (acc10+(val41*val23));
          acc11 = (acc11+(val45*val27));
          acc12 = (acc12+(val34*val31));
          acc13 = (acc13+(val38*val19));
          acc14 = (acc14+(val42*val23));
          acc15 = (acc15+(val46*val27));
          acc16 = (acc16+(val47*val16));
          acc17 = (acc17+(val35*val20));
          acc18 = (acc18+(val39*val24));
          acc19 = (acc19+(val43*val28));
          acc20 = (acc20+(val32*val16));
          acc21 = (acc21+(val36*val20));
          acc22 = (acc22+(val40*val24));
          acc23 = (acc23+(val44*val28));
          acc24 = (acc24+(val33*val16));
          acc25 = (acc25+(val37*val20));
          acc26 = (acc26+(val41*val24));
          acc27 = (acc27+(val45*val28));
          acc28 = (acc28+(val34*val16));
          acc29 = (acc29+(val38*val20));
          acc30 = (acc30+(val42*val24));
          acc31 = (acc31+(val46*val28));
          acc32 = (acc32+(val47*val17));
          acc33 = (acc33+(val35*val21));
          acc34 = (acc34+(val39*val25));
          acc35 = (acc35+(val43*val29));
          acc36 = (acc36+(val32*val17));
          acc37 = (acc37+(val36*val21));
          acc38 = (acc38+(val40*val25));
          acc39 = (acc39+(val44*val29));
          acc40 = (acc40+(val33*val17));
          acc41 = (acc41+(val37*val21));
          acc42 = (acc42+(val41*val25));
          acc43 = (acc43+(val45*val29));
          acc44 = (acc44+(val34*val17));
          acc45 = (acc45+(val38*val21));
          acc46 = (acc46+(val42*val25));
          acc47 = (acc47+(val46*val29));
          acc48 = (acc48+(val47*val18));
          acc49 = (acc49+(val35*val22));
          acc50 = (acc50+(val39*val26));
          acc51 = (acc51+(val43*val30));
          acc52 = (acc52+(val32*val18));
          acc53 = (acc53+(val36*val22));
          acc54 = (acc54+(val40*val26));
          acc55 = (acc55+(val44*val30));
          acc56 = (acc56+(val33*val18));
          acc57 = (acc57+(val37*val22));
          acc58 = (acc58+(val41*val26));
          acc59 = (acc59+(val45*val30));
          acc60 = (acc60+(val34*val18));
          acc61 = (acc61+(val38*val22));
          acc62 = (acc62+(val42*val26));
          acc63 = (acc63+(val46*val30));
        }
        data0[alu3] = ((acc0*0.125f)+val15);
        data0[alu3+256] = ((acc1*0.125f)+val15);
        data0[alu3+512] = ((acc2*0.125f)+val15);
        data0[alu3+768] = ((acc3*0.125f)+val15);
        data0[alu3+16] = ((acc4*0.125f)+val3);
        data0[alu3+272] = ((acc5*0.125f)+val3);
        data0[alu3+528] = ((acc6*0.125f)+val3);
        data0[alu3+784] = ((acc7*0.125f)+val3);
        data0[alu3+32] = ((acc8*0.125f)+val7);
        data0[alu3+288] = ((acc9*0.125f)+val7);
        data0[alu3+544] = ((acc10*0.125f)+val7);
        data0[alu3+800] = ((acc11*0.125f)+val7);
        data0[alu3+48] = ((acc12*0.125f)+val11);
        data0[alu3+304] = ((acc13*0.125f)+val11);
        data0[alu3+560] = ((acc14*0.125f)+val11);
        data0[alu3+816] = ((acc15*0.125f)+val11);
        data0[alu3+1] = ((acc16*0.125f)+val0);
        data0[alu3+257] = ((acc17*0.125f)+val0);
        data0[alu3+513] = ((acc18*0.125f)+val0);
        data0[alu3+769] = ((acc19*0.125f)+val0);
        data0[alu3+17] = ((acc20*0.125f)+val4);
        data0[alu3+273] = ((acc21*0.125f)+val4);
        data0[alu3+529] = ((acc22*0.125f)+val4);
        data0[alu3+785] = ((acc23*0.125f)+val4);
        data0[alu3+33] = ((acc24*0.125f)+val8);
        data0[alu3+289] = ((acc25*0.125f)+val8);
        data0[alu3+545] = ((acc26*0.125f)+val8);
        data0[alu3+801] = ((acc27*0.125f)+val8);
        data0[alu3+49] = ((acc28*0.125f)+val12);
        data0[alu3+305] = ((acc29*0.125f)+val12);
        data0[alu3+561] = ((acc30*0.125f)+val12);
        data0[alu3+817] = ((acc31*0.125f)+val12);
        data0[alu3+2] = ((acc32*0.125f)+val1);
        data0[alu3+258] = ((acc33*0.125f)+val1);
        data0[alu3+514] = ((acc34*0.125f)+val1);
        data0[alu3+770] = ((acc35*0.125f)+val1);
        data0[alu3+18] = ((acc36*0.125f)+val5);
        data0[alu3+274] = ((acc37*0.125f)+val5);
        data0[alu3+530] = ((acc38*0.125f)+val5);
        data0[alu3+786] = ((acc39*0.125f)+val5);
        data0[alu3+34] = ((acc40*0.125f)+val9);
        data0[alu3+290] = ((acc41*0.125f)+val9);
        data0[alu3+546] = ((acc42*0.125f)+val9);
        data0[alu3+802] = ((acc43*0.125f)+val9);
        data0[alu3+50] = ((acc44*0.125f)+val13);
        data0[alu3+306] = ((acc45*0.125f)+val13);
        data0[alu3+562] = ((acc46*0.125f)+val13);
        data0[alu3+818] = ((acc47*0.125f)+val13);
        data0[alu3+3] = ((acc48*0.125f)+val2);
        data0[alu3+259] = ((acc49*0.125f)+val2);
        data0[alu3+515] = ((acc50*0.125f)+val2);
        data0[alu3+771] = ((acc51*0.125f)+val2);
        data0[alu3+19] = ((acc52*0.125f)+val6);
        data0[alu3+275] = ((acc53*0.125f)+val6);
        data0[alu3+531] = ((acc54*0.125f)+val6);
        data0[alu3+787] = ((acc55*0.125f)+val6);
        data0[alu3+35] = ((acc56*0.125f)+val10);
        data0[alu3+291] = ((acc57*0.125f)+val10);
        data0[alu3+547] = ((acc58*0.125f)+val10);
        data0[alu3+803] = ((acc59*0.125f)+val10);
        data0[alu3+51] = ((acc60*0.125f)+val14);
        data0[alu3+307] = ((acc61*0.125f)+val14);
        data0[alu3+563] = ((acc62*0.125f)+val14);
        data0[alu3+819] = ((acc63*0.125f)+val14);
      }
    }
  }
}
"""

entry = """unsigned long long HAP_perf_get_time_us(void);
int entry(unsigned long long handle, unsigned int sc, void* pra) {
return HAP_perf_get_time_us() == 1 ? 4 : 0;
}
"""

if __name__ == "__main__":
  dev = DSPDevice()

  bufs = [dev.allocator.alloc(0x60000) for _ in range(4)]

  only_entry = dev.compiler.compile(entry)
  app1 = dev.runtime("test", only_entry)
  x = app1(*bufs)

  entry_n_unsued_code = dev.compiler.compile(kernel + "\n" + entry)
  app2 = dev.runtime("test", entry_n_unsued_code)
  x = app2(*bufs)