<!DOCTYPE html>
< html lang = "en" >
< head >
< meta charset = "UTF-8" >
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" >
< title > tinygrad has WebGPU< / title >
< style >
/* General Reset */
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: #f4f7fb;
color: #333;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
padding: 20px;
flex-direction: column;
text-align: center;
}
h1 {
font-size: 2.2rem;
margin-bottom: 20px;
color: #4A90E2;
}
#wgpuError {
color: red;
font-size: 1.2rem;
margin-top: 20px;
display: none;
}
#sdTitle {
font-size: 1.5rem;
margin-bottom: 30px;
}
#mybox {
background: #ffffff;
padding: 15px;
border-radius: 10px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
width: 120%;
max-width: 550px;
margin-bottom: 10px;
}
input[type="text"] {
width: 100%;
padding: 12px;
margin-bottom: 20px;
font-size: 1rem;
border: 1px solid #ccc;
border-radius: 8px;
outline: none;
transition: all 0.3s ease;
}
input[type="text"]:focus {
border-color: #4A90E2;
}
label {
display: flex;
justify-content: space-between;
font-size: 1rem;
margin-bottom: 15px;
align-items: center;
}
input[type="range"] {
width: 100%;
margin-left: 10px;
-webkit-appearance: none;
appearance: none;
height: 8px;
border-radius: 4px;
background: #ddd;
outline: none;
transition: background 0.3s ease;
}
input[type="range"]:focus {
background: #4A90E2;
}
#stepRange,
#guidanceRange {
width: 80%;
}
span {
font-size: 1.1rem;
font-weight: 600;
color: #333;
}
input[type="button"] {
padding: 12px 25px;
background-color: #4A90E2;
color: #fff;
font-size: 1.2rem;
border: none;
border-radius: 8px;
cursor: pointer;
transition: background-color 0.3s ease;
width: 100%;
margin-bottom: 20px;
}
input[type="button"]:disabled {
background-color: #ccc;
cursor: not-allowed;
}
input[type="button"]:hover {
background-color: #357ABD;
}
#divModelDl, #divStepProgress {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 20px;
}
#modelDlProgressBar,
#progressBar {
width: 80%;
height: 12px;
border-radius: 6px;
background-color: #e0e0e0;
}
#modelDlProgressBar::-webkit-progress-bar,
#progressBar::-webkit-progress-bar {
border-radius: 6px;
}
#modelDlProgressValue, #progressFraction {
font-size: 1rem;
font-weight: 600;
color: #333;
}
canvas {
max-width: 100%;
max-height: 450px;
margin-top: 10px;
border-radius: 8px;
border: 1px solid #ddd;
}
< / style >
< script type = "module" >
import ClipTokenizer from './clip_tokenizer.js';
window.clipTokenizer = new ClipTokenizer();
< / script >
< script src = "./net.js" > < / script >
< / head >
< body >
< h1 id = "wgpuError" style = "display: none;" > WebGPU is not supported in this browser< / h1 >
< h1 id = "sdTitle" > StableDiffusion powered by < a href = "https://github.com/tinygrad/tinygrad" target = "_blank" style = "color: #4A90E2;" > tinygrad< / a > < / h1 >
< a href = "https://github.com/tinygrad/tinygrad" target = "_blank" style = "position: absolute; top: 20px; right: 20px;" >
< img src = "https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg"
alt="GitHub Logo"
style="width: 32px; height: 32px;">
< / a >
< div id = "mybox" >
< form id = "promptForm" >
< input id = "promptText" type = "text" placeholder = "Enter your prompt here" value = "a human standing on the surface of mars" >
< label >
Steps: < span id = "stepValue" > 9< / span >
< input id = "stepRange" type = "range" min = "5" max = "20" value = "9" step = "1" >
< / label >
< label >
Guidance: < span id = "guidanceValue" > 8.0< / span >
< input id = "guidanceRange" type = "range" min = "3" max = "15" value = "8.0" step = "0.1" >
< / label >
< input id = "btnRunNet" type = "button" value = "Run" disabled >
< / form >
< div id = "divModelDl" style = "display: flex; align-items: center; width: 100%; gap: 10px;" >
< span id = "modelDlTitle" > Downloading model< / span >
< progress id = "modelDlProgressBar" value = "0" max = "100" style = "flex-grow: 1;" > < / progress >
< span id = "modelDlProgressValue" > < / span >
< / div >
< div id = "divStepProgress" style = "display: none; align-items: center; width: 100%; gap: 10px;" >
< progress id = "progressBar" value = "0" max = "100" style = "flex-grow: 1;" > < / progress >
< span id = "progressFraction" > < / span >
< / div >
< div id = "divStepTime" style = "display: none; align-items: center; width: 100%; gap: 10px;" >
< span id = "stepTimeValue" > 0 ms< / span >
< / div >
< / div >
< canvas id = "canvas" width = "512" height = "512" > < / canvas >
< script >
let f16decomp = null;
function initDb() {
return new Promise((resolve, reject) => {
let db;
const request = indexedDB.open('tinydb', 1);
request.onerror = (event) => {
console.error('Database error:', event.target.error);
resolve(null);
};
request.onsuccess = (event) => {
db = event.target.result;
console.log("Db initialized.");
resolve(db);
};
request.onupgradeneeded = (event) => {
db = event.target.result;
if (!db.objectStoreNames.contains('tensors')) {
db.createObjectStore('tensors', { keyPath: 'id' });
}
};
});
}
function saveTensorToDb(db, id, tensor) {
return readTensorFromDb(db, id).then((result) => {
if (!result) {
new Promise((resolve, reject) => {
if (db == null) {
resolve(null);
}
const transaction = db.transaction(['tensors'], 'readwrite');
const store = transaction.objectStore('tensors');
const request = store.put({ id: id, content: tensor });
transaction.onabort = (event) => {
console.log("Transaction error while saving tensor: " + event.target.error);
resolve(null);
};
request.onsuccess = () => {
console.log('Tensor saved successfully.');
resolve();
};
request.onerror = (event) => {
console.error('Tensor save failed:', event.target.error);
resolve(null);
};
});
} else {
return null;
}
}).catch(()=> null);
}
function readTensorFromDb(db, id) {
return new Promise((resolve, reject) => {
if (db == null) {
resolve(null);
}
const transaction = db.transaction(['tensors'], 'readonly');
const store = transaction.objectStore('tensors');
const request = store.get(id);
transaction.onabort = (event) => {
console.log("Transaction error while reading tensor: " + event.target.error);
resolve(null);
};
request.onsuccess = (event) => {
const result = event.target.result;
if (result) {
resolve(result);
} else {
resolve(null);
}
};
request.onerror = (event) => {
console.error('Tensor retrieve failed: ', event.target.error);
resolve(null);
};
});
}
window.addEventListener('load', async function() {
if (!navigator.gpu) {
document.getElementById("wgpuError").style.display = "block";
document.getElementById("sdTitle").style.display = "none";
return;
}
let db = await initDb();
const ctx = document.getElementById("canvas").getContext("2d", { willReadFrequently: true });
let labels, nets, safetensorParts;
const getDevice = async () => {
const adapter = await navigator.gpu.requestAdapter();
const requiredLimits = {};
const maxBufferSizeInSDModel = 1073741824;
requiredLimits.maxStorageBufferBindingSize = maxBufferSizeInSDModel;
requiredLimits.maxBufferSize = maxBufferSizeInSDModel;
return await adapter.requestDevice({
requiredLimits,
requiredFeatures: ["shader-f16"],
powerPreference: "high-performance"
});
};
const timer = async (func, label = "") => {
const start = performance.now();
const out = await func();
const delta = (performance.now() - start).toFixed(1)
console.log(`${delta} ms ${label}`);
return out;
}
const getProgressDlForPart = async (part, progressCallback) => {
const response = await fetch(part);
const contentLength = response.headers.get('content-length');
const total = parseInt(contentLength, 10);
const res = new Response(new ReadableStream({
async start(controller) {
const reader = response.body.getReader();
for (;;) {
const { done, value } = await reader.read();
if (done) break;
progressCallback(part, value.byteLength, total);
controller.enqueue(value);
}
controller.close();
},
}));
return res.arrayBuffer();
};
const getAndDecompressF16Safetensors = async (device, progress) => {
let totalLoaded = 0;
let totalSize = 0;
let partSize = {};
const getPart = async(key) => {
let part = await readTensorFromDb(db, key);
if (part) {
console.log(`Cache hit: ${key}`);
return Promise.resolve(part.content);
} else {
console.log(`Cache miss: ${key}`);
return getProgressDlForPart(`${window.MODEL_BASE_URL}/${key}.safetensors`, progressCallback);
}
}
const progressCallback = (part, loaded, total) => {
totalLoaded += loaded;
if (!partSize[part]) {
totalSize += total;
partSize[part] = true;
}
progress(totalLoaded, totalSize);
};
let netKeys = ["net_part0", "net_part1", "net_part2", "net_part3", "net_textmodel"];
let buffers = await Promise.all(netKeys.map(key => getPart(key)));
// Combine everything except for text model, since that's already f32
const totalLength = buffers.reduce((acc, buffer, index, array) => {
if (index < 4 ) {
return acc + buffer.byteLength;
} else {
return acc;
}
}, 0
);
combinedBuffer = new Uint8Array(totalLength);
let offset = 0;
buffers.forEach((buffer, index) => {
saveTensorToDb(db, netKeys[index], new Uint8Array(buffer));
if (index < 4 ) {
combinedBuffer.set(new Uint8Array(buffer), offset);
offset += buffer.byteLength;
buffer = null;
}
});
let textModelU8 = new Uint8Array(buffers[4]);
document.getElementById("modelDlTitle").innerHTML = "Decompressing model";
const textModelOffset = 3772703308;
const metadataLength = Number(new DataView(combinedBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(combinedBuffer.subarray(8, 8 + metadataLength)));
const allToDecomp = combinedBuffer.byteLength - (8 + metadataLength);
const decodeChunkSize = 8388480;
const numChunks = Math.ceil(allToDecomp/decodeChunkSize);
console.log(allToDecomp + " bytes to decompress");
console.log("Will be decompressed in " + numChunks+ " chunks");
let partOffsets = [{start: 0, end: 1131408336}, {start: 1131408336, end: 2227518416}, {start: 2227518416, end: 3308987856}, {start: 3308987856, end: 4265298864}];
let parts = [];
for (let offsets of partOffsets) {
parts.push(new Uint8Array(offsets.end-offsets.start));
}
parts[0].set(new Uint8Array(new BigUint64Array([BigInt(metadataLength)]).buffer), 0);
parts[0].set(combinedBuffer.subarray(8, 8 + metadataLength), 8);
parts[3].set(textModelU8, textModelOffset+8+metadataLength - partOffsets[3].start);
let start = Date.now();
let cursor = 0;
for (let i = 0; i < numChunks ; i + + ) {
progress(i, numChunks);
let chunkStartF16 = 8 + metadataLength + (decodeChunkSize * i);
let chunkEndF16 = chunkStartF16 + decodeChunkSize;
let chunk = combinedBuffer.subarray(chunkStartF16, chunkEndF16);
let uint32Chunk = new Uint32Array(chunk.buffer, chunk.byteOffset, chunk.byteLength / 4);
let result = await f16decomp(uint32Chunk);
let resultUint8 = new Uint8Array(result.buffer);
let chunkStartF32 = 8 + metadataLength + (decodeChunkSize * i * 2);
let chunkEndF32 = chunkStartF32 + resultUint8.byteLength;
let offsetInPart = chunkStartF32 - partOffsets[cursor].start;
if (chunkEndF32 < partOffsets [ cursor ] . end | | cursor = == parts . length - 1 ) {
parts[cursor].set(resultUint8, offsetInPart);
} else {
let spaceLeftInCurrentPart = partOffsets[cursor].end - chunkStartF32;
parts[cursor].set(resultUint8.subarray(0, spaceLeftInCurrentPart), offsetInPart);
cursor++;
if (cursor < parts.length ) {
let nextPartOffset = spaceLeftInCurrentPart;
let nextPartLength = resultUint8.length - nextPartOffset;
parts[cursor].set(resultUint8.subarray(nextPartOffset, nextPartOffset + nextPartLength), 0);
}
}
resultUint8 = null;
result = null;
}
combinedBuffer = null;
let end = Date.now();
console.log("Decoding took: " + ((end - start) / 1000) + " s");
console.log("Avarage " + ((end - start) / numChunks) + " ms per chunk");
return parts;
};
const loadNet = async () => {
const modelDlTitle = document.getElementById("modelDlTitle");
const progress = (loaded, total) => {
document.getElementById("modelDlProgressBar").value = (loaded/total) * 100
document.getElementById("modelDlProgressValue").innerHTML = Math.trunc((loaded/total) * 100) + "%"
}
const device = await getDevice();
f16decomp = await f16tof32().setup(device, safetensorParts),
safetensorParts = await getAndDecompressF16Safetensors(device, progress);
modelDlTitle.innerHTML = "Compiling model"
let models = ["textModel", "diffusor", "decoder"];
nets = await timer(() => Promise.all([
textModel().setup(device, safetensorParts),
diffusor().setup(device, safetensorParts),
decoder().setup(device, safetensorParts)
]).then((loadedModels) => loadedModels.reduce((acc, model, index) => { acc[models[index]] = model; return acc; }, {})), "(compilation)")
progress(1, 1);
modelDlTitle.innerHTML = "Model ready"
setTimeout(() => {
document.getElementById("modelDlProgressBar").style.display = "none";
document.getElementById("modelDlProgressValue").style.display = "none";
document.getElementById("divStepProgress").style.display = "flex";
}, 1000);
document.getElementById("btnRunNet").disabled = false;
}
function runStableDiffusion(prompt, steps, guidance, showStep) {
return new Promise(async (resolve, reject) => {
let context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP(prompt)));
let unconditional_context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP("")));
let timesteps = [];
for (let i = 1; i < 1000 ; i + = ( 1000 / steps ) ) {
timesteps.push(i);
}
console.log("Timesteps: " + timesteps);
let alphasCumprod = getWeight(safetensorParts,"alphas_cumprod");
let alphas = [];
for (t of timesteps) {
alphas.push(alphasCumprod[Math.floor(t)]);
}
alphas_prev = [1.0];
for (let i = 0; i < alphas.length-1 ; i + + ) {
alphas_prev.push(alphas[i]);
}
let inpSize = 4*64*64;
latent = new Float32Array(inpSize);
for (let i = 0; i < inpSize ; i + + ) {
latent[i] = Math.sqrt(-2.0 * Math.log(Math.random())) * Math.cos(2.0 * Math.PI * Math.random());
}
for (let i = timesteps.length - 1; i >= 0; i--) {
let timestep = new Float32Array([timesteps[i]]);
let start = performance.now()
let x_prev = await nets["diffusor"](unconditional_context, context, latent, timestep, new Float32Array([alphas[i]]), new Float32Array([alphas_prev[i]]), new Float32Array([guidance]));
document.getElementById("divStepTime").style.display = "block";
document.getElementById("stepTimeValue").innerText = `${(performance.now() - start).toFixed(1)} ms / step`
latent = x_prev;
if (showStep != null) {
showStep(await nets["decoder"](latent));
}
document.getElementById("progressBar").value = ((steps - i) / steps) * 100
document.getElementById("progressFraction").innerHTML = (steps - i) + "/" + steps
}
resolve(await timer(() => nets["decoder"](latent)));
});
}
function renderImage(image) {
let pixels = []
let pixelCounter = 0
for (let j = 0; j < 512 ; j + + ) {
for (let k = 0; k < 512 ; k + + ) {
pixels.push(image[pixelCounter])
pixels.push(image[pixelCounter+1])
pixels.push(image[pixelCounter+2])
pixels.push(255)
pixelCounter += 3
}
}
ctx.putImageData(new ImageData(new Uint8ClampedArray(pixels), 512, 512), 0, 0);
}
const handleRunNetAndRenderResult = () => {
document.getElementById("btnRunNet").disabled = true;
const canvas = document.getElementById("canvas");
ctx.clearRect(0, 0, canvas.width, canvas.height);
const prevTitleValue = document.getElementById("modelDlTitle").innerHTML;
document.getElementById("modelDlTitle").innerHTML = "Running model";
runStableDiffusion(
document.getElementById("promptText").value,
document.getElementById("stepRange").value,
document.getElementById("guidanceRange").value,
// Decode at each step
null
).then((image) => {
renderImage(image);
}).finally(() => {
document.getElementById("modelDlTitle").innerHTML = prevTitleValue;
document.getElementById("btnRunNet").disabled = false;
});
};
document.getElementById("btnRunNet").addEventListener("click", handleRunNetAndRenderResult, false);
document.getElementById("promptForm").addEventListener("submit", function (event) {
event.preventDefault();
if (document.getElementById("btnRunNet").disabled) return;
handleRunNetAndRenderResult();
})
const stepSlider = document.getElementById('stepRange');
const stepValue = document.getElementById('stepValue');
stepSlider.addEventListener('input', function() {
stepValue.textContent = stepSlider.value;
});
const guidanceSlider = document.getElementById('guidanceRange');
const guidanceValue = document.getElementById('guidanceValue');
guidanceSlider.addEventListener('input', function() {
guidanceValue.textContent = guidanceSlider.value;
});
loadNet();
});
< / script >
< / body >
< / html >