import os
from extra . export_model import compile_net , jit_model , dtype_to_js_type
from extra . f16_decompress import u32_to_f16
from examples . stable_diffusion import StableDiffusion
from tinygrad . nn . state import get_state_dict , safe_save , safe_load_metadata , torch_load , load_state_dict
from tinygrad . tensor import Tensor
from tinygrad import Device , dtypes
from tinygrad . helpers import fetch
from typing import NamedTuple , Any , List
import requests
import argparse
import numpy as np
def convert_f32_to_f16 ( input_file , output_file ) :
with open ( input_file , ' rb ' ) as f :
metadata_length_bytes = f . read ( 8 )
metadata_length = int . from_bytes ( metadata_length_bytes , byteorder = ' little ' , signed = False )
metadata_json_bytes = f . read ( metadata_length )
float32_values = np . fromfile ( f , dtype = np . float32 )
first_text_model_offset = 3772703308
num_elements = int ( ( first_text_model_offset ) / 4 )
front_float16_values = float32_values [ : num_elements ] . astype ( np . float16 )
rest_float32_values = float32_values [ num_elements : ]
with open ( output_file , ' wb ' ) as f :
f . write ( metadata_length_bytes )
f . write ( metadata_json_bytes )
front_float16_values . tofile ( f )
rest_float32_values . tofile ( f )
def split_safetensor ( fn ) :
_ , data_start , metadata = safe_load_metadata ( fn )
text_model_offset = 3772703308
chunk_size = 536870912
for k in metadata :
# safetensor is in fp16, except for text moel
if ( metadata [ k ] [ " data_offsets " ] [ 0 ] < text_model_offset ) :
metadata [ k ] [ " data_offsets " ] [ 0 ] = int ( metadata [ k ] [ " data_offsets " ] [ 0 ] / 2 )
metadata [ k ] [ " data_offsets " ] [ 1 ] = int ( metadata [ k ] [ " data_offsets " ] [ 1 ] / 2 )
last_offset = 0
part_end_offsets = [ ]
for k in metadata :
offset = metadata [ k ] [ ' data_offsets ' ] [ 0 ]
if offset == text_model_offset :
break
part_offset = offset - last_offset
if ( part_offset > = chunk_size ) :
part_end_offsets . append ( data_start + offset )
last_offset = offset
text_model_start = int ( text_model_offset / 2 )
net_bytes = bytes ( open ( fn , ' rb ' ) . read ( ) )
part_end_offsets . append ( text_model_start + data_start )
cur_pos = 0
for i , end_pos in enumerate ( part_end_offsets ) :
with open ( os . path . join ( os . path . dirname ( __file__ ) , f ' ./net_part { i } .safetensors ' ) , " wb+ " ) as f :
f . write ( net_bytes [ cur_pos : end_pos ] )
cur_pos = end_pos
with open ( os . path . join ( os . path . dirname ( __file__ ) , f ' ./net_textmodel.safetensors ' ) , " wb+ " ) as f :
f . write ( net_bytes [ text_model_start + data_start : ] )
return part_end_offsets
def fetch_dep ( file , url ) :
with open ( file , " w " , encoding = " utf-8 " ) as f :
f . write ( requests . get ( url ) . text . replace ( " https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/bpe_simple_vocab_16e6.mjs " , " ./bpe_simple_vocab_16e6.mjs " ) )
if __name__ == " __main__ " :
fetch_dep ( os . path . join ( os . path . dirname ( __file__ ) , " clip_tokenizer.js " ) , " https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/clip_tokenizer.js " )
fetch_dep ( os . path . join ( os . path . dirname ( __file__ ) , " bpe_simple_vocab_16e6.mjs " ) , " https://huggingface.co/wpmed/tinygrad-sd-f16/raw/main/bpe_simple_vocab_16e6.mjs " )
parser = argparse . ArgumentParser ( description = ' Run Stable Diffusion ' , formatter_class = argparse . ArgumentDefaultsHelpFormatter )
parser . add_argument ( ' --remoteweights ' , action = ' store_true ' , help = " Use safetensors from Huggingface, or from local " )
args = parser . parse_args ( )
Device . DEFAULT = " WEBGPU "
model = StableDiffusion ( )
# load in weights
load_state_dict ( model , torch_load ( fetch ( ' https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt ' , ' sd-v1-4.ckpt ' ) ) [ ' state_dict ' ] , strict = False )
class Step ( NamedTuple ) :
name : str = " "
input : List [ Tensor ] = [ ]
forward : Any = None
sub_steps = [
Step ( name = " textModel " , input = [ Tensor . randn ( 1 , 77 ) ] , forward = model . cond_stage_model . transformer . text_model ) ,
Step ( name = " diffusor " , input = [ Tensor . randn ( 1 , 77 , 768 ) , Tensor . randn ( 1 , 77 , 768 ) , Tensor . randn ( 1 , 4 , 64 , 64 ) , Tensor . rand ( 1 ) , Tensor . randn ( 1 ) , Tensor . randn ( 1 ) , Tensor . randn ( 1 ) ] , forward = model ) ,
Step ( name = " decoder " , input = [ Tensor . randn ( 1 , 4 , 64 , 64 ) ] , forward = model . decode ) ,
Step ( name = " f16tof32 " , input = [ Tensor . randn ( 2097120 , dtype = dtypes . uint32 ) ] , forward = u32_to_f16 )
]
prg = " "
def fixup_code ( code , key ) :
code = code . replace ( key , ' main ' ) \
. replace ( " var<uniform> INFINITY : f32; \n " , " fn inf(a: f32) -> f32 { return a/0.0; } \n " ) \
. replace ( " @group(0) @binding(0) " , " " ) \
. replace ( " INFINITY " , " inf(1.0) " )
for i in range ( 1 , 9 ) : code = code . replace ( f " binding( { i } ) " , f " binding( { i - 1 } ) " )
return code
def compile_step ( model , step : Step ) :
run , special_names = jit_model ( step , * step . input )
functions , statements , bufs , _ = compile_net ( run , special_names )
state = get_state_dict ( model )
weights = { id ( x . uop . base . realized ) : name for name , x in state . items ( ) }
kernel_code = ' \n \n ' . join ( [ f " const { key } = ` { fixup_code ( code , key ) } `; " for key , code in functions . items ( ) ] )
kernel_names = ' , ' . join ( [ name for ( name , _ , _ , _ ) in statements ] )
input_names = [ name for _ , name in special_names . items ( ) if " input " in name ]
output_names = [ name for _ , name in special_names . items ( ) if " output " in name ]
input_buf_types = [ dtype_to_js_type ( bufs [ inp_name ] [ 1 ] ) for inp_name in input_names ]
output_buf_types = [ dtype_to_js_type ( bufs [ out_name ] [ 1 ] ) for out_name in output_names ]
kernel_calls = ' \n ' . join ( [ f " addComputePass(device, commandEncoder, piplines[ { i } ], [ { ' , ' . join ( args ) } ], { global_size } ); " for i , ( _name , args , global_size , _local_size ) in enumerate ( statements ) ] )
exported_bufs = ' \n ' . join ( [ f " const { name } = " + ( f " createEmptyBuf(device, { size } ); " if _key not in weights else f " createWeightBuf(device, { size } , getTensorBuffer(safetensor, metadata[ ' { weights [ _key ] } ' ], ' { weights [ _key ] } ' )) " ) + " ; " for name , ( size , dtype , _key ) in bufs . items ( ) ] )
gpu_write_bufs = ' \n ' . join ( [ f " const gpuWriteBuffer { i } = device.createBuffer( {{ size:input { i } .size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }} ); " for i , ( _ , value ) in enumerate ( special_names . items ( ) ) if " output " not in value ] )
input_writer = ' \n ' . join ( [ f " await gpuWriteBuffer { i } .mapAsync(GPUMapMode.WRITE); \n new { input_buf_types [ i ] } (gpuWriteBuffer { i } .getMappedRange()).set( " + f ' data { i } ); ' + f " \n gpuWriteBuffer { i } .unmap(); \n commandEncoder.copyBufferToBuffer(gpuWriteBuffer { i } , 0, input { i } , 0, gpuWriteBuffer { i } .size); " for i , _ in enumerate ( input_names ) ] )
return f """ \n var { step . name } = function() {{
{ kernel_code }
return { {
" setup " : async ( device , safetensor ) = > { {
const metadata = safetensor ? getTensorMetadata ( safetensor [ 0 ] ) : null ;
{ exported_bufs }
{ gpu_write_bufs }
const gpuReadBuffer = device . createBuffer ( { { size : output0 . size , usage : GPUBufferUsage . COPY_DST | GPUBufferUsage . MAP_READ } } ) ;
const kernels = [ { kernel_names } ] ;
const piplines = await Promise . all ( kernels . map ( name = > device . createComputePipelineAsync ( { { layout : " auto " , compute : { { module : device . createShaderModule ( { { code : name } } ) , entryPoint : " main " } } } } ) ) ) ;
return async ( { " , " . join ( [ f ' data { i } ' for i , ( k , v ) in enumerate ( special_names . items ( ) ) if v != " output0 " ] ) } ) = > { {
const commandEncoder = device . createCommandEncoder ( ) ;
{ input_writer }
{ kernel_calls }
commandEncoder . copyBufferToBuffer ( output0 , 0 , gpuReadBuffer , 0 , output0 . size ) ;
const gpuCommands = commandEncoder . finish ( ) ;
device . queue . submit ( [ gpuCommands ] ) ;
await gpuReadBuffer . mapAsync ( GPUMapMode . READ ) ;
const resultBuffer = new { output_buf_types [ 0 ] } ( gpuReadBuffer . size / { bufs [ output_names [ 0 ] ] [ 1 ] . itemsize } ) ;
resultBuffer . set ( new { output_buf_types [ 0 ] } ( gpuReadBuffer . getMappedRange ( ) ) ) ;
gpuReadBuffer . unmap ( ) ;
return resultBuffer ;
} }
} }
} }
} }
"""
for step in sub_steps :
print ( f ' Executing step= { step . name } ' )
prg + = compile_step ( model , step )
if step . name == " diffusor " :
if args . remoteweights :
base_url = " https://huggingface.co/wpmed/stable-diffusion-f16-new/resolve/main "
else :
state = get_state_dict ( model )
safe_save ( state , os . path . join ( os . path . dirname ( __file__ ) , " net.safetensors " ) )
convert_f32_to_f16 ( os . path . join ( os . path . dirname ( __file__ ) , " ./net.safetensors " ) , os . path . join ( os . path . dirname ( __file__ ) , " ./net_conv.safetensors " ) )
split_safetensor ( os . path . join ( os . path . dirname ( __file__ ) , " ./net_conv.safetensors " ) )
os . remove ( os . path . join ( os . path . dirname ( __file__ ) , " net.safetensors " ) )
os . remove ( os . path . join ( os . path . dirname ( __file__ ) , " net_conv.safetensors " ) )
base_url = " . "
prekernel = f """
window . MODEL_BASE_URL = " {base_url} " ;
const getTensorMetadata = ( safetensorBuffer ) = > { {
const metadataLength = Number ( new DataView ( safetensorBuffer . buffer ) . getBigUint64 ( 0 , true ) ) ;
const metadata = JSON . parse ( new TextDecoder ( " utf8 " ) . decode ( safetensorBuffer . subarray ( 8 , 8 + metadataLength ) ) ) ;
return Object . fromEntries ( Object . entries ( metadata ) . filter ( ( [ k , v ] ) = > k != = " __metadata__ " ) . map ( ( [ k , v ] ) = > [ k , { { . . . v , data_offsets : v . data_offsets . map ( x = > 8 + metadataLength + x ) } } ] ) ) ;
} } ;
const getTensorBuffer = ( safetensorParts , tensorMetadata , key ) = > { {
let selectedPart = 0 ;
let counter = 0 ;
let partStartOffsets = [ 1131408336 , 2227518416 , 3308987856 , 4265298864 ] ;
let correctedOffsets = tensorMetadata . data_offsets ;
let prev_offset = 0 ;
for ( let start of partStartOffsets ) { {
prev_offset = ( counter == 0 ) ? 0 : partStartOffsets [ counter - 1 ] ;
if ( tensorMetadata . data_offsets [ 0 ] < start ) { {
selectedPart = counter ;
correctedOffsets = [ correctedOffsets [ 0 ] - prev_offset , correctedOffsets [ 1 ] - prev_offset ] ;
break ;
} }
counter + + ;
} }
return safetensorParts [ selectedPart ] . subarray ( . . . correctedOffsets ) ;
} }
const getWeight = ( safetensors , key ) = > { {
let uint8Data = getTensorBuffer ( safetensors , getTensorMetadata ( safetensors [ 0 ] ) [ key ] , key ) ;
return new Float32Array ( uint8Data . buffer , uint8Data . byteOffset , uint8Data . byteLength / Float32Array . BYTES_PER_ELEMENT ) ;
} }
const createEmptyBuf = ( device , size ) = > { {
return device . createBuffer ( { { size , usage : GPUBufferUsage . STORAGE | GPUBufferUsage . COPY_SRC | GPUBufferUsage . COPY_DST } } ) ;
} } ;
const createWeightBuf = ( device , size , data ) = > { {
const buf = device . createBuffer ( { { mappedAtCreation : true , size , usage : GPUBufferUsage . STORAGE } } ) ;
new Uint8Array ( buf . getMappedRange ( ) ) . set ( data ) ;
buf . unmap ( ) ;
return buf ;
} } ;
const addComputePass = ( device , commandEncoder , pipeline , bufs , workgroup ) = > { {
const bindGroup = device . createBindGroup ( { { layout : pipeline . getBindGroupLayout ( 0 ) , entries : bufs . map ( ( buffer , index ) = > ( { { binding : index , resource : { { buffer } } } } ) ) } } ) ;
const passEncoder = commandEncoder . beginComputePass ( ) ;
passEncoder . setPipeline ( pipeline ) ;
passEncoder . setBindGroup ( 0 , bindGroup ) ;
passEncoder . dispatchWorkgroups ( . . . workgroup ) ;
passEncoder . end ( ) ;
} } ; """
with open ( os . path . join ( os . path . dirname ( __file__ ) , " net.js " ) , " w " ) as text_file :
text_file . write ( prekernel + prg )