openpilot is an open source driver assistance system. openpilot performs the functions of Automated Lane Centering and Adaptive Cruise Control for over 200 supported car makes and models.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

239 lines
11 KiB

import os
from extra.export_model import compile_net, jit_model, dtype_to_js_type
from extra.f16_decompress import u32_to_f16
from examples.stable_diffusion import StableDiffusion
from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict
from tinygrad.tensor import Tensor
from tinygrad import Device, dtypes
from tinygrad.helpers import fetch
from typing import NamedTuple, Any, List
import requests
import argparse
import numpy as np
def convert_f32_to_f16(input_file, output_file):
with open(input_file, 'rb') as f:
metadata_length_bytes = f.read(8)
metadata_length = int.from_bytes(metadata_length_bytes, byteorder='little', signed=False)
metadata_json_bytes = f.read(metadata_length)
float32_values = np.fromfile(f, dtype=np.float32)
first_text_model_offset = 3772703308
num_elements = int((first_text_model_offset)/4)
front_float16_values = float32_values[:num_elements].astype(np.float16)
rest_float32_values = float32_values[num_elements:]
with open(output_file, 'wb') as f:
f.write(metadata_length_bytes)
f.write(metadata_json_bytes)
front_float16_values.tofile(f)
rest_float32_values.tofile(f)
def split_safetensor(fn):
_, data_start, metadata = safe_load_metadata(fn)
text_model_offset = 3772703308
chunk_size = 536870912
for k in metadata:
# safetensor is in fp16, except for text moel
if (metadata[k]["data_offsets"][0] < text_model_offset):
metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0]/2)
metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1]/2)
last_offset = 0
part_end_offsets = []
for k in metadata:
offset = metadata[k]['data_offsets'][0]
if offset == text_model_offset:
break
part_offset = offset - last_offset
if (part_offset >= chunk_size):
part_end_offsets.append(data_start+offset)
last_offset = offset
text_model_start = int(text_model_offset/2)
net_bytes = bytes(open(fn, 'rb').read())
part_end_offsets.append(text_model_start+data_start)
cur_pos = 0
for i, end_pos in enumerate(part_end_offsets):
with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.safetensors'), "wb+") as f:
f.write(net_bytes[cur_pos:end_pos])
cur_pos = end_pos
with open(os.path.join(os.path.dirname(__file__), f'./net_textmodel.safetensors'), "wb+") as f:
f.write(net_bytes[text_model_start+data_start:])
return part_end_offsets
def fetch_dep(file, url):
with open(file, "w", encoding="utf-8") as f:
f.write(requests.get(url).text.replace("https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/bpe_simple_vocab_16e6.mjs", "./bpe_simple_vocab_16e6.mjs"))
if __name__ == "__main__":
fetch_dep(os.path.join(os.path.dirname(__file__), "clip_tokenizer.js"), "https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/clip_tokenizer.js")
fetch_dep(os.path.join(os.path.dirname(__file__), "bpe_simple_vocab_16e6.mjs"), "https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/bpe_simple_vocab_16e6.mjs")
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--remoteweights', action='store_true', help="Use safetensors from Huggingface, or from local")
args = parser.parse_args()
Device.DEFAULT = "WEBGPU"
Tensor.no_grad = True
model = StableDiffusion()
# load in weights
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
class Step(NamedTuple):
name: str = ""
input: List[Tensor] = []
forward: Any = None
sub_steps = [
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model),
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode),
Step(name = "f16tof32", input = [Tensor.randn(2097120, dtype=dtypes.uint32)], forward = u32_to_f16)
]
prg = ""
def fixup_code(code, key):
code = code.replace(key, 'main')\
.replace("var<uniform> INFINITY : f32;\n", "fn inf(a: f32) -> f32 { return a/0.0; }\n")\
.replace("@group(0) @binding(0)", "")\
.replace("INFINITY", "inf(1.0)")
for i in range(1,9): code = code.replace(f"binding({i})", f"binding({i-1})")
return code
def compile_step(model, step: Step):
run, special_names = jit_model(step, *step.input)
functions, statements, bufs, _ = compile_net(run, special_names)
state = get_state_dict(model)
weights = {id(x.lazydata.base.realized): name for name, x in state.items()}
kernel_code = '\n\n'.join([f"const {key} = `{fixup_code(code, key)}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _, _, _) in statements])
input_names = [name for _,name in special_names.items() if "input" in name]
output_names = [name for _,name in special_names.items() if "output" in name]
input_buf_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names]
output_buf_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names]
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
exported_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value])
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buf_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,_ in enumerate(input_names)])
return f"""\n var {step.name} = function() {{
{kernel_code}
return {{
"setup": async (device, safetensor) => {{
const metadata = safetensor ? getTensorMetadata(safetensor[0]) : null;
{exported_bufs}
{gpu_write_bufs}
const gpuReadBuffer = device.createBuffer({{ size: output0.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
const kernels = [{kernel_names}];
const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}})));
return async ({",".join([f'data{i}' for i,(k,v) in enumerate(special_names.items()) if v != "output0"])}) => {{
const commandEncoder = device.createCommandEncoder();
{input_writer}
{kernel_calls}
commandEncoder.copyBufferToBuffer(output0, 0, gpuReadBuffer, 0, output0.size);
const gpuCommands = commandEncoder.finish();
device.queue.submit([gpuCommands]);
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
const resultBuffer = new {output_buf_types[0]}(gpuReadBuffer.size/{bufs[output_names[0]][1].itemsize});
resultBuffer.set(new {output_buf_types[0]}(gpuReadBuffer.getMappedRange()));
gpuReadBuffer.unmap();
return resultBuffer;
}}
}}
}}
}}
"""
for step in sub_steps:
print(f'Executing step={step.name}')
prg += compile_step(model, step)
if step.name == "diffusor":
if args.remoteweights:
base_url = "https://huggingface.co/wpmed/stable-diffusion-f16-new/resolve/main"
else:
state = get_state_dict(model)
safe_save(state, os.path.join(os.path.dirname(__file__), "net.safetensors"))
convert_f32_to_f16(os.path.join(os.path.dirname(__file__), "./net.safetensors"), os.path.join(os.path.dirname(__file__), "./net_conv.safetensors"))
split_safetensor(os.path.join(os.path.dirname(__file__), "./net_conv.safetensors"))
os.remove(os.path.join(os.path.dirname(__file__), "net.safetensors"))
os.remove(os.path.join(os.path.dirname(__file__), "net_conv.safetensors"))
base_url = "."
prekernel = f"""
window.MODEL_BASE_URL= "{base_url}";
const getTensorMetadata = (safetensorBuffer) => {{
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
}};
const getTensorBuffer = (safetensorParts, tensorMetadata, key) => {{
let selectedPart = 0;
let counter = 0;
let partStartOffsets = [1131408336, 2227518416, 3308987856, 4265298864];
let correctedOffsets = tensorMetadata.data_offsets;
let prev_offset = 0;
for (let start of partStartOffsets) {{
prev_offset = (counter == 0) ? 0 : partStartOffsets[counter-1];
if (tensorMetadata.data_offsets[0] < start) {{
selectedPart = counter;
correctedOffsets = [correctedOffsets[0]-prev_offset, correctedOffsets[1]-prev_offset];
break;
}}
counter++;
}}
return safetensorParts[selectedPart].subarray(...correctedOffsets);
}}
const getWeight = (safetensors, key) => {{
let uint8Data = getTensorBuffer(safetensors, getTensorMetadata(safetensors[0])[key], key);
return new Float32Array(uint8Data.buffer, uint8Data.byteOffset, uint8Data.byteLength / Float32Array.BYTES_PER_ELEMENT);
}}
const createEmptyBuf = (device, size) => {{
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
}};
const createWeightBuf = (device, size, data) => {{
const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }});
new Uint8Array(buf.getMappedRange()).set(data);
buf.unmap();
return buf;
}};
const addComputePass = (device, commandEncoder, pipeline, bufs, workgroup) => {{
const bindGroup = device.createBindGroup({{layout: pipeline.getBindGroupLayout(0), entries: bufs.map((buffer, index) => ({{ binding: index, resource: {{ buffer }} }}))}});
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(pipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.dispatchWorkgroups(...workgroup);
passEncoder.end();
}};"""
with open(os.path.join(os.path.dirname(__file__), "net.js"), "w") as text_file:
text_file.write(prekernel + prg)