You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
40 lines
1.2 KiB
40 lines
1.2 KiB
2 months ago
|
from functools import lru_cache
|
||
|
from .tensor import Device, Function, register
|
||
|
|
||
|
@lru_cache
|
||
|
def compile_wrapper(ane, dat):
|
||
|
return ane.compile(dat)
|
||
|
|
||
|
def roundup(x, v):
|
||
|
return x + (v-x)%v
|
||
|
|
||
|
@lru_cache
|
||
|
def compile_relu(ane, sz):
|
||
|
dat = list(open("accel/ane/ops/relu.hwx", "rb").read())
|
||
|
# TODO: make this all nice and once
|
||
|
# number of engines? (max 0x100)
|
||
|
l2_stride = max(0x100, roundup(sz*2, 0x10))
|
||
|
# 0x1ec = L2.SourceChannelStride.Stride, 0x1f0 = L2.SourceRowStride.Stride
|
||
|
# 0x1f4, 0x1f8?
|
||
|
# 0x214 = L2.ResultBase.Addr
|
||
|
dat = ane.fill(dat, [0x1ec, 0x1f0, 0x1f4, 0x1f8, 0x214], "I", l2_stride)
|
||
|
stride = roundup(sz*2, 0x40)
|
||
|
dat = ane.filln(dat, {
|
||
|
"NeuronType": 0x11, # 0x10 makes this a copy, 0x11 = ReLU, 0x12 = crash
|
||
|
"InputWidth": sz, "OutputWidth": sz,
|
||
|
"InputRowStride": stride, "InputPlaneStride": stride, "InputDepthStride": stride,
|
||
|
"OutputRowStride": stride, "OutputPlaneStride": stride, "OutputDepthStride": stride,
|
||
|
})
|
||
|
return compile_wrapper(ane, bytes(dat))
|
||
|
|
||
|
class ReLU(Function):
|
||
|
def forward(ctx, input):
|
||
|
ret = ctx.ane.tensor(input.shape)
|
||
|
ctx.ane.run(compile_relu(ctx.ane, input.sz), input, ret)
|
||
|
return ret
|
||
|
|
||
|
def backward(ctx, grad_output):
|
||
|
return 0
|
||
|
|
||
|
register('relu', ReLU, device=Device.ANE)
|