You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					62 lines
				
				2.7 KiB
			
		
		
			
		
	
	
					62 lines
				
				2.7 KiB
			| 
											5 days ago
										 | from tinygrad import Tensor
 | ||
|  | 
 | ||
|  | class TransformerBlock:
 | ||
|  |   def __init__(self, embed_dim, num_heads, ff_dim, prenorm=False, act=lambda x: x.relu(), dropout=0.1):
 | ||
|  |     assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
 | ||
|  | 
 | ||
|  |     self.num_heads = num_heads
 | ||
|  |     self.head_size = embed_dim // num_heads
 | ||
|  |     self.prenorm, self.act = prenorm, act
 | ||
|  |     self.dropout = dropout
 | ||
|  | 
 | ||
|  |     self.query = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
 | ||
|  |     self.key = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
 | ||
|  |     self.value = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
 | ||
|  | 
 | ||
|  |     self.out = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
 | ||
|  | 
 | ||
|  |     self.ff1 = (Tensor.scaled_uniform(embed_dim, ff_dim), Tensor.zeros(ff_dim))
 | ||
|  |     self.ff2 = (Tensor.scaled_uniform(ff_dim, embed_dim), Tensor.zeros(embed_dim))
 | ||
|  | 
 | ||
|  |     self.ln1 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
 | ||
|  |     self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
 | ||
|  | 
 | ||
|  |   def attn(self, x):
 | ||
|  |     # x: (bs, time, embed_dim) -> (bs, time, embed_dim)
 | ||
|  |     query, key, value = [x.linear(*y).reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)).transpose(1,2) for y in [self.query, self.key, self.value]]
 | ||
|  |     attention = Tensor.scaled_dot_product_attention(query, key, value).transpose(1,2)
 | ||
|  |     return attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)).linear(*self.out)
 | ||
|  | 
 | ||
|  |   def __call__(self, x):
 | ||
|  |     if self.prenorm:
 | ||
|  |       x = x + self.attn(x.layernorm().linear(*self.ln1)).dropout(self.dropout)
 | ||
|  |       x = x + self.act(x.layernorm().linear(*self.ln2).linear(*self.ff1)).linear(*self.ff2).dropout(self.dropout)
 | ||
|  |     else:
 | ||
|  |       x = x + self.attn(x).dropout(self.dropout)
 | ||
|  |       x = x.layernorm().linear(*self.ln1)
 | ||
|  |       x = x + self.act(x.linear(*self.ff1)).linear(*self.ff2).dropout(self.dropout)
 | ||
|  |       x = x.layernorm().linear(*self.ln2)
 | ||
|  |     return x
 | ||
|  | 
 | ||
|  | class Transformer:
 | ||
|  |   def __init__(self, syms, maxlen, layers, embed_dim, num_heads, ff_dim):
 | ||
|  |     self.maxlen, self.syms = maxlen, syms
 | ||
|  |     self.embed = Tensor.scaled_uniform(maxlen+syms, embed_dim, requires_grad=False)
 | ||
|  |     self.tbs = [TransformerBlock(embed_dim, num_heads, ff_dim) for _ in range(layers)]
 | ||
|  |     self.final = Tensor.scaled_uniform(embed_dim, syms)
 | ||
|  | 
 | ||
|  |   def forward(self, x):
 | ||
|  |     bs = x.shape[0]
 | ||
|  | 
 | ||
|  |     maxlen_eye = Tensor.eye(x.shape[1])
 | ||
|  |     maxlen_eye = maxlen_eye.unsqueeze(0).expand([bs, *maxlen_eye.shape])
 | ||
|  | 
 | ||
|  |     onehot_feat = x.int().one_hot(self.syms)
 | ||
|  | 
 | ||
|  |     onehot = maxlen_eye.cat(onehot_feat, dim=2).flatten(end_dim=1)
 | ||
|  | 
 | ||
|  |     x = onehot.dot(self.embed).reshape((bs, x.shape[1], -1))
 | ||
|  |     x = x.sequential(self.tbs)
 | ||
|  |     x = x.reshape((-1, x.shape[-1])).dot(self.final).log_softmax()
 | ||
|  |     return x.reshape((bs, -1, x.shape[-1]))
 |