You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
46 lines
1.3 KiB
46 lines
1.3 KiB
|
3 days ago
|
from tinygrad.dtype import AddrSpace
|
||
|
|
|
||
|
|
from extra.thunder.tiny.tk import WARP_THREADS
|
||
|
|
|
||
|
|
class GL:
|
||
|
|
def __init__(self, shape, dtype, ker):
|
||
|
|
self.shape, self.dtype = shape, dtype
|
||
|
|
self._uop = ker.alloc(shape, dtype, AddrSpace.GLOBAL)
|
||
|
|
|
||
|
|
class ST:
|
||
|
|
def __init__(self, shape, dtype, ker):
|
||
|
|
self.shape, self.dtype = shape, dtype
|
||
|
|
self._uop = ker.alloc(shape, dtype, AddrSpace.LOCAL)
|
||
|
|
|
||
|
|
class RT:
|
||
|
|
TILE_ROW_DIM, TILE_COL_DIM = 16, 16
|
||
|
|
BASE_TILE_NE = TILE_ROW_DIM * TILE_COL_DIM
|
||
|
|
BASE_TILE_NEPT = BASE_TILE_NE // WARP_THREADS
|
||
|
|
|
||
|
|
def __init__(self, shape, dtype, ker):
|
||
|
|
assert len(shape) == 2
|
||
|
|
assert shape[0] % RT.TILE_ROW_DIM == 0
|
||
|
|
assert shape[1] % RT.TILE_COL_DIM == 0
|
||
|
|
|
||
|
|
height = shape[0] // RT.TILE_ROW_DIM
|
||
|
|
width = shape[1] // RT.TILE_COL_DIM
|
||
|
|
|
||
|
|
self.shape, self.dtype = (height, width, self.BASE_TILE_NEPT), dtype
|
||
|
|
self._uop = ker.alloc(self.shape, dtype, AddrSpace.REG)
|
||
|
|
|
||
|
|
class RV:
|
||
|
|
def __init__(self, length, dtype, layout, ker):
|
||
|
|
tiles = length // RT.TILE_ROW_DIM
|
||
|
|
|
||
|
|
match layout:
|
||
|
|
case "naive":
|
||
|
|
inner_dim = 1
|
||
|
|
outer_dim = (tiles + 1) // 2
|
||
|
|
case "ortho":
|
||
|
|
inner_dim = 1
|
||
|
|
outer_dim = tiles
|
||
|
|
case _: raise NotImplementedError(f"rv layout {layout} not implemented")
|
||
|
|
|
||
|
|
self.shape, self.dtype = (outer_dim, inner_dim, 2), dtype
|
||
|
|
self._uop = ker.alloc(self.shape, dtype, AddrSpace.REG)
|