You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
27 lines
856 B
27 lines
856 B
#!/usr/bin/env python3
|
|
import time
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
print(jax.devices())
|
|
DEVICES = len(jax.devices())
|
|
BS = 32
|
|
N = 4096
|
|
dtype = jnp.float16
|
|
A = jnp.zeros((DEVICES, BS, N, N), dtype)
|
|
B = jnp.zeros((1, 1, N, N), dtype)
|
|
A = jax.device_put_sharded([A[i] for i in range(DEVICES)], jax.devices())
|
|
B = jax.device_put_sharded([B for i in range(DEVICES)], jax.devices())
|
|
|
|
OPS = DEVICES*BS*N*N*N*2
|
|
def matmul(A,B): return jnp.matmul(A,B,preferred_element_type=jnp.float32)
|
|
pmatmul = jax.pmap(matmul)
|
|
|
|
MAX_TFLOPS = 123*DEVICES # Peak FP16 Tensor TFLOPS with FP32 Acc (7900XTX)
|
|
for i in range(10):
|
|
st = time.perf_counter()
|
|
C = pmatmul(A,B).block_until_ready()
|
|
et = time.perf_counter()-st
|
|
tflops = (OPS*1e-12)/et
|
|
print(f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}")
|
|
|
|
|