You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					631 lines
				
				22 KiB
			
		
		
			
		
	
	
					631 lines
				
				22 KiB
			| 
											8 months ago
										 | <!DOCTYPE html>
 | ||
|  | <html lang="en">
 | ||
|  | 
 | ||
|  | <head>
 | ||
|  |     <meta charset="UTF-8">
 | ||
|  |     <meta name="viewport" content="width=device-width, initial-scale=1.0">
 | ||
|  |     <title>tinygrad has WebGPU</title>
 | ||
|  |     <style>
 | ||
|  |         /* General Reset */
 | ||
|  |         * {
 | ||
|  |             margin: 0;
 | ||
|  |             padding: 0;
 | ||
|  |             box-sizing: border-box;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         body {
 | ||
|  |             font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
 | ||
|  |             background: #f4f7fb;
 | ||
|  |             color: #333;
 | ||
|  |             display: flex;
 | ||
|  |             justify-content: center;
 | ||
|  |             align-items: center;
 | ||
|  |             min-height: 100vh;
 | ||
|  |             padding: 20px;
 | ||
|  |             flex-direction: column;
 | ||
|  |             text-align: center;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         h1 {
 | ||
|  |             font-size: 2.2rem;
 | ||
|  |             margin-bottom: 20px;
 | ||
|  |             color: #4A90E2;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         #wgpuError {
 | ||
|  |             color: red;
 | ||
|  |             font-size: 1.2rem;
 | ||
|  |             margin-top: 20px;
 | ||
|  |             display: none;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         #sdTitle {
 | ||
|  |             font-size: 1.5rem;
 | ||
|  |             margin-bottom: 30px;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         #mybox {
 | ||
|  |             background: #ffffff;
 | ||
|  |             padding: 15px;
 | ||
|  |             border-radius: 10px;
 | ||
|  |             box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
 | ||
|  |             width: 120%;
 | ||
|  |             max-width: 550px;
 | ||
|  |             margin-bottom: 10px;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         input[type="text"] {
 | ||
|  |             width: 100%;
 | ||
|  |             padding: 12px;
 | ||
|  |             margin-bottom: 20px;
 | ||
|  |             font-size: 1rem;
 | ||
|  |             border: 1px solid #ccc;
 | ||
|  |             border-radius: 8px;
 | ||
|  |             outline: none;
 | ||
|  |             transition: all 0.3s ease;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         input[type="text"]:focus {
 | ||
|  |             border-color: #4A90E2;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         label {
 | ||
|  |             display: flex;
 | ||
|  |             justify-content: space-between;
 | ||
|  |             font-size: 1rem;
 | ||
|  |             margin-bottom: 15px;
 | ||
|  |             align-items: center;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         input[type="range"] {
 | ||
|  |             width: 100%;
 | ||
|  |             margin-left: 10px;
 | ||
|  |             -webkit-appearance: none;
 | ||
|  |             appearance: none;
 | ||
|  |             height: 8px;
 | ||
|  |             border-radius: 4px;
 | ||
|  |             background: #ddd;
 | ||
|  |             outline: none;
 | ||
|  |             transition: background 0.3s ease;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         input[type="range"]:focus {
 | ||
|  |             background: #4A90E2;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         #stepRange,
 | ||
|  |         #guidanceRange {
 | ||
|  |             width: 80%;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         span {
 | ||
|  |             font-size: 1.1rem;
 | ||
|  |             font-weight: 600;
 | ||
|  |             color: #333;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         input[type="button"] {
 | ||
|  |             padding: 12px 25px;
 | ||
|  |             background-color: #4A90E2;
 | ||
|  |             color: #fff;
 | ||
|  |             font-size: 1.2rem;
 | ||
|  |             border: none;
 | ||
|  |             border-radius: 8px;
 | ||
|  |             cursor: pointer;
 | ||
|  |             transition: background-color 0.3s ease;
 | ||
|  |             width: 100%;
 | ||
|  |             margin-bottom: 20px;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         input[type="button"]:disabled {
 | ||
|  |             background-color: #ccc;
 | ||
|  |             cursor: not-allowed;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         input[type="button"]:hover {
 | ||
|  |             background-color: #357ABD;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         #divModelDl, #divStepProgress {
 | ||
|  |             display: flex;
 | ||
|  |             justify-content: space-between;
 | ||
|  |             align-items: center;
 | ||
|  |             margin-bottom: 20px;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         #modelDlProgressBar,
 | ||
|  |         #progressBar {
 | ||
|  |             width: 80%;
 | ||
|  |             height: 12px;
 | ||
|  |             border-radius: 6px;
 | ||
|  |             background-color: #e0e0e0;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         #modelDlProgressBar::-webkit-progress-bar,
 | ||
|  |         #progressBar::-webkit-progress-bar {
 | ||
|  |             border-radius: 6px;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         #modelDlProgressValue, #progressFraction {
 | ||
|  |             font-size: 1rem;
 | ||
|  |             font-weight: 600;
 | ||
|  |             color: #333;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         canvas {
 | ||
|  |             max-width: 100%;
 | ||
|  |             max-height: 450px;
 | ||
|  |             margin-top: 10px;
 | ||
|  |             border-radius: 8px;
 | ||
|  |             border: 1px solid #ddd;
 | ||
|  |         }
 | ||
|  |     </style>
 | ||
|  | 
 | ||
|  |     <script type="module">
 | ||
|  |         import ClipTokenizer from './clip_tokenizer.js';
 | ||
|  |         window.clipTokenizer = new ClipTokenizer();
 | ||
|  |     </script>
 | ||
|  |     <script src="./net.js"></script>
 | ||
|  | </head>
 | ||
|  | <body>
 | ||
|  |     <h1 id="wgpuError" style="display: none;">WebGPU is not supported in this browser</h1>
 | ||
|  |     <h1 id="sdTitle">StableDiffusion powered by <a href="https://github.com/tinygrad/tinygrad" target="_blank" style="color: #4A90E2;">tinygrad</a></h1>
 | ||
|  |     <a href="https://github.com/tinygrad/tinygrad" target="_blank" style="position: absolute; top: 20px; right: 20px;">
 | ||
|  |         <img src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg"
 | ||
|  |              alt="GitHub Logo"
 | ||
|  |              style="width: 32px; height: 32px;">
 | ||
|  |     </a>
 | ||
|  |     
 | ||
|  |     <div id="mybox">
 | ||
|  |         <form id="promptForm">
 | ||
|  |             <input id="promptText" type="text" placeholder="Enter your prompt here" value="a human standing on the surface of mars">
 | ||
|  | 
 | ||
|  |             <label>
 | ||
|  |                 Steps: <span id="stepValue">9</span>
 | ||
|  |                 <input id="stepRange" type="range" min="5" max="20" value="9" step="1">
 | ||
|  |             </label>
 | ||
|  | 
 | ||
|  |             <label>
 | ||
|  |                 Guidance: <span id="guidanceValue">8.0</span>
 | ||
|  |                 <input id="guidanceRange" type="range" min="3" max="15" value="8.0" step="0.1">
 | ||
|  |             </label>
 | ||
|  | 
 | ||
|  |             <input id="btnRunNet" type="button" value="Run" disabled>
 | ||
|  |         </form>
 | ||
|  | 
 | ||
|  |         <div id="divModelDl" style="display: flex; align-items: center; width: 100%; gap: 10px;">
 | ||
|  |             <span id="modelDlTitle">Downloading model</span>
 | ||
|  |             <progress id="modelDlProgressBar" value="0" max="100" style="flex-grow: 1;"></progress>
 | ||
|  |             <span id="modelDlProgressValue"></span>
 | ||
|  |         </div>
 | ||
|  | 
 | ||
|  |         <div id="divStepProgress" style="display: none; align-items: center; width: 100%; gap: 10px;">
 | ||
|  |             <progress id="progressBar" value="0" max="100" style="flex-grow: 1;"></progress>
 | ||
|  |             <span id="progressFraction"></span>
 | ||
|  |         </div>
 | ||
|  | 
 | ||
|  |         <div id="divStepTime" style="display: none; align-items: center; width: 100%; gap: 10px;">
 | ||
|  |             <span id="stepTimeValue">0 ms</span>
 | ||
|  |         </div>
 | ||
|  |     </div>
 | ||
|  | 
 | ||
|  |     <canvas id="canvas" width="512" height="512"></canvas>
 | ||
|  | 
 | ||
|  | <script>
 | ||
|  |     let f16decomp = null;
 | ||
|  | 
 | ||
|  |     function initDb() {
 | ||
|  |         return new Promise((resolve, reject) => {
 | ||
|  |             let db;
 | ||
|  |             const request = indexedDB.open('tinydb', 1);
 | ||
|  |             request.onerror = (event) => {
 | ||
|  |                 console.error('Database error:', event.target.error);
 | ||
|  |                 resolve(null);
 | ||
|  |             };
 | ||
|  | 
 | ||
|  |             request.onsuccess = (event) => {
 | ||
|  |                 db = event.target.result;
 | ||
|  |                 console.log("Db initialized.");
 | ||
|  |                 resolve(db);
 | ||
|  |             };
 | ||
|  | 
 | ||
|  |             request.onupgradeneeded = (event) => {
 | ||
|  |                 db = event.target.result;
 | ||
|  |                 if (!db.objectStoreNames.contains('tensors')) {
 | ||
|  |                     db.createObjectStore('tensors', { keyPath: 'id' });
 | ||
|  |                 }
 | ||
|  |             };
 | ||
|  |         });
 | ||
|  |     }
 | ||
|  | 
 | ||
|  |     function saveTensorToDb(db, id, tensor) {
 | ||
|  |         return readTensorFromDb(db, id).then((result) => {
 | ||
|  |             if (!result) {
 | ||
|  |                 new Promise((resolve, reject) => {
 | ||
|  |                     if (db == null) {
 | ||
|  |                         resolve(null);
 | ||
|  |                     }
 | ||
|  | 
 | ||
|  |                     const transaction = db.transaction(['tensors'], 'readwrite');
 | ||
|  |                     const store = transaction.objectStore('tensors');
 | ||
|  |                     const request = store.put({ id: id, content: tensor });
 | ||
|  | 
 | ||
|  |                     transaction.onabort = (event) => {
 | ||
|  |                         console.log("Transaction error while saving tensor: " + event.target.error);
 | ||
|  |                         resolve(null);
 | ||
|  |                     };
 | ||
|  | 
 | ||
|  |                     request.onsuccess = () => {
 | ||
|  |                         console.log('Tensor saved successfully.');
 | ||
|  |                         resolve();
 | ||
|  |                     };
 | ||
|  | 
 | ||
|  |                     request.onerror = (event) => {
 | ||
|  |                         console.error('Tensor save failed:', event.target.error);
 | ||
|  |                         resolve(null);
 | ||
|  |                     };
 | ||
|  |                 });
 | ||
|  |             } else {
 | ||
|  |                 return null;
 | ||
|  |             }
 | ||
|  |         }).catch(()=> null);
 | ||
|  |     }
 | ||
|  | 
 | ||
|  |     function readTensorFromDb(db, id) {
 | ||
|  |         return new Promise((resolve, reject) => {
 | ||
|  |             if (db == null) {
 | ||
|  |                 resolve(null);
 | ||
|  |             }
 | ||
|  |             
 | ||
|  |             const transaction = db.transaction(['tensors'], 'readonly');
 | ||
|  |             const store = transaction.objectStore('tensors');
 | ||
|  |             const request = store.get(id);
 | ||
|  | 
 | ||
|  |             transaction.onabort = (event) => {
 | ||
|  |                 console.log("Transaction error while reading tensor: " + event.target.error);
 | ||
|  |                 resolve(null);
 | ||
|  |             };
 | ||
|  | 
 | ||
|  |             request.onsuccess = (event) => {
 | ||
|  |                 const result = event.target.result;
 | ||
|  |                 if (result) {
 | ||
|  |                     resolve(result);
 | ||
|  |                 } else {
 | ||
|  |                     resolve(null);
 | ||
|  |                 }
 | ||
|  |             };
 | ||
|  | 
 | ||
|  |             request.onerror = (event) => {
 | ||
|  |                 console.error('Tensor retrieve failed: ', event.target.error);
 | ||
|  |                 resolve(null);
 | ||
|  |             };
 | ||
|  |         });
 | ||
|  |     }
 | ||
|  | 
 | ||
|  |     window.addEventListener('load', async function() {
 | ||
|  |         if (!navigator.gpu) {
 | ||
|  |             document.getElementById("wgpuError").style.display = "block";
 | ||
|  |             document.getElementById("sdTitle").style.display = "none";
 | ||
|  |             return;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         let db = await initDb();
 | ||
|  | 
 | ||
|  |         const ctx = document.getElementById("canvas").getContext("2d", { willReadFrequently: true });
 | ||
|  |         let labels, nets, safetensorParts;
 | ||
|  | 
 | ||
|  |         const getDevice = async () => {
 | ||
|  |             const adapter = await navigator.gpu.requestAdapter();
 | ||
|  |             const requiredLimits = {};
 | ||
|  |             const maxBufferSizeInSDModel = 1073741824;
 | ||
|  |             requiredLimits.maxStorageBufferBindingSize = maxBufferSizeInSDModel;
 | ||
|  |             requiredLimits.maxBufferSize = maxBufferSizeInSDModel;
 | ||
|  |             
 | ||
|  |             return await adapter.requestDevice({
 | ||
|  |                 requiredLimits
 | ||
|  |             });
 | ||
|  |         };
 | ||
|  | 
 | ||
|  |         const timer = async (func, label = "") => {
 | ||
|  |             const start = performance.now();
 | ||
|  |             const out = await func();
 | ||
|  |             const delta = (performance.now() - start).toFixed(1)
 | ||
|  |             console.log(`${delta} ms ${label}`);
 | ||
|  |             return out;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         const getProgressDlForPart = async (part, progressCallback) => {
 | ||
|  |             const response = await fetch(part);
 | ||
|  |             const contentLength = response.headers.get('content-length');
 | ||
|  |             const total = parseInt(contentLength, 10);
 | ||
|  | 
 | ||
|  |             const res = new Response(new ReadableStream({
 | ||
|  |                 async start(controller) {
 | ||
|  |                     const reader = response.body.getReader();
 | ||
|  |                     for (;;) {
 | ||
|  |                         const { done, value } = await reader.read();
 | ||
|  |                         if (done) break;
 | ||
|  |                         progressCallback(part, value.byteLength, total);
 | ||
|  |                         controller.enqueue(value);
 | ||
|  |                     }
 | ||
|  |                     
 | ||
|  |                     controller.close();
 | ||
|  |                 },
 | ||
|  |             }));
 | ||
|  |         
 | ||
|  |             return res.arrayBuffer();
 | ||
|  |         };
 | ||
|  |   
 | ||
|  |         const getAndDecompressF16Safetensors = async (device, progress) => {
 | ||
|  |             let totalLoaded = 0;
 | ||
|  |             let totalSize = 0;
 | ||
|  |             let partSize = {};
 | ||
|  | 
 | ||
|  |             const getPart = async(key) => {
 | ||
|  |                 let part = await readTensorFromDb(db, key);
 | ||
|  | 
 | ||
|  |                 if (part) {
 | ||
|  |                     console.log(`Cache hit: ${key}`);
 | ||
|  |                     return Promise.resolve(part.content);
 | ||
|  |                 } else {
 | ||
|  |                     console.log(`Cache miss: ${key}`);
 | ||
|  |                     return getProgressDlForPart(`${window.MODEL_BASE_URL}/${key}.safetensors`, progressCallback);
 | ||
|  |                 }
 | ||
|  |             }
 | ||
|  | 
 | ||
|  |             const progressCallback = (part, loaded, total) => {
 | ||
|  |                 totalLoaded += loaded;
 | ||
|  | 
 | ||
|  |                 if (!partSize[part]) {
 | ||
|  |                     totalSize += total;
 | ||
|  |                     partSize[part] = true;
 | ||
|  |                 }
 | ||
|  |                 
 | ||
|  |                 progress(totalLoaded, totalSize);
 | ||
|  |             };
 | ||
|  | 
 | ||
|  |             let netKeys = ["net_part0", "net_part1", "net_part2", "net_part3", "net_textmodel"];
 | ||
|  |             let buffers = await Promise.all(netKeys.map(key => getPart(key)));
 | ||
|  | 
 | ||
|  |             // Combine everything except for text model, since that's already f32
 | ||
|  |             const totalLength = buffers.reduce((acc, buffer, index, array) => {
 | ||
|  |                 if (index < 4) {
 | ||
|  |                     return acc + buffer.byteLength;
 | ||
|  |                 } else {
 | ||
|  |                     return acc;
 | ||
|  |                 }
 | ||
|  |                 }, 0
 | ||
|  |             );
 | ||
|  | 
 | ||
|  |             combinedBuffer = new Uint8Array(totalLength);
 | ||
|  |             let offset = 0;
 | ||
|  |             buffers.forEach((buffer, index) => {
 | ||
|  |                 saveTensorToDb(db, netKeys[index], new Uint8Array(buffer));
 | ||
|  |                 if (index < 4) {
 | ||
|  |                     combinedBuffer.set(new Uint8Array(buffer), offset);
 | ||
|  |                     offset += buffer.byteLength;
 | ||
|  |                     buffer = null;
 | ||
|  |                 }
 | ||
|  |             });
 | ||
|  | 
 | ||
|  |             let textModelU8 = new Uint8Array(buffers[4]);
 | ||
|  |             document.getElementById("modelDlTitle").innerHTML = "Decompressing model";
 | ||
|  | 
 | ||
|  |             const textModelOffset = 3772703308;
 | ||
|  |             const metadataLength = Number(new DataView(combinedBuffer.buffer).getBigUint64(0, true));
 | ||
|  |             const metadata = JSON.parse(new TextDecoder("utf8").decode(combinedBuffer.subarray(8, 8 + metadataLength)));
 | ||
|  | 
 | ||
|  |             const allToDecomp = combinedBuffer.byteLength - (8 + metadataLength);
 | ||
|  |             const decodeChunkSize = 8388480;
 | ||
|  |             const numChunks = Math.ceil(allToDecomp/decodeChunkSize);
 | ||
|  | 
 | ||
|  |             console.log(allToDecomp + " bytes to decompress");
 | ||
|  |             console.log("Will be decompressed in " + numChunks+ " chunks");
 | ||
|  | 
 | ||
|  |             let partOffsets = [{start: 0, end: 1131408336}, {start: 1131408336, end: 2227518416}, {start: 2227518416, end: 3308987856}, {start: 3308987856, end: 4265298864}];
 | ||
|  |             let parts = [];
 | ||
|  | 
 | ||
|  |             for (let offsets of partOffsets) {
 | ||
|  |                 parts.push(new Uint8Array(offsets.end-offsets.start));
 | ||
|  |             }
 | ||
|  |             parts[0].set(new Uint8Array(new BigUint64Array([BigInt(metadataLength)]).buffer), 0);
 | ||
|  |             parts[0].set(combinedBuffer.subarray(8, 8 + metadataLength), 8);
 | ||
|  |             parts[3].set(textModelU8, textModelOffset+8+metadataLength - partOffsets[3].start);
 | ||
|  | 
 | ||
|  |             let start = Date.now();
 | ||
|  |             let cursor = 0;
 | ||
|  | 
 | ||
|  |             for (let i = 0; i < numChunks; i++) {
 | ||
|  |                 progress(i, numChunks);
 | ||
|  |                 let chunkStartF16 = 8 + metadataLength + (decodeChunkSize * i);
 | ||
|  |                 let chunkEndF16 = chunkStartF16 + decodeChunkSize;
 | ||
|  |                 let chunk = combinedBuffer.subarray(chunkStartF16, chunkEndF16);
 | ||
|  |                 let uint32Chunk = new Uint32Array(chunk.buffer, chunk.byteOffset, chunk.byteLength / 4);
 | ||
|  |                 let result = await f16decomp(uint32Chunk);
 | ||
|  |                 let resultUint8 = new Uint8Array(result.buffer);
 | ||
|  |                 let chunkStartF32 = 8 + metadataLength + (decodeChunkSize * i * 2);
 | ||
|  |                 let chunkEndF32 = chunkStartF32 + resultUint8.byteLength;
 | ||
|  |                 let offsetInPart = chunkStartF32 - partOffsets[cursor].start;
 | ||
|  | 
 | ||
|  |                 if (chunkEndF32 < partOffsets[cursor].end || cursor === parts.length - 1) {
 | ||
|  |                     parts[cursor].set(resultUint8, offsetInPart);
 | ||
|  |                 } else {
 | ||
|  |                     let spaceLeftInCurrentPart = partOffsets[cursor].end - chunkStartF32;
 | ||
|  |                     parts[cursor].set(resultUint8.subarray(0, spaceLeftInCurrentPart), offsetInPart);
 | ||
|  | 
 | ||
|  |                     cursor++;
 | ||
|  | 
 | ||
|  |                     if (cursor < parts.length) {
 | ||
|  |                         let nextPartOffset = spaceLeftInCurrentPart;
 | ||
|  |                         let nextPartLength = resultUint8.length - nextPartOffset;
 | ||
|  |                         parts[cursor].set(resultUint8.subarray(nextPartOffset, nextPartOffset + nextPartLength), 0);
 | ||
|  |                     }
 | ||
|  |                 }
 | ||
|  | 
 | ||
|  |                 resultUint8 = null;
 | ||
|  |                 result = null;
 | ||
|  |             }
 | ||
|  | 
 | ||
|  |             combinedBuffer = null;
 | ||
|  | 
 | ||
|  |             let end = Date.now();
 | ||
|  |             console.log("Decoding took: " + ((end - start) / 1000) + " s");
 | ||
|  |             console.log("Avarage " + ((end - start) / numChunks) + " ms per chunk");
 | ||
|  | 
 | ||
|  |             return parts;
 | ||
|  |         };
 | ||
|  | 
 | ||
|  |         const loadNet = async () => {
 | ||
|  |             const modelDlTitle = document.getElementById("modelDlTitle");
 | ||
|  | 
 | ||
|  |             const progress = (loaded, total) => {
 | ||
|  |                 document.getElementById("modelDlProgressBar").value = (loaded/total) * 100
 | ||
|  |                 document.getElementById("modelDlProgressValue").innerHTML = Math.trunc((loaded/total) * 100) + "%"
 | ||
|  |             }
 | ||
|  | 
 | ||
|  |             const device = await getDevice();
 | ||
|  |             f16decomp =  await f16tof32().setup(device, safetensorParts),
 | ||
|  |             safetensorParts = await getAndDecompressF16Safetensors(device, progress);
 | ||
|  | 
 | ||
|  |             modelDlTitle.innerHTML = "Compiling model"
 | ||
|  | 
 | ||
|  |             let models = ["textModel", "diffusor", "decoder"];
 | ||
|  | 
 | ||
|  |             nets = await timer(() => Promise.all([
 | ||
|  |                 textModel().setup(device, safetensorParts),
 | ||
|  |                 diffusor().setup(device, safetensorParts),
 | ||
|  |                 decoder().setup(device, safetensorParts)
 | ||
|  |             ]).then((loadedModels) => loadedModels.reduce((acc, model, index) => { acc[models[index]] = model; return acc; }, {})), "(compilation)")
 | ||
|  | 
 | ||
|  |             progress(1, 1);
 | ||
|  | 
 | ||
|  |             modelDlTitle.innerHTML = "Model ready"
 | ||
|  |             setTimeout(() => {
 | ||
|  |                 document.getElementById("modelDlProgressBar").style.display = "none";
 | ||
|  |                 document.getElementById("modelDlProgressValue").style.display = "none";
 | ||
|  |                 document.getElementById("divStepProgress").style.display = "flex";
 | ||
|  |             }, 1000);
 | ||
|  |             document.getElementById("btnRunNet").disabled = false;
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         function runStableDiffusion(prompt, steps, guidance, showStep) {
 | ||
|  |             return new Promise(async (resolve, reject) => {
 | ||
|  |                 let context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP(prompt)));
 | ||
|  |                 let unconditional_context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP("")));
 | ||
|  | 
 | ||
|  |                 let timesteps = [];
 | ||
|  | 
 | ||
|  |                 for (let i = 1; i < 1000; i += (1000/steps)) {
 | ||
|  |                     timesteps.push(i);
 | ||
|  |                 }
 | ||
|  | 
 | ||
|  |                 console.log("Timesteps: " + timesteps);
 | ||
|  | 
 | ||
|  |                 let alphasCumprod = getWeight(safetensorParts,"alphas_cumprod");
 | ||
|  |                 let alphas = [];
 | ||
|  | 
 | ||
|  |                 for (t of timesteps) {
 | ||
|  |                     alphas.push(alphasCumprod[Math.floor(t)]);
 | ||
|  |                 }
 | ||
|  | 
 | ||
|  |                 alphas_prev = [1.0];
 | ||
|  | 
 | ||
|  |                 for (let i = 0; i < alphas.length-1; i++) {
 | ||
|  |                     alphas_prev.push(alphas[i]);
 | ||
|  |                 }
 | ||
|  | 
 | ||
|  |                 let inpSize = 4*64*64;
 | ||
|  |                 latent = new Float32Array(inpSize);
 | ||
|  | 
 | ||
|  |                 for (let i = 0; i < inpSize; i++) {
 | ||
|  |                     latent[i] = Math.sqrt(-2.0 * Math.log(Math.random())) * Math.cos(2.0 * Math.PI * Math.random());
 | ||
|  |                 }
 | ||
|  |                 
 | ||
|  |                 for (let i = timesteps.length - 1; i >= 0; i--) {
 | ||
|  |                     let timestep = new Float32Array([timesteps[i]]);
 | ||
|  |                     let start = performance.now()
 | ||
|  |                     let x_prev = await nets["diffusor"](unconditional_context, context, latent, timestep, new Float32Array([alphas[i]]), new Float32Array([alphas_prev[i]]), new Float32Array([guidance]));
 | ||
|  |                     document.getElementById("divStepTime").style.display = "block";
 | ||
|  |                     document.getElementById("stepTimeValue").innerText = `${(performance.now() - start).toFixed(1)} ms / step`
 | ||
|  |                     latent = x_prev;
 | ||
|  | 
 | ||
|  |                     if (showStep != null) {
 | ||
|  |                         showStep(await nets["decoder"](latent));
 | ||
|  |                     }
 | ||
|  | 
 | ||
|  |                     document.getElementById("progressBar").value = ((steps - i) / steps) * 100
 | ||
|  |                     document.getElementById("progressFraction").innerHTML = (steps - i) + "/" + steps
 | ||
|  |                 }
 | ||
|  |                 
 | ||
|  |                 resolve(await timer(() => nets["decoder"](latent)));
 | ||
|  |             });
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         function renderImage(image) {
 | ||
|  |             let pixels = []
 | ||
|  |             let pixelCounter = 0
 | ||
|  | 
 | ||
|  |             for (let j = 0; j < 512; j++) {
 | ||
|  |                 for (let k = 0; k < 512; k++) {
 | ||
|  |                     pixels.push(image[pixelCounter])
 | ||
|  |                     pixels.push(image[pixelCounter+1])
 | ||
|  |                     pixels.push(image[pixelCounter+2])
 | ||
|  |                     pixels.push(255)
 | ||
|  |                     pixelCounter += 3
 | ||
|  |                 }
 | ||
|  |             }
 | ||
|  | 
 | ||
|  |             ctx.putImageData(new ImageData(new Uint8ClampedArray(pixels), 512, 512), 0, 0);
 | ||
|  |         }
 | ||
|  | 
 | ||
|  |         const handleRunNetAndRenderResult = () => {
 | ||
|  |             document.getElementById("btnRunNet").disabled = true;
 | ||
|  |             const canvas = document.getElementById("canvas");
 | ||
|  |             ctx.clearRect(0, 0, canvas.width, canvas.height);
 | ||
|  | 
 | ||
|  |             const prevTitleValue = document.getElementById("modelDlTitle").innerHTML;
 | ||
|  |             document.getElementById("modelDlTitle").innerHTML = "Running model";
 | ||
|  | 
 | ||
|  |             runStableDiffusion(
 | ||
|  |                 document.getElementById("promptText").value,
 | ||
|  |                 document.getElementById("stepRange").value,
 | ||
|  |                 document.getElementById("guidanceRange").value,
 | ||
|  |                 // Decode at each step
 | ||
|  |                 null
 | ||
|  |             ).then((image) => {
 | ||
|  |                 renderImage(image);
 | ||
|  |             }).finally(() => {
 | ||
|  |                 document.getElementById("modelDlTitle").innerHTML = prevTitleValue;
 | ||
|  |                 document.getElementById("btnRunNet").disabled = false;
 | ||
|  |             });
 | ||
|  |         };
 | ||
|  | 
 | ||
|  |         document.getElementById("btnRunNet").addEventListener("click", handleRunNetAndRenderResult, false);
 | ||
|  | 
 | ||
|  |         document.getElementById("promptForm").addEventListener("submit", function (event) {
 | ||
|  |             event.preventDefault();
 | ||
|  |             if (document.getElementById("btnRunNet").disabled) return;
 | ||
|  | 
 | ||
|  |             handleRunNetAndRenderResult();
 | ||
|  |         })
 | ||
|  | 
 | ||
|  |         const stepSlider = document.getElementById('stepRange');
 | ||
|  |         const stepValue = document.getElementById('stepValue');
 | ||
|  | 
 | ||
|  |         stepSlider.addEventListener('input', function() {
 | ||
|  |             stepValue.textContent = stepSlider.value;
 | ||
|  |         });
 | ||
|  | 
 | ||
|  |         const guidanceSlider = document.getElementById('guidanceRange');
 | ||
|  |         const guidanceValue = document.getElementById('guidanceValue');
 | ||
|  | 
 | ||
|  |         guidanceSlider.addEventListener('input', function() {
 | ||
|  |             guidanceValue.textContent = guidanceSlider.value;
 | ||
|  |         });
 | ||
|  | 
 | ||
|  |         loadNet();
 | ||
|  |     });
 | ||
|  | </script>
 | ||
|  | </body>
 | ||
|  | </html>
 |