You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							398 lines
						
					
					
						
							14 KiB
						
					
					
				
			
		
		
	
	
							398 lines
						
					
					
						
							14 KiB
						
					
					
				| # Preprocessing of downloaded text from Wikipedia for MLPerf BERT training
 | |
| # This is a modified version of the original script:
 | |
| # https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/cleanup_scripts/create_pretraining_data.py
 | |
| # ENV VARS:
 | |
| # MAX_SEQ_LENGTH          - Maximum sequence length
 | |
| # MAX_PREDICTIONS_PER_SEQ - Maximum number of masked LM predictions per sequence
 | |
| # RANDOM_SEED             - Random seed
 | |
| # DUPE_FACTOR             - Number of times to duplicate the input data with different masks
 | |
| # MASKED_LM_PROB          - Probability of masking a token
 | |
| # SHORT_SEQ_PROB          - Probability of picking a sequence shorter than MAX_SEQ_LENGTH
 | |
| 
 | |
| import os, sys, pickle, random, unicodedata
 | |
| from pathlib import Path
 | |
| import numpy as np
 | |
| from tqdm import tqdm
 | |
| from tqdm.contrib.concurrent import process_map
 | |
| 
 | |
| from tinygrad.helpers import diskcache, getenv
 | |
| 
 | |
| BASEDIR = getenv('BASEDIR', Path(__file__).parent / "wiki")
 | |
| 
 | |
| ################### Tokenization #####################
 | |
| 
 | |
| def _is_whitespace(char:str) -> bool:
 | |
|   if char == " " or char == "\t" or char == "\n" or char == "\r":
 | |
|     return True
 | |
|   return unicodedata.category(char) == "Zs"
 | |
| 
 | |
| def _is_control(char:str) -> bool:
 | |
|   if char == "\t" or char == "\n" or char == "\r":
 | |
|     return False
 | |
|   return unicodedata.category(char).startswith("C")
 | |
| 
 | |
| def _is_punctuation(char:str) -> bool:
 | |
|   # range(33, 48) -> ! " # $ % & ' ( ) * + , - . /
 | |
|   # range(58, 65) -> : ; < = > ? @
 | |
|   # range(91, 97) -> [ \ ] ^ _
 | |
|   # range(123, 127) -> { | } ~
 | |
|   if (cp := ord(char)) in range(33, 48) or cp in range(58, 65) or cp in range(91, 97) or cp in range(123, 127):
 | |
|     return True
 | |
|   return unicodedata.category(char).startswith("P")
 | |
| 
 | |
| def _is_chinese_char(cp:int) -> bool:
 | |
|   if ((cp >= 0x4E00 and cp <= 0x9FFF) or
 | |
|       (cp >= 0x3400 and cp <= 0x4DBF) or
 | |
|       (cp >= 0x20000 and cp <= 0x2A6DF) or
 | |
|       (cp >= 0x2A700 and cp <= 0x2B73F) or
 | |
|       (cp >= 0x2B740 and cp <= 0x2B81F) or
 | |
|       (cp >= 0x2B820 and cp <= 0x2CEAF) or
 | |
|       (cp >= 0xF900 and cp <= 0xFAFF) or
 | |
|       (cp >= 0x2F800 and cp <= 0x2FA1F)):
 | |
|     return True
 | |
|   return False
 | |
| 
 | |
| def _run_split_on_punc(text:str) -> list[str]:
 | |
|   if text in ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"):
 | |
|     return [text]
 | |
|   start_new_word = True
 | |
|   output = []
 | |
|   for i in range(len(text)):
 | |
|     if _is_punctuation(char := text[i]):
 | |
|       output.append([char])
 | |
|       start_new_word = True
 | |
|     else:
 | |
|       if start_new_word:
 | |
|         output.append([])
 | |
|       start_new_word = False
 | |
|       output[-1].append(char)
 | |
|   return ["".join(x) for x in output]
 | |
| 
 | |
| def _run_strip_accents(text:str) -> str:
 | |
|   output = []
 | |
|   for char in unicodedata.normalize("NFD", text):
 | |
|     if unicodedata.category(char) != "Mn":
 | |
|       output.append(char)
 | |
|   return "".join(output)
 | |
| 
 | |
| def _clean_text(text:str) -> str:
 | |
|   output = []
 | |
|   for char in text:
 | |
|     if not ((cp := ord(char)) == 0 or cp == 0xfffd or _is_control(char)):
 | |
|       output.append(" " if _is_whitespace(char) else char)
 | |
|   return "".join(output)
 | |
| 
 | |
| def _tokenize_chinese_chars(text:str) -> str:
 | |
|   output = []
 | |
|   for char in text:
 | |
|     cp = ord(char)
 | |
|     if _is_chinese_char(cp):
 | |
|       output.append(" ")
 | |
|       output.append(char)
 | |
|       output.append(" ")
 | |
|     else:
 | |
|       output.append(char)
 | |
|   return "".join(output)
 | |
| 
 | |
| def whitespace_tokenize(text):
 | |
|   if not (text := text.strip()): return []
 | |
|   return text.split()
 | |
| 
 | |
| def _wordpiece_tokenize(text:str, vocab:dict[str, int]) -> list[str]:
 | |
|   text = text.decode("utf-8", "ignore") if isinstance(text, bytes) else text
 | |
|   output_tokens = []
 | |
|   for token in text.strip().split():
 | |
|     chars = list(token)
 | |
|     if len(chars) > 200:
 | |
|       output_tokens.append("[UNK]")
 | |
|       continue
 | |
| 
 | |
|     is_bad = False
 | |
|     start = 0
 | |
|     sub_tokens = []
 | |
|     while start < len(chars):
 | |
|       end = len(chars)
 | |
|       cur_substr = None
 | |
|       while start < end:
 | |
|         substr = "".join(chars[start:end])
 | |
|         if start > 0: substr = "##" + substr
 | |
|         if substr in vocab:
 | |
|           cur_substr = substr
 | |
|           break
 | |
|         end -= 1
 | |
|       if cur_substr is None:
 | |
|         is_bad = True
 | |
|         break
 | |
|       sub_tokens.append(cur_substr)
 | |
|       start = end
 | |
| 
 | |
|     if is_bad: output_tokens.append("[UNK]")
 | |
|     else: output_tokens.extend(sub_tokens)
 | |
|   return output_tokens
 | |
| 
 | |
| class Tokenizer:
 | |
|   def __init__(self, vocab_file):
 | |
|     self.vocab = {}
 | |
|     with open(vocab_file) as f:
 | |
|       for line in f:
 | |
|         line = line.decode("utf-8", "ignore") if isinstance(line, bytes) else line
 | |
|         if (token := line.strip()) and token not in self.vocab: self.vocab[token] = len(self.vocab)
 | |
|     self.inv_vocab = {v: k for k, v in self.vocab.items()}
 | |
| 
 | |
|   def tokenize(self, text:str) -> list[str]:
 | |
|     # BasicTokenizer
 | |
|     split_tokens = []
 | |
|     for token in whitespace_tokenize(_tokenize_chinese_chars(_clean_text(text.decode("utf-8", "ignore") if isinstance(text, bytes) else text))):
 | |
|       split_tokens.extend(_run_split_on_punc(_run_strip_accents(token.lower())))
 | |
|     split_tokens = " ".join(split_tokens).strip().split()
 | |
|     # WordpieceTokenizer
 | |
|     tokens = []
 | |
|     for token in split_tokens:
 | |
|       tokens.extend(_wordpiece_tokenize(token, self.vocab))
 | |
|     return tokens
 | |
| 
 | |
|   def convert_tokens_to_ids(self, tokens:list[str]) -> list[int]: return [self.vocab[token] for token in tokens]
 | |
|   def convert_ids_to_tokens(self, ids:list[int]) -> list[str]: return [self.inv_vocab[id] for id in ids]
 | |
| 
 | |
| ##################### Feature transformation #####################
 | |
| 
 | |
| def truncate_seq_pair(tokens_a:list[str], tokens_b:list[str], max_num_tokens:int, rng:random.Random) -> None:
 | |
|   while True:
 | |
|     total_length = len(tokens_a) + len(tokens_b)
 | |
|     if total_length <= max_num_tokens:
 | |
|       break
 | |
| 
 | |
|     trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
 | |
|     assert len(trunc_tokens) >= 1
 | |
| 
 | |
|     if rng.random() < 0.5:
 | |
|       del trunc_tokens[0]
 | |
|     else:
 | |
|       trunc_tokens.pop()
 | |
| 
 | |
| def create_masked_lm_predictions(tokens:list[str], tokenizer:Tokenizer, rng:random.Random, vocab_words:list[str]) -> tuple[list[str], list[int], list[str]]:
 | |
|   cand_indices = []
 | |
|   for i, token in enumerate(tokens):
 | |
|     if token == "[CLS]" or token == "[SEP]":
 | |
|       continue
 | |
|     cand_indices.append(i)
 | |
| 
 | |
|   rng.shuffle(cand_indices)
 | |
|   output_tokens = list(tokens)
 | |
|   num_to_predict = min(getenv('MAX_PREDICTIONS_PER_SEQ', 76), max(1, int(round(len(tokens) * getenv("MASKED_LM_PROB", 0.15)))))
 | |
| 
 | |
|   masked_lms = []
 | |
|   covered_indices = set()
 | |
|   for index in cand_indices:
 | |
|     if len(masked_lms) >= num_to_predict:
 | |
|       break
 | |
|     if index in covered_indices:
 | |
|       continue
 | |
|     covered_indices.add(index)
 | |
| 
 | |
|     masked_token = None
 | |
|     if rng.random() < 0.8:
 | |
|       masked_token = "[MASK]"
 | |
|     else:
 | |
|       if rng.random() < 0.5:
 | |
|         masked_token = tokens[index]
 | |
|       else:
 | |
|         masked_token = vocab_words[rng.randint(0, len(tokenizer.vocab) - 1)]
 | |
| 
 | |
|     output_tokens[index] = masked_token
 | |
|     masked_lms.append((index, tokens[index]))
 | |
|   masked_lms = sorted(masked_lms, key=lambda x: x[0])
 | |
| 
 | |
|   masked_lm_positions = []
 | |
|   masked_lm_labels = []
 | |
|   for p in masked_lms:
 | |
|     masked_lm_positions.append(p[0])
 | |
|     masked_lm_labels.append(p[1])
 | |
| 
 | |
|   return output_tokens, masked_lm_positions, masked_lm_labels
 | |
| 
 | |
| def create_instances_from_document(rng:random.Random, tokenizer:Tokenizer, doc:list[str], di:int, documents:list[list[str]]) -> list[dict]:
 | |
|   max_num_tokens = getenv('MAX_SEQ_LENGTH', 512) - 3 # [CLS] + 2 * [SEP]
 | |
| 
 | |
|   target_seq_length = max_num_tokens
 | |
|   if rng.random() < getenv("SHORT_SEQ_PROB", 0.1):
 | |
|     target_seq_length = rng.randint(2, max_num_tokens)
 | |
| 
 | |
|   instances = []
 | |
|   current_chunk = []
 | |
|   current_length = 0
 | |
|   i = 0
 | |
|   while i < len(doc):
 | |
|     segment = doc[i]
 | |
|     current_chunk.append(segment)
 | |
|     current_length += len(segment)
 | |
|     if i == len(doc) - 1 or current_length >= target_seq_length:
 | |
|       if current_chunk:
 | |
|         a_end = 1
 | |
|         if len(current_chunk) >= 2:
 | |
|           a_end = rng.randint(1, len(current_chunk) - 1)
 | |
| 
 | |
|         tokens_a = []
 | |
|         for j in range(a_end):
 | |
|           tokens_a.extend(current_chunk[j])
 | |
| 
 | |
|         tokens_b = []
 | |
|         is_random_next = False
 | |
|         if len(current_chunk) == 1 or rng.random() < 0.5:
 | |
|           is_random_next = True
 | |
|           target_b_length = target_seq_length - len(tokens_a)
 | |
| 
 | |
|           for _ in range(10):
 | |
|             random_document_index = rng.randint(0, len(documents) - 1)
 | |
|             if random_document_index != di:
 | |
|               break
 | |
| 
 | |
|           random_document = documents[random_document_index]
 | |
|           random_start = rng.randint(0, len(random_document) - 1)
 | |
|           for j in range(random_start, len(random_document)):
 | |
|             tokens_b.extend(random_document[j])
 | |
|             if len(tokens_b) >= target_b_length:
 | |
|               break
 | |
| 
 | |
|           num_unused_segments = len(current_chunk) - a_end
 | |
|           i -= num_unused_segments
 | |
|         else:
 | |
|           is_random_next = False
 | |
|           for j in range(a_end, len(current_chunk)):
 | |
|             tokens_b.extend(current_chunk[j])
 | |
|         truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
 | |
| 
 | |
|         assert len(tokens_a) >= 1
 | |
|         assert len(tokens_b) >= 1
 | |
| 
 | |
|         tokens = []
 | |
|         segment_ids = []
 | |
|         tokens.append("[CLS]")
 | |
|         segment_ids.append(0)
 | |
|         for token in tokens_a:
 | |
|           tokens.append(token)
 | |
|           segment_ids.append(0)
 | |
|         tokens.append("[SEP]")
 | |
|         segment_ids.append(0)
 | |
|         for token in tokens_b:
 | |
|           tokens.append(token)
 | |
|           segment_ids.append(1)
 | |
|         tokens.append("[SEP]")
 | |
|         segment_ids.append(1)
 | |
| 
 | |
|         tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(tokens, tokenizer, rng, list(tokenizer.vocab.keys()))
 | |
|         instances.append({
 | |
|           "tokens": tokens,
 | |
|           "segment_ids": segment_ids,
 | |
|           "masked_lm_positions": masked_lm_positions,
 | |
|           "masked_lm_labels": masked_lm_labels,
 | |
|           "is_random_next": is_random_next
 | |
|         })
 | |
|       current_chunk = []
 | |
|       current_length = 0
 | |
|     i += 1
 | |
|   return instances
 | |
| 
 | |
| def get_documents(rng:random.Random, tokenizer:Tokenizer, fn:str) -> list[list[str]]:
 | |
|   documents = [[]]
 | |
|   with open(BASEDIR / fn) as f:
 | |
|     for line in f.readlines():
 | |
|       if not (line := line.decode("utf-8", "ignore") if isinstance(line, bytes) else line): break
 | |
|       if not (line := line.strip()): documents.append([])
 | |
|       if (tokens := tokenizer.tokenize(line)): documents[-1].append(tokens)
 | |
|   documents = [x for x in documents if x]
 | |
|   rng.shuffle(documents)
 | |
|   return documents
 | |
| 
 | |
| def get_instances(rng:random.Random, tokenizer:Tokenizer, documents:list[list[str]]) -> list[dict]:
 | |
|   instances = []
 | |
|   for _ in range(getenv('DUPE_FACTOR', 10)):
 | |
|     for di, doc in enumerate(documents):
 | |
|       instances.extend(create_instances_from_document(rng, tokenizer, doc, di, documents))
 | |
|   rng.shuffle(instances)
 | |
|   return instances
 | |
| 
 | |
| def instance_to_features(instance:dict, tokenizer:Tokenizer) -> dict:
 | |
|   input_ids = tokenizer.convert_tokens_to_ids(instance["tokens"])
 | |
|   input_mask = [1] * len(input_ids)
 | |
|   segment_ids = instance["segment_ids"]
 | |
| 
 | |
|   max_seq_length = getenv('MAX_SEQ_LENGTH', 512)
 | |
| 
 | |
|   assert len(input_ids) <= max_seq_length
 | |
|   while len(input_ids) < max_seq_length:
 | |
|     input_ids.append(0)
 | |
|     input_mask.append(0)
 | |
|     segment_ids.append(0)
 | |
|   assert len(input_ids) == max_seq_length
 | |
|   assert len(input_mask) == max_seq_length
 | |
|   assert len(segment_ids) == max_seq_length
 | |
| 
 | |
|   masked_lm_positions = instance["masked_lm_positions"]
 | |
|   masked_lm_ids = tokenizer.convert_tokens_to_ids(instance["masked_lm_labels"])
 | |
|   masked_lm_weights = [1.0] * len(masked_lm_ids)
 | |
| 
 | |
|   while len(masked_lm_positions) < getenv("MAX_PREDICTIONS_PER_SEQ", 76):
 | |
|     masked_lm_positions.append(0)
 | |
|     masked_lm_ids.append(0)
 | |
|     masked_lm_weights.append(0.0)
 | |
| 
 | |
|   next_sentence_label = 1 if instance["is_random_next"] else 0
 | |
| 
 | |
|   return {
 | |
|     "input_ids": np.expand_dims(np.array(input_ids, dtype=np.int32), 0),
 | |
|     "input_mask": np.expand_dims(np.array(input_mask, dtype=np.int32), 0),
 | |
|     "segment_ids": np.expand_dims(np.array(segment_ids, dtype=np.int32), 0),
 | |
|     "masked_lm_positions": np.expand_dims(np.array(masked_lm_positions, dtype=np.int32), 0),
 | |
|     "masked_lm_ids": np.expand_dims(np.array(masked_lm_ids, dtype=np.int32), 0),
 | |
|     "masked_lm_weights": np.expand_dims(np.array(masked_lm_weights, dtype=np.float32), 0),
 | |
|     "next_sentence_labels": np.expand_dims(np.array([next_sentence_label], dtype=np.int32), 0),
 | |
|   }
 | |
| 
 | |
| def process_part(part:int):
 | |
|   tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
 | |
|   os.makedirs(BASEDIR / "train", exist_ok=True)
 | |
| 
 | |
|   if os.path.exists(BASEDIR / f"train/{str(part)}.pkl"): return
 | |
|   features = get_features_from_part(tokenizer, val=False, part=part)
 | |
|   with open(BASEDIR / f"train/{str(part)}.pkl", "wb") as f:
 | |
|     pickle.dump(features, f)
 | |
| 
 | |
| def get_features_from_part(tokenizer:Tokenizer, val:bool=False, part:int=0) -> list[dict]: # Convert raw text to masked NSP samples
 | |
|   rng = random.Random(getenv('RANDOM_SEED', 12345))
 | |
| 
 | |
|   if val:
 | |
|     tqdm.write("Getting samples from dataset")
 | |
|     documents = get_documents(rng, tokenizer, "results4/eval.txt")
 | |
|     instances = get_instances(rng, tokenizer, documents)
 | |
| 
 | |
|     tqdm.write(f"There are {len(instances)} samples in the dataset")
 | |
|     tqdm.write(f"Picking 10000 samples")
 | |
| 
 | |
|     pick_ratio = len(instances) / 10000
 | |
|     return [instance_to_features(instances[int(inst*pick_ratio)], tokenizer) for inst in range(10000)]
 | |
|   else:
 | |
|     documents = get_documents(rng, tokenizer, f"results4/part-{part:05d}-of-00500")
 | |
|     instances = get_instances(rng, tokenizer, documents)
 | |
|     return [instance_to_features(instance, tokenizer) for instance in instances]
 | |
| 
 | |
| ##################### Load files #####################
 | |
| 
 | |
| @diskcache
 | |
| def get_wiki_train_files(): return sorted(list((BASEDIR / "train/").glob("*.pkl")))
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|   tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
 | |
| 
 | |
|   assert len(sys.argv) > 1, "Usage: python wikipedia.py pre-eval|pre-train [part]|all"
 | |
| 
 | |
|   if sys.argv[1] == "pre-eval": # Generate 10000 eval samples
 | |
|     with open(BASEDIR / "eval.pkl", "wb") as f:
 | |
|       pickle.dump(get_features_from_part(tokenizer, val=True), f)
 | |
|   elif sys.argv[1] == "pre-train":
 | |
|     if sys.argv[2] == "all": # Use all 500 parts for training generation
 | |
|       process_map(process_part, [part for part in range(500)], max_workers=getenv('NUM_WORKERS', min(os.cpu_count(), 32)), chunksize=1)
 | |
|     else: # Use a specific part for training generation
 | |
|       part = sys.argv[2]
 | |
|       print(f"Processing part {part}...")
 | |
|       process_part(int(part))
 | |
| 
 |