You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					48 lines
				
				2.0 KiB
			
		
		
			
		
	
	
					48 lines
				
				2.0 KiB
			| 
											4 days ago
										 | # NOTE: this is probably the wrong backend to work on, you want backend.py
 | ||
|  | # discussion about this in #pytorch-backend on our Discord
 | ||
|  | # this is using high level debug features from torch and we should integrate deeper
 | ||
|  | 
 | ||
|  | from tinygrad import Tensor
 | ||
|  | import torch, contextlib
 | ||
|  | from torch.utils._python_dispatch import TorchDispatchMode
 | ||
|  | from tinygrad.dtype import _from_torch_dtype
 | ||
|  | 
 | ||
|  | def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
 | ||
|  |   return TTensor(Tensor.empty(*size, dtype=_from_torch_dtype(dtype)))
 | ||
|  | 
 | ||
|  | # NOTE: if we have a way to change wrap/unwrap, these can be the same methods from backend.py
 | ||
|  | tiny_backend = {
 | ||
|  |   "aten.empty.memory_format": empty_memory_format,
 | ||
|  |   "aten.view.default": lambda x,sz: TTensor(x.tiny.reshape(sz)),
 | ||
|  |   "aten.abs.default": lambda x: TTensor(x.tiny.abs()),
 | ||
|  |   "aten.eq.Tensor": lambda x,y: TTensor(x.tiny == y.tiny),
 | ||
|  |   "aten.bitwise_and.Tensor": lambda x,y: TTensor(x.tiny & y.tiny),
 | ||
|  |   "aten.ne.Scalar": lambda x,y: TTensor(x.tiny != y),
 | ||
|  |   "aten.mul.Tensor": lambda x,y: TTensor(x.tiny * y.tiny),
 | ||
|  |   "aten.masked_select.default": lambda x,y: TTensor(Tensor(x.tiny.numpy()[y.tiny.numpy()])),
 | ||
|  | }
 | ||
|  | 
 | ||
|  | class TTensor(torch.Tensor):
 | ||
|  |   tiny: Tensor
 | ||
|  |   context = contextlib.nullcontext
 | ||
|  | 
 | ||
|  |   @staticmethod
 | ||
|  |   def __new__(cls, tiny, *args, **kwargs):
 | ||
|  |     out = torch.Tensor._make_wrapper_subclass(cls, tiny.shape)
 | ||
|  |     torch._C._set_throw_on_mutable_data_ptr(out)
 | ||
|  |     out.tiny = tiny
 | ||
|  |     return out
 | ||
|  |   def __repr__(self): return super().__repr__(tensor_contents=f"{self.tiny}")
 | ||
|  |   def __torch_dispatch__(cls, func, types, args, kwargs=None):
 | ||
|  |     print(f"Dispatch Log: {func}(*{[type(x) for x in args]}, **{kwargs.keys()})")
 | ||
|  |     #print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
 | ||
|  |     new_func = tiny_backend.get(str(func), None)
 | ||
|  |     if new_func is None: raise NotImplementedError(f"add support for {func}")
 | ||
|  |     return new_func(*args, **(kwargs or {}))
 | ||
|  | 
 | ||
|  | class Dispatcher(TorchDispatchMode): __torch_dispatch__ = TTensor.__torch_dispatch__
 | ||
|  | Dispatcher().__enter__()
 | ||
|  | 
 | ||
|  | if __name__ == "__main__":
 | ||
|  |   a = torch.empty((4,), dtype=torch.int)
 |