You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					63 lines
				
				3.3 KiB
			
		
		
			
		
	
	
					63 lines
				
				3.3 KiB
			| 
											4 days ago
										 | import functools, argparse, pathlib
 | ||
|  | from tinygrad import Tensor, nn, Device, GlobalCounters, Variable
 | ||
|  | from tinygrad.helpers import Timing, Profiling, CI, tqdm
 | ||
|  | from tinygrad.nn.state import torch_load, get_state_dict
 | ||
|  | from extra.models.llama import FeedForward, Transformer
 | ||
|  | from extra.bench_log import BenchEvent, WallTimeEvent
 | ||
|  | 
 | ||
|  | class MixtureFeedForward:
 | ||
|  |   def __init__(self, num_experts:int, dim:int, hidden_dim:int, linear=nn.Linear):
 | ||
|  |     self.gate = nn.Linear(dim, num_experts, bias=False)
 | ||
|  |     self.experts = [FeedForward(dim, hidden_dim, linear) for _ in range(num_experts)]
 | ||
|  |   def __call__(self, x:Tensor) -> Tensor:
 | ||
|  |     assert x.shape[0] == 1, "only BS=1"
 | ||
|  |     g = self.gate(x).float().exp()
 | ||
|  |     choice = g.data().tolist()[0][0]
 | ||
|  |     top = sorted(enumerate(choice), key=lambda x: -x[1])
 | ||
|  |     norm = top[0][1] + top[1][1]
 | ||
|  |     e1, e2 = self.experts[top[0][0]], self.experts[top[1][0]]
 | ||
|  |     scale = Tensor([top[0][1]/norm, top[1][1]/norm])
 | ||
|  |     ret = e1(x.to(e1.w1.weight.device)).to(x.device) * scale[0] + \
 | ||
|  |           e2(x.to(e2.w1.weight.device)).to(x.device) * scale[1]
 | ||
|  |     return ret
 | ||
|  | 
 | ||
|  | if __name__ == "__main__":
 | ||
|  |   parser = argparse.ArgumentParser(description="Run Mixtral in tinygrad", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 | ||
|  |   parser.add_argument("--count", type=int, default=30, help="Max number of tokens to generate")
 | ||
|  |   parser.add_argument("--temperature", type=float, default=0.7, help="Temperature in the softmax")
 | ||
|  |   parser.add_argument("--timing", action="store_true", help="Print timing per token")
 | ||
|  |   parser.add_argument("--profile", action="store_true", help="Profile generation")
 | ||
|  |   parser.add_argument("--weights", type=str, default=(pathlib.Path(__file__).parent.parent / "weights/mixtral-8x7b-32kseqlen").as_posix(),
 | ||
|  |                       help="Path to the downloaded weights")
 | ||
|  |   args = parser.parse_args()
 | ||
|  | 
 | ||
|  |   with WallTimeEvent(BenchEvent.LOAD_WEIGHTS):
 | ||
|  |     state = torch_load(args.weights + "/consolidated.00.pth.b")
 | ||
|  |     model = Transformer(n_layers=32, dim=4096, hidden_dim=14336, n_heads=32, n_kv_heads=8, norm_eps=1e-5, vocab_size=32000, feed_forward=functools.partial(MixtureFeedForward, 8), jit=False)
 | ||
|  |     model_state_dict = get_state_dict(model)
 | ||
|  | 
 | ||
|  |     for k in (t := tqdm(state, disable=CI)):
 | ||
|  |       if 'feed_forward.experts.' in k:
 | ||
|  |         expert_no = int(k.split('feed_forward.experts.')[1].split('.')[0])
 | ||
|  |         device = Device.DEFAULT + ":" + str((expert_no//2)+1)
 | ||
|  |       else:
 | ||
|  |         device = Device.DEFAULT
 | ||
|  |       t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, loading {k} to {device}")
 | ||
|  |       model_state_dict[k].replace(state[k].to(device).half()).realize()
 | ||
|  |   if CI: print(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
 | ||
|  | 
 | ||
|  |   from sentencepiece import SentencePieceProcessor
 | ||
|  |   spp = SentencePieceProcessor(model_file=args.weights + "/tokenizer.model")
 | ||
|  | 
 | ||
|  |   toks = [spp.bos_id()]
 | ||
|  |   start_pos = 0
 | ||
|  |   for i in range(args.count):
 | ||
|  |     GlobalCounters.reset()
 | ||
|  |     with Profiling(sort="time", frac=0.1, enabled=args.profile):
 | ||
|  |       with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/sec"):
 | ||
|  |         with WallTimeEvent(BenchEvent.STEP):
 | ||
|  |           tok = model(Tensor([toks[start_pos:]]), 0 if start_pos == 0 else Variable("start_pos", 1, 1024-1).bind(start_pos), args.temperature).item()
 | ||
|  |     toks.append(tok)
 | ||
|  |     start_pos += 1
 | ||
|  |     print(spp.decode(toks))
 |