You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							75 lines
						
					
					
						
							1.8 KiB
						
					
					
				
			
		
		
	
	
							75 lines
						
					
					
						
							1.8 KiB
						
					
					
				// https://github.com/moskewcz/boda/issues/13
 | 
						|
 | 
						|
//#define USE_FP16
 | 
						|
 | 
						|
#ifdef USE_FP16
 | 
						|
  #define xtype half4
 | 
						|
  #define read_imagep read_imageh
 | 
						|
  #define write_imagep write_imageh
 | 
						|
#else
 | 
						|
  #define xtype float4
 | 
						|
  #define read_imagep read_imagef
 | 
						|
  #define write_imagep write_imagef
 | 
						|
#endif
 | 
						|
 | 
						|
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
 | 
						|
__kernel void gemm(const int M, const int N, const int K,
 | 
						|
  read_only image2d_t A,
 | 
						|
  read_only image2d_t B,
 | 
						|
  write_only image2d_t C)
 | 
						|
{
 | 
						|
  const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE |
 | 
						|
                        CLK_ADDRESS_CLAMP           |
 | 
						|
                        CLK_FILTER_NEAREST;
 | 
						|
 | 
						|
  xtype c_r[4] = {0,0,0,0};
 | 
						|
  xtype a_r[4], b_r[4];
 | 
						|
 | 
						|
  int const a_off_thr = get_global_id(0);
 | 
						|
  int const b_off_thr = get_global_id(1);
 | 
						|
 | 
						|
  int2 a_samp = {0, a_off_thr};
 | 
						|
  int2 b_samp = {0, b_off_thr};
 | 
						|
 | 
						|
  for (short k = 0; k < K/4; k++) {
 | 
						|
    for (short i = 0; i < 4; ++i) {
 | 
						|
      a_r[i] = read_imagep(A, smp, a_samp);
 | 
						|
      b_r[i] = read_imagep(B, smp, b_samp);
 | 
						|
      ++a_samp.x;
 | 
						|
      ++b_samp.x;
 | 
						|
    }
 | 
						|
 | 
						|
    for (short i = 0; i < 4; ++i) {
 | 
						|
      float4 ov = c_r[i];
 | 
						|
 | 
						|
      ov.x += a_r[i].x * b_r[0].x;
 | 
						|
      ov.x += a_r[i].y * b_r[0].y;
 | 
						|
      ov.x += a_r[i].z * b_r[0].z;
 | 
						|
      ov.x += a_r[i].w * b_r[0].w;
 | 
						|
 | 
						|
      ov.y += a_r[i].x * b_r[1].x;
 | 
						|
      ov.y += a_r[i].y * b_r[1].y;
 | 
						|
      ov.y += a_r[i].z * b_r[1].z;
 | 
						|
      ov.y += a_r[i].w * b_r[1].w;
 | 
						|
 | 
						|
      ov.z += a_r[i].x * b_r[2].x;
 | 
						|
      ov.z += a_r[i].y * b_r[2].y;
 | 
						|
      ov.z += a_r[i].z * b_r[2].z;
 | 
						|
      ov.z += a_r[i].w * b_r[2].w;
 | 
						|
 | 
						|
      ov.w += a_r[i].x * b_r[3].x;
 | 
						|
      ov.w += a_r[i].y * b_r[3].y;
 | 
						|
      ov.w += a_r[i].z * b_r[3].z;
 | 
						|
      ov.w += a_r[i].w * b_r[3].w;
 | 
						|
 | 
						|
      c_r[i] = ov;
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  int2 c_samp = {a_off_thr, b_off_thr*4};
 | 
						|
  for (short i = 0; i < 4; i++) {
 | 
						|
    write_imagep(C, c_samp, c_r[i]);
 | 
						|
    ++c_samp.y;
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
 |