# test cases are modified from pytorch test_indexing.py https://github.com/pytorch/pytorch/blob/597d3fb86a2f3b8d6d8ee067e769624dcca31cdb/test/test_indexing.py

import unittest, random, copy, warnings
import numpy as np

from tinygrad import Tensor, dtypes, Device, TinyJit
from tinygrad.device import is_dtype_supported
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.helpers import CI, all_same, prod

random.seed(42)

def numpy_testing_assert_equal_helper(a, b):
  if isinstance(a, Tensor): a = a.numpy()
  if isinstance(b, Tensor): b = b.numpy()
  np.testing.assert_equal(a, b)

def consec(shape, start=1):
  return Tensor.arange(prod(shape)).reshape(shape)+start

# creates strided tensor with base set to reference tensor's base, equivalent to torch.set_()
def set_(reference: Tensor, shape, strides, offset):
  if reference.lazydata.base.realized is None: reference.realize()
  assert reference.lazydata.base.realized, "base has to be realized before setting it to strided's base"
  strided = Tensor(reference.lazydata.view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),))))
  assert strided.lazydata.st.real_strides() == strides, "real_strides should equal strides for strided"
  return strided

def clone(original:Tensor): return copy.copy(original)
def copy_(src:Tensor, other:Tensor) -> Tensor: return copy.copy(src)
# this is fine for tested usecases since as geohotstan understands,
# data_ptr is used to compare if operations needed between tensors is the same
def data_ptr(tensor:Tensor): return tensor.lazydata

# https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html
def index_put_(tensor:Tensor, indices, values, accumulate) -> Tensor:
  if accumulate: tensor[indices] += values
  else: tensor[indices] = values

# https://pytorch.org/docs/stable/generated/torch.argsort.html
def argsort(tensor:Tensor) -> Tensor:
  pass

# https://pytorch.org/docs/stable/generated/torch.all.html
def all_(tensor:Tensor) -> Tensor:
  return tensor != 0

# https://pytorch.org/docs/stable/generated/torch.diagonal.html
def diagonal(tensor:Tensor) -> Tensor:
  assert tensor.ndim == 2 and all_same(tensor.shape), 'only support 2 ndim square tensors'
  return (Tensor.eye(tensor.shape[0]) * tensor).sum(0)

# https://numpy.org/doc/stable/reference/generated/numpy.unravel_index.html
def unravel_index(tensor, shape):
  pass

# https://github.com/pytorch/pytorch/blob/79811e765c23242210ebdc623539d2103a166463/torch/testing/_creation.py#L38
def make_tensor(shape, dtype:dtypes, noncontiguous) -> Tensor:
  r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with
  values uniformly drawn from ``[low, high)``.

  If :attr:`low` or :attr:`high` are specified and are outside the range of the :attr:`dtype`'s representable
  finite values then they are clamped to the lowest or highest representable finite value, respectively.
  If ``None``, then the following table describes the default values for :attr:`low` and :attr:`high`,
  which depend on :attr:`dtype`.

  +---------------------------+------------+----------+
  | ``dtype``                 | ``low``    | ``high`` |
  +===========================+============+==========+
  | boolean type              | ``0``      | ``2``    |
  +---------------------------+------------+----------+
  | unsigned integral type    | ``0``      | ``10``   |
  +---------------------------+------------+----------+
  | signed integral types     | ``-9``     | ``10``   |
  +---------------------------+------------+----------+
  | floating types            | ``-9``     | ``9``    |
  +---------------------------+------------+----------+
  | complex types             | ``-9``     | ``9``    |
  +---------------------------+------------+----------+
  """
  contiguous = not noncontiguous
  if dtype == dtypes.bool: return Tensor.randint(shape=shape, low=0, high=2, contiguous=contiguous).cast(dtypes.bool)
  elif dtype.is_unsigned(): return Tensor.randint(shape=shape, low=0, high=10, contiguous=contiguous).cast(dtype)
  elif dtype.is_int(): return Tensor.randint(shape=shape, low=-9, high=10, contiguous=contiguous).cast(dtype) # signed int
  elif dtype.is_float(): return Tensor.rand(shape=shape, low=-9, high=9, dtype=dtype, contiguous=contiguous)
  else: raise NotImplementedError(f"{dtype} not implemented")

class TestIndexing(unittest.TestCase):
  def test_index(self):

    reference = consec((3, 3, 3))

    numpy_testing_assert_equal_helper(reference[0], consec((3, 3)))
    numpy_testing_assert_equal_helper(reference[1], consec((3, 3), 10))
    numpy_testing_assert_equal_helper(reference[2], consec((3, 3), 19))
    numpy_testing_assert_equal_helper(reference[0, 1], consec((3,), 4))
    numpy_testing_assert_equal_helper(reference[0:2], consec((2, 3, 3)))
    numpy_testing_assert_equal_helper(reference[2, 2, 2], 27)
    numpy_testing_assert_equal_helper(reference[:], consec((3, 3, 3)))

    # indexing with Ellipsis
    numpy_testing_assert_equal_helper(reference[..., 2], np.array([[3., 6., 9.],[12., 15., 18.],[21., 24., 27.]]))
    numpy_testing_assert_equal_helper(reference[0, ..., 2], np.array([3., 6., 9.]))
    numpy_testing_assert_equal_helper(reference[..., 2], reference[:, :, 2])
    numpy_testing_assert_equal_helper(reference[0, ..., 2], reference[0, :, 2])
    numpy_testing_assert_equal_helper(reference[0, 2, ...], reference[0, 2])
    numpy_testing_assert_equal_helper(reference[..., 2, 2, 2], 27)
    numpy_testing_assert_equal_helper(reference[2, ..., 2, 2], 27)
    numpy_testing_assert_equal_helper(reference[2, 2, ..., 2], 27)
    numpy_testing_assert_equal_helper(reference[2, 2, 2, ...], 27)
    numpy_testing_assert_equal_helper(reference[...], reference)

    reference_5d = consec((3, 3, 3, 3, 3))
    numpy_testing_assert_equal_helper(reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0])
    numpy_testing_assert_equal_helper(reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0])
    numpy_testing_assert_equal_helper(reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1])
    numpy_testing_assert_equal_helper(reference_5d[...], reference_5d)

    # None indexing
    numpy_testing_assert_equal_helper(reference[2, None], reference[2].unsqueeze(0))
    numpy_testing_assert_equal_helper(reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0))
    numpy_testing_assert_equal_helper(reference[2:4, None], reference[2:4].unsqueeze(1))
    numpy_testing_assert_equal_helper(reference[None, 2, None, None], reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0))
    numpy_testing_assert_equal_helper(reference[None, 2:5, None, None], reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2))

    # indexing 0-length slice
    numpy_testing_assert_equal_helper(np.empty((0, 3, 3)), reference[slice(0)])
    numpy_testing_assert_equal_helper(np.empty((0, 3)), reference[slice(0), 2])
    numpy_testing_assert_equal_helper(np.empty((0, 3)), reference[2, slice(0)])
    numpy_testing_assert_equal_helper(np.empty([]), reference[2, 1:1, 2])

    # indexing with step
    reference = consec((10, 10, 10))
    numpy_testing_assert_equal_helper(reference[1:5:2], Tensor.stack(reference[1], reference[3], dim=0))
    numpy_testing_assert_equal_helper(reference[1:6:2], Tensor.stack(reference[1], reference[3], reference[5], dim=0))
    numpy_testing_assert_equal_helper(reference[1:9:4], Tensor.stack(reference[1], reference[5], dim=0))
    numpy_testing_assert_equal_helper(reference[2:4, 1:5:2], Tensor.stack(reference[2:4, 1], reference[2:4, 3], dim=1))
    numpy_testing_assert_equal_helper(reference[3, 1:6:2], Tensor.stack(reference[3, 1], reference[3, 3], reference[3, 5], dim=0))
    numpy_testing_assert_equal_helper(reference[None, 2, 1:9:4], Tensor.stack(reference[2, 1], reference[2, 5], dim=0).unsqueeze(0))
    numpy_testing_assert_equal_helper(reference[:, 2, 1:6:2], Tensor.stack(reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5], dim=1))

    lst = [list(range(i, i+10)) for i in range(0, 100, 10)]
    tensor = Tensor(lst)
    for _ in range(100):
      idx1_start = random.randrange(10)
      idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1)
      idx1_step = random.randrange(1, 8)
      idx1 = slice(idx1_start, idx1_end, idx1_step)
      if random.randrange(2) == 0:
        idx2_start = random.randrange(10)
        idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1)
        idx2_step = random.randrange(1, 8)
        idx2 = slice(idx2_start, idx2_end, idx2_step)
        lst_indexed = [l[idx2] for l in lst[idx1]]
        tensor_indexed = tensor[idx1, idx2]
      else:
        lst_indexed = lst[idx1]
        tensor_indexed = tensor[idx1]
      numpy_testing_assert_equal_helper(tensor_indexed, np.array(lst_indexed))

    self.assertRaises(ValueError, lambda: reference[1:9:0])
    # NOTE torch doesn't support this but numpy does so we should too. Torch raises ValueError
    # see test_slice_negative_strides in test_ops.py
    # self.assertRaises(ValueError, lambda: reference[1:9:-1])

    self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1])
    self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1])
    self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3])

    self.assertRaises(IndexError, lambda: reference[0.0])
    self.assertRaises(TypeError, lambda: reference[0.0:2.0])
    self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0])
    self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0])
    self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0])
    self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0])

    # TODO: delitem
    # def delitem(): del reference[0]
    # self.assertRaises(TypeError, delitem)

  # TODO: LLVM is quite fast, why are other compiled backends slow?
  @unittest.skipIf(CI and Device.DEFAULT in ["CLANG", "GPU", "METAL", "NV", "AMD"], "slow")
  def test_advancedindex(self):
    # integer array indexing

    # pick a random valid indexer type
    def ri(indices):
      choice = random.randint(0, 2)
      if choice == 0: return Tensor(indices)
      if choice == 1: return list(indices)
      return tuple(indices)

    def validate_indexing(x):
      numpy_testing_assert_equal_helper(x[[0]], consec((1,)))
      numpy_testing_assert_equal_helper(x[ri([0]),], consec((1,)))
      numpy_testing_assert_equal_helper(x[ri([3]),], consec((1,), 4))
      numpy_testing_assert_equal_helper(x[[2, 3, 4]], consec((3,), 3))
      numpy_testing_assert_equal_helper(x[ri([2, 3, 4]),], consec((3,), 3))
      numpy_testing_assert_equal_helper(x[ri([0, 2, 4]),], np.array([1, 3, 5]))

    def validate_setting(x):
      x[[0]] = -2
      numpy_testing_assert_equal_helper(x[[0]], np.array([-2]))
      x[[0]] = -1
      numpy_testing_assert_equal_helper(x[ri([0]), ], np.array([-1]))
      x[[2, 3, 4]] = 4
      numpy_testing_assert_equal_helper(x[[2, 3, 4]], np.array([4, 4, 4]))
      x[ri([2, 3, 4]), ] = 3
      numpy_testing_assert_equal_helper(x[ri([2, 3, 4]), ], np.array([3, 3, 3]))
      x[ri([0, 2, 4]), ] = Tensor([5, 4, 3])
      numpy_testing_assert_equal_helper(x[ri([0, 2, 4]), ], np.array([5, 4, 3]))

    # Case 1: Purely Integer Array Indexing
    reference = consec((10,))
    validate_indexing(reference)
    # setting values
    validate_setting(reference)

    # Tensor with stride != 1
    # strided is [1, 3, 5, 7]

    reference = consec((10,))
    strided = set_(reference, (4,), (2,), 0)

    numpy_testing_assert_equal_helper(strided[[0]], np.array([1]))
    numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([1]))
    numpy_testing_assert_equal_helper(strided[ri([3]), ], np.array([7]))
    numpy_testing_assert_equal_helper(strided[[1, 2]], np.array([3, 5]))
    numpy_testing_assert_equal_helper(strided[ri([1, 2]), ], np.array([3, 5]))
    numpy_testing_assert_equal_helper(strided[ri([[2, 1], [0, 3]]), ],
                      np.array([[5, 3], [1, 7]]))

    # stride is [4, 8]

    strided = set_(reference, (2,), (4,), offset=4)

    numpy_testing_assert_equal_helper(strided[[0]], np.array([5]))
    numpy_testing_assert_equal_helper(strided[ri([0]), ], np.array([5]))
    numpy_testing_assert_equal_helper(strided[ri([1]), ], np.array([9]))
    numpy_testing_assert_equal_helper(strided[[0, 1]], np.array([5, 9]))
    numpy_testing_assert_equal_helper(strided[ri([0, 1]), ], np.array([5, 9]))
    numpy_testing_assert_equal_helper(strided[ri([[0, 1], [1, 0]]), ],
                      np.array([[5, 9], [9, 5]]))

    # reference is 1 2
    #              3 4
    #              5 6
    reference = consec((3, 2))
    numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])], np.array([1, 3, 5]))
    numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([1])], np.array([2, 4, 6]))
    numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], consec((1,)))
    numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], consec((1,), 6))
    numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([1, 2]))
    numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 2]), ri([1])]], np.array([2, 4, 4, 2, 6]))
    numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([1, 2, 3, 3]))

    rows = ri([[0, 0],
               [1, 2]])
    columns = [0],
    numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 1],
                                                                          [3, 5]]))

    rows = ri([[0, 0],
               [1, 2]])
    columns = ri([1, 0])
    numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[2, 1],
                                                                          [4, 5]]))
    rows = ri([[0, 0],
               [1, 2]])
    columns = ri([[0, 1],
                  [1, 0]])
    numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[1, 2],
                                                                          [4, 5]]))

    # setting values
    reference[ri([0]), ri([1])] = -1
    numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])], np.array([-1]))
    reference[ri([0, 1, 2]), ri([0])] = Tensor([-1, 2, -4])
    numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
                      np.array([-1, 2, -4]))
    reference[rows, columns] = Tensor([[4, 6], [2, 3]])
    numpy_testing_assert_equal_helper(reference[rows, columns],
                      np.array([[4, 6], [2, 3]]))

    # Verify still works with Transposed (i.e. non-contiguous) Tensors
    reference = Tensor([[0, 1, 2, 3],
                        [4, 5, 6, 7],
                        [8, 9, 10, 11]]).T

    # Transposed: [[0, 4, 8],
    #              [1, 5, 9],
    #              [2, 6, 10],
    #              [3, 7, 11]]

    numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])], np.array([0, 1, 2]))
    numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([1])], np.array([4, 5, 6]))
    numpy_testing_assert_equal_helper(reference[ri([0]), ri([0])], np.array([0]))
    numpy_testing_assert_equal_helper(reference[ri([2]), ri([1])], np.array([6]))
    numpy_testing_assert_equal_helper(reference[[ri([0, 0]), ri([0, 1])]], np.array([0, 4]))
    numpy_testing_assert_equal_helper(reference[[ri([0, 1, 1, 0, 3]), ri([1])]], np.array([4, 5, 5, 4, 7]))
    numpy_testing_assert_equal_helper(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], np.array([0, 4, 1, 1]))

    rows = ri([[0, 0],
               [1, 2]])
    columns = [0],
    numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 0], [1, 2]]))

    rows = ri([[0, 0],
               [1, 2]])
    columns = ri([1, 0])
    numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[4, 0], [5, 2]]))
    rows = ri([[0, 0],
               [1, 3]])
    columns = ri([[0, 1],
                  [1, 2]])
    numpy_testing_assert_equal_helper(reference[rows, columns], np.array([[0, 4], [5, 11]]))

    # TODO: non contiguous setitem
    '''
    # setting values
    reference[ri([0]), ri([1])] = -1
    numpy_testing_assert_equal_helper(reference[ri([0]), ri([1])],
                      np.array([-1]))
    reference[ri([0, 1, 2]), ri([0])] = np.array([-1, 2, -4])
    numpy_testing_assert_equal_helper(reference[ri([0, 1, 2]), ri([0])],
                      np.array([-1, 2, -4]))
    reference[rows, columns] = np.array([[4, 6], [2, 3]])
    numpy_testing_assert_equal_helper(reference[rows, columns],
                      np.array([[4, 6], [2, 3]]))
    '''

    # stride != 1

    # strided is [[1 3 5 7],
    #             [9 11 13 15]]

    reference = Tensor.arange(0., 24).reshape(3, 8)
    strided = set_(reference, (2,4), (8,2), 1)

    numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([0])],
                      np.array([1, 9]))
    numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1])],
                      np.array([3, 11]))
    numpy_testing_assert_equal_helper(strided[ri([0]), ri([0])],
                      np.array([1]))
    numpy_testing_assert_equal_helper(strided[ri([1]), ri([3])],
                      np.array([15]))
    numpy_testing_assert_equal_helper(strided[[ri([0, 0]), ri([0, 3])]],
                      np.array([1, 7]))
    numpy_testing_assert_equal_helper(strided[[ri([1]), ri([0, 1, 1, 0, 3])]],
                      np.array([9, 11, 11, 9, 15]))
    numpy_testing_assert_equal_helper(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
                      np.array([1, 3, 9, 9]))

    rows = ri([[0, 0],
                [1, 1]])
    columns = [0],
    numpy_testing_assert_equal_helper(strided[rows, columns],
                      np.array([[1, 1], [9, 9]]))

    rows = ri([[0, 1],
                [1, 0]])
    columns = ri([1, 2])
    numpy_testing_assert_equal_helper(strided[rows, columns],
                      np.array([[3, 13], [11, 5]]))
    rows = ri([[0, 0],
                [1, 1]])
    columns = ri([[0, 1],
                  [1, 2]])
    numpy_testing_assert_equal_helper(strided[rows, columns],
                      np.array([[1, 3], [11, 13]]))

    # setting values

    # strided is [[10, 11],
    #             [17, 18]]

    reference = Tensor.arange(0., 24).reshape(3, 8)
    strided = set_(reference, (2,2), (7,1), 10)

    numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
                      np.array([11]))
    # TODO non contiguous setitem
    '''
    strided[ri([0]), ri([1])] = -1
    numpy_testing_assert_equal_helper(strided[ri([0]), ri([1])],
                      Tensor([-1]))
    '''
    reference = Tensor.arange(0., 24).reshape(3, 8)
    strided = set_(reference, (2,2), (7,1), 10)

    numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
                      np.array([11, 17]))
    # TODO non contiguous setitem
    '''
    strided[ri([0, 1]), ri([1, 0])] = Tensor([-1, 2])
    numpy_testing_assert_equal_helper(strided[ri([0, 1]), ri([1, 0])],
                      Tensor([-1, 2]))
    '''

    reference = Tensor.arange(0., 24).realize().reshape(3, 8)
    strided = set_(reference, (2,2), (7,1), 10)

    rows = ri([[0],
                [1]])
    columns = ri([[0, 1],
                  [0, 1]])
    numpy_testing_assert_equal_helper(strided[rows, columns],
                      np.array([[10, 11], [17, 18]]))
    # TODO non contiguous setitem
    '''
    strided[rows, columns] = Tensor([[4, 6], [2, 3]])
    numpy_testing_assert_equal_helper(strided[rows, columns],
                      Tensor([[4, 6], [2, 3]]))
    '''

    # Tests using less than the number of dims, and ellipsis

    # reference is 1 2
    #              3 4
    #              5 6
    reference = consec((3, 2))
    numpy_testing_assert_equal_helper(reference[ri([0, 2]),], np.array([[1, 2], [5, 6]]))
    numpy_testing_assert_equal_helper(reference[ri([1]), ...], np.array([[3, 4]]))
    numpy_testing_assert_equal_helper(reference[..., ri([1])], np.array([[2], [4], [6]]))

    # verify too many indices fails
    with self.assertRaises(IndexError): reference[ri([1]), ri([0, 2]), ri([3])]

    # test invalid index fails
    reference = Tensor.empty(10)
    for err_idx in (10, -11):
      with self.assertRaises(IndexError):
        reference[err_idx]
      # NOTE cannot check for out of bounds with Tensor indexing
      # see tensor.py: __getitem__ (Tiny Things)
      '''
      with self.assertRaises(IndexError):
        reference[Tensor([err_idx], dtype=dtypes.int64)]
      with self.assertRaises(IndexError):
        reference[[err_idx]]
      '''

    def tensor_indices_to_np(tensor: Tensor, indices):
      npt = tensor.numpy()
      idxs = tuple(i.numpy().tolist() if isinstance(i, Tensor) and i.dtype == dtypes.int64 else
                  i for i in indices)
      return npt, idxs

    def get_numpy(tensor, indices):
      npt, idxs = tensor_indices_to_np(tensor, indices)
      return Tensor(npt[idxs])

    def set_numpy(tensor:Tensor, indices, value):
      if not isinstance(value, int):
        value = value.numpy()
      npt, idxs = tensor_indices_to_np(tensor, indices)
      npt[idxs] = value
      return npt

    def assert_get_eq(tensor, indexer):
      numpy_testing_assert_equal_helper(tensor[indexer], get_numpy(tensor, indexer))

    def assert_set_eq(tensor: Tensor, indexer, val):
      pyt = clone(tensor)
      numt = clone(tensor)
      pyt[indexer] = val
      numt = set_numpy(numt, indexer, val)
      numpy_testing_assert_equal_helper(pyt, numt)

    # NOTE: torch initiates the gradients using g0cpu (rand as gradients)
    def assert_backward_eq(tensor: Tensor, indexer):
      cpu = clone(tensor.float())
      cpu.requires_grad = True
      outcpu = cpu[indexer].sum()
      outcpu.backward()
      dev = cpu.detach()
      dev.requires_grad = True
      outdev = dev[indexer].sum()
      outdev.backward()
      numpy_testing_assert_equal_helper(cpu.grad, dev.grad)

    def get_set_tensor(indexed: Tensor, indexer):
      set_size = indexed[indexer].shape
      set_count = indexed[indexer].numel()
      set_tensor = Tensor.randint(set_count, high=set_count).reshape(set_size) #.cast(dtypes.float64)
      return set_tensor

    # Tensor is  0  1  2  3  4
    #            5  6  7  8  9
    #           10 11 12 13 14
    #           15 16 17 18 19
    reference = Tensor.arange(0., 20).reshape(4, 5)

    indices_to_test = [
      # grab the second, fourth columns
      [slice(None), [1, 3]],

      # first, third rows,
      [[0, 2], slice(None)],

      # weird shape
      [slice(None), [[0, 1],
                      [2, 3]]],
      # negatives
      [[-1], [0]],
      [[0, 2], [-1]],
      [slice(None), [-1]],
    ]

    # only test dupes on gets
    get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]

    for indexer in get_indices_to_test:
      assert_get_eq(reference, indexer)
      assert_backward_eq(reference, indexer)

    for indexer in indices_to_test:
      assert_set_eq(reference, indexer, 44)
      assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))

    reference = Tensor.arange(0., 160).reshape(4, 8, 5)

    indices_to_test = [
      [slice(None), slice(None), [0, 3, 4]],
      [slice(None), [2, 4, 5, 7], slice(None)],
      [[2, 3], slice(None), slice(None)],
      [slice(None), [0, 2, 3], [1, 3, 4]],
      [slice(None), [0], [1, 2, 4]],
      [slice(None), [0, 1, 3], [4]],
      [slice(None), [[0, 1], [1, 0]], [[2, 3]]],
      [slice(None), [[0, 1], [2, 3]], [[0]]],
      [slice(None), [[5, 6]], [[0, 3], [4, 4]]],
      [[0, 2, 3], [1, 3, 4], slice(None)],
      [[0], [1, 2, 4], slice(None)],
      [[0, 1, 3], [4], slice(None)],
      [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
      [[[0, 1], [1, 0]], [[2, 3]], slice(None)],
      [[[0, 1], [2, 3]], [[0]], slice(None)],
      [[[2, 1]], [[0, 3], [4, 4]], slice(None)],
      [[[2]], [[0, 3], [4, 1]], slice(None)],
      # non-contiguous indexing subspace
      [[0, 2, 3], slice(None), [1, 3, 4]],

      # less dim, ellipsis
      [[0, 2], ],
      [[0, 2], slice(None)],
      [[0, 2], Ellipsis],
      [[0, 2], slice(None), Ellipsis],
      [[0, 2], Ellipsis, slice(None)],
      [[0, 2], [1, 3]],
      [[0, 2], [1, 3], Ellipsis],
      [Ellipsis, [1, 3], [2, 3]],
      [Ellipsis, [2, 3, 4]],
      [Ellipsis, slice(None), [2, 3, 4]],
      [slice(None), Ellipsis, [2, 3, 4]],

      # ellipsis counts for nothing
      [Ellipsis, slice(None), slice(None), [0, 3, 4]],
      [slice(None), Ellipsis, slice(None), [0, 3, 4]],
      [slice(None), slice(None), Ellipsis, [0, 3, 4]],
      [slice(None), slice(None), [0, 3, 4], Ellipsis],
      [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
      [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],
      [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],
    ]

    for indexer in indices_to_test:
      assert_get_eq(reference, indexer)

      assert_set_eq(reference, indexer, 212)
      assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
      assert_backward_eq(reference, indexer)

    reference = Tensor.arange(0., 1296).reshape(3, 9, 8, 6)

    indices_to_test = [
      [slice(None), slice(None), slice(None), [0, 3, 4]],
      [slice(None), slice(None), [2, 4, 5, 7], slice(None)],
      [slice(None), [2, 3], slice(None), slice(None)],
      [[1, 2], slice(None), slice(None), slice(None)],
      [slice(None), slice(None), [0, 2, 3], [1, 3, 4]],
      [slice(None), slice(None), [0], [1, 2, 4]],
      [slice(None), slice(None), [0, 1, 3], [4]],
      [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],
      [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],
      [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],
      [slice(None), [0, 2, 3], [1, 3, 4], slice(None)],
      [slice(None), [0], [1, 2, 4], slice(None)],
      [slice(None), [0, 1, 3], [4], slice(None)],
      [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],
      [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],
      [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],
      [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],
      [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],
      [[0, 1, 2], [1, 3, 4], slice(None), slice(None)],
      [[0], [1, 2, 4], slice(None), slice(None)],
      [[0, 1, 2], [4], slice(None), slice(None)],
      [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],
      [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],
      [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],
      [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],
      [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],
      [slice(None), [2, 3, 4], [1, 3, 4], [4]],
      [slice(None), [0, 1, 3], [4], [1, 3, 4]],
      [slice(None), [6], [0, 2, 3], [1, 3, 4]],
      [slice(None), [2, 3, 5], [3], [4]],
      [slice(None), [0], [4], [1, 3, 4]],
      [slice(None), [6], [0, 2, 3], [1]],
      [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],
      [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],
      [[2, 0, 1], [1, 2, 3], [4], slice(None)],
      [[0, 1, 2], [4], [1, 3, 4], slice(None)],
      [[0], [0, 2, 3], [1, 3, 4], slice(None)],
      [[0, 2, 1], [3], [4], slice(None)],
      [[0], [4], [1, 3, 4], slice(None)],
      [[1], [0, 2, 3], [1], slice(None)],
      [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],

      # less dim, ellipsis
      [Ellipsis, [0, 3, 4]],
      [Ellipsis, slice(None), [0, 3, 4]],
      [Ellipsis, slice(None), slice(None), [0, 3, 4]],
      [slice(None), Ellipsis, [0, 3, 4]],
      [slice(None), slice(None), Ellipsis, [0, 3, 4]],
      [slice(None), [0, 2, 3], [1, 3, 4]],
      [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],
      [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],
      [[0], [1, 2, 4]],
      [[0], [1, 2, 4], slice(None)],
      [[0], [1, 2, 4], Ellipsis],
      [[0], [1, 2, 4], Ellipsis, slice(None)],
      [[1], ],
      [[0, 2, 1], [3], [4]],
      [[0, 2, 1], [3], [4], slice(None)],
      [[0, 2, 1], [3], [4], Ellipsis],
      [Ellipsis, [0, 2, 1], [3], [4]],
    ]

    for indexer in indices_to_test:
      assert_get_eq(reference, indexer)
      assert_set_eq(reference, indexer, 1333)
      assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))

    indices_to_test += [
      [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],
      [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],
    ]
    for indexer in indices_to_test:
      assert_get_eq(reference, indexer)
      assert_set_eq(reference, indexer, 1333)
      assert_backward_eq(reference, indexer)

  # TODO setitem backward
  '''
  def test_set_item_to_scalar_tensor(self):
    m = random.randint(1, 10)
    n = random.randint(1, 10)
    z = Tensor.randn([m, n])
    a = 1.0
    w = Tensor(a, requires_grad=True)
    z[:, 0] = w
    z.sum().backward()
    numpy_testing_assert_equal_helper(w.grad, m * a)
  '''

  def test_single_int(self):
    v = Tensor.randn(5, 7, 3)
    numpy_testing_assert_equal_helper(v[4].shape, (7, 3))

  def test_multiple_int(self):
    v = Tensor.randn(5, 7, 3)
    numpy_testing_assert_equal_helper(v[4].shape, (7, 3))
    numpy_testing_assert_equal_helper(v[4, :, 1].shape, (7,))

  def test_none(self):
    v = Tensor.randn(5, 7, 3)
    numpy_testing_assert_equal_helper(v[None].shape, (1, 5, 7, 3))
    numpy_testing_assert_equal_helper(v[:, None].shape, (5, 1, 7, 3))
    numpy_testing_assert_equal_helper(v[:, None, None].shape, (5, 1, 1, 7, 3))
    numpy_testing_assert_equal_helper(v[..., None].shape, (5, 7, 3, 1))

  def test_step(self):
    v = Tensor.arange(10)
    numpy_testing_assert_equal_helper(v[::1], v)
    numpy_testing_assert_equal_helper(v[::2], [0, 2, 4, 6, 8])
    numpy_testing_assert_equal_helper(v[::3], [0, 3, 6, 9])
    numpy_testing_assert_equal_helper(v[::11], [0])
    numpy_testing_assert_equal_helper(v[1:6:2], [1, 3, 5])

  def test_step_assignment(self):
    v = Tensor.zeros(4, 4).contiguous()
    v[0, 1::2] = Tensor([3., 4.])
    numpy_testing_assert_equal_helper(v[0].numpy().tolist(), [0, 3, 0, 4])
    numpy_testing_assert_equal_helper(v[1:].sum(), 0)

  @unittest.skip("bool indexing not supported")
  def test_bool_indices(self):
    v = Tensor.randn(5, 7, 3)
    boolIndices = Tensor([True, False, True, True, False], dtype=dtypes.bool)
    numpy_testing_assert_equal_helper(v[boolIndices].shape, (3, 7, 3))
    numpy_testing_assert_equal_helper(v[boolIndices], Tensor.stack([v[0], v[2], v[3]]))

    v = Tensor([True, False, True], dtype=dtypes.bool)
    boolIndices = Tensor([True, False, False], dtype=dtypes.bool)
    uint8Indices = Tensor([1, 0, 0], dtype=dtypes.uint8)
    with warnings.catch_warnings(record=True) as w:
      numpy_testing_assert_equal_helper(v[boolIndices].shape, v[uint8Indices].shape)
      numpy_testing_assert_equal_helper(v[boolIndices], v[uint8Indices])
      numpy_testing_assert_equal_helper(v[boolIndices], Tensor([True]))
      numpy_testing_assert_equal_helper(len(w), 2)

  @unittest.skip("bool indexing not supported")
  def test_bool_indices_accumulate(self):
    mask = Tensor.zeros(size=(10, ), dtype=dtypes.bool)
    y = Tensor.ones(size=(10, 10))
    index_put_(y, (mask, ), y[mask], accumulate=True)
    numpy_testing_assert_equal_helper(y, Tensor.ones(size=(10, 10)))

  @unittest.skip("bool indexing not supported")
  def test_multiple_bool_indices(self):
    v = Tensor.randn(5, 7, 3)
    # note: these broadcast together and are transposed to the first dim
    mask1 = Tensor([1, 0, 1, 1, 0], dtype=dtypes.bool)
    mask2 = Tensor([1, 1, 1], dtype=dtypes.bool)
    numpy_testing_assert_equal_helper(v[mask1, :, mask2].shape, (3, 7))

  @unittest.skip("bool indexing not supported")
  def test_byte_mask(self):
    v = Tensor.randn(5, 7, 3)
    mask = Tensor([1, 0, 1, 1, 0], dtype=dtypes.uint8)
    with warnings.catch_warnings(record=True) as w:
      numpy_testing_assert_equal_helper(v[mask].shape, (3, 7, 3))
      numpy_testing_assert_equal_helper(v[mask], Tensor.stack([v[0], v[2], v[3]]))
      numpy_testing_assert_equal_helper(len(w), 2)

    v = Tensor([1.])
    numpy_testing_assert_equal_helper(v[v == 0], Tensor([]))

  @unittest.skip("bool indexing not supported")
  def test_byte_mask_accumulate(self):
    mask = Tensor.zeros(size=(10, ), dtype=dtypes.uint8)
    y = Tensor.ones(size=(10, 10))
    with warnings.catch_warnings(record=True) as w:
      warnings.simplefilter("always")
      index_put_(y, (mask, ), y[mask], accumulate=True)
      numpy_testing_assert_equal_helper(y, Tensor.ones(size=(10, 10)))
      numpy_testing_assert_equal_helper(len(w), 2)

  # TODO setitem
  # NOTE: tinygrad doesn't support idx.max that big
  '''
  def test_index_put_accumulate_large_tensor(self):
    # This test is for tensors with number of elements >= INT_MAX (2^31 - 1).
    N = (1 << 31) + 5
    dt = dtypes.int8
    a = Tensor.ones(N, dtype=dt).contiguous()
    indices = Tensor([-2, 0, -2, -1, 0, -1, 1], dtype=dtypes.int64)
    values = Tensor([6, 5, 6, 6, 5, 7, 11], dtype=dt)

    index_put_(a, (indices, ), values, accumulate=True)

    numpy_testing_assert_equal_helper(a[0], 11)
    numpy_testing_assert_equal_helper(a[1], 12)
    numpy_testing_assert_equal_helper(a[2], 1)
    numpy_testing_assert_equal_helper(a[-3], 1)
    numpy_testing_assert_equal_helper(a[-2], 13)
    numpy_testing_assert_equal_helper(a[-1], 14)

    a = Tensor.ones((2, N), dtype=dt).contiguous()
    indices0 = np.array([0, -1, 0, 1], dtype=dtypes.int64)
    indices1 = np.array([-2, -1, 0, 1], dtype=dtypes.int64)
    values = np.array([12, 13, 10, 11], dtype=dt)

    index_put_(a, (indices0, indices1), values, accumulate=True)

    numpy_testing_assert_equal_helper(a[0, 0], 11)
    numpy_testing_assert_equal_helper(a[0, 1], 1)
    numpy_testing_assert_equal_helper(a[1, 0], 1)
    numpy_testing_assert_equal_helper(a[1, 1], 12)
    numpy_testing_assert_equal_helper(a[:, 2], Tensor.ones(2, dtype=dtypes.int8))
    numpy_testing_assert_equal_helper(a[:, -3], Tensor.ones(2, dtype=dtypes.int8))
    numpy_testing_assert_equal_helper(a[0, -2], 13)
    numpy_testing_assert_equal_helper(a[1, -2], 1)
    numpy_testing_assert_equal_helper(a[-1, -1], 14)
    numpy_testing_assert_equal_helper(a[0, -1], 1)
  '''

  # TODO fancy setitem
  '''
  def test_index_put_accumulate_duplicate_indices(self):
    for i in range(1, 512):
      # generate indices by random walk, this will create indices with
      # lots of duplicates interleaved with each other
      delta = Tensor.uniform(low=-1, high=1, dtype=dtypes.double)
      indices = delta.cumsum(0).cast(dtypes.int64)

      # input = torch.randn(indices.abs().max() + 1)
      input = Tensor.randn(indices.abs().max().item() + 1)
      # values = torch.randn(indices.size(0))
      values = Tensor.randn(indices.shape(0))
      output = index_put_(input, (indices,), values, accumulate=True)

      input_list = input.numpy().tolist()
      indices_list = indices.numpy().tolist()
      values_list = values.numpy().tolist()
      for i, v in zip(indices_list, values_list):
        input_list[i] += v

      numpy_testing_assert_equal_helper(output, input_list)
  '''

  @unittest.skipUnless(is_dtype_supported(dtypes.long), f"long dtype not supported on {Device.DEFAULT}")
  def test_index_ind_dtype(self):
    x = Tensor.randn(4, 4)
    # ind_long = torch.randint(4, (4,), dtype=torch.long)
    # TODO should we spend an extra line to allow for randint other dtypes?
    # copied from randint
    ind_long = (Tensor.rand((4,),)*(4-0)+0).cast(dtypes.int64)
    # ind_int = ind_long.int()
    ind_int = (ind_long).cast(dtypes.int32)
    ref = x[ind_long, ind_long]
    res = x[ind_int, ind_int]
    numpy_testing_assert_equal_helper(ref, res)
    ref = x[ind_long, :]
    res = x[ind_int, :]
    numpy_testing_assert_equal_helper(ref, res)
    ref = x[:, ind_long]
    res = x[:, ind_int]
    numpy_testing_assert_equal_helper(ref, res)
    # no repeating indices for index_put
    # TODO fancy setitem
    '''
    src = Tensor.randn(4)
    ind_long = Tensor.arange(4, dtype=dtypes.int64)
    ind_int = ind_long.cast(dtypes.int32)
    for accum in (True, False):
      inp_ref = clone(x)
      inp_res = clone(x)
      index_put_(inp_ref, (ind_long, ind_long), src, accum)
      index_put_(inp_res, (ind_int, ind_int), src, accum)
      numpy_testing_assert_equal_helper(inp_ref, inp_res)
    '''

  # TODO empty setitem
  '''
  def test_index_put_accumulate_empty(self):
    # Regression test for https://github.com/pytorch/pytorch/issues/94667
    input = Tensor.rand([], dtype=dtypes.float32)
    with self.assertRaises(RuntimeError):
      index_put_(input, [], np.array([1.0]), True)
  '''

  @unittest.skip("bool indexing not supported")
  def test_multiple_byte_mask(self):
    v = Tensor.randn(5, 7, 3)
    # note: these broadcast together and are transposed to the first dim
    mask1 = Tensor([1, 0, 1, 1, 0], dtype=dtypes.uint8)
    mask2 = Tensor([1, 1, 1], dtype=dtypes.uint8)
    with warnings.catch_warnings(record=True) as w:
      warnings.simplefilter("always")
      numpy_testing_assert_equal_helper(v[mask1, :, mask2].shape, (3, 7))
      numpy_testing_assert_equal_helper(len(w), 2)

  @unittest.skip("bool indexing not supported")
  def test_byte_mask2d(self):
    v = Tensor.randn(5, 7, 3)
    c = Tensor.randn(5, 7)
    num_ones = (c > 0).sum()
    r = v[c > 0]
    numpy_testing_assert_equal_helper(r.shape, (num_ones, 3))

  @unittest.skip("bool indexing not supported")
  def test_jit_indexing(self):
    def fn1(x):
      x[x < 50] = 1.0
      return x

    def fn2(x):
      x[0:50] = 1.0
      return x

    scripted_fn1 = TinyJit(fn1)
    scripted_fn2 = TinyJit(fn2)
    data = Tensor.arange(100, dtype=dtypes.float)
    out = scripted_fn1(clone(data))
    ref = Tensor(np.concatenate((np.ones(50), np.arange(50, 100))), dtype=dtypes.float)
    numpy_testing_assert_equal_helper(out, ref)
    out = scripted_fn2(clone(data))
    numpy_testing_assert_equal_helper(out, ref)

  def test_int_indices(self):
    v = Tensor.randn(5, 7, 3)
    numpy_testing_assert_equal_helper(v[[0, 4, 2]].shape, (3, 7, 3))
    numpy_testing_assert_equal_helper(v[:, [0, 4, 2]].shape, (5, 3, 3))
    numpy_testing_assert_equal_helper(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))

  # TODO fancy setitem
  '''
  def test_index_put_src_datatype(self, dtype):
    src = Tensor.ones(3, 2, 4, dtype=dtype)
    vals = Tensor.ones(3, 2, 4, dtype=dtype)
    indices = (np.array([0, 2, 1]),)
    res = index_put_(src, indices, vals, accumulate=True)
    numpy_testing_assert_equal_helper(res.shape, src.shape)
  '''

  def test_index_src_datatype(self):
    src = Tensor.ones(3, 2, 4)
    # test index
    res = src[[0, 2, 1], :, :]
    numpy_testing_assert_equal_helper(res.shape, src.shape)
    # test index_put, no accum
    # TODO fancy setitem
    '''
    src[[0, 2, 1], :, :] = res
    numpy_testing_assert_equal_helper(res.shape, src.shape)
    '''

  def test_int_indices2d(self):
    # From the NumPy indexing example
    x = Tensor.arange(0, 12).reshape(4, 3)
    rows = Tensor([[0, 0], [3, 3]])
    columns = Tensor([[0, 2], [0, 2]])
    numpy_testing_assert_equal_helper(x[rows, columns].numpy().tolist(), [[0, 2], [9, 11]])

  def test_int_indices_broadcast(self):
    # From the NumPy indexing example
    x = Tensor.arange(0, 12).reshape(4, 3)
    rows = Tensor([0, 3])
    columns = Tensor([0, 2])
    result = x[rows[:, None], columns]
    numpy_testing_assert_equal_helper(result.numpy().tolist(), [[0, 2], [9, 11]])

  # TODO jax supports empty tensor indexing
  @unittest.skip("empty tensor indexing not supported")
  def test_empty_index(self):
    x = Tensor.arange(0, 12).reshape(4, 3)
    idx = Tensor([], dtype=dtypes.int64)
    numpy_testing_assert_equal_helper(x[idx].numel(), 0)

    # TODO empty setitem
    '''
    # empty assignment should have no effect but not throw an exception
    y = clone(x)
    y[idx] = -1
    numpy_testing_assert_equal_helper(x, y)

    mask = Tensor.zeros(4, 3).cast(dtypes.bool)
    y[mask] = -1
    numpy_testing_assert_equal_helper(x, y)
    '''

  # TODO jax supports empty tensor indexing
  @unittest.skip("empty tensor indexing not supported")
  def test_empty_ndim_index(self):
    x = Tensor.randn(5)
    numpy_testing_assert_equal_helper(Tensor.empty(0, 2), x[Tensor.empty(0, 2, dtype=dtypes.int64)])

    x = Tensor.randn(2, 3, 4, 5)
    numpy_testing_assert_equal_helper(Tensor.empty(2, 0, 6, 4, 5),
                      x[:, Tensor.empty(0, 6, dtype=dtypes.int64)])

    x = Tensor.empty(10, 0)
    numpy_testing_assert_equal_helper(x[[1, 2]].shape, (2, 0))
    numpy_testing_assert_equal_helper(x[[], []].shape, (0,))
    with self.assertRaises(IndexError):
      x[:, [0, 1]]

  def test_empty_slice(self):
    x = Tensor.randn(2, 3, 4, 5)
    y = x[:, :, :, 1]
    z = y[:, 1:1, :]
    numpy_testing_assert_equal_helper((2, 0, 4), z.shape)
    # this isn't technically necessary, but matches NumPy stride calculations.
    # NOTE: this is empty and shouldn't have strides
    #numpy_testing_assert_equal_helper((60, 20, 5), z.lazydata.st.real_strides())
    # NOTE tinygrad's int slicing implementation makes this not contiguous
    # self.assertTrue(z.lazydata.st.contiguous)

  @unittest.skip("bool indexing not supported")
  def test_index_getitem_copy_bools_slices(self):
    true = Tensor(1, dtype=dtypes.uint8)
    false = Tensor(0, dtype=dtypes.uint8)

    tensors = [Tensor.randn(2, 3), Tensor(3.)]

    for a in tensors:
      self.assertNotEqual(data_ptr(a), data_ptr(a[True]))
      numpy_testing_assert_equal_helper(Tensor.empty(0, *a.shape), a[False])
      self.assertNotEqual(data_ptr(a), data_ptr(a[true]))
      numpy_testing_assert_equal_helper(Tensor.empty(0, *a.shape), a[false])
      self.assertEqual(data_ptr(a), data_ptr(a[None]))
      self.assertEqual(data_ptr(a), data_ptr(a[...]))

  @unittest.skip("bool indexing not supported")
  def test_index_setitem_bools_slices(self):
    true = Tensor(1, dtype=dtypes.uint8)
    false = Tensor(0, dtype=dtypes.uint8)

    tensors = [Tensor.randn(2, 3), Tensor(3)]

    for a in tensors:
      # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
      # (some of these ops already prefix a 1 to the size)
      neg_ones = Tensor.ones_like(a) * -1
      neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
      a[True] = neg_ones_expanded
      numpy_testing_assert_equal_helper(a, neg_ones)
      a[False] = 5
      numpy_testing_assert_equal_helper(a, neg_ones)
      a[true] = neg_ones_expanded * 2
      numpy_testing_assert_equal_helper(a, neg_ones * 2)
      a[false] = 5
      numpy_testing_assert_equal_helper(a, neg_ones * 2)
      a[None] = neg_ones_expanded * 3
      numpy_testing_assert_equal_helper(a, neg_ones * 3)
      a[...] = neg_ones_expanded * 4
      numpy_testing_assert_equal_helper(a, neg_ones * 4)
      if a.dim() == 0:
        with self.assertRaises(IndexError):
          a[:] = neg_ones_expanded * 5

  @unittest.skip("bool indexing not supported")
  def test_index_scalar_with_bool_mask(self):
    a = Tensor(1)
    uintMask = Tensor(True, dtype=dtypes.uint8)
    boolMask = Tensor(True, dtype=dtypes.bool)
    numpy_testing_assert_equal_helper(a[uintMask], a[boolMask])
    numpy_testing_assert_equal_helper(a[uintMask].dtype, a[boolMask].dtype)

    a = Tensor(True, dtype=dtypes.bool)
    numpy_testing_assert_equal_helper(a[uintMask], a[boolMask])
    numpy_testing_assert_equal_helper(a[uintMask].dtype, a[boolMask].dtype)

  @unittest.skip("bool indexing not supported")
  def test_setitem_expansion_error(self):
    true = Tensor(True)
    a = Tensor.randn(2, 3)
    # check prefix with  non-1s doesn't work
    # a_expanded = a.expand(torch.Size([5, 1]) + a.size())
    a_expanded = a.expand((5, 1) + a.shape)
    # NumPy: ValueError
    with self.assertRaises(RuntimeError):
      a[True] = a_expanded
    with self.assertRaises(RuntimeError):
      a[true] = a_expanded

  def test_getitem_scalars_simple(self):
    src = Tensor([[[1.,2.],[3.,4.]], [[1,1],[1,1]]])
    a = src[0].mul(src[1])
    self.assertEqual(a[0,1].item(), 2)

  def test_getitem_scalars(self):
    zero = Tensor(0, dtype=dtypes.int64)
    one = Tensor(1, dtype=dtypes.int64)

    # non-scalar indexed with scalars
    a = Tensor.randn(2, 3)
    numpy_testing_assert_equal_helper(a[0], a[zero])
    numpy_testing_assert_equal_helper(a[0][1], a[zero][one])
    numpy_testing_assert_equal_helper(a[0, 1], a[zero, one])
    numpy_testing_assert_equal_helper(a[0, one], a[zero, 1])

    # indexing by a scalar should slice (not copy)
    numpy_testing_assert_equal_helper(a[0, 1], a[zero, one])
    numpy_testing_assert_equal_helper(a[1], a[one.cast(dtypes.int32)])
    numpy_testing_assert_equal_helper(a[1], a[one.cast(dtypes.int16)])

    # scalar indexed with scalar
    r = Tensor.randn()
    with self.assertRaises(IndexError):
      r[:]
    with self.assertRaises(IndexError):
      r[zero]
    numpy_testing_assert_equal_helper(r, r[...])

  # TODO fancy setitem
  '''
  def test_setitem_scalars(self):
    zero = Tensor(0, dtype=dtypes.int64)

    # non-scalar indexed with scalars
    a = Tensor.randn(2, 3).contiguous()
    a_set_with_number = clone(a).contiguous()
    a_set_with_scalar = clone(a).contiguous()
    b = Tensor.randn(3)

    a_set_with_number[0] = b
    a_set_with_scalar[zero] = b
    numpy_testing_assert_equal_helper(a_set_with_number, a_set_with_scalar)
    a[1, zero] = 7.7
    # TODO: weird inaccuracy Max relative difference: 2.47707621e-08
    # numpy_testing_assert_equal_helper(7.7, a[1, 0])
    np.testing.assert_allclose(7.7, a[1, 0].numpy(), rtol=1e-7)

    # scalar indexed with scalars
    r = Tensor.randn().contiguous()
    with self.assertRaises(IndexError):
      r[:] = 8.8
    with self.assertRaises(IndexError):
      r[zero] = 8.8
    r[...] = 9.9
    # TODO: weird inaccuracy Max relative difference: 3.85322971e-08
    # numpy_testing_assert_equal_helper(9.9, r)
    np.testing.assert_allclose(9.9, r, rtol=1e-7)
  '''

  @unittest.skip("getitem expects elementwise ops folding to be instant, but it happens late in the scheduler")
  def test_getitem_casted_scalars_folding(self):
    Tensor.manual_seed(0)
    # cast of const is just another const, don't need extra kernels for this
    a = Tensor.randn(2, 3)
    one = Tensor(1, dtype=dtypes.int64)
    self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int32)]))
    self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int16)]))

  def test_getitem_scalars_simple_folding(self):
    a = Tensor.randn(2, 3)
    zero = Tensor(0, dtype=dtypes.int64)
    one = Tensor(1, dtype=dtypes.int64)
    self.assertEqual(data_ptr(a[0, 1]), data_ptr(a[zero, one]))

  def test_basic_advanced_combined(self):
    # From the NumPy indexing example
    x = Tensor.arange(0, 12).reshape(4, 3)
    numpy_testing_assert_equal_helper(x[1:2, 1:3], x[1:2, [1, 2]])
    numpy_testing_assert_equal_helper(x[1:2, 1:3].numpy().tolist(), [[4, 5]])

    # Check that it is a copy
    unmodified = clone(x)
    x[1:2, [1, 2]].zeros_like()
    numpy_testing_assert_equal_helper(x, unmodified)

    # But assignment should modify the original
    # TODO fancy setitem
    '''
    unmodified = clone(x)
    x[1:2, [1, 2]] = 0
    self.assertNotEqual(x, unmodified)
    '''

  def test_int_assignment(self):
    x = Tensor.arange(0, 4).reshape(2, 2)
    x[1] = 5
    numpy_testing_assert_equal_helper(x.numpy().tolist(), [[0, 1], [5, 5]])

    x = Tensor.arange(0, 4).reshape(2, 2)
    x[1] = Tensor.arange(5, 7)
    numpy_testing_assert_equal_helper(x.numpy().tolist(), [[0, 1], [5, 6]])

  # TODO fancy setitem
  '''
  def test_byte_tensor_assignment(self):
    x = Tensor.arange(0., 16).reshape(4, 4)
    b = Tensor([True, False, True, False], dtype=dtypes.uint8)
    value = Tensor([3., 4., 5., 6.])

    with warnings.catch_warnings(record=True) as w:
      x[b] = value
      numpy_testing_assert_equal_helper(len(w), 1)

    numpy_testing_assert_equal_helper(x[0], value)
    numpy_testing_assert_equal_helper(x[1], Tensor.arange(4., 8))
    numpy_testing_assert_equal_helper(x[2], value)
    numpy_testing_assert_equal_helper(x[3], Tensor.arange(12., 16))
  '''

  @unittest.skip("Tensor unpacking not supported")
  def test_variable_slicing(self):
    x = Tensor.arange(0, 16).reshape(4, 4)
    indices = Tensor([0, 1], dtype=dtypes.int32)
    i, j = indices
    numpy_testing_assert_equal_helper(x[i:j], x[0:1])

  def test_ellipsis_tensor(self):
    x = Tensor.arange(0, 9).reshape(3, 3)
    idx = Tensor([0, 2])
    numpy_testing_assert_equal_helper(x[..., idx].numpy().tolist(), [[0, 2],
                                                                     [3, 5],
                                                                     [6, 8]])
    numpy_testing_assert_equal_helper(x[idx, ...].numpy().tolist(), [[0, 1, 2],
                                                                     [6, 7, 8]])

  # TODO unravel_index
  '''
  def test_unravel_index_errors(self):
    with self.assertRaises(TypeError):
      unravel_index(
        Tensor(0.5),
        (2, 2))

    with self.assertRaises(TypeError):
      unravel_index(
        Tensor([]),
        (10, 3, 5))

    with self.assertRaises(TypeError):
      unravel_index(
        Tensor([1], dtype=dtypes.int64),
        Tensor([1, 2, 3]))

    with self.assertRaises(TypeError):
      unravel_index(
        Tensor([1], dtype=dtypes.int64),
        (1, 2, 2.0))

    with self.assertRaises(ValueError):
      unravel_index(
        Tensor(0),
        (2, -3))
  '''

  def test_invalid_index(self):
    x = Tensor.arange(0, 16).reshape(4, 4)
    self.assertRaises(TypeError, lambda: x["0":"1"])

  def test_out_of_bound_index(self):
    x = Tensor.arange(0, 100).reshape(2, 5, 10)
    self.assertRaises(IndexError, lambda: x[0, 5])
    self.assertRaises(IndexError, lambda: x[4, 5])
    self.assertRaises(IndexError, lambda: x[0, 1, 15])
    self.assertRaises(IndexError, lambda: x[:, :, 12])

  def test_zero_dim_index(self):
    x = Tensor(10)
    numpy_testing_assert_equal_helper(x, x.item())

    def runner():
      print(x[0])
      return x[0]

    self.assertRaises(IndexError, runner)

  # TODO fancy setitem
  '''
  def test_cpu_indices(self):
    idx = Tensor([0, 1])
    b = Tensor.zeros(2)
    x = Tensor.ones(10).contiguous()
    x[idx] = b  # index_put_
    ref = Tensor.ones(10).contiguous()
    ref[:2] = 0
    numpy_testing_assert_equal_helper(x, ref)
    out = x[idx]  # index
    numpy_testing_assert_equal_helper(out, Tensor.zeros(2))
  '''

  def test_take_along_dim(self):
    def _test_against_numpy(t: Tensor, indices: Tensor, dim):
      actual = t.gather(dim, indices)
      t_np = t.numpy()
      indices_np = indices.numpy()
      expected = np.take_along_axis(t_np, indices_np, axis=dim)
      numpy_testing_assert_equal_helper(actual, expected)

      # TODO argsort
      '''
      for shape in [(3, 2), (2, 3, 5), (2, 4, 0), (2, 3, 1, 4)]:
        for noncontiguous in [True, False]:
          for dtype in (dtypes.float32, dtypes.int64):
            t = make_tensor(shape, dtype=dtype, noncontiguous=noncontiguous)
            for dim in list(range(t.ndim)) + [None]:
              if dim is None:
                indices = argsort(t.reshape(-1))
              else:
                indices = argsort(t, dim=dim)

          _test_against_numpy(t, indices, dim)
      '''

      # test broadcasting
      t = Tensor.ones((3, 4, 1))
      indices = Tensor.ones((1, 2, 5), dtype=dtypes.int64)

      _test_against_numpy(t, indices, 1)

      # test empty indices
      t = Tensor.ones((3, 4, 5))
      indices = Tensor.ones((3, 0, 5), dtype=dtypes.int64)

      _test_against_numpy(t, indices, 1)

  # TODO argsort
  '''
  def test_take_along_dim_invalid(self):
    for dtype in (dtypes.int64, dtypes.float32):
      shape = (2, 3, 1, 4)
      dim = 0
      t = make_tensor(shape, dtype=dtype)
      indices = argsort(t, dim=dim)

      # dim of `t` and `indices` does not match
      with self.assertRaises(RuntimeError, "input and indices should have the same number of dimensions"):
        t.gather(0, indices[0])

      # invalid `indices` dtype
      with self.assertRaises(RuntimeError):
        t.gather(0, indices.cast(dtypes.bool))

      with self.assertRaises(RuntimeError):
        t.gather(0, indices.cast(dtypes.float32))

      with self.assertRaises(RuntimeError):
        t.gather(0, indices.cast(dtypes.int32))

      # invalid axis
      with self.assertRaises(IndexError):
        t.gather(-7, indices)

      with self.assertRaises(IndexError):
        t.gather(7, indices)
  '''

class TestNumpy(unittest.TestCase):
  def test_index_no_floats(self):
    a = Tensor([[[5.]]])

    self.assertRaises(IndexError, lambda: a[0.0])
    self.assertRaises(IndexError, lambda: a[0, 0.0])
    self.assertRaises(IndexError, lambda: a[0.0, 0])
    self.assertRaises(IndexError, lambda: a[0.0, :])
    self.assertRaises(IndexError, lambda: a[:, 0.0])
    self.assertRaises(IndexError, lambda: a[:, 0.0, :])
    self.assertRaises(IndexError, lambda: a[0.0, :, :])
    self.assertRaises(IndexError, lambda: a[0, 0, 0.0])
    self.assertRaises(IndexError, lambda: a[0.0, 0, 0])
    self.assertRaises(IndexError, lambda: a[0, 0.0, 0])
    self.assertRaises(IndexError, lambda: a[-1.4])
    self.assertRaises(IndexError, lambda: a[0, -1.4])
    self.assertRaises(IndexError, lambda: a[-1.4, 0])
    self.assertRaises(IndexError, lambda: a[-1.4, :])
    self.assertRaises(IndexError, lambda: a[:, -1.4])
    self.assertRaises(IndexError, lambda: a[:, -1.4, :])
    self.assertRaises(IndexError, lambda: a[-1.4, :, :])
    self.assertRaises(IndexError, lambda: a[0, 0, -1.4])
    self.assertRaises(IndexError, lambda: a[-1.4, 0, 0])
    self.assertRaises(IndexError, lambda: a[0, -1.4, 0])
    # these two trigger slice internal type verification first
    self.assertRaises(TypeError, lambda: a[0.0:, 0.0])
    self.assertRaises(TypeError, lambda: a[0.0:, 0.0,:])

  def test_none_index(self):
    # `None` index adds newaxis
    a = Tensor([1, 2, 3])
    numpy_testing_assert_equal_helper(a[None].ndim, a.ndim+1)

  def test_empty_tuple_index(self):
    # Empty tuple index creates a view
    a = Tensor([1, 2, 3])
    numpy_testing_assert_equal_helper(a[()], a)
    #self.assertEqual(data_ptr(a[()]), data_ptr(a))

  # TODO jax supports empty tensor indexing
  @unittest.skip("empty tensor indexing not supported")
  def test_empty_fancy_index(self):
    # Empty list index creates an empty array
    a = Tensor([1, 2, 3])
    numpy_testing_assert_equal_helper(a[[]], np.array([]))

    b = Tensor([]).cast(dtypes.int64)
    numpy_testing_assert_equal_helper(a[[]], np.array([]))

    b = Tensor([]).float()
    self.assertRaises(IndexError, lambda: a[b])

  def test_ellipsis_index(self):
    a = Tensor([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])
    self.assertIsNot(a[...], a)
    numpy_testing_assert_equal_helper(a[...], a)
    # `a[...]` was `a` in numpy <1.9.
    #numpy_testing_assert_equal_helper(data_ptr(a[...]), data_ptr(a))

    # Slicing with ellipsis can skip an
    # arbitrary number of dimensions
    numpy_testing_assert_equal_helper(a[0, ...], a[0])
    numpy_testing_assert_equal_helper(a[0, ...], a[0, :])
    numpy_testing_assert_equal_helper(a[..., 0], a[:, 0])

    # In NumPy, slicing with ellipsis results in a 0-dim array. In PyTorch
    # we don't have separate 0-dim arrays and scalars.
    numpy_testing_assert_equal_helper(a[0, ..., 1], np.array(2))

    # Assignment with `(Ellipsis,)` on 0-d arrays
    b = np.array(1)
    b[(Ellipsis,)] = 2
    numpy_testing_assert_equal_helper(b, 2)

  def test_single_int_index(self):
    # Single integer index selects one row
    a = Tensor([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])

    numpy_testing_assert_equal_helper(a[0], [1, 2, 3])
    numpy_testing_assert_equal_helper(a[-1], [7, 8, 9])

    self.assertRaises(IndexError, a.__getitem__, 1 << 30)
    self.assertRaises(IndexError, a.__getitem__, 1 << 64)

  @unittest.skip("bool indexing not supported")
  def test_single_bool_index(self):
    # Single boolean index
    a = Tensor([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])

    numpy_testing_assert_equal_helper(a[True], a[None])
    numpy_testing_assert_equal_helper(a[False], a[None][0:0])

  @unittest.skip("bool indexing not supported")
  def test_boolean_shape_mismatch(self):
    arr = Tensor.ones((5, 4, 3))

    index = Tensor([True])
    self.assertRaises(IndexError, lambda: arr[index])

    index = Tensor([False] * 6)
    self.assertRaises(IndexError, lambda: arr[index])

    index = Tensor.zeros(4, 4, dtype=dtypes.uint8)
    self.assertRaises(IndexError, lambda: arr[index])
    self.assertRaises(IndexError, lambda: arr[(slice(None), index)])

  @unittest.skip("bool indexing not supported")
  def test_boolean_indexing_onedim(self):
    # Indexing a 2-dimensional array with
    # boolean array of length one
    a = Tensor([[0., 0., 0.]])
    b = Tensor([True])
    numpy_testing_assert_equal_helper(a[b], a)
    # boolean assignment
    a[b] = 1.
    numpy_testing_assert_equal_helper(a, Tensor([[1., 1., 1.]]))

  @unittest.skip("bool indexing not supported")
  def test_boolean_assignment_value_mismatch(self):
    # A boolean assignment should fail when the shape of the values
    # cannot be broadcast to the subscription. (see also gh-3458)
    a = Tensor.arange(0, 4)

    def f(a, v):
      a[a > -1] = Tensor(v)

    self.assertRaises(Exception, f, a, [])
    self.assertRaises(Exception, f, a, [1, 2, 3])
    self.assertRaises(Exception, f, a[:1], [1, 2, 3])

  @unittest.skip("bool indexing not supported")
  def test_boolean_indexing_twodim(self):
    # Indexing a 2-dimensional array with
    # 2-dimensional boolean array
    a = Tensor([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])
    b = Tensor([[True, False, True],
                [False, True, False],
                [True, False, True]])
    numpy_testing_assert_equal_helper(a[b], Tensor([1, 3, 5, 7, 9]))
    numpy_testing_assert_equal_helper(a[b[1]], Tensor([[4, 5, 6]]))
    numpy_testing_assert_equal_helper(a[b[0]], a[b[2]])

    # boolean assignment
    a[b] = 0
    numpy_testing_assert_equal_helper(a, Tensor([[0, 2, 0],
                                                  [4, 0, 6],
                                                  [0, 8, 0]]))

  @unittest.skip("bool indexing not supported")
  def test_boolean_indexing_weirdness(self):
    # Weird boolean indexing things
    a = Tensor.ones((2, 3, 4))
    numpy_testing_assert_equal_helper((0, 2, 3, 4), a[False, True, ...].shape)
    numpy_testing_assert_equal_helper(Tensor.ones(1, 2), a[True, [0, 1], True, True, [1], [[2]]])
    self.assertRaises(IndexError, lambda: a[False, [0, 1], ...])

  @unittest.skip("bool indexing not supported")
  def test_boolean_indexing_weirdness_tensors(self):
    # Weird boolean indexing things
    false = Tensor(False)
    true = Tensor(True)
    a = Tensor.ones((2, 3, 4))
    numpy_testing_assert_equal_helper((0, 2, 3, 4), a[False, True, ...].shape)
    numpy_testing_assert_equal_helper(Tensor.ones(1, 2), a[true, [0, 1], true, true, [1], [[2]]])
    self.assertRaises(IndexError, lambda: a[false, [0, 1], ...])

  @unittest.skip("bool indexing not supported")
  def test_boolean_indexing_alldims(self):
    true = Tensor(True)
    a = Tensor.ones((2, 3))
    numpy_testing_assert_equal_helper((1, 2, 3), a[True, True].shape)
    numpy_testing_assert_equal_helper((1, 2, 3), a[true, true].shape)

  @unittest.skip("bool indexing not supported")
  def test_boolean_list_indexing(self):
    # Indexing a 2-dimensional array with
    # boolean lists
    a = Tensor([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])
    b = [True, False, False]
    c = [True, True, False]
    numpy_testing_assert_equal_helper(a[b], Tensor([[1, 2, 3]]))
    numpy_testing_assert_equal_helper(a[b, b], Tensor([1]))
    numpy_testing_assert_equal_helper(a[c], Tensor([[1, 2, 3], [4, 5, 6]]))
    numpy_testing_assert_equal_helper(a[c, c], Tensor([1, 5]))

  def test_everything_returns_views(self):
    # Before `...` would return a itself.
    a = Tensor([5])

    self.assertIsNot(a, a[()])
    self.assertIsNot(a, a[...])
    self.assertIsNot(a, a[:])

  def test_broaderrors_indexing(self):
    a = Tensor.zeros(5, 5)
    self.assertRaises(IndexError, a.__getitem__, ([0, 1], [0, 1, 2]))
    self.assertRaises(IndexError, a.contiguous().__setitem__, ([0, 1], [0, 1, 2]), 0)

  # TODO out of bound getitem does not raise error
  '''
  def test_trivial_fancy_out_of_bounds(self):
    a = Tensor.zeros(5)
    ind = Tensor.ones(20, dtype=dtypes.int64)
    ind[-1] = 10
    self.assertRaises(IndexError, a.__getitem__, ind)
    self.assertRaises(IndexError, a.__setitem__, ind, 0)
    ind = Tensor.ones(20, dtype=dtypes.int64)
    ind[0] = 11
    self.assertRaises(IndexError, a.__getitem__, ind)
    self.assertRaises(IndexError, a.__setitem__, ind, 0)
  '''

  # TODO fancy setitem
  '''
  def test_index_is_larger(self):
    # Simple case of fancy index broadcasting of the index.
    a = Tensor.zeros((5, 5))
    a[[[0], [1], [2]], [0, 1, 2]] = Tensor([2., 3., 4.])

    self.assertTrue((a[:3, :3] == all_(Tensor([2., 3., 4.]))))
  '''

  # TODO fancy setitem
  '''
  def test_broadcast_subspace(self):
    a = Tensor.zeros((100, 100))
    v = Tensor.arange(0., 100)[:, None]
    b = Tensor.arange(99, -1, -1).cast(dtypes.int64)
    a[b] = v
    expected = b.float().unsqueeze(1).expand(100, 100)
    numpy_testing_assert_equal_helper(a, expected)
  '''

  # TODO fancy setitem
  '''
  def test_truncate_leading_1s(self):
    col_max = Tensor.randn(1, 4)
    kernel = col_max.T * col_max  # [4, 4] tensor
    kernel2 = clone(kernel)
    # Set the diagonal
    # len(torch.tensor) is just tensor.shape[0]
    kernel[range(kernel.shape[0]), range(kernel.shape[0])] = col_max.square()
    kernel2 = diagonal(kernel2)
    # torch.diagonal(kernel2).copy_(torch.square(col_max.view(4)))
    kernel2 = copy_(kernel2, col_max.reshape(4).square())
    numpy_testing_assert_equal_helper(kernel, kernel2)
  '''

if __name__ == '__main__':
  unittest.main()