import functools, struct from tinygrad.device import Compiled, Allocator, Compiler from tinygrad.renderer.wgsl import WGSLRenderer from tinygrad.helpers import round_up from tinygrad.runtime.autogen import webgpu from typing import List, Any import ctypes import os backend_types = {v: k for k, v in webgpu.WGPUBackendType__enumvalues.items() } instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True))) def to_c_string(_str): return ctypes.create_string_buffer(_str.encode('utf-8')) def from_wgpu_str(string_view): return ctypes.string_at(string_view.data, string_view.length).decode("utf-8") def to_wgpu_str(_str): return webgpu.WGPUStringView(data=ctypes.cast(ctypes.pointer(to_c_string(_str)), ctypes.POINTER(ctypes.c_char)), length=len(_str)) def _wait(future): assert webgpu.wgpuInstanceWaitAny(instance, 1, webgpu.WGPUFutureWaitInfo(future=future), 2**64-1) == webgpu.WGPUWaitStatus_Success, "Future failed" def write_buffer(device, buf, offset, src): src = bytearray(src) webgpu.wgpuQueueWriteBuffer(webgpu.wgpuDeviceGetQueue(device), buf, offset, (ctypes.c_uint8 * len(src)).from_buffer(src), len(src)) def _run(async_fun, cb_info_type, cb_type, status_enum, res_idx, msg_idx, *params): result: List[Any] = [] def cb(*params): result[:] = params if msg_idx: result[msg_idx] = from_wgpu_str(result[msg_idx]) cb_info = cb_info_type(nextInChain=None, mode=webgpu.WGPUCallbackMode_WaitAnyOnly, callback=cb_type(cb)) _wait(async_fun(*params, cb_info)) if result[0] != 1: raise RuntimeError(f"[{status_enum[result[0]] if status_enum else 'ERROR'}]{result[msg_idx] if msg_idx else ''}") return result[res_idx] if res_idx else None def copy_buffer_to_buffer(dev, src, src_offset, dst, dst_offset, size): encoder = webgpu.wgpuDeviceCreateCommandEncoder(dev, webgpu.WGPUCommandEncoderDescriptor()) webgpu.wgpuCommandEncoderCopyBufferToBuffer(encoder, src, src_offset, dst, dst_offset, size) cb = webgpu.wgpuCommandEncoderFinish(encoder, webgpu.WGPUCommandBufferDescriptor()) webgpu.wgpuQueueSubmit(webgpu.wgpuDeviceGetQueue(dev), 1, (webgpu.WGPUCommandBuffer*1)(cb)) webgpu.wgpuCommandBufferRelease(cb) webgpu.wgpuCommandEncoderRelease(encoder) def read_buffer(dev, buf): size = webgpu.wgpuBufferGetSize(buf) tmp_buffer = webgpu.wgpuDeviceCreateBuffer(dev, webgpu.WGPUBufferDescriptor(size=size, usage=webgpu.WGPUBufferUsage_CopyDst | webgpu.WGPUBufferUsage_MapRead, mappedAtCreation=False)) copy_buffer_to_buffer(dev, buf, 0, tmp_buffer, 0, size) _run(webgpu.wgpuBufferMapAsync2, webgpu.WGPUBufferMapCallbackInfo2, webgpu.WGPUBufferMapCallback2, webgpu.WGPUBufferMapAsyncStatus__enumvalues, None, 0, tmp_buffer, webgpu.WGPUMapMode_Read, 0, size) void_ptr = ctypes.cast(webgpu.wgpuBufferGetConstMappedRange(tmp_buffer, 0, size), ctypes.c_void_p) buf_copy = bytearray((ctypes.c_uint8 * size).from_address(void_ptr.value)) webgpu.wgpuBufferUnmap(tmp_buffer) webgpu.wgpuBufferDestroy(tmp_buffer) return memoryview(buf_copy).cast("B") def pop_error(device): return _run(webgpu.wgpuDevicePopErrorScopeF, webgpu.WGPUPopErrorScopeCallbackInfo, webgpu.WGPUPopErrorScopeCallback, None, 2, 2, device) def create_uniform(wgpu_device, val): buf = webgpu.wgpuDeviceCreateBuffer(wgpu_device, webgpu.WGPUBufferDescriptor(size=4, usage=webgpu.WGPUBufferUsage_Uniform | webgpu.WGPUBufferUsage_CopyDst)) write_buffer(wgpu_device, buf, 0, val.to_bytes(4, "little") if isinstance(val, int) else struct.pack('= len(tmp_bufs) else webgpu.WGPUBufferBindingType_Storage)) for i in range(len(tmp_bufs)+len(vals))] bl_arr_type = webgpu.WGPUBindGroupLayoutEntry * len(binding_layouts) webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation) bind_group_layouts = [webgpu.wgpuDeviceCreateBindGroupLayout(self.dev, webgpu.WGPUBindGroupLayoutDescriptor( entryCount=len(binding_layouts), entries=ctypes.cast(bl_arr_type(*binding_layouts), ctypes.POINTER(webgpu.WGPUBindGroupLayoutEntry))))] if bg_layout_err := pop_error(self.dev): raise RuntimeError(f"Error creating bind group layout: {bg_layout_err}") # Creating pipeline layout pipeline_layout_desc = webgpu.WGPUPipelineLayoutDescriptor(bindGroupLayoutCount=len(bind_group_layouts), bindGroupLayouts = (webgpu.WGPUBindGroupLayout * len(bind_group_layouts))(*bind_group_layouts)) webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation) pipeline_layout = webgpu.wgpuDeviceCreatePipelineLayout(self.dev, pipeline_layout_desc) if pipe_err := pop_error(self.dev): raise RuntimeError(f"Error creating pipeline layout: {pipe_err}") # Creating bind group bindings = [webgpu.WGPUBindGroupEntry(binding=0, buffer=create_uniform(self.dev, float('inf')), offset=0, size=4)] bindings += [webgpu.WGPUBindGroupEntry(binding=i+1, buffer=create_uniform(self.dev, x) if i >= len(tmp_bufs) else x, offset=0, size=4 if i >= len(tmp_bufs) else webgpu.wgpuBufferGetSize(x)) for i,x in enumerate(tuple(tmp_bufs)+vals)] bg_arr_type = webgpu.WGPUBindGroupEntry * len(bindings) bind_group_desc = webgpu.WGPUBindGroupDescriptor(layout=bind_group_layouts[0], entryCount=len(bindings), entries=bg_arr_type(*bindings)) webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation) bind_group = webgpu.wgpuDeviceCreateBindGroup(self.dev, bind_group_desc) if bind_err := pop_error(self.dev): raise RuntimeError(f"Error creating bind group: {bind_err}") # Creating compute pipeline compute_desc = webgpu.WGPUComputePipelineDescriptor(layout=pipeline_layout, compute=webgpu.WGPUComputeState(module=self.prg, entryPoint=to_wgpu_str(self.name))) pipeline_result = _run(webgpu.wgpuDeviceCreateComputePipelineAsync2, webgpu.WGPUCreateComputePipelineAsyncCallbackInfo2, webgpu.WGPUCreateComputePipelineAsyncCallback2, webgpu.WGPUCreatePipelineAsyncStatus__enumvalues, 1, None, self.dev, compute_desc) command_encoder = webgpu.wgpuDeviceCreateCommandEncoder(self.dev, webgpu.WGPUCommandEncoderDescriptor()) comp_pass_desc = webgpu.WGPUComputePassDescriptor(nextInChain=None) if wait: query_set = webgpu.wgpuDeviceCreateQuerySet(self.dev, webgpu.WGPUQuerySetDescriptor(type=webgpu.WGPUQueryType_Timestamp, count=2)) query_buf = webgpu.wgpuDeviceCreateBuffer(self.dev, webgpu.WGPUBufferDescriptor(size=16, usage=webgpu.WGPUBufferUsage_QueryResolve | webgpu.WGPUBufferUsage_CopySrc)) comp_pass_desc.timestampWrites = ctypes.pointer(webgpu.WGPUComputePassTimestampWrites( querySet=query_set, beginningOfPassWriteIndex=0, endOfPassWriteIndex=1)) # Begin compute pass compute_pass = webgpu.wgpuCommandEncoderBeginComputePass(command_encoder, comp_pass_desc) webgpu.wgpuComputePassEncoderSetPipeline(compute_pass, pipeline_result) webgpu.wgpuComputePassEncoderSetBindGroup(compute_pass, 0, bind_group, 0, None) webgpu.wgpuComputePassEncoderDispatchWorkgroups(compute_pass, *global_size) webgpu.wgpuComputePassEncoderEnd(compute_pass) if wait: webgpu.wgpuCommandEncoderResolveQuerySet(command_encoder, query_set, 0, 2, query_buf, 0) cmd_buf = webgpu.wgpuCommandEncoderFinish(command_encoder, webgpu.WGPUCommandBufferDescriptor()) webgpu.wgpuQueueSubmit(webgpu.wgpuDeviceGetQueue(self.dev), 1, (webgpu.WGPUCommandBuffer*1)(cmd_buf)) if buf_patch: copy_buffer_to_buffer(self.dev, tmp_bufs[0], 0, bufs[0], 0, webgpu.wgpuBufferGetSize(bufs[0])) webgpu.wgpuBufferDestroy(tmp_bufs[0]) if wait: time = ((timestamps:=read_buffer(self.dev, query_buf).cast("Q").tolist())[1] - timestamps[0]) / 1e9 webgpu.wgpuBufferDestroy(query_buf) webgpu.wgpuQuerySetDestroy(query_set) return time class WebGpuAllocator(Allocator): def __init__(self, dev): self.dev = dev def _alloc(self, size: int, options): # WebGPU buffers have to be 4-byte aligned return webgpu.wgpuDeviceCreateBuffer(self.dev, webgpu.WGPUBufferDescriptor(size=round_up(size, 4), usage=webgpu.WGPUBufferUsage_Storage | webgpu.WGPUBufferUsage_CopyDst | webgpu.WGPUBufferUsage_CopySrc)) def _copyin(self, dest, src: memoryview): if src.nbytes % 4: padded_src = bytearray(round_up(src.nbytes, 4)) padded_src[:src.nbytes] = src write_buffer(self.dev, dest, 0, padded_src if src.nbytes % 4 else src) def _copyout(self, dest: memoryview, src): buffer_data = read_buffer(self.dev, src) dest[:] = buffer_data[:dest.nbytes] if webgpu.wgpuBufferGetSize(src) > dest.nbytes else buffer_data def _free(self, opaque, options): webgpu.wgpuBufferDestroy(opaque) class WebGpuDevice(Compiled): def __init__(self, device:str): # Requesting an adapter adapter_res = _run(webgpu.wgpuInstanceRequestAdapterF, webgpu.WGPURequestAdapterCallbackInfo, webgpu.WGPURequestAdapterCallback, webgpu.WGPURequestAdapterStatus__enumvalues, 1, 2, instance, webgpu.WGPURequestAdapterOptions(powerPreference=webgpu.WGPUPowerPreference_HighPerformance, backendType=backend_types.get(os.getenv("WEBGPU_BACKEND", ""), 0))) # Get supported features supported_features = webgpu.WGPUSupportedFeatures() webgpu.wgpuAdapterGetFeatures(adapter_res, supported_features) supported = [supported_features.features[i] for i in range(supported_features.featureCount)] features = [feat for feat in [webgpu.WGPUFeatureName_TimestampQuery, webgpu.WGPUFeatureName_ShaderF16] if feat in supported] dev_desc = webgpu.WGPUDeviceDescriptor(requiredFeatureCount=len(features),requiredFeatures=(webgpu.WGPUFeatureName * len(features))(*features)) # Limits supported_limits = webgpu.WGPUSupportedLimits() webgpu.wgpuAdapterGetLimits(adapter_res, ctypes.cast(ctypes.pointer(supported_limits),ctypes.POINTER(webgpu.struct_WGPUSupportedLimits))) limits = webgpu.WGPURequiredLimits(limits=supported_limits.limits) dev_desc.requiredLimits = ctypes.cast(ctypes.pointer(limits),ctypes.POINTER(webgpu.struct_WGPURequiredLimits)) # Requesting a device device_res = _run(webgpu.wgpuAdapterRequestDeviceF, webgpu.WGPURequestDeviceCallbackInfo, webgpu.WGPURequestDeviceCallback, webgpu.WGPURequestDeviceStatus__enumvalues, 1, 2, adapter_res, dev_desc) super().__init__(device, WebGpuAllocator(device_res), WGSLRenderer(), Compiler(), functools.partial(WebGPUProgram, (device_res, webgpu.WGPUFeatureName_TimestampQuery in supported))) def synchronize(self): _run(webgpu.wgpuQueueOnSubmittedWorkDone2, webgpu.WGPUQueueWorkDoneCallbackInfo2, webgpu.WGPUQueueWorkDoneCallback2, webgpu.WGPUQueueWorkDoneStatus__enumvalues, None, None, webgpu.wgpuDeviceGetQueue(self.runtime.args[0][0]))