from tinygrad import Tensor, dtypes from tinygrad.helpers import fetch from tinygrad.nn import Linear, LayerNorm, Embedding, Conv2d from typing import List, Optional, Union, Tuple, Dict from abc import ABC, abstractmethod from functools import lru_cache from PIL import Image import numpy as np import re, gzip @lru_cache() def default_bpe(): # Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license) return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz") class Tokenizer: """ Namespace for CLIP Text Tokenizer components. """ @staticmethod def get_pairs(word): """ Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). """ return set(zip(word, word[1:])) @staticmethod def whitespace_clean(text): text = re.sub(r'\s+', ' ', text) text = text.strip() return text @staticmethod def bytes_to_unicode(): """ Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) cs = bs[:] n = 0 for b in range(2**8): if b not in bs: bs.append(b) cs.append(2**8+n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) class ClipTokenizer: def __init__(self): self.byte_encoder = Tokenizer.bytes_to_unicode() merges = gzip.open(default_bpe()).read().decode("utf-8").split('\n') merges = merges[1:49152-256-2+1] merges = [tuple(merge.split()) for merge in merges] vocab = list(Tokenizer.bytes_to_unicode().values()) vocab = vocab + [v+'' for v in vocab] for merge in merges: vocab.append(''.join(merge)) vocab.extend(['<|startoftext|>', '<|endoftext|>']) self.encoder = dict(zip(vocab, range(len(vocab)))) self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[^\s]+""", re.IGNORECASE) def bpe(self, token): if token in self.cache: return self.cache[token] word = tuple(token[:-1]) + ( token[-1] + '',) pairs = Tokenizer.get_pairs(word) if not pairs: return token+'' while True: bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) if bigram not in self.bpe_ranks: break first, second = bigram new_word = [] i = 0 while i < len(word): try: j = word.index(first, i) new_word.extend(word[i:j]) i = j except Exception: new_word.extend(word[i:]) break if word[i] == first and i < len(word)-1 and word[i+1] == second: new_word.append(first+second) i += 2 else: new_word.append(word[i]) i += 1 new_word = tuple(new_word) word = new_word if len(word) == 1: break pairs = Tokenizer.get_pairs(word) word = ' '.join(word) self.cache[token] = word return word def encode(self, text:str, pad_with_zeros:bool=False) -> List[int]: bpe_tokens: List[int] = [] text = Tokenizer.whitespace_clean(text.strip()).lower() for token in re.findall(self.pat, text): token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) # Truncation, keeping two slots for start and end tokens. if len(bpe_tokens) > 75: bpe_tokens = bpe_tokens[:75] return [49406] + bpe_tokens + [49407] + ([0] if pad_with_zeros else [49407]) * (77 - len(bpe_tokens) - 2) class Embedder(ABC): input_key: str @abstractmethod def __call__(self, x:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]: pass class Closed: """ Namespace for OpenAI CLIP model components. """ class ClipMlp: def __init__(self): self.fc1 = Linear(768, 3072) self.fc2 = Linear(3072, 768) def __call__(self, h:Tensor) -> Tensor: h = self.fc1(h) h = h.quick_gelu() h = self.fc2(h) return h class ClipAttention: def __init__(self): self.embed_dim = 768 self.num_heads = 12 self.head_dim = self.embed_dim // self.num_heads self.k_proj = Linear(self.embed_dim, self.embed_dim) self.v_proj = Linear(self.embed_dim, self.embed_dim) self.q_proj = Linear(self.embed_dim, self.embed_dim) self.out_proj = Linear(self.embed_dim, self.embed_dim) def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor: bsz, tgt_len, embed_dim = hidden_states.shape q,k,v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) q,k,v = [x.reshape(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) for x in (q,k,v)] attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=causal_attention_mask) return self.out_proj(attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim)) class ClipEncoderLayer: def __init__(self): self.self_attn = Closed.ClipAttention() self.layer_norm1 = LayerNorm(768) self.mlp = Closed.ClipMlp() self.layer_norm2 = LayerNorm(768) def __call__(self, hidden_states:Tensor, causal_attention_mask:Tensor) -> Tensor: residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states = self.self_attn(hidden_states, causal_attention_mask) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class ClipTextEmbeddings: def __init__(self): self.token_embedding = Embedding(49408, 768) self.position_embedding = Embedding(77, 768) def __call__(self, input_ids:Tensor, position_ids:Tensor) -> Tensor: return self.token_embedding(input_ids) + self.position_embedding(position_ids) class ClipEncoder: def __init__(self, layer_count:int=12): self.layers = [Closed.ClipEncoderLayer() for _ in range(layer_count)] def __call__(self, x:Tensor, causal_attention_mask:Tensor, ret_layer_idx:Optional[int]=None) -> Tensor: # the indexing of layers is NOT off by 1, the original code considers the "input" as the first hidden state layers = self.layers if ret_layer_idx is None else self.layers[:ret_layer_idx] for l in layers: x = l(x, causal_attention_mask) return x class ClipTextTransformer: def __init__(self, ret_layer_idx:Optional[int]=None): self.embeddings = Closed.ClipTextEmbeddings() self.encoder = Closed.ClipEncoder() self.final_layer_norm = LayerNorm(768) self.ret_layer_idx = ret_layer_idx def __call__(self, input_ids:Tensor) -> Tensor: x = self.embeddings(input_ids, Tensor.arange(input_ids.shape[1]).reshape(1, -1)) x = self.encoder(x, Tensor.full((1, 1, 77, 77), float("-inf")).triu(1), self.ret_layer_idx) return self.final_layer_norm(x) if (self.ret_layer_idx is None) else x class ClipTextModel: def __init__(self, ret_layer_idx:Optional[int]): self.text_model = Closed.ClipTextTransformer(ret_layer_idx=ret_layer_idx) # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L331 class FrozenClosedClipEmbedder(Embedder): def __init__(self, ret_layer_idx:Optional[int]=None): self.tokenizer = Tokenizer.ClipTokenizer() self.transformer = Closed.ClipTextModel(ret_layer_idx) self.input_key = "txt" def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]: if isinstance(texts, str): texts = [texts] assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}" tokens = Tensor.cat(*[Tensor(self.tokenizer.encode(text)) for text in texts], dim=0) return self.transformer.text_model(tokens.reshape(len(texts),-1)) class Open: """ Namespace for OpenCLIP model components. """ class MultiheadAttention: def __init__(self, dims:int, n_heads:int): self.dims = dims self.n_heads = n_heads self.d_head = self.dims // self.n_heads self.in_proj_bias = Tensor.empty(3*dims) self.in_proj_weight = Tensor.empty(3*dims, dims) self.out_proj = Linear(dims, dims) def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor: T,B,C = x.shape proj = x.linear(self.in_proj_weight.T, self.in_proj_bias) proj = proj.unflatten(-1, (3,C)).unsqueeze(0).transpose(0, -2) q,k,v = [y.reshape(T, B*self.n_heads, self.d_head).transpose(0, 1).reshape(B, self.n_heads, T, self.d_head) for y in proj.chunk(3)] attn_output = Tensor.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) attn_output = attn_output.permute(2, 0, 1, 3).reshape(T*B, C) attn_output = self.out_proj(attn_output) attn_output = attn_output.reshape(T, B, C) return attn_output class Mlp: def __init__(self, dims, hidden_dims): self.c_fc = Linear(dims, hidden_dims) self.c_proj = Linear(hidden_dims, dims) def __call__(self, x:Tensor) -> Tensor: return x.sequential([self.c_fc, Tensor.gelu, self.c_proj]) # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L210 class ResidualAttentionBlock: def __init__(self, dims:int, n_heads:int, mlp_ratio:float): self.ln_1 = LayerNorm(dims) self.attn = Open.MultiheadAttention(dims, n_heads) self.ln_2 = LayerNorm(dims) self.mlp = Open.Mlp(dims, int(dims * mlp_ratio)) def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None, transpose:bool=False) -> Tensor: q_x = self.ln_1(x) attn_out = self.attn(q_x.transpose(0, 1) if transpose else q_x, attn_mask=attn_mask) attn_out = attn_out.transpose(0, 1) if transpose else attn_out x = x + attn_out x = x + self.mlp(self.ln_2(x)) return x # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L317 class ClipTransformer: def __init__(self, dims:int, layers:int, n_heads:int, mlp_ratio:float=4.0): self.resblocks = [ Open.ResidualAttentionBlock(dims, n_heads, mlp_ratio) for _ in range(layers) ] def __call__(self, x:Tensor, attn_mask:Optional[Tensor]=None) -> Tensor: for r in self.resblocks: x = r(x, attn_mask=attn_mask, transpose=True) return x # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/model.py#L220 # https://github.com/mlfoundations/open_clip/blob/58e4e39aaabc6040839b0d2a7e8bf20979e4558a/src/open_clip/transformer.py#L661 class ClipTextTransformer: def __init__(self, width:int, n_heads:int, layers:int, vocab_size:int=49408, ctx_length:int=77): self.token_embedding = Embedding(vocab_size, width) self.positional_embedding = Tensor.empty(ctx_length, width) self.transformer = Open.ClipTransformer(width, layers, n_heads) self.ln_final = LayerNorm(width) self.text_projection = Tensor.empty(width, width) self.attn_mask = Tensor.full((77, 77), float("-inf")).triu(1).realize() def __call__(self, text:Tensor) -> Tensor: seq_len = text.shape[1] x = self.token_embedding(text) x = x + self.positional_embedding[:seq_len] x = self.transformer(x, attn_mask=self.attn_mask) x = self.ln_final(x) pooled = x[:, text.argmax(dim=-1)] @ self.text_projection return pooled class ClipVisionTransformer: def __init__(self, width:int, layers:int, d_head:int, image_size:int, patch_size:int): grid_size = image_size // patch_size n_heads = width // d_head assert n_heads * d_head == width self.conv1 = Conv2d(3, width, kernel_size=patch_size, stride=patch_size, bias=False) self.class_embedding = Tensor.empty(width) self.positional_embedding = Tensor.empty(grid_size * grid_size + 1, width) self.transformer = Open.ClipTransformer(width, layers, n_heads) self.ln_pre = LayerNorm(width) self.ln_post = LayerNorm(width) self.proj = Tensor.empty(width, 1024) def __call__(self, x:Tensor) -> Tensor: x = self.conv1(x) x = x.reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1) x = self.class_embedding.reshape(1, 1, -1).expand(x.shape[0], 1, -1).cat(x, dim=1) x = x + self.positional_embedding x = self.ln_pre(x) x = self.transformer(x) x = self.ln_post(x) pooled = x[:, 0] @ self.proj return pooled # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L396 # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/encoders/modules.py#L498 class FrozenOpenClipEmbedder(Embedder): def __init__(self, dims:int, n_heads:int, layers:int, return_pooled:bool, ln_penultimate:bool=False): self.tokenizer = Tokenizer.ClipTokenizer() self.model = Open.ClipTextTransformer(dims, n_heads, layers) self.return_pooled = return_pooled self.input_key = "txt" self.ln_penultimate = ln_penultimate def tokenize(self, text:str, device:Optional[str]=None) -> Tensor: return Tensor(self.tokenizer.encode(text, pad_with_zeros=True), dtype=dtypes.int64, device=device).reshape(1,-1) def text_transformer_forward(self, x:Tensor, attn_mask:Optional[Tensor]=None): for r in self.model.transformer.resblocks: x, penultimate = r(x, attn_mask=attn_mask), x return x.permute(1, 0, 2), penultimate.permute(1, 0, 2) def embed_tokens(self, tokens:Tensor) -> Union[Tensor,Tuple[Tensor,...]]: x = self.model.token_embedding(tokens).add(self.model.positional_embedding).permute(1,0,2) x, penultimate = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) if self.ln_penultimate: penultimate = self.model.ln_final(penultimate) if self.return_pooled: x = self.model.ln_final(x) index = tokens.argmax(axis=-1).reshape(-1,1,1).expand(x.shape[0],1,x.shape[-1]) pooled = x.gather(1, index).squeeze(1) @ self.model.text_projection return penultimate, pooled else: return penultimate def __call__(self, texts:Union[str,List[str],Tensor]) -> Union[Tensor,Tuple[Tensor,...]]: if isinstance(texts, str): texts = [texts] assert isinstance(texts, (list,tuple)), f"expected list of strings, got {type(texts).__name__}" tokens = Tensor.cat(*[self.tokenize(text) for text in texts], dim=0) return self.embed_tokens(tokens) clip_configs: Dict = { "ViT-H-14": { "dims": 1024, "vision_cfg": { "width": 1280, "layers": 32, "d_head": 80, "image_size": 224, "patch_size": 14, }, "text_cfg": { "width": 1024, "n_heads": 16, "layers": 24, "ctx_length": 77, "vocab_size": 49408, }, "return_pooled": False, "ln_penultimate": True, } } class OpenClipEncoder: def __init__(self, dims:int, text_cfg:Dict, vision_cfg:Dict, **_): self.visual = Open.ClipVisionTransformer(**vision_cfg) text = Open.ClipTextTransformer(**text_cfg) self.transformer = text.transformer self.token_embedding = text.token_embedding self.positional_embedding = text.positional_embedding self.ln_final = text.ln_final self.text_projection = text.text_projection self.attn_mask = Tensor.full((77, 77), float("-inf")).triu(1).realize() self.mean = Tensor([0.48145466, 0.45782750, 0.40821073]).reshape(-1, 1, 1) self.std = Tensor([0.26862954, 0.26130258, 0.27577711]).reshape(-1, 1, 1) # TODO: # Should be doable in pure tinygrad, would just require some work and verification. # This is very desirable since it would allow for full generation->evaluation in a single JIT call. def prepare_image(self, image:Image.Image) -> Tensor: SIZE = 224 w, h = image.size scale = min(SIZE / h, SIZE / w) image = image.resize((max(int(w*scale),SIZE),max(int(h*scale),SIZE)), Image.Resampling.BICUBIC) w, h = image.size if w > SIZE: left = (w - SIZE) // 2 image = image.crop((left, left+SIZE, 0, SIZE)) elif h > SIZE: top = (h - SIZE) // 2 image = image.crop((0, SIZE, top, top+SIZE)) x = Tensor(np.array(image.convert('RGB')), device=self.std.device) x = x.permute(2, 0, 1).cast(dtypes.float32) / 255.0 return (x - self.mean) / self.std def encode_tokens(self, tokens:Tensor) -> Tensor: x = self.token_embedding(tokens) x = x + self.positional_embedding x = self.transformer(x, attn_mask=self.attn_mask) x = self.ln_final(x) x = x[:, tokens.argmax(axis=-1)] x = x @ self.text_projection return x def get_clip_score(self, tokens:Tensor, image:Tensor) -> Tensor: image_features: Tensor = self.visual(image) image_features /= image_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm text_features = self.encode_tokens(tokens) text_features /= text_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm return (image_features * text_features).sum(axis=-1)