from __future__ import annotations import sys, argparse from tinygrad import Tensor, nn, UOp, TinyJit, getenv class SimpleTokenizer: def __init__(self, vocab: list[str]): self.vocab: list[str] = vocab self.biggest_token: int = max(map(len, vocab)) self.token_to_id: dict[str, int] = {tok: i for i, tok in enumerate(vocab)} self.replace_space = "Ġ" self.replace_newline = "Ċ" def encode(self, text:str) -> list[int]: s = text.replace(" ", self.replace_space).replace("\n", self.replace_newline) out: list[int] = [] i = 0 while i < len(s): j = min(i+self.biggest_token, len(s)) while i < j and (tid:=self.token_to_id.get(s[i:j])) is None: j -= 1 if tid is None: raise RuntimeError(f"token not found in {s}") assert tid is not None, f"token not found in {s}" out.append(tid) i = j return out def decode(self, ids: list[int]) -> str: return ''.join(self.vocab[tid] for tid in ids).replace(self.replace_space, " ").replace(self.replace_newline, "\n") def role(self, role:str): return [t for x in ["<|start_header_id|>", role, "<|end_header_id|>\n\n"] for t in self.encode(x)] # llama style def apply_rope(x:Tensor, start_pos:int|UOp, base:int=10000): B, H, T, Hd = x.shape # NOTE: this is usually in a RoPE cache, but tinygrad JIT should prune it outside the kernel # TODO: make it do that freq = base ** (-Tensor.arange(0, 1, 2/Hd, dtype='float32')) angles = Tensor.arange(start_pos, start_pos+T, dtype='float32')[None, None, :, None] * freq cos, sin = angles.cos(), angles.sin() x = x.reshape(B, H, T, Hd // 2, 2) # split into pairs y1 = x[..., 0] * cos - x[..., 1] * sin y2 = x[..., 0] * sin + x[..., 1] * cos return Tensor.stack(y1, y2, dim=-1).reshape(B, H, T, Hd) class TransformerBlock: def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int=0): self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.head_dim = dim // n_heads self.max_context = max_context # --- attention projections (all linear, bias-free) ------------------ kv_proj_out = self.head_dim * n_kv_heads # Llama-3 uses the same dim for K/V self.attn_q = nn.Linear(dim, dim, bias=False) self.attn_k = nn.Linear(dim, kv_proj_out, bias=False) self.attn_v = nn.Linear(dim, kv_proj_out, bias=False) self.attn_output = nn.Linear(dim, dim, bias=False) # --- RMSNorms -------------------------------------------------------- self.attn_norm = nn.RMSNorm(dim, norm_eps) self.ffn_norm = nn.RMSNorm(dim, norm_eps) # --- feed-forward ---------------------------------------------------- self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False) self.ffn_up = nn.Linear(dim, hidden_dim, bias=False) self.ffn_down = nn.Linear(hidden_dim, dim, bias=False) def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor: x_norm = self.attn_norm(x) # (B,T,D) q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm) B, T, _ = x.shape q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd) k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd) v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd) q = apply_rope(q, start_pos) k = apply_rope(k, start_pos) # TODO: remove these kv cache realizes if not hasattr(self, "cache_kv"): self.cache_kv = Tensor.zeros(2, B, self.n_kv_heads, self.max_context, self.head_dim, dtype=k.dtype, device=k.device).contiguous().realize() self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v)).realize() # type: ignore k = self.cache_kv[0, :, :, 0:start_pos+T, :] v = self.cache_kv[1, :, :, 0:start_pos+T, :] # NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if T > 1 else None attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd) attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D) attn = self.attn_output(attn) return x + attn def _feed_forward(self, h: Tensor) -> Tensor: h_norm = self.ffn_norm(h) gated = self.ffn_gate(h_norm).silu() * self.ffn_up(h_norm) return h + self.ffn_down(gated) def __call__(self, x: Tensor, start_pos: int|UOp): return self._feed_forward(self._attention(x, start_pos)) class Transformer: def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, max_context): self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context) for _ in range(num_blocks)] self.token_embd = nn.Embedding(vocab_size, dim) self.output_norm = nn.RMSNorm(dim, norm_eps) self.output = nn.Linear(dim, vocab_size, bias=False) self.max_context = max_context # JIT is used if T=1 and start_pos is a UOp. TODO: make this not needed by including T in the JIT and making start_pos always a UOp self.forward_jit = TinyJit(self.forward) def forward(self, tokens:Tensor, start_pos:int|UOp) -> Tensor: x = self.token_embd(tokens) # (B, T, D) for block in self.blk: x = block(x, start_pos) # TODO: add temperature return self.output(self.output_norm(x))[:, -1, :].softmax(-1).argmax(-1, keepdim=True) def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor: return (self.forward_jit if getenv("JIT", 1) and tokens.shape[1] == 1 and isinstance(start_pos, UOp) else self.forward)(tokens, start_pos) @staticmethod def from_gguf(gguf:Tensor, max_context:int|None=None) -> tuple[Transformer, dict]: # TODO: remove the need for copy to default device kv, state_dict = nn.state.gguf_load(gguf.to(None)) # all state items should be float16, not float32 state_dict = {k:v.cast('float16') for k,v in state_dict.items()} # some models like Llama 3.2 don't have an output.weight, they just tie to the token_embd.weight if 'output.weight' not in state_dict: state_dict['output.weight'] = state_dict['token_embd.weight'] arch = kv['general.architecture'] max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length'] model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'], hidden_dim=kv[f'{arch}.feed_forward_length'], n_heads=kv[f'{arch}.attention.head_count'], n_kv_heads=kv[f'{arch}.attention.head_count_kv'], norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'], vocab_size=len(kv['tokenizer.ggml.tokens']), max_context=max_context) nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused return model, kv def generate(self, tokens:list[int], start_pos=0): v_start_pos = UOp.variable("start_pos", 1, self.max_context-1) start_pos = 0 t = Tensor([tokens[start_pos:]], dtype="int32") self.forward_jit.reset() # TODO: why is this required? root cause the issue and make it not be needed while len(tokens) < self.max_context: t = self(t, v_start_pos.bind(start_pos) if getenv("SYM", 1) and start_pos != 0 and t.shape[-1] == 1 else start_pos) next_id = int(t.item()) tokens.append(next_id) start_pos = len(tokens) - 1 yield next_id models = { "1B": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf", "3B": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q6_K.gguf", "3B_f16": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-f16.gguf", "8B": "https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf", } if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--size", choices=list(models.keys()), default=list(models.keys())[0], help="Model size") parser.add_argument("--max_context", type=int, default=4096, help="Max Context Length") args = parser.parse_args() # load the model model, kv = Transformer.from_gguf(Tensor.from_url(models[args.size]), args.max_context) # extract some metadata tok = SimpleTokenizer(kv["tokenizer.ggml.tokens"]) bos_id: int = kv['tokenizer.ggml.bos_token_id'] eos_id: int = kv['tokenizer.ggml.eos_token_id'] ids: list[int] = [bos_id] while 1: start_pos = len(ids) - 1 try: ids += tok.role("user") + tok.encode(input('>>> ')) + [eos_id] + tok.role("assistant") except EOFError: break for next_id in model.generate(ids, start_pos): sys.stdout.write(tok.decode([next_id]) if next_id != eos_id else "\n\n") sys.stdout.flush() if next_id == eos_id: break