You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
28 lines
856 B
28 lines
856 B
1 month ago
|
#!/usr/bin/env python3
|
||
|
import time
|
||
|
import jax
|
||
|
import jax.numpy as jnp
|
||
|
|
||
|
print(jax.devices())
|
||
|
DEVICES = len(jax.devices())
|
||
|
BS = 32
|
||
|
N = 4096
|
||
|
dtype = jnp.float16
|
||
|
A = jnp.zeros((DEVICES, BS, N, N), dtype)
|
||
|
B = jnp.zeros((1, 1, N, N), dtype)
|
||
|
A = jax.device_put_sharded([A[i] for i in range(DEVICES)], jax.devices())
|
||
|
B = jax.device_put_sharded([B for i in range(DEVICES)], jax.devices())
|
||
|
|
||
|
OPS = DEVICES*BS*N*N*N*2
|
||
|
def matmul(A,B): return jnp.matmul(A,B,preferred_element_type=jnp.float32)
|
||
|
pmatmul = jax.pmap(matmul)
|
||
|
|
||
|
MAX_TFLOPS = 123*DEVICES # Peak FP16 Tensor TFLOPS with FP32 Acc (7900XTX)
|
||
|
for i in range(10):
|
||
|
st = time.perf_counter()
|
||
|
C = pmatmul(A,B).block_until_ready()
|
||
|
et = time.perf_counter()-st
|
||
|
tflops = (OPS*1e-12)/et
|
||
|
print(f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}")
|
||
|
|