You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					226 lines
				
				13 KiB
			
		
		
			
		
	
	
					226 lines
				
				13 KiB
			| 
											3 days ago
										 | import functools, struct
 | ||
|  | from tinygrad.device import  Compiled, Allocator, Compiler, BufferSpec
 | ||
|  | from tinygrad.renderer.wgsl import WGSLRenderer
 | ||
|  | from tinygrad.helpers import round_up, suppress_finalizing
 | ||
|  | from tinygrad.runtime.autogen import webgpu
 | ||
|  | from typing import List, Any, TypeAlias
 | ||
|  | import ctypes
 | ||
|  | import os
 | ||
|  | 
 | ||
|  | WGPUDevPtr: TypeAlias = webgpu.WGPUDevice # type: ignore
 | ||
|  | WGPUBufPtr: TypeAlias = webgpu.WGPUBuffer # type: ignore
 | ||
|  | 
 | ||
|  | backend_types = {v: k for k, v in webgpu.WGPUBackendType__enumvalues.items() }
 | ||
|  | 
 | ||
|  | instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True)))
 | ||
|  | 
 | ||
|  | def to_c_string(_str:str) -> ctypes.Array: return ctypes.create_string_buffer(_str.encode('utf-8'))
 | ||
|  | 
 | ||
|  | def from_wgpu_str(string_view:webgpu.struct_WGPUStringView) -> str: return ctypes.string_at(string_view.data, string_view.length).decode("utf-8")
 | ||
|  | 
 | ||
|  | def to_wgpu_str(_str:str) -> webgpu.struct_WGPUStringView:
 | ||
|  |   return webgpu.WGPUStringView(data=ctypes.cast(ctypes.pointer(to_c_string(_str)), ctypes.POINTER(ctypes.c_char)), length=len(_str))
 | ||
|  | 
 | ||
|  | def _wait(future:webgpu.struct_WGPUFuture):
 | ||
|  |   assert webgpu.wgpuInstanceWaitAny(instance, 1, webgpu.WGPUFutureWaitInfo(future=future), 2**64-1) == webgpu.WGPUWaitStatus_Success, "Future failed"
 | ||
|  | 
 | ||
|  | def write_buffer(device:WGPUDevPtr, buf:WGPUBufPtr, offset:int, src:memoryview|bytearray|bytes):
 | ||
|  |   src = bytearray(src)
 | ||
|  |   webgpu.wgpuQueueWriteBuffer(webgpu.wgpuDeviceGetQueue(device), buf, offset, (ctypes.c_uint8 * len(src)).from_buffer(src), len(src))
 | ||
|  | 
 | ||
|  | def _run(async_fun, cb_info_type, cb_type, status_enum, res_idx:int|None, msg_idx:int|None, *params):
 | ||
|  |   result: List[Any] = []
 | ||
|  | 
 | ||
|  |   def cb(*params):
 | ||
|  |     result[:] = params
 | ||
|  |     if msg_idx: result[msg_idx] = from_wgpu_str(result[msg_idx])
 | ||
|  | 
 | ||
|  |   cb_info = cb_info_type(nextInChain=None, mode=webgpu.WGPUCallbackMode_WaitAnyOnly, callback=cb_type(cb))
 | ||
|  |   _wait(async_fun(*params, cb_info))
 | ||
|  | 
 | ||
|  |   if result[0] != 1: raise RuntimeError(f"[{status_enum[result[0]] if status_enum else 'ERROR'}]{result[msg_idx] if msg_idx else ''}")
 | ||
|  |   return result[res_idx] if res_idx else None
 | ||
|  | 
 | ||
|  | def copy_buffer_to_buffer(dev:WGPUDevPtr, src:WGPUBufPtr, src_offset:int, dst:WGPUBufPtr, dst_offset:int, size:int):
 | ||
|  |   encoder = webgpu.wgpuDeviceCreateCommandEncoder(dev, webgpu.WGPUCommandEncoderDescriptor())
 | ||
|  |   webgpu.wgpuCommandEncoderCopyBufferToBuffer(encoder, src, src_offset, dst, dst_offset, size)
 | ||
|  |   cb = webgpu.wgpuCommandEncoderFinish(encoder, webgpu.WGPUCommandBufferDescriptor())
 | ||
|  |   webgpu.wgpuQueueSubmit(webgpu.wgpuDeviceGetQueue(dev), 1, (webgpu.WGPUCommandBuffer*1)(cb))
 | ||
|  |   webgpu.wgpuCommandBufferRelease(cb)
 | ||
|  |   webgpu.wgpuCommandEncoderRelease(encoder)
 | ||
|  | 
 | ||
|  | def read_buffer(dev:WGPUDevPtr, buf:WGPUBufPtr) -> memoryview:
 | ||
|  |   size = webgpu.wgpuBufferGetSize(buf)
 | ||
|  |   tmp_buffer = webgpu.wgpuDeviceCreateBuffer(dev, webgpu.WGPUBufferDescriptor(size=size,
 | ||
|  |     usage=webgpu.WGPUBufferUsage_CopyDst | webgpu.WGPUBufferUsage_MapRead, mappedAtCreation=False))
 | ||
|  |   copy_buffer_to_buffer(dev, buf, 0, tmp_buffer, 0, size)
 | ||
|  |   _run(webgpu.wgpuBufferMapAsync2, webgpu.WGPUBufferMapCallbackInfo2, webgpu.WGPUBufferMapCallback2, webgpu.WGPUBufferMapAsyncStatus__enumvalues,
 | ||
|  |     None, 0, tmp_buffer, webgpu.WGPUMapMode_Read, 0, size)
 | ||
|  |   void_ptr = ctypes.cast(webgpu.wgpuBufferGetConstMappedRange(tmp_buffer, 0, size), ctypes.c_void_p)
 | ||
|  |   buf_copy = bytearray((ctypes.c_uint8 * size).from_address(void_ptr.value))
 | ||
|  |   webgpu.wgpuBufferUnmap(tmp_buffer)
 | ||
|  |   webgpu.wgpuBufferDestroy(tmp_buffer)
 | ||
|  |   return memoryview(buf_copy).cast("B")
 | ||
|  | 
 | ||
|  | def pop_error(device:WGPUDevPtr) -> str:
 | ||
|  |   return _run(webgpu.wgpuDevicePopErrorScopeF, webgpu.WGPUPopErrorScopeCallbackInfo, webgpu.WGPUPopErrorScopeCallback, None, 2, 2, device)
 | ||
|  | 
 | ||
|  | def create_uniform(wgpu_device:WGPUDevPtr, val:int|float) -> WGPUBufPtr:
 | ||
|  |   buf = webgpu.wgpuDeviceCreateBuffer(wgpu_device,
 | ||
|  |     webgpu.WGPUBufferDescriptor(size=4, usage=webgpu.WGPUBufferUsage_Uniform | webgpu.WGPUBufferUsage_CopyDst))
 | ||
|  |   write_buffer(wgpu_device, buf, 0, val.to_bytes(4, "little") if isinstance(val, int) else struct.pack('<f', val))
 | ||
|  |   return buf
 | ||
|  | 
 | ||
|  | class WebGPUProgram:
 | ||
|  |   def __init__(self, dev:tuple[WGPUDevPtr, bool], name:str, lib:bytes):
 | ||
|  |     (self.dev, self.timestamp_supported) = dev
 | ||
|  | 
 | ||
|  |     # Creating shader module
 | ||
|  |     shader = webgpu.WGPUShaderModuleWGSLDescriptor(code=to_wgpu_str(lib.decode()),
 | ||
|  |       chain=webgpu.WGPUChainedStruct(sType=webgpu.WGPUSType_ShaderSourceWGSL))
 | ||
|  |     module = webgpu.WGPUShaderModuleDescriptor()
 | ||
|  |     module.nextInChain = ctypes.cast(ctypes.pointer(shader), ctypes.POINTER(webgpu.struct_WGPUChainedStruct))
 | ||
|  | 
 | ||
|  |     # Check compiler error
 | ||
|  |     webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
 | ||
|  |     shader_module = webgpu.wgpuDeviceCreateShaderModule(self.dev, module)
 | ||
|  | 
 | ||
|  |     if err := pop_error(self.dev): raise RuntimeError(f"Shader compilation failed: {err}")
 | ||
|  | 
 | ||
|  |     self.name, self.lib, self.prg = name, lib, shader_module
 | ||
|  |   def __call__(self, *bufs:WGPUBufPtr, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
 | ||
|  |                vals:tuple[int, ...]=(), wait=False) -> float|None:
 | ||
|  |     wait = wait and self.timestamp_supported
 | ||
|  |     tmp_bufs = [*bufs]
 | ||
|  |     buf_patch = False
 | ||
|  | 
 | ||
|  |     # WebGPU does not allow using the same buffer for input and output
 | ||
|  |     for i in range(1, len(bufs)):
 | ||
|  |       if ctypes.addressof(bufs[i]) == ctypes.addressof(bufs[0]):
 | ||
|  |         tmp_bufs[0] = webgpu.wgpuDeviceCreateBuffer(self.dev,
 | ||
|  |           webgpu.WGPUBufferDescriptor(size=webgpu.wgpuBufferGetSize(bufs[0]), usage=webgpu.wgpuBufferGetUsage(bufs[0])))
 | ||
|  |         buf_patch = True
 | ||
|  | 
 | ||
|  |     # Creating bind group layout
 | ||
|  |     binding_layouts = [webgpu.WGPUBindGroupLayoutEntry(binding=0, visibility= webgpu.WGPUShaderStage_Compute,
 | ||
|  |       buffer=webgpu.WGPUBufferBindingLayout(type=webgpu.WGPUBufferBindingType_Uniform))]
 | ||
|  |     binding_layouts += [webgpu.WGPUBindGroupLayoutEntry(binding=i+1, visibility=webgpu.WGPUShaderStage_Compute,
 | ||
|  |       buffer=webgpu.WGPUBufferBindingLayout(type=webgpu.WGPUBufferBindingType_Uniform if i >= len(tmp_bufs)
 | ||
|  |       else webgpu.WGPUBufferBindingType_Storage)) for i in range(len(tmp_bufs)+len(vals))]
 | ||
|  | 
 | ||
|  |     bl_arr_type = webgpu.WGPUBindGroupLayoutEntry * len(binding_layouts)
 | ||
|  |     webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
 | ||
|  |     bind_group_layouts = [webgpu.wgpuDeviceCreateBindGroupLayout(self.dev, webgpu.WGPUBindGroupLayoutDescriptor(
 | ||
|  |       entryCount=len(binding_layouts), entries=ctypes.cast(bl_arr_type(*binding_layouts), ctypes.POINTER(webgpu.WGPUBindGroupLayoutEntry))))]
 | ||
|  | 
 | ||
|  |     if bg_layout_err := pop_error(self.dev): raise RuntimeError(f"Error creating bind group layout: {bg_layout_err}")
 | ||
|  | 
 | ||
|  |     # Creating pipeline layout
 | ||
|  |     pipeline_layout_desc = webgpu.WGPUPipelineLayoutDescriptor(bindGroupLayoutCount=len(bind_group_layouts),
 | ||
|  |       bindGroupLayouts = (webgpu.WGPUBindGroupLayout * len(bind_group_layouts))(*bind_group_layouts))
 | ||
|  | 
 | ||
|  |     webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
 | ||
|  |     pipeline_layout = webgpu.wgpuDeviceCreatePipelineLayout(self.dev, pipeline_layout_desc)
 | ||
|  | 
 | ||
|  |     if pipe_err := pop_error(self.dev): raise RuntimeError(f"Error creating pipeline layout: {pipe_err}")
 | ||
|  | 
 | ||
|  |     # Creating bind group
 | ||
|  |     bindings = [webgpu.WGPUBindGroupEntry(binding=0, buffer=create_uniform(self.dev, float('inf')), offset=0, size=4)]
 | ||
|  |     bindings += [webgpu.WGPUBindGroupEntry(binding=i+1, buffer=create_uniform(self.dev, x) if i >= len(tmp_bufs) else x, offset=0,
 | ||
|  |       size=4 if i >= len(tmp_bufs) else webgpu.wgpuBufferGetSize(x)) for i,x in enumerate(tuple(tmp_bufs)+vals)]
 | ||
|  | 
 | ||
|  |     bg_arr_type = webgpu.WGPUBindGroupEntry * len(bindings)
 | ||
|  |     bind_group_desc = webgpu.WGPUBindGroupDescriptor(layout=bind_group_layouts[0], entryCount=len(bindings), entries=bg_arr_type(*bindings))
 | ||
|  |     webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
 | ||
|  |     bind_group = webgpu.wgpuDeviceCreateBindGroup(self.dev, bind_group_desc)
 | ||
|  | 
 | ||
|  |     if bind_err := pop_error(self.dev): raise RuntimeError(f"Error creating bind group: {bind_err}")
 | ||
|  | 
 | ||
|  |     # Creating compute pipeline
 | ||
|  |     compute_desc = webgpu.WGPUComputePipelineDescriptor(layout=pipeline_layout,
 | ||
|  |       compute=webgpu.WGPUComputeState(module=self.prg, entryPoint=to_wgpu_str(self.name)))
 | ||
|  |     pipeline_result = _run(webgpu.wgpuDeviceCreateComputePipelineAsync2, webgpu.WGPUCreateComputePipelineAsyncCallbackInfo2,
 | ||
|  |     webgpu.WGPUCreateComputePipelineAsyncCallback2, webgpu.WGPUCreatePipelineAsyncStatus__enumvalues, 1, None, self.dev, compute_desc)
 | ||
|  | 
 | ||
|  |     command_encoder = webgpu.wgpuDeviceCreateCommandEncoder(self.dev, webgpu.WGPUCommandEncoderDescriptor())
 | ||
|  |     comp_pass_desc = webgpu.WGPUComputePassDescriptor(nextInChain=None)
 | ||
|  | 
 | ||
|  |     if wait:
 | ||
|  |       query_set = webgpu.wgpuDeviceCreateQuerySet(self.dev, webgpu.WGPUQuerySetDescriptor(type=webgpu.WGPUQueryType_Timestamp, count=2))
 | ||
|  |       query_buf = webgpu.wgpuDeviceCreateBuffer(self.dev,
 | ||
|  |         webgpu.WGPUBufferDescriptor(size=16, usage=webgpu.WGPUBufferUsage_QueryResolve | webgpu.WGPUBufferUsage_CopySrc))
 | ||
|  |       comp_pass_desc.timestampWrites = ctypes.pointer(webgpu.WGPUComputePassTimestampWrites(
 | ||
|  |         querySet=query_set, beginningOfPassWriteIndex=0, endOfPassWriteIndex=1))
 | ||
|  | 
 | ||
|  |     # Begin compute pass
 | ||
|  |     compute_pass = webgpu.wgpuCommandEncoderBeginComputePass(command_encoder, comp_pass_desc)
 | ||
|  |     webgpu.wgpuComputePassEncoderSetPipeline(compute_pass, pipeline_result)
 | ||
|  |     webgpu.wgpuComputePassEncoderSetBindGroup(compute_pass, 0, bind_group, 0, None)
 | ||
|  |     webgpu.wgpuComputePassEncoderDispatchWorkgroups(compute_pass, *global_size)
 | ||
|  |     webgpu.wgpuComputePassEncoderEnd(compute_pass)
 | ||
|  | 
 | ||
|  |     if wait: webgpu.wgpuCommandEncoderResolveQuerySet(command_encoder, query_set, 0, 2, query_buf, 0)
 | ||
|  | 
 | ||
|  |     cmd_buf = webgpu.wgpuCommandEncoderFinish(command_encoder, webgpu.WGPUCommandBufferDescriptor())
 | ||
|  |     webgpu.wgpuQueueSubmit(webgpu.wgpuDeviceGetQueue(self.dev), 1, (webgpu.WGPUCommandBuffer*1)(cmd_buf))
 | ||
|  | 
 | ||
|  |     if buf_patch:
 | ||
|  |       copy_buffer_to_buffer(self.dev, tmp_bufs[0], 0, bufs[0], 0, webgpu.wgpuBufferGetSize(bufs[0]))
 | ||
|  |       webgpu.wgpuBufferDestroy(tmp_bufs[0])
 | ||
|  | 
 | ||
|  |     if wait:
 | ||
|  |       time = ((timestamps:=read_buffer(self.dev, query_buf).cast("Q").tolist())[1] - timestamps[0]) / 1e9
 | ||
|  |       webgpu.wgpuBufferDestroy(query_buf)
 | ||
|  |       webgpu.wgpuQuerySetDestroy(query_set)
 | ||
|  |       return time
 | ||
|  |     return None
 | ||
|  | 
 | ||
|  | class WebGpuAllocator(Allocator['WGPUDevPtr']):
 | ||
|  |   def _alloc(self, size:int, options:BufferSpec) -> WGPUBufPtr:
 | ||
|  |     # WebGPU buffers have to be 4-byte aligned
 | ||
|  |     return webgpu.wgpuDeviceCreateBuffer(self.dev, webgpu.WGPUBufferDescriptor(size=round_up(size, 4),
 | ||
|  |       usage=webgpu.WGPUBufferUsage_Storage | webgpu.WGPUBufferUsage_CopyDst | webgpu.WGPUBufferUsage_CopySrc))
 | ||
|  |   def _copyin(self, dest:WGPUBufPtr, src:memoryview):
 | ||
|  |     if src.nbytes % 4:
 | ||
|  |       padded_src = bytearray(round_up(src.nbytes, 4))
 | ||
|  |       padded_src[:src.nbytes] = src
 | ||
|  |     write_buffer(self.dev, dest, 0, padded_src if src.nbytes % 4 else src)
 | ||
|  |   def _copyout(self, dest:memoryview, src:WGPUBufPtr):
 | ||
|  |     buffer_data = read_buffer(self.dev, src)
 | ||
|  |     dest[:] = buffer_data[:dest.nbytes] if webgpu.wgpuBufferGetSize(src)  > dest.nbytes else buffer_data
 | ||
|  |   @suppress_finalizing
 | ||
|  |   def _free(self, opaque:WGPUBufPtr, options:BufferSpec): webgpu.wgpuBufferDestroy(opaque)
 | ||
|  | 
 | ||
|  | class WebGpuDevice(Compiled):
 | ||
|  |   def __init__(self, device:str):
 | ||
|  |     # Requesting an adapter
 | ||
|  |     adapter_res = _run(webgpu.wgpuInstanceRequestAdapterF, webgpu.WGPURequestAdapterCallbackInfo, webgpu.WGPURequestAdapterCallback,
 | ||
|  |     webgpu.WGPURequestAdapterStatus__enumvalues, 1, 2, instance,
 | ||
|  | 
 | ||
|  |     webgpu.WGPURequestAdapterOptions(powerPreference=webgpu.WGPUPowerPreference_HighPerformance,
 | ||
|  |       backendType=backend_types.get(os.getenv("WEBGPU_BACKEND", ""), 0)))
 | ||
|  | 
 | ||
|  |     # Get supported features
 | ||
|  |     supported_features = webgpu.WGPUSupportedFeatures()
 | ||
|  |     webgpu.wgpuAdapterGetFeatures(adapter_res, supported_features)
 | ||
|  |     supported = [supported_features.features[i] for i in range(supported_features.featureCount)]
 | ||
|  |     features = [feat for feat in [webgpu.WGPUFeatureName_TimestampQuery, webgpu.WGPUFeatureName_ShaderF16] if feat in supported]
 | ||
|  |     dev_desc = webgpu.WGPUDeviceDescriptor(requiredFeatureCount=len(features),requiredFeatures=(webgpu.WGPUFeatureName * len(features))(*features))
 | ||
|  | 
 | ||
|  |     # Limits
 | ||
|  |     supported_limits = webgpu.WGPUSupportedLimits()
 | ||
|  |     webgpu.wgpuAdapterGetLimits(adapter_res, ctypes.cast(ctypes.pointer(supported_limits),ctypes.POINTER(webgpu.struct_WGPUSupportedLimits)))
 | ||
|  |     limits = webgpu.WGPURequiredLimits(limits=supported_limits.limits)
 | ||
|  |     dev_desc.requiredLimits = ctypes.cast(ctypes.pointer(limits),ctypes.POINTER(webgpu.struct_WGPURequiredLimits))
 | ||
|  | 
 | ||
|  |     # Requesting a device
 | ||
|  |     device_res = _run(webgpu.wgpuAdapterRequestDeviceF, webgpu.WGPURequestDeviceCallbackInfo, webgpu.WGPURequestDeviceCallback,
 | ||
|  |     webgpu.WGPURequestDeviceStatus__enumvalues, 1, 2, adapter_res, dev_desc)
 | ||
|  | 
 | ||
|  |     super().__init__(device, WebGpuAllocator(device_res), [(WGSLRenderer, Compiler)],
 | ||
|  |       functools.partial(WebGPUProgram, (device_res, webgpu.WGPUFeatureName_TimestampQuery in supported)))
 | ||
|  | 
 | ||
|  |   def synchronize(self):
 | ||
|  |     _run(webgpu.wgpuQueueOnSubmittedWorkDone2, webgpu.WGPUQueueWorkDoneCallbackInfo2, webgpu.WGPUQueueWorkDoneCallback2,
 | ||
|  |     webgpu.WGPUQueueWorkDoneStatus__enumvalues, None, None, webgpu.wgpuDeviceGetQueue(self.runtime.args[0][0]))
 |