You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
102 lines
3.4 KiB
102 lines
3.4 KiB
14 hours ago
|
import ctypes, ctypes.util, time, os, builtins, fcntl
|
||
|
from tinygrad.runtime.support.hcq import HWInterface
|
||
|
from test.mockgpu.nv.nvdriver import NVDriver
|
||
|
from test.mockgpu.amd.amddriver import AMDDriver
|
||
|
start = time.perf_counter()
|
||
|
|
||
|
# *** ioctl lib ***
|
||
|
libc = ctypes.CDLL(ctypes.util.find_library("c"))
|
||
|
libc.mmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_long]
|
||
|
libc.mmap.restype = ctypes.c_void_p
|
||
|
|
||
|
drivers = [AMDDriver(), NVDriver()]
|
||
|
tracked_fds = {}
|
||
|
|
||
|
orignal_memoryview = builtins.memoryview
|
||
|
class TrackedMemoryView:
|
||
|
def __init__(self, data, rcb, wcb):
|
||
|
self.mv = orignal_memoryview(data)
|
||
|
self.rcb, self.wcb = rcb, wcb
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
self.rcb(self.mv, index)
|
||
|
return self.mv[index]
|
||
|
|
||
|
def __setitem__(self, index, value):
|
||
|
self.mv[index] = value
|
||
|
self.wcb(self.mv, index)
|
||
|
|
||
|
def cast(self, new_type, **kwargs):
|
||
|
self.mv = self.mv.cast(new_type, **kwargs)
|
||
|
return self
|
||
|
|
||
|
@property
|
||
|
def nbytes(self): return self.mv.nbytes
|
||
|
def __len__(self): return len(self.mv)
|
||
|
def __repr__(self): return repr(self.mv)
|
||
|
|
||
|
def _memoryview(cls, mem):
|
||
|
if isinstance(mem, int) or isinstance(mem, ctypes.Array):
|
||
|
addr = ctypes.addressof(mem) if isinstance(mem, ctypes.Array) else mem
|
||
|
for d in drivers:
|
||
|
for st,en,rcb,wcb in d.tracked_addresses:
|
||
|
if st <= addr <= en: return TrackedMemoryView(mem, rcb, wcb)
|
||
|
return orignal_memoryview(mem)
|
||
|
builtins.memoryview = type("memoryview", (), {'__new__': _memoryview}) # type: ignore
|
||
|
|
||
|
def _open(path, flags):
|
||
|
for d in drivers:
|
||
|
for x in d.tracked_files:
|
||
|
if path == x.path:
|
||
|
virtfd = d.open(path, flags, 0o777, x)
|
||
|
tracked_fds[virtfd.fd] = virtfd
|
||
|
return virtfd.fd
|
||
|
return os.open(path, flags, 0o777) if os.path.exists(path) else None
|
||
|
|
||
|
class MockHWInterface(HWInterface):
|
||
|
def __init__(self, path:str="", flags:int=os.O_RDONLY, fd:int|None=None):
|
||
|
self.path = path
|
||
|
self.fd = fd or _open(path, flags)
|
||
|
|
||
|
def __del__(self):
|
||
|
if self.fd in tracked_fds:
|
||
|
tracked_fds[self.fd].close(self.fd)
|
||
|
tracked_fds.pop(self.fd)
|
||
|
else: os.close(self.fd)
|
||
|
|
||
|
def ioctl(self, request, arg):
|
||
|
if self.fd in tracked_fds:
|
||
|
return tracked_fds[self.fd].ioctl(self.fd, request, ctypes.addressof(arg))
|
||
|
return fcntl.ioctl(self.fd, request, arg)
|
||
|
|
||
|
def mmap(self, start, sz, prot, flags, offset):
|
||
|
if self.fd in tracked_fds:
|
||
|
return tracked_fds[self.fd].mmap(start, sz, prot, flags, self.fd, offset)
|
||
|
return libc.mmap(start, sz, prot, flags, self.fd, offset)
|
||
|
|
||
|
def read(self, size=None, binary=False, offset=None):
|
||
|
if binary: raise NotImplementedError()
|
||
|
if self.fd in tracked_fds:
|
||
|
return tracked_fds[self.fd].read_contents(size)
|
||
|
with open(self.fd, "rb" if binary else "r", closefd=False) as file:
|
||
|
if file.tell() >= os.fstat(self.fd).st_size: file.seek(0)
|
||
|
return file.read(size)
|
||
|
|
||
|
def listdir(self):
|
||
|
if self.fd in tracked_fds:
|
||
|
return tracked_fds[self.fd].list_contents()
|
||
|
return os.listdir(self.path)
|
||
|
|
||
|
def write(self, content, binary=False, offset=None): raise NotImplementedError()
|
||
|
def seek(self, offset):
|
||
|
if self.fd in tracked_fds:
|
||
|
tracked_fds[self.fd].seek(offset)
|
||
|
else:
|
||
|
os.lseek(self.fd, offset, os.SEEK_CUR)
|
||
|
@staticmethod
|
||
|
def exists(path): return _open(path, os.O_RDONLY) is not None
|
||
|
@staticmethod
|
||
|
def readlink(path): raise NotImplementedError()
|
||
|
@staticmethod
|
||
|
def eventfd(initval, flags=None): NotImplementedError()
|