import json
import os
from pathlib import Path
from transformers import BertTokenizer
import numpy as np
from tinygrad.helpers import fetch

BASEDIR = Path(__file__).parent / "squad"
def init_dataset():
  os.makedirs(BASEDIR, exist_ok=True)
  fetch("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json", BASEDIR / "dev-v1.1.json")
  with open(BASEDIR / "dev-v1.1.json") as f:
    data = json.load(f)["data"]

  examples = []
  for article in data:
    for paragraph in article["paragraphs"]:
      text = paragraph["context"]
      doc_tokens = []
      prev_is_whitespace = True
      for c in text:
        if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
          prev_is_whitespace = True
        else:
          if prev_is_whitespace:
            doc_tokens.append(c)
          else:
            doc_tokens[-1] += c
          prev_is_whitespace = False

      for qa in paragraph["qas"]:
        qa_id = qa["id"]
        q_text = qa["question"]

        examples.append({
          "id": qa_id,
          "question": q_text,
          "context": doc_tokens,
          "answers": list(map(lambda x: x["text"], qa["answers"]))
        })
  return examples

def _check_is_max_context(doc_spans, cur_span_index, position):
  best_score, best_span_index = None, None
  for di, (doc_start, doc_length) in enumerate(doc_spans):
    end = doc_start + doc_length - 1
    if position < doc_start:
      continue
    if position > end:
      continue
    num_left_context = position - doc_start
    num_right_context = end - position
    score = min(num_left_context, num_right_context) + 0.01 * doc_length
    if best_score is None or score > best_score:
      best_score = score
      best_span_index = di
  return cur_span_index == best_span_index

def convert_example_to_features(example, tokenizer):
  query_tokens = tokenizer.tokenize(example["question"])

  if len(query_tokens) > 64:
    query_tokens = query_tokens[:64]

  tok_to_orig_index = []
  orig_to_tok_index = []
  all_doc_tokens = []
  for i, token in enumerate(example["context"]):
    orig_to_tok_index.append(len(all_doc_tokens))
    sub_tokens = tokenizer.tokenize(token)
    for sub_token in sub_tokens:
      tok_to_orig_index.append(i)
      all_doc_tokens.append(sub_token)

  max_tokens_for_doc = 384 - len(query_tokens) - 3

  doc_spans = []
  start_offset = 0
  while start_offset < len(all_doc_tokens):
    length = len(all_doc_tokens) - start_offset
    length = min(length, max_tokens_for_doc)
    doc_spans.append((start_offset, length))
    if start_offset + length == len(all_doc_tokens):
      break
    start_offset += min(length, 128)

  outputs = []
  for di, (doc_start, doc_length) in enumerate(doc_spans):
    tokens = []
    token_to_orig_map = {}
    token_is_max_context = {}
    segment_ids = []
    tokens.append("[CLS]")
    segment_ids.append(0)
    for token in query_tokens:
      tokens.append(token)
      segment_ids.append(0)
    tokens.append("[SEP]")
    segment_ids.append(0)

    for i in range(doc_length):
      split_token_index = doc_start + i
      token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
      token_is_max_context[len(tokens)] = _check_is_max_context(doc_spans, di, split_token_index)
      tokens.append(all_doc_tokens[split_token_index])
      segment_ids.append(1)
    tokens.append("[SEP]")
    segment_ids.append(1)

    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_mask = [1] * len(input_ids)

    while len(input_ids) < 384:
      input_ids.append(0)
      input_mask.append(0)
      segment_ids.append(0)

    assert len(input_ids) == 384
    assert len(input_mask) == 384
    assert len(segment_ids) == 384

    outputs.append({
      "input_ids": np.expand_dims(np.array(input_ids), 0).astype(np.float32),
      "input_mask": np.expand_dims(np.array(input_mask), 0).astype(np.float32),
      "segment_ids": np.expand_dims(np.array(segment_ids), 0).astype(np.float32),
      "token_to_orig_map": token_to_orig_map,
      "token_is_max_context": token_is_max_context,
      "tokens": tokens,
    })

  return outputs

def iterate(tokenizer, start=0):
  examples = init_dataset()
  print(f"there are {len(examples)} pairs in the dataset")

  for i in range(start, len(examples)):
    example = examples[i]
    features = convert_example_to_features(example, tokenizer)
    # we need to yield all features here as the f1 score is the maximum over all features
    yield features, example

if __name__ == "__main__":
  tokenizer = BertTokenizer(str(Path(__file__).parents[2] / "weights" / "bert_vocab.txt"))

  X, Y = next(iterate(tokenizer))
  print(" ".join(X[0]["tokens"]))
  print(X[0]["input_ids"].shape, Y)