You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
632 lines
22 KiB
632 lines
22 KiB
<!DOCTYPE html>
|
|
<html lang="en">
|
|
|
|
<head>
|
|
<meta charset="UTF-8">
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
<title>tinygrad has WebGPU</title>
|
|
<style>
|
|
/* General Reset */
|
|
* {
|
|
margin: 0;
|
|
padding: 0;
|
|
box-sizing: border-box;
|
|
}
|
|
|
|
body {
|
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
|
background: #f4f7fb;
|
|
color: #333;
|
|
display: flex;
|
|
justify-content: center;
|
|
align-items: center;
|
|
min-height: 100vh;
|
|
padding: 20px;
|
|
flex-direction: column;
|
|
text-align: center;
|
|
}
|
|
|
|
h1 {
|
|
font-size: 2.2rem;
|
|
margin-bottom: 20px;
|
|
color: #4A90E2;
|
|
}
|
|
|
|
#wgpuError {
|
|
color: red;
|
|
font-size: 1.2rem;
|
|
margin-top: 20px;
|
|
display: none;
|
|
}
|
|
|
|
#sdTitle {
|
|
font-size: 1.5rem;
|
|
margin-bottom: 30px;
|
|
}
|
|
|
|
#mybox {
|
|
background: #ffffff;
|
|
padding: 15px;
|
|
border-radius: 10px;
|
|
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
|
|
width: 120%;
|
|
max-width: 550px;
|
|
margin-bottom: 10px;
|
|
}
|
|
|
|
input[type="text"] {
|
|
width: 100%;
|
|
padding: 12px;
|
|
margin-bottom: 20px;
|
|
font-size: 1rem;
|
|
border: 1px solid #ccc;
|
|
border-radius: 8px;
|
|
outline: none;
|
|
transition: all 0.3s ease;
|
|
}
|
|
|
|
input[type="text"]:focus {
|
|
border-color: #4A90E2;
|
|
}
|
|
|
|
label {
|
|
display: flex;
|
|
justify-content: space-between;
|
|
font-size: 1rem;
|
|
margin-bottom: 15px;
|
|
align-items: center;
|
|
}
|
|
|
|
input[type="range"] {
|
|
width: 100%;
|
|
margin-left: 10px;
|
|
-webkit-appearance: none;
|
|
appearance: none;
|
|
height: 8px;
|
|
border-radius: 4px;
|
|
background: #ddd;
|
|
outline: none;
|
|
transition: background 0.3s ease;
|
|
}
|
|
|
|
input[type="range"]:focus {
|
|
background: #4A90E2;
|
|
}
|
|
|
|
#stepRange,
|
|
#guidanceRange {
|
|
width: 80%;
|
|
}
|
|
|
|
span {
|
|
font-size: 1.1rem;
|
|
font-weight: 600;
|
|
color: #333;
|
|
}
|
|
|
|
input[type="button"] {
|
|
padding: 12px 25px;
|
|
background-color: #4A90E2;
|
|
color: #fff;
|
|
font-size: 1.2rem;
|
|
border: none;
|
|
border-radius: 8px;
|
|
cursor: pointer;
|
|
transition: background-color 0.3s ease;
|
|
width: 100%;
|
|
margin-bottom: 20px;
|
|
}
|
|
|
|
input[type="button"]:disabled {
|
|
background-color: #ccc;
|
|
cursor: not-allowed;
|
|
}
|
|
|
|
input[type="button"]:hover {
|
|
background-color: #357ABD;
|
|
}
|
|
|
|
#divModelDl, #divStepProgress {
|
|
display: flex;
|
|
justify-content: space-between;
|
|
align-items: center;
|
|
margin-bottom: 20px;
|
|
}
|
|
|
|
#modelDlProgressBar,
|
|
#progressBar {
|
|
width: 80%;
|
|
height: 12px;
|
|
border-radius: 6px;
|
|
background-color: #e0e0e0;
|
|
}
|
|
|
|
#modelDlProgressBar::-webkit-progress-bar,
|
|
#progressBar::-webkit-progress-bar {
|
|
border-radius: 6px;
|
|
}
|
|
|
|
#modelDlProgressValue, #progressFraction {
|
|
font-size: 1rem;
|
|
font-weight: 600;
|
|
color: #333;
|
|
}
|
|
|
|
canvas {
|
|
max-width: 100%;
|
|
max-height: 450px;
|
|
margin-top: 10px;
|
|
border-radius: 8px;
|
|
border: 1px solid #ddd;
|
|
}
|
|
</style>
|
|
|
|
<script type="module">
|
|
import ClipTokenizer from './clip_tokenizer.js';
|
|
window.clipTokenizer = new ClipTokenizer();
|
|
</script>
|
|
<script src="./net.js"></script>
|
|
</head>
|
|
<body>
|
|
<h1 id="wgpuError" style="display: none;">WebGPU is not supported in this browser</h1>
|
|
<h1 id="sdTitle">StableDiffusion powered by <a href="https://github.com/tinygrad/tinygrad" target="_blank" style="color: #4A90E2;">tinygrad</a></h1>
|
|
<a href="https://github.com/tinygrad/tinygrad" target="_blank" style="position: absolute; top: 20px; right: 20px;">
|
|
<img src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg"
|
|
alt="GitHub Logo"
|
|
style="width: 32px; height: 32px;">
|
|
</a>
|
|
|
|
<div id="mybox">
|
|
<form id="promptForm">
|
|
<input id="promptText" type="text" placeholder="Enter your prompt here" value="a human standing on the surface of mars">
|
|
|
|
<label>
|
|
Steps: <span id="stepValue">9</span>
|
|
<input id="stepRange" type="range" min="5" max="20" value="9" step="1">
|
|
</label>
|
|
|
|
<label>
|
|
Guidance: <span id="guidanceValue">8.0</span>
|
|
<input id="guidanceRange" type="range" min="3" max="15" value="8.0" step="0.1">
|
|
</label>
|
|
|
|
<input id="btnRunNet" type="button" value="Run" disabled>
|
|
</form>
|
|
|
|
<div id="divModelDl" style="display: flex; align-items: center; width: 100%; gap: 10px;">
|
|
<span id="modelDlTitle">Downloading model</span>
|
|
<progress id="modelDlProgressBar" value="0" max="100" style="flex-grow: 1;"></progress>
|
|
<span id="modelDlProgressValue"></span>
|
|
</div>
|
|
|
|
<div id="divStepProgress" style="display: none; align-items: center; width: 100%; gap: 10px;">
|
|
<progress id="progressBar" value="0" max="100" style="flex-grow: 1;"></progress>
|
|
<span id="progressFraction"></span>
|
|
</div>
|
|
|
|
<div id="divStepTime" style="display: none; align-items: center; width: 100%; gap: 10px;">
|
|
<span id="stepTimeValue">0 ms</span>
|
|
</div>
|
|
</div>
|
|
|
|
<canvas id="canvas" width="512" height="512"></canvas>
|
|
|
|
<script>
|
|
let f16decomp = null;
|
|
|
|
function initDb() {
|
|
return new Promise((resolve, reject) => {
|
|
let db;
|
|
const request = indexedDB.open('tinydb', 1);
|
|
request.onerror = (event) => {
|
|
console.error('Database error:', event.target.error);
|
|
resolve(null);
|
|
};
|
|
|
|
request.onsuccess = (event) => {
|
|
db = event.target.result;
|
|
console.log("Db initialized.");
|
|
resolve(db);
|
|
};
|
|
|
|
request.onupgradeneeded = (event) => {
|
|
db = event.target.result;
|
|
if (!db.objectStoreNames.contains('tensors')) {
|
|
db.createObjectStore('tensors', { keyPath: 'id' });
|
|
}
|
|
};
|
|
});
|
|
}
|
|
|
|
function saveTensorToDb(db, id, tensor) {
|
|
return readTensorFromDb(db, id).then((result) => {
|
|
if (!result) {
|
|
new Promise((resolve, reject) => {
|
|
if (db == null) {
|
|
resolve(null);
|
|
}
|
|
|
|
const transaction = db.transaction(['tensors'], 'readwrite');
|
|
const store = transaction.objectStore('tensors');
|
|
const request = store.put({ id: id, content: tensor });
|
|
|
|
transaction.onabort = (event) => {
|
|
console.log("Transaction error while saving tensor: " + event.target.error);
|
|
resolve(null);
|
|
};
|
|
|
|
request.onsuccess = () => {
|
|
console.log('Tensor saved successfully.');
|
|
resolve();
|
|
};
|
|
|
|
request.onerror = (event) => {
|
|
console.error('Tensor save failed:', event.target.error);
|
|
resolve(null);
|
|
};
|
|
});
|
|
} else {
|
|
return null;
|
|
}
|
|
}).catch(()=> null);
|
|
}
|
|
|
|
function readTensorFromDb(db, id) {
|
|
return new Promise((resolve, reject) => {
|
|
if (db == null) {
|
|
resolve(null);
|
|
}
|
|
|
|
const transaction = db.transaction(['tensors'], 'readonly');
|
|
const store = transaction.objectStore('tensors');
|
|
const request = store.get(id);
|
|
|
|
transaction.onabort = (event) => {
|
|
console.log("Transaction error while reading tensor: " + event.target.error);
|
|
resolve(null);
|
|
};
|
|
|
|
request.onsuccess = (event) => {
|
|
const result = event.target.result;
|
|
if (result) {
|
|
resolve(result);
|
|
} else {
|
|
resolve(null);
|
|
}
|
|
};
|
|
|
|
request.onerror = (event) => {
|
|
console.error('Tensor retrieve failed: ', event.target.error);
|
|
resolve(null);
|
|
};
|
|
});
|
|
}
|
|
|
|
window.addEventListener('load', async function() {
|
|
if (!navigator.gpu) {
|
|
document.getElementById("wgpuError").style.display = "block";
|
|
document.getElementById("sdTitle").style.display = "none";
|
|
return;
|
|
}
|
|
|
|
let db = await initDb();
|
|
|
|
const ctx = document.getElementById("canvas").getContext("2d", { willReadFrequently: true });
|
|
let labels, nets, safetensorParts;
|
|
|
|
const getDevice = async () => {
|
|
const adapter = await navigator.gpu.requestAdapter();
|
|
const requiredLimits = {};
|
|
const maxBufferSizeInSDModel = 1073741824;
|
|
requiredLimits.maxStorageBufferBindingSize = maxBufferSizeInSDModel;
|
|
requiredLimits.maxBufferSize = maxBufferSizeInSDModel;
|
|
|
|
return await adapter.requestDevice({
|
|
requiredLimits,
|
|
requiredFeatures: ["shader-f16"],
|
|
powerPreference: "high-performance"
|
|
});
|
|
};
|
|
|
|
const timer = async (func, label = "") => {
|
|
const start = performance.now();
|
|
const out = await func();
|
|
const delta = (performance.now() - start).toFixed(1)
|
|
console.log(`${delta} ms ${label}`);
|
|
return out;
|
|
}
|
|
|
|
const getProgressDlForPart = async (part, progressCallback) => {
|
|
const response = await fetch(part);
|
|
const contentLength = response.headers.get('content-length');
|
|
const total = parseInt(contentLength, 10);
|
|
|
|
const res = new Response(new ReadableStream({
|
|
async start(controller) {
|
|
const reader = response.body.getReader();
|
|
for (;;) {
|
|
const { done, value } = await reader.read();
|
|
if (done) break;
|
|
progressCallback(part, value.byteLength, total);
|
|
controller.enqueue(value);
|
|
}
|
|
|
|
controller.close();
|
|
},
|
|
}));
|
|
|
|
return res.arrayBuffer();
|
|
};
|
|
|
|
const getAndDecompressF16Safetensors = async (device, progress) => {
|
|
let totalLoaded = 0;
|
|
let totalSize = 0;
|
|
let partSize = {};
|
|
|
|
const getPart = async(key) => {
|
|
let part = await readTensorFromDb(db, key);
|
|
|
|
if (part) {
|
|
console.log(`Cache hit: ${key}`);
|
|
return Promise.resolve(part.content);
|
|
} else {
|
|
console.log(`Cache miss: ${key}`);
|
|
return getProgressDlForPart(`${window.MODEL_BASE_URL}/${key}.safetensors`, progressCallback);
|
|
}
|
|
}
|
|
|
|
const progressCallback = (part, loaded, total) => {
|
|
totalLoaded += loaded;
|
|
|
|
if (!partSize[part]) {
|
|
totalSize += total;
|
|
partSize[part] = true;
|
|
}
|
|
|
|
progress(totalLoaded, totalSize);
|
|
};
|
|
|
|
let netKeys = ["net_part0", "net_part1", "net_part2", "net_part3", "net_textmodel"];
|
|
let buffers = await Promise.all(netKeys.map(key => getPart(key)));
|
|
|
|
// Combine everything except for text model, since that's already f32
|
|
const totalLength = buffers.reduce((acc, buffer, index, array) => {
|
|
if (index < 4) {
|
|
return acc + buffer.byteLength;
|
|
} else {
|
|
return acc;
|
|
}
|
|
}, 0
|
|
);
|
|
|
|
combinedBuffer = new Uint8Array(totalLength);
|
|
let offset = 0;
|
|
buffers.forEach((buffer, index) => {
|
|
saveTensorToDb(db, netKeys[index], new Uint8Array(buffer));
|
|
if (index < 4) {
|
|
combinedBuffer.set(new Uint8Array(buffer), offset);
|
|
offset += buffer.byteLength;
|
|
buffer = null;
|
|
}
|
|
});
|
|
|
|
let textModelU8 = new Uint8Array(buffers[4]);
|
|
document.getElementById("modelDlTitle").innerHTML = "Decompressing model";
|
|
|
|
const textModelOffset = 3772703308;
|
|
const metadataLength = Number(new DataView(combinedBuffer.buffer).getBigUint64(0, true));
|
|
const metadata = JSON.parse(new TextDecoder("utf8").decode(combinedBuffer.subarray(8, 8 + metadataLength)));
|
|
|
|
const allToDecomp = combinedBuffer.byteLength - (8 + metadataLength);
|
|
const decodeChunkSize = 8388480;
|
|
const numChunks = Math.ceil(allToDecomp/decodeChunkSize);
|
|
|
|
console.log(allToDecomp + " bytes to decompress");
|
|
console.log("Will be decompressed in " + numChunks+ " chunks");
|
|
|
|
let partOffsets = [{start: 0, end: 1131408336}, {start: 1131408336, end: 2227518416}, {start: 2227518416, end: 3308987856}, {start: 3308987856, end: 4265298864}];
|
|
let parts = [];
|
|
|
|
for (let offsets of partOffsets) {
|
|
parts.push(new Uint8Array(offsets.end-offsets.start));
|
|
}
|
|
parts[0].set(new Uint8Array(new BigUint64Array([BigInt(metadataLength)]).buffer), 0);
|
|
parts[0].set(combinedBuffer.subarray(8, 8 + metadataLength), 8);
|
|
parts[3].set(textModelU8, textModelOffset+8+metadataLength - partOffsets[3].start);
|
|
|
|
let start = Date.now();
|
|
let cursor = 0;
|
|
|
|
for (let i = 0; i < numChunks; i++) {
|
|
progress(i, numChunks);
|
|
let chunkStartF16 = 8 + metadataLength + (decodeChunkSize * i);
|
|
let chunkEndF16 = chunkStartF16 + decodeChunkSize;
|
|
let chunk = combinedBuffer.subarray(chunkStartF16, chunkEndF16);
|
|
let uint32Chunk = new Uint32Array(chunk.buffer, chunk.byteOffset, chunk.byteLength / 4);
|
|
let result = await f16decomp(uint32Chunk);
|
|
let resultUint8 = new Uint8Array(result.buffer);
|
|
let chunkStartF32 = 8 + metadataLength + (decodeChunkSize * i * 2);
|
|
let chunkEndF32 = chunkStartF32 + resultUint8.byteLength;
|
|
let offsetInPart = chunkStartF32 - partOffsets[cursor].start;
|
|
|
|
if (chunkEndF32 < partOffsets[cursor].end || cursor === parts.length - 1) {
|
|
parts[cursor].set(resultUint8, offsetInPart);
|
|
} else {
|
|
let spaceLeftInCurrentPart = partOffsets[cursor].end - chunkStartF32;
|
|
parts[cursor].set(resultUint8.subarray(0, spaceLeftInCurrentPart), offsetInPart);
|
|
|
|
cursor++;
|
|
|
|
if (cursor < parts.length) {
|
|
let nextPartOffset = spaceLeftInCurrentPart;
|
|
let nextPartLength = resultUint8.length - nextPartOffset;
|
|
parts[cursor].set(resultUint8.subarray(nextPartOffset, nextPartOffset + nextPartLength), 0);
|
|
}
|
|
}
|
|
|
|
resultUint8 = null;
|
|
result = null;
|
|
}
|
|
|
|
combinedBuffer = null;
|
|
|
|
let end = Date.now();
|
|
console.log("Decoding took: " + ((end - start) / 1000) + " s");
|
|
console.log("Avarage " + ((end - start) / numChunks) + " ms per chunk");
|
|
|
|
return parts;
|
|
};
|
|
|
|
const loadNet = async () => {
|
|
const modelDlTitle = document.getElementById("modelDlTitle");
|
|
|
|
const progress = (loaded, total) => {
|
|
document.getElementById("modelDlProgressBar").value = (loaded/total) * 100
|
|
document.getElementById("modelDlProgressValue").innerHTML = Math.trunc((loaded/total) * 100) + "%"
|
|
}
|
|
|
|
const device = await getDevice();
|
|
f16decomp = await f16tof32().setup(device, safetensorParts),
|
|
safetensorParts = await getAndDecompressF16Safetensors(device, progress);
|
|
|
|
modelDlTitle.innerHTML = "Compiling model"
|
|
|
|
let models = ["textModel", "diffusor", "decoder"];
|
|
|
|
nets = await timer(() => Promise.all([
|
|
textModel().setup(device, safetensorParts),
|
|
diffusor().setup(device, safetensorParts),
|
|
decoder().setup(device, safetensorParts)
|
|
]).then((loadedModels) => loadedModels.reduce((acc, model, index) => { acc[models[index]] = model; return acc; }, {})), "(compilation)")
|
|
|
|
progress(1, 1);
|
|
|
|
modelDlTitle.innerHTML = "Model ready"
|
|
setTimeout(() => {
|
|
document.getElementById("modelDlProgressBar").style.display = "none";
|
|
document.getElementById("modelDlProgressValue").style.display = "none";
|
|
document.getElementById("divStepProgress").style.display = "flex";
|
|
}, 1000);
|
|
document.getElementById("btnRunNet").disabled = false;
|
|
}
|
|
|
|
function runStableDiffusion(prompt, steps, guidance, showStep) {
|
|
return new Promise(async (resolve, reject) => {
|
|
let context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP(prompt)));
|
|
let unconditional_context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP("")));
|
|
|
|
let timesteps = [];
|
|
|
|
for (let i = 1; i < 1000; i += (1000/steps)) {
|
|
timesteps.push(i);
|
|
}
|
|
|
|
console.log("Timesteps: " + timesteps);
|
|
|
|
let alphasCumprod = getWeight(safetensorParts,"alphas_cumprod");
|
|
let alphas = [];
|
|
|
|
for (t of timesteps) {
|
|
alphas.push(alphasCumprod[Math.floor(t)]);
|
|
}
|
|
|
|
alphas_prev = [1.0];
|
|
|
|
for (let i = 0; i < alphas.length-1; i++) {
|
|
alphas_prev.push(alphas[i]);
|
|
}
|
|
|
|
let inpSize = 4*64*64;
|
|
latent = new Float32Array(inpSize);
|
|
|
|
for (let i = 0; i < inpSize; i++) {
|
|
latent[i] = Math.sqrt(-2.0 * Math.log(Math.random())) * Math.cos(2.0 * Math.PI * Math.random());
|
|
}
|
|
|
|
for (let i = timesteps.length - 1; i >= 0; i--) {
|
|
let timestep = new Float32Array([timesteps[i]]);
|
|
let start = performance.now()
|
|
let x_prev = await nets["diffusor"](unconditional_context, context, latent, timestep, new Float32Array([alphas[i]]), new Float32Array([alphas_prev[i]]), new Float32Array([guidance]));
|
|
document.getElementById("divStepTime").style.display = "block";
|
|
document.getElementById("stepTimeValue").innerText = `${(performance.now() - start).toFixed(1)} ms / step`
|
|
latent = x_prev;
|
|
|
|
if (showStep != null) {
|
|
showStep(await nets["decoder"](latent));
|
|
}
|
|
|
|
document.getElementById("progressBar").value = ((steps - i) / steps) * 100
|
|
document.getElementById("progressFraction").innerHTML = (steps - i) + "/" + steps
|
|
}
|
|
|
|
resolve(await timer(() => nets["decoder"](latent)));
|
|
});
|
|
}
|
|
|
|
function renderImage(image) {
|
|
let pixels = []
|
|
let pixelCounter = 0
|
|
|
|
for (let j = 0; j < 512; j++) {
|
|
for (let k = 0; k < 512; k++) {
|
|
pixels.push(image[pixelCounter])
|
|
pixels.push(image[pixelCounter+1])
|
|
pixels.push(image[pixelCounter+2])
|
|
pixels.push(255)
|
|
pixelCounter += 3
|
|
}
|
|
}
|
|
|
|
ctx.putImageData(new ImageData(new Uint8ClampedArray(pixels), 512, 512), 0, 0);
|
|
}
|
|
|
|
const handleRunNetAndRenderResult = () => {
|
|
document.getElementById("btnRunNet").disabled = true;
|
|
const canvas = document.getElementById("canvas");
|
|
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
|
|
|
const prevTitleValue = document.getElementById("modelDlTitle").innerHTML;
|
|
document.getElementById("modelDlTitle").innerHTML = "Running model";
|
|
|
|
runStableDiffusion(
|
|
document.getElementById("promptText").value,
|
|
document.getElementById("stepRange").value,
|
|
document.getElementById("guidanceRange").value,
|
|
// Decode at each step
|
|
null
|
|
).then((image) => {
|
|
renderImage(image);
|
|
}).finally(() => {
|
|
document.getElementById("modelDlTitle").innerHTML = prevTitleValue;
|
|
document.getElementById("btnRunNet").disabled = false;
|
|
});
|
|
};
|
|
|
|
document.getElementById("btnRunNet").addEventListener("click", handleRunNetAndRenderResult, false);
|
|
|
|
document.getElementById("promptForm").addEventListener("submit", function (event) {
|
|
event.preventDefault();
|
|
if (document.getElementById("btnRunNet").disabled) return;
|
|
|
|
handleRunNetAndRenderResult();
|
|
})
|
|
|
|
const stepSlider = document.getElementById('stepRange');
|
|
const stepValue = document.getElementById('stepValue');
|
|
|
|
stepSlider.addEventListener('input', function() {
|
|
stepValue.textContent = stepSlider.value;
|
|
});
|
|
|
|
const guidanceSlider = document.getElementById('guidanceRange');
|
|
const guidanceValue = document.getElementById('guidanceValue');
|
|
|
|
guidanceSlider.addEventListener('input', function() {
|
|
guidanceValue.textContent = guidanceSlider.value;
|
|
});
|
|
|
|
loadNet();
|
|
});
|
|
</script>
|
|
</body>
|
|
</html>
|
|
|