You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
183 lines
8.9 KiB
183 lines
8.9 KiB
1 day ago
|
from __future__ import annotations
|
||
|
import sys, argparse
|
||
|
from tinygrad import Tensor, nn, UOp, TinyJit, getenv
|
||
|
|
||
|
class SimpleTokenizer:
|
||
|
def __init__(self, vocab: list[str]):
|
||
|
self.vocab: list[str] = vocab
|
||
|
self.biggest_token: int = max(map(len, vocab))
|
||
|
self.token_to_id: dict[str, int] = {tok: i for i, tok in enumerate(vocab)}
|
||
|
self.replace_space = "Ġ"
|
||
|
self.replace_newline = "Ċ"
|
||
|
|
||
|
def encode(self, text:str) -> list[int]:
|
||
|
s = text.replace(" ", self.replace_space).replace("\n", self.replace_newline)
|
||
|
out: list[int] = []
|
||
|
i = 0
|
||
|
while i < len(s):
|
||
|
j = min(i+self.biggest_token, len(s))
|
||
|
while i < j and (tid:=self.token_to_id.get(s[i:j])) is None: j -= 1
|
||
|
if tid is None: raise RuntimeError(f"token not found in {s}")
|
||
|
assert tid is not None, f"token not found in {s}"
|
||
|
out.append(tid)
|
||
|
i = j
|
||
|
return out
|
||
|
|
||
|
def decode(self, ids: list[int]) -> str:
|
||
|
return ''.join(self.vocab[tid] for tid in ids).replace(self.replace_space, " ").replace(self.replace_newline, "\n")
|
||
|
|
||
|
def role(self, role:str):
|
||
|
return [t for x in ["<|start_header_id|>", role, "<|end_header_id|>\n\n"] for t in self.encode(x)] # llama style
|
||
|
|
||
|
def apply_rope(x:Tensor, start_pos:int|UOp, base:int=10000):
|
||
|
B, H, T, Hd = x.shape
|
||
|
# NOTE: this is usually in a RoPE cache, but tinygrad JIT should prune it outside the kernel
|
||
|
# TODO: make it do that
|
||
|
freq = base ** (-Tensor.arange(0, 1, 2/Hd, dtype='float32'))
|
||
|
angles = Tensor.arange(start_pos, start_pos+T, dtype='float32')[None, None, :, None] * freq
|
||
|
cos, sin = angles.cos(), angles.sin()
|
||
|
x = x.reshape(B, H, T, Hd // 2, 2) # split into pairs
|
||
|
y1 = x[..., 0] * cos - x[..., 1] * sin
|
||
|
y2 = x[..., 0] * sin + x[..., 1] * cos
|
||
|
return Tensor.stack(y1, y2, dim=-1).reshape(B, H, T, Hd)
|
||
|
|
||
|
class TransformerBlock:
|
||
|
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int=0):
|
||
|
self.n_heads = n_heads
|
||
|
self.n_kv_heads = n_kv_heads
|
||
|
self.head_dim = dim // n_heads
|
||
|
self.max_context = max_context
|
||
|
|
||
|
# --- attention projections (all linear, bias-free) ------------------
|
||
|
kv_proj_out = self.head_dim * n_kv_heads # Llama-3 uses the same dim for K/V
|
||
|
self.attn_q = nn.Linear(dim, dim, bias=False)
|
||
|
self.attn_k = nn.Linear(dim, kv_proj_out, bias=False)
|
||
|
self.attn_v = nn.Linear(dim, kv_proj_out, bias=False)
|
||
|
self.attn_output = nn.Linear(dim, dim, bias=False)
|
||
|
|
||
|
# --- RMSNorms --------------------------------------------------------
|
||
|
self.attn_norm = nn.RMSNorm(dim, norm_eps)
|
||
|
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
|
||
|
|
||
|
# --- feed-forward ----------------------------------------------------
|
||
|
self.ffn_gate = nn.Linear(dim, hidden_dim, bias=False)
|
||
|
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
|
||
|
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
|
||
|
|
||
|
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||
|
x_norm = self.attn_norm(x) # (B,T,D)
|
||
|
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
|
||
|
|
||
|
B, T, _ = x.shape
|
||
|
q = q.reshape(B, T, self.n_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
|
||
|
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||
|
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||
|
|
||
|
q = apply_rope(q, start_pos)
|
||
|
k = apply_rope(k, start_pos)
|
||
|
|
||
|
# TODO: remove these kv cache realizes
|
||
|
if not hasattr(self, "cache_kv"):
|
||
|
self.cache_kv = Tensor.zeros(2, B, self.n_kv_heads, self.max_context, self.head_dim, dtype=k.dtype, device=k.device).contiguous().realize()
|
||
|
self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v)).realize() # type: ignore
|
||
|
k = self.cache_kv[0, :, :, 0:start_pos+T, :]
|
||
|
v = self.cache_kv[1, :, :, 0:start_pos+T, :]
|
||
|
|
||
|
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
|
||
|
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if T > 1 else None
|
||
|
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
|
||
|
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
|
||
|
attn = self.attn_output(attn)
|
||
|
return x + attn
|
||
|
|
||
|
def _feed_forward(self, h: Tensor) -> Tensor:
|
||
|
h_norm = self.ffn_norm(h)
|
||
|
gated = self.ffn_gate(h_norm).silu() * self.ffn_up(h_norm)
|
||
|
return h + self.ffn_down(gated)
|
||
|
|
||
|
def __call__(self, x: Tensor, start_pos: int|UOp):
|
||
|
return self._feed_forward(self._attention(x, start_pos))
|
||
|
|
||
|
class Transformer:
|
||
|
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, max_context):
|
||
|
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context) for _ in range(num_blocks)]
|
||
|
self.token_embd = nn.Embedding(vocab_size, dim)
|
||
|
self.output_norm = nn.RMSNorm(dim, norm_eps)
|
||
|
self.output = nn.Linear(dim, vocab_size, bias=False)
|
||
|
self.max_context = max_context
|
||
|
# JIT is used if T=1 and start_pos is a UOp. TODO: make this not needed by including T in the JIT and making start_pos always a UOp
|
||
|
self.forward_jit = TinyJit(self.forward)
|
||
|
|
||
|
def forward(self, tokens:Tensor, start_pos:int|UOp) -> Tensor:
|
||
|
x = self.token_embd(tokens) # (B, T, D)
|
||
|
for block in self.blk: x = block(x, start_pos)
|
||
|
# TODO: add temperature
|
||
|
return self.output(self.output_norm(x))[:, -1, :].softmax(-1).argmax(-1, keepdim=True)
|
||
|
|
||
|
def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor:
|
||
|
return (self.forward_jit if getenv("JIT", 1) and tokens.shape[1] == 1 and isinstance(start_pos, UOp) else self.forward)(tokens, start_pos)
|
||
|
|
||
|
@staticmethod
|
||
|
def from_gguf(gguf:Tensor, max_context:int|None=None) -> tuple[Transformer, dict]:
|
||
|
# TODO: remove the need for copy to default device
|
||
|
kv, state_dict = nn.state.gguf_load(gguf.to(None))
|
||
|
|
||
|
# all state items should be float16, not float32
|
||
|
state_dict = {k:v.cast('float16') for k,v in state_dict.items()}
|
||
|
|
||
|
# some models like Llama 3.2 don't have an output.weight, they just tie to the token_embd.weight
|
||
|
if 'output.weight' not in state_dict: state_dict['output.weight'] = state_dict['token_embd.weight']
|
||
|
|
||
|
arch = kv['general.architecture']
|
||
|
max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
|
||
|
model = Transformer(num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'], hidden_dim=kv[f'{arch}.feed_forward_length'],
|
||
|
n_heads=kv[f'{arch}.attention.head_count'], n_kv_heads=kv[f'{arch}.attention.head_count_kv'],
|
||
|
norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'], vocab_size=len(kv['tokenizer.ggml.tokens']), max_context=max_context)
|
||
|
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
||
|
return model, kv
|
||
|
|
||
|
def generate(self, tokens:list[int], start_pos=0):
|
||
|
v_start_pos = UOp.variable("start_pos", 1, self.max_context-1)
|
||
|
start_pos = 0
|
||
|
t = Tensor([tokens[start_pos:]], dtype="int32")
|
||
|
self.forward_jit.reset() # TODO: why is this required? root cause the issue and make it not be needed
|
||
|
while len(tokens) < self.max_context:
|
||
|
t = self(t, v_start_pos.bind(start_pos) if getenv("SYM", 1) and start_pos != 0 and t.shape[-1] == 1 else start_pos)
|
||
|
next_id = int(t.item())
|
||
|
tokens.append(next_id)
|
||
|
start_pos = len(tokens) - 1
|
||
|
yield next_id
|
||
|
|
||
|
models = {
|
||
|
"1B": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf",
|
||
|
"3B": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-Q6_K.gguf",
|
||
|
"3B_f16": "https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF/resolve/main/Llama-3.2-3B-Instruct-f16.gguf",
|
||
|
"8B": "https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf",
|
||
|
}
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("--size", choices=list(models.keys()), default=list(models.keys())[0], help="Model size")
|
||
|
parser.add_argument("--max_context", type=int, default=4096, help="Max Context Length")
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
# load the model
|
||
|
model, kv = Transformer.from_gguf(Tensor.from_url(models[args.size]), args.max_context)
|
||
|
|
||
|
# extract some metadata
|
||
|
tok = SimpleTokenizer(kv["tokenizer.ggml.tokens"])
|
||
|
bos_id: int = kv['tokenizer.ggml.bos_token_id']
|
||
|
eos_id: int = kv['tokenizer.ggml.eos_token_id']
|
||
|
|
||
|
ids: list[int] = [bos_id]
|
||
|
while 1:
|
||
|
start_pos = len(ids) - 1
|
||
|
try:
|
||
|
ids += tok.role("user") + tok.encode(input('>>> ')) + [eos_id] + tok.role("assistant")
|
||
|
except EOFError:
|
||
|
break
|
||
|
for next_id in model.generate(ids, start_pos):
|
||
|
sys.stdout.write(tok.decode([next_id]) if next_id != eos_id else "\n\n")
|
||
|
sys.stdout.flush()
|
||
|
if next_id == eos_id: break
|