You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
17 lines
414 B
17 lines
414 B
from tinygrad import Tensor, TinyJit, Device
|
|
import numpy as np
|
|
|
|
GPUS = 4
|
|
N = 128
|
|
ds = tuple([Device.canonicalize(f"{Device.DEFAULT}:{i}") for i in range(GPUS)])
|
|
t = Tensor.rand(N, N, N).shard(ds, 0)
|
|
n = t.numpy()
|
|
|
|
@TinyJit
|
|
def allreduce(t:Tensor) -> Tensor:
|
|
return t.sum(0) #.realize()
|
|
|
|
for i in range(10):
|
|
print(i)
|
|
tn = allreduce(t).numpy()
|
|
np.testing.assert_allclose(tn, n.sum(0), atol=1e-4, rtol=1e-4)
|
|
|