You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							79 lines
						
					
					
						
							3.9 KiB
						
					
					
				
			
		
		
	
	
							79 lines
						
					
					
						
							3.9 KiB
						
					
					
				import pathlib
 | 
						|
from tinygrad.device import Device
 | 
						|
from tinygrad.runtime.ops_amd import AMDProgram, HIPCompiler
 | 
						|
import time
 | 
						|
import os
 | 
						|
 | 
						|
NUM_WORKGROUPS = 96
 | 
						|
WAVE_SIZE = 32
 | 
						|
NUM_WAVES = 2
 | 
						|
FLOPS_PER_MATMUL =  16*16*16*2
 | 
						|
INTERNAL_LOOP  = 1_000_000
 | 
						|
INSTRUCTIONS_PER_LOOP = 1_000
 | 
						|
 | 
						|
assemblyTemplate = (pathlib.Path(__file__).parent / "template.s").read_text()
 | 
						|
 | 
						|
def launchBenchmark(instruction, vgprIndices, dense = True):
 | 
						|
  if dense:
 | 
						|
    instructions = "{} v[0:{}], v[{}:{}], v[{}:{}], 1\n".format(instruction, vgprIndices[0],
 | 
						|
                                                                vgprIndices[1], vgprIndices[2],
 | 
						|
                                                                vgprIndices[1], vgprIndices[2]) * INSTRUCTIONS_PER_LOOP
 | 
						|
  else:
 | 
						|
    instructions = "{} v[0:{}], v[{}:{}], v[{}:{}], v{}\n".format(instruction, vgprIndices[0],
 | 
						|
                                                                vgprIndices[1], vgprIndices[2],
 | 
						|
                                                                vgprIndices[3], vgprIndices[4],
 | 
						|
                                                                vgprIndices[5]) * INSTRUCTIONS_PER_LOOP
 | 
						|
  src = assemblyTemplate.replace("INSTRUCTION", instructions)
 | 
						|
  lib = COMPILER.compile(src)
 | 
						|
  fxn = AMDProgram(DEV, "matmul", lib)
 | 
						|
  start = time.perf_counter()
 | 
						|
  fxn(global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True) #For some reason the returned time is very small after the first kernel execution
 | 
						|
  end = time.perf_counter()
 | 
						|
  elapsed = end-start
 | 
						|
  FLOPs = FLOPS_PER_MATMUL * NUM_WAVES * NUM_WORKGROUPS * INTERNAL_LOOP * INSTRUCTIONS_PER_LOOP
 | 
						|
  print("{:<29} : {} T(FL)OPS".format(instruction, round(FLOPs/elapsed/10**12, 2)))
 | 
						|
 | 
						|
if __name__=="__main__":
 | 
						|
  DEVICENUM = os.getenv("DEVICENUM", "0")
 | 
						|
  try:
 | 
						|
    DEV = Device['AMD:' + DEVICENUM]
 | 
						|
  except:
 | 
						|
    raise RuntimeError("Error while initiating AMD device")
 | 
						|
 | 
						|
  if (ARCH := DEV.arch) not in ['gfx1100', 'gfx1201']:
 | 
						|
    raise RuntimeError("only gfx1100 and gfx1201 supported")
 | 
						|
  COMPILER = HIPCompiler(ARCH)
 | 
						|
 | 
						|
  if ARCH == 'gfx1100':
 | 
						|
    launchBenchmark("v_wmma_bf16_16x16x16_bf16", (7,8,15))
 | 
						|
    launchBenchmark("v_wmma_f16_16x16x16_f16", (7,8,15))
 | 
						|
    launchBenchmark("v_wmma_f32_16x16x16_bf16", (7,8,15))
 | 
						|
    launchBenchmark("v_wmma_f32_16x16x16_f16", (7,8,15))
 | 
						|
    launchBenchmark("v_wmma_i32_16x16x16_iu4", (7,8,9))
 | 
						|
    launchBenchmark("v_wmma_i32_16x16x16_iu8", (7,8,11))
 | 
						|
  if ARCH == 'gfx1201':
 | 
						|
    NUM_WORKGROUPS = 64
 | 
						|
    launchBenchmark("v_wmma_bf16_16x16x16_bf16", (3,4,7))
 | 
						|
    launchBenchmark("v_wmma_f16_16x16x16_f16", (3,4,7))
 | 
						|
    launchBenchmark("v_wmma_f32_16x16x16_bf16", (7,8,11))
 | 
						|
    launchBenchmark("v_wmma_f32_16x16x16_f16", (7,8,11))
 | 
						|
    launchBenchmark("v_wmma_i32_16x16x16_iu4", (7,8,8))
 | 
						|
    launchBenchmark("v_wmma_i32_16x16x16_iu8", (7,8,9))
 | 
						|
    launchBenchmark("v_wmma_f32_16x16x16_fp8_fp8", (7,8,9))
 | 
						|
    launchBenchmark("v_wmma_f32_16x16x16_fp8_bf8", (7,8,9))
 | 
						|
    launchBenchmark("v_wmma_f32_16x16x16_bf8_fp8", (7,8,9))
 | 
						|
    launchBenchmark("v_wmma_f32_16x16x16_bf8_bf8", (7,8,9))
 | 
						|
    FLOPS_PER_MATMUL = 16*16*32*2
 | 
						|
    launchBenchmark("v_wmma_i32_16X16X32_iu4", (7,8,9))
 | 
						|
    launchBenchmark("v_swmmac_f32_16x16x32_f16", (7,8,11,12,19,20), False)
 | 
						|
    launchBenchmark("v_swmmac_f32_16x16x32_bf16", (7,8,11,12,19,20), False)
 | 
						|
    launchBenchmark("v_swmmac_f16_16x16x32_f16", (3,4,7,8,15,16), False)
 | 
						|
    launchBenchmark("v_swmmac_bf16_16x16x32_bf16", (3,4,7,8,15,16), False)
 | 
						|
    launchBenchmark("v_swmmac_i32_16x16x32_iu8", (7,8,9,10,13,14), False)
 | 
						|
    launchBenchmark("v_swmmac_i32_16x16x32_iu4", (7,8,8,9,10,11), False)
 | 
						|
    launchBenchmark("v_swmmac_f32_16x16x32_fp8_fp8", (7,8,9,10,13,14), False)
 | 
						|
    launchBenchmark("v_swmmac_f32_16x16x32_fp8_bf8", (7,8,9,10,13,14), False)
 | 
						|
    launchBenchmark("v_swmmac_f32_16x16x32_bf8_fp8", (7,8,9,10,13,14), False)
 | 
						|
    launchBenchmark("v_swmmac_f32_16x16x32_bf8_bf8", (7,8,9,10,13,14), False)
 | 
						|
    FLOPS_PER_MATMUL = 16*16*64*2
 | 
						|
    launchBenchmark("v_swmmac_i32_16x16x64_iu4", (7,8,9,10,13,14), False) |