You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
128 lines
3.9 KiB
128 lines
3.9 KiB
import numpy as np
|
|
import ctypes
|
|
from tinygrad import Tensor, GlobalCounters, Context
|
|
from tinygrad.engine.realize import lower_schedule, CompiledRunner
|
|
from tinygrad.device import CPUProgram
|
|
from dataclasses import replace
|
|
from keystone import Ks, KS_ARCH_ARM64, KS_MODE_LITTLE_ENDIAN
|
|
|
|
# only the memory access, over 100 GB/s! (sometimes)
|
|
reduce_asm = """
|
|
movi v0.2d, #0000000000000000
|
|
mov w9, #0x30
|
|
mov w10, #0x20
|
|
mov x8, #-0x10
|
|
movi v1.2d, #0000000000000000
|
|
movk w9, #0x300, lsl #16
|
|
movi v2.2d, #0000000000000000
|
|
movk w10, #0x200, lsl #16
|
|
movi v3.2d, #0000000000000000
|
|
mov w11, #0x1000000
|
|
mov w12, #0x3ffff0
|
|
loop:
|
|
ldp q4, q5, [x1]
|
|
add x13, x1, x11
|
|
add x15, x1, x10
|
|
add x14, x1, x9
|
|
add x8, x8, #0x10
|
|
cmp x8, x12
|
|
ldp q6, q7, [x1, #0x20]
|
|
add x1, x1, #0x40
|
|
ldp q4, q5, [x13]
|
|
ldp q6, q7, [x13, #0x20]
|
|
ldp q4, q5, [x15, #-0x20]
|
|
ldp q6, q7, [x15]
|
|
ldp q4, q5, [x14, #-0x30]
|
|
ldp q6, q7, [x14, #-0x10]
|
|
b.lo loop
|
|
fadd v0.4s, v1.4s, v0.4s
|
|
fadd v0.4s, v2.4s, v0.4s
|
|
fadd v0.4s, v3.4s, v0.4s
|
|
dup v1.4s, v0.s[1]
|
|
dup v2.4s, v0.s[2]
|
|
fadd v1.4s, v0.4s, v1.4s
|
|
dup v0.4s, v0.s[3]
|
|
fadd v1.4s, v2.4s, v1.4s
|
|
fadd v0.4s, v0.4s, v1.4s
|
|
str s0, [x0]
|
|
ret
|
|
"""
|
|
|
|
ks = Ks(KS_ARCH_ARM64, KS_MODE_LITTLE_ENDIAN)
|
|
arm_bytecode, _ = ks.asm(reduce_asm)
|
|
arm_bytecode = bytes(arm_bytecode)
|
|
|
|
reduce_src = """
|
|
// data1 is 16M inputs
|
|
typedef float float4 __attribute__((aligned(32),vector_size(16)));
|
|
void reduce(float* restrict data0, float* restrict data1) {
|
|
float4 acc0 = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
float4 acc1 = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
float4 acc2 = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
float4 acc3 = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
float4 acc4 = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
float4 acc5 = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
float4 acc6 = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
float4 acc7 = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
float* data1_1 = data1+4194304;
|
|
float* data1_2 = data1+(4194304*2);
|
|
float* data1_3 = data1+(4194304*3);
|
|
for (int ridx0 = 0; ridx0 < 16777216/4; ridx0+=16) {
|
|
float4 val0 = *(float4*)((data1+(ridx0+0)));
|
|
float4 val1 = *(float4*)((data1+(ridx0+4)));
|
|
float4 val2 = *(float4*)((data1+(ridx0+8)));
|
|
float4 val3 = *(float4*)((data1+(ridx0+12)));
|
|
acc0 += val0;
|
|
acc1 += val1;
|
|
acc2 += val2;
|
|
acc3 += val3;
|
|
val0 = *(float4*)((data1_1+(ridx0+0)));
|
|
val1 = *(float4*)((data1_1+(ridx0+4)));
|
|
val2 = *(float4*)((data1_1+(ridx0+8)));
|
|
val3 = *(float4*)((data1_1+(ridx0+12)));
|
|
acc4 += val0;
|
|
acc5 += val1;
|
|
acc6 += val2;
|
|
acc7 += val3;
|
|
val0 = *(float4*)((data1_2+(ridx0+0)));
|
|
val1 = *(float4*)((data1_2+(ridx0+4)));
|
|
val2 = *(float4*)((data1_2+(ridx0+8)));
|
|
val3 = *(float4*)((data1_2+(ridx0+12)));
|
|
acc0 += val0;
|
|
acc1 += val1;
|
|
acc2 += val2;
|
|
acc3 += val3;
|
|
val0 = *(float4*)((data1_3+(ridx0+0)));
|
|
val1 = *(float4*)((data1_3+(ridx0+4)));
|
|
val2 = *(float4*)((data1_3+(ridx0+8)));
|
|
val3 = *(float4*)((data1_3+(ridx0+12)));
|
|
acc4 += val0;
|
|
acc5 += val1;
|
|
acc6 += val2;
|
|
acc7 += val3;
|
|
}
|
|
float4 out = acc0+acc1+acc2+acc3+acc4+acc5+acc6+acc7;
|
|
*(data0+0) = out[0]+out[1]+out[2]+out[3];
|
|
}
|
|
"""
|
|
|
|
if __name__ == "__main__":
|
|
a = Tensor(np_array:=(np.random.default_rng().random((4096, 4096), dtype=np.float32)-0.5)).realize()
|
|
with Context(SPLIT_REDUCEOP=0):
|
|
# TODO: make it easy to alter the OptOps for a ScheduleItem
|
|
GlobalCounters.reset()
|
|
out = a.sum()
|
|
sis = out.schedule()
|
|
for i,(_,ei) in enumerate(lower_schedule(sis)):
|
|
if i == 0:
|
|
# change the source code
|
|
prg_spec = ei.prg.p
|
|
prg_spec = replace(prg_spec, name="reduce", src=reduce_src)
|
|
prg = CompiledRunner(prg_spec)
|
|
# change the assembly
|
|
#prg._prg = CPUProgram(prg_spec.name, arm_bytecode)
|
|
print("buffer at:",hex(ctypes.addressof(ei.bufs[1]._buf)))
|
|
ei = replace(ei, prg=prg)
|
|
ei.run()
|
|
print(out.item())
|
|
np.testing.assert_allclose(out.item(), np_array.sum(), atol=1, rtol=1e-4)
|
|
|