You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
279 lines
10 KiB
279 lines
10 KiB
from tinygrad.runtime.ops_dsp import DSPDevice
|
|
|
|
kernel = """__attribute__((noinline)) void r_64_4_4_64_4_4_4(float* restrict __attribute__((align_value(128))) data0, const float* restrict __attribute__((align_value(128))) data1, const float* restrict __attribute__((align_value(128))) data2, const float* restrict __attribute__((align_value(128))) data3) {
|
|
for (int ridx0 = 0; ridx0 < 64; ridx0++) {
|
|
int alu0 = (ridx0*4096);
|
|
for (int ridx1 = 0; ridx1 < 4; ridx1++) {
|
|
int alu1 = (ridx1*64);
|
|
for (int ridx2 = 0; ridx2 < 4; ridx2++) {
|
|
int alu2 = (ridx2*4);
|
|
int alu3 = ((ridx0*1024)+alu1+alu2);
|
|
int alu4 = (alu1+alu2);
|
|
float val0 = data3[alu4+1];
|
|
float val1 = data3[alu4+2];
|
|
float val2 = data3[alu4+3];
|
|
float val3 = data3[alu4+16];
|
|
float val4 = data3[alu4+17];
|
|
float val5 = data3[alu4+18];
|
|
float val6 = data3[alu4+19];
|
|
float val7 = data3[alu4+32];
|
|
float val8 = data3[alu4+33];
|
|
float val9 = data3[alu4+34];
|
|
float val10 = data3[alu4+35];
|
|
float val11 = data3[alu4+48];
|
|
float val12 = data3[alu4+49];
|
|
float val13 = data3[alu4+50];
|
|
float val14 = data3[alu4+51];
|
|
float val15 = data3[alu4];
|
|
float acc0 = 0.0f;
|
|
float acc1 = 0.0f;
|
|
float acc2 = 0.0f;
|
|
float acc3 = 0.0f;
|
|
float acc4 = 0.0f;
|
|
float acc5 = 0.0f;
|
|
float acc6 = 0.0f;
|
|
float acc7 = 0.0f;
|
|
float acc8 = 0.0f;
|
|
float acc9 = 0.0f;
|
|
float acc10 = 0.0f;
|
|
float acc11 = 0.0f;
|
|
float acc12 = 0.0f;
|
|
float acc13 = 0.0f;
|
|
float acc14 = 0.0f;
|
|
float acc15 = 0.0f;
|
|
float acc16 = 0.0f;
|
|
float acc17 = 0.0f;
|
|
float acc18 = 0.0f;
|
|
float acc19 = 0.0f;
|
|
float acc20 = 0.0f;
|
|
float acc21 = 0.0f;
|
|
float acc22 = 0.0f;
|
|
float acc23 = 0.0f;
|
|
float acc24 = 0.0f;
|
|
float acc25 = 0.0f;
|
|
float acc26 = 0.0f;
|
|
float acc27 = 0.0f;
|
|
float acc28 = 0.0f;
|
|
float acc29 = 0.0f;
|
|
float acc30 = 0.0f;
|
|
float acc31 = 0.0f;
|
|
float acc32 = 0.0f;
|
|
float acc33 = 0.0f;
|
|
float acc34 = 0.0f;
|
|
float acc35 = 0.0f;
|
|
float acc36 = 0.0f;
|
|
float acc37 = 0.0f;
|
|
float acc38 = 0.0f;
|
|
float acc39 = 0.0f;
|
|
float acc40 = 0.0f;
|
|
float acc41 = 0.0f;
|
|
float acc42 = 0.0f;
|
|
float acc43 = 0.0f;
|
|
float acc44 = 0.0f;
|
|
float acc45 = 0.0f;
|
|
float acc46 = 0.0f;
|
|
float acc47 = 0.0f;
|
|
float acc48 = 0.0f;
|
|
float acc49 = 0.0f;
|
|
float acc50 = 0.0f;
|
|
float acc51 = 0.0f;
|
|
float acc52 = 0.0f;
|
|
float acc53 = 0.0f;
|
|
float acc54 = 0.0f;
|
|
float acc55 = 0.0f;
|
|
float acc56 = 0.0f;
|
|
float acc57 = 0.0f;
|
|
float acc58 = 0.0f;
|
|
float acc59 = 0.0f;
|
|
float acc60 = 0.0f;
|
|
float acc61 = 0.0f;
|
|
float acc62 = 0.0f;
|
|
float acc63 = 0.0f;
|
|
for (int ridx3 = 0; ridx3 < 64; ridx3++) {
|
|
int alu5 = (alu0+(ridx2*256)+ridx3);
|
|
float val16 = data2[alu5+64];
|
|
float val17 = data2[alu5+128];
|
|
float val18 = data2[alu5+192];
|
|
float val19 = data2[alu5+1024];
|
|
float val20 = data2[alu5+1088];
|
|
float val21 = data2[alu5+1152];
|
|
float val22 = data2[alu5+1216];
|
|
float val23 = data2[alu5+2048];
|
|
float val24 = data2[alu5+2112];
|
|
float val25 = data2[alu5+2176];
|
|
float val26 = data2[alu5+2240];
|
|
float val27 = data2[alu5+3072];
|
|
float val28 = data2[alu5+3136];
|
|
float val29 = data2[alu5+3200];
|
|
float val30 = data2[alu5+3264];
|
|
float val31 = data2[alu5];
|
|
int alu6 = (alu0+(ridx1*256)+ridx3);
|
|
float val32 = data1[alu6+64];
|
|
float val33 = data1[alu6+128];
|
|
float val34 = data1[alu6+192];
|
|
float val35 = data1[alu6+1024];
|
|
float val36 = data1[alu6+1088];
|
|
float val37 = data1[alu6+1152];
|
|
float val38 = data1[alu6+1216];
|
|
float val39 = data1[alu6+2048];
|
|
float val40 = data1[alu6+2112];
|
|
float val41 = data1[alu6+2176];
|
|
float val42 = data1[alu6+2240];
|
|
float val43 = data1[alu6+3072];
|
|
float val44 = data1[alu6+3136];
|
|
float val45 = data1[alu6+3200];
|
|
float val46 = data1[alu6+3264];
|
|
float val47 = data1[alu6];
|
|
acc0 = (acc0+(val47*val31));
|
|
acc1 = (acc1+(val35*val19));
|
|
acc2 = (acc2+(val39*val23));
|
|
acc3 = (acc3+(val43*val27));
|
|
acc4 = (acc4+(val32*val31));
|
|
acc5 = (acc5+(val36*val19));
|
|
acc6 = (acc6+(val40*val23));
|
|
acc7 = (acc7+(val44*val27));
|
|
acc8 = (acc8+(val33*val31));
|
|
acc9 = (acc9+(val37*val19));
|
|
acc10 = (acc10+(val41*val23));
|
|
acc11 = (acc11+(val45*val27));
|
|
acc12 = (acc12+(val34*val31));
|
|
acc13 = (acc13+(val38*val19));
|
|
acc14 = (acc14+(val42*val23));
|
|
acc15 = (acc15+(val46*val27));
|
|
acc16 = (acc16+(val47*val16));
|
|
acc17 = (acc17+(val35*val20));
|
|
acc18 = (acc18+(val39*val24));
|
|
acc19 = (acc19+(val43*val28));
|
|
acc20 = (acc20+(val32*val16));
|
|
acc21 = (acc21+(val36*val20));
|
|
acc22 = (acc22+(val40*val24));
|
|
acc23 = (acc23+(val44*val28));
|
|
acc24 = (acc24+(val33*val16));
|
|
acc25 = (acc25+(val37*val20));
|
|
acc26 = (acc26+(val41*val24));
|
|
acc27 = (acc27+(val45*val28));
|
|
acc28 = (acc28+(val34*val16));
|
|
acc29 = (acc29+(val38*val20));
|
|
acc30 = (acc30+(val42*val24));
|
|
acc31 = (acc31+(val46*val28));
|
|
acc32 = (acc32+(val47*val17));
|
|
acc33 = (acc33+(val35*val21));
|
|
acc34 = (acc34+(val39*val25));
|
|
acc35 = (acc35+(val43*val29));
|
|
acc36 = (acc36+(val32*val17));
|
|
acc37 = (acc37+(val36*val21));
|
|
acc38 = (acc38+(val40*val25));
|
|
acc39 = (acc39+(val44*val29));
|
|
acc40 = (acc40+(val33*val17));
|
|
acc41 = (acc41+(val37*val21));
|
|
acc42 = (acc42+(val41*val25));
|
|
acc43 = (acc43+(val45*val29));
|
|
acc44 = (acc44+(val34*val17));
|
|
acc45 = (acc45+(val38*val21));
|
|
acc46 = (acc46+(val42*val25));
|
|
acc47 = (acc47+(val46*val29));
|
|
acc48 = (acc48+(val47*val18));
|
|
acc49 = (acc49+(val35*val22));
|
|
acc50 = (acc50+(val39*val26));
|
|
acc51 = (acc51+(val43*val30));
|
|
acc52 = (acc52+(val32*val18));
|
|
acc53 = (acc53+(val36*val22));
|
|
acc54 = (acc54+(val40*val26));
|
|
acc55 = (acc55+(val44*val30));
|
|
acc56 = (acc56+(val33*val18));
|
|
acc57 = (acc57+(val37*val22));
|
|
acc58 = (acc58+(val41*val26));
|
|
acc59 = (acc59+(val45*val30));
|
|
acc60 = (acc60+(val34*val18));
|
|
acc61 = (acc61+(val38*val22));
|
|
acc62 = (acc62+(val42*val26));
|
|
acc63 = (acc63+(val46*val30));
|
|
}
|
|
data0[alu3] = ((acc0*0.125f)+val15);
|
|
data0[alu3+256] = ((acc1*0.125f)+val15);
|
|
data0[alu3+512] = ((acc2*0.125f)+val15);
|
|
data0[alu3+768] = ((acc3*0.125f)+val15);
|
|
data0[alu3+16] = ((acc4*0.125f)+val3);
|
|
data0[alu3+272] = ((acc5*0.125f)+val3);
|
|
data0[alu3+528] = ((acc6*0.125f)+val3);
|
|
data0[alu3+784] = ((acc7*0.125f)+val3);
|
|
data0[alu3+32] = ((acc8*0.125f)+val7);
|
|
data0[alu3+288] = ((acc9*0.125f)+val7);
|
|
data0[alu3+544] = ((acc10*0.125f)+val7);
|
|
data0[alu3+800] = ((acc11*0.125f)+val7);
|
|
data0[alu3+48] = ((acc12*0.125f)+val11);
|
|
data0[alu3+304] = ((acc13*0.125f)+val11);
|
|
data0[alu3+560] = ((acc14*0.125f)+val11);
|
|
data0[alu3+816] = ((acc15*0.125f)+val11);
|
|
data0[alu3+1] = ((acc16*0.125f)+val0);
|
|
data0[alu3+257] = ((acc17*0.125f)+val0);
|
|
data0[alu3+513] = ((acc18*0.125f)+val0);
|
|
data0[alu3+769] = ((acc19*0.125f)+val0);
|
|
data0[alu3+17] = ((acc20*0.125f)+val4);
|
|
data0[alu3+273] = ((acc21*0.125f)+val4);
|
|
data0[alu3+529] = ((acc22*0.125f)+val4);
|
|
data0[alu3+785] = ((acc23*0.125f)+val4);
|
|
data0[alu3+33] = ((acc24*0.125f)+val8);
|
|
data0[alu3+289] = ((acc25*0.125f)+val8);
|
|
data0[alu3+545] = ((acc26*0.125f)+val8);
|
|
data0[alu3+801] = ((acc27*0.125f)+val8);
|
|
data0[alu3+49] = ((acc28*0.125f)+val12);
|
|
data0[alu3+305] = ((acc29*0.125f)+val12);
|
|
data0[alu3+561] = ((acc30*0.125f)+val12);
|
|
data0[alu3+817] = ((acc31*0.125f)+val12);
|
|
data0[alu3+2] = ((acc32*0.125f)+val1);
|
|
data0[alu3+258] = ((acc33*0.125f)+val1);
|
|
data0[alu3+514] = ((acc34*0.125f)+val1);
|
|
data0[alu3+770] = ((acc35*0.125f)+val1);
|
|
data0[alu3+18] = ((acc36*0.125f)+val5);
|
|
data0[alu3+274] = ((acc37*0.125f)+val5);
|
|
data0[alu3+530] = ((acc38*0.125f)+val5);
|
|
data0[alu3+786] = ((acc39*0.125f)+val5);
|
|
data0[alu3+34] = ((acc40*0.125f)+val9);
|
|
data0[alu3+290] = ((acc41*0.125f)+val9);
|
|
data0[alu3+546] = ((acc42*0.125f)+val9);
|
|
data0[alu3+802] = ((acc43*0.125f)+val9);
|
|
data0[alu3+50] = ((acc44*0.125f)+val13);
|
|
data0[alu3+306] = ((acc45*0.125f)+val13);
|
|
data0[alu3+562] = ((acc46*0.125f)+val13);
|
|
data0[alu3+818] = ((acc47*0.125f)+val13);
|
|
data0[alu3+3] = ((acc48*0.125f)+val2);
|
|
data0[alu3+259] = ((acc49*0.125f)+val2);
|
|
data0[alu3+515] = ((acc50*0.125f)+val2);
|
|
data0[alu3+771] = ((acc51*0.125f)+val2);
|
|
data0[alu3+19] = ((acc52*0.125f)+val6);
|
|
data0[alu3+275] = ((acc53*0.125f)+val6);
|
|
data0[alu3+531] = ((acc54*0.125f)+val6);
|
|
data0[alu3+787] = ((acc55*0.125f)+val6);
|
|
data0[alu3+35] = ((acc56*0.125f)+val10);
|
|
data0[alu3+291] = ((acc57*0.125f)+val10);
|
|
data0[alu3+547] = ((acc58*0.125f)+val10);
|
|
data0[alu3+803] = ((acc59*0.125f)+val10);
|
|
data0[alu3+51] = ((acc60*0.125f)+val14);
|
|
data0[alu3+307] = ((acc61*0.125f)+val14);
|
|
data0[alu3+563] = ((acc62*0.125f)+val14);
|
|
data0[alu3+819] = ((acc63*0.125f)+val14);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
"""
|
|
|
|
entry = """unsigned long long HAP_perf_get_time_us(void);
|
|
int entry(unsigned long long handle, unsigned int sc, void* pra) {
|
|
return HAP_perf_get_time_us() == 1 ? 4 : 0;
|
|
}
|
|
"""
|
|
|
|
if __name__ == "__main__":
|
|
dev = DSPDevice()
|
|
|
|
bufs = [dev.allocator.alloc(0x60000) for _ in range(4)]
|
|
|
|
only_entry = dev.compiler.compile(entry)
|
|
app1 = dev.runtime("test", only_entry)
|
|
x = app1(*bufs)
|
|
|
|
entry_n_unsued_code = dev.compiler.compile(kernel + "\n" + entry)
|
|
app2 = dev.runtime("test", entry_n_unsued_code)
|
|
x = app2(*bufs)
|
|
|