You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
1.8 KiB
75 lines
1.8 KiB
// https://github.com/moskewcz/boda/issues/13
|
|
|
|
//#define USE_FP16
|
|
|
|
#ifdef USE_FP16
|
|
#define xtype half4
|
|
#define read_imagep read_imageh
|
|
#define write_imagep write_imageh
|
|
#else
|
|
#define xtype float4
|
|
#define read_imagep read_imagef
|
|
#define write_imagep write_imagef
|
|
#endif
|
|
|
|
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
|
__kernel void gemm(const int M, const int N, const int K,
|
|
read_only image2d_t A,
|
|
read_only image2d_t B,
|
|
write_only image2d_t C)
|
|
{
|
|
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE |
|
|
CLK_ADDRESS_CLAMP |
|
|
CLK_FILTER_NEAREST;
|
|
|
|
xtype c_r[4] = {0,0,0,0};
|
|
xtype a_r[4], b_r[4];
|
|
|
|
int const a_off_thr = get_global_id(0);
|
|
int const b_off_thr = get_global_id(1);
|
|
|
|
int2 a_samp = {0, a_off_thr};
|
|
int2 b_samp = {0, b_off_thr};
|
|
|
|
for (short k = 0; k < K/4; k++) {
|
|
for (short i = 0; i < 4; ++i) {
|
|
a_r[i] = read_imagep(A, smp, a_samp);
|
|
b_r[i] = read_imagep(B, smp, b_samp);
|
|
++a_samp.x;
|
|
++b_samp.x;
|
|
}
|
|
|
|
for (short i = 0; i < 4; ++i) {
|
|
float4 ov = c_r[i];
|
|
|
|
ov.x += a_r[i].x * b_r[0].x;
|
|
ov.x += a_r[i].y * b_r[0].y;
|
|
ov.x += a_r[i].z * b_r[0].z;
|
|
ov.x += a_r[i].w * b_r[0].w;
|
|
|
|
ov.y += a_r[i].x * b_r[1].x;
|
|
ov.y += a_r[i].y * b_r[1].y;
|
|
ov.y += a_r[i].z * b_r[1].z;
|
|
ov.y += a_r[i].w * b_r[1].w;
|
|
|
|
ov.z += a_r[i].x * b_r[2].x;
|
|
ov.z += a_r[i].y * b_r[2].y;
|
|
ov.z += a_r[i].z * b_r[2].z;
|
|
ov.z += a_r[i].w * b_r[2].w;
|
|
|
|
ov.w += a_r[i].x * b_r[3].x;
|
|
ov.w += a_r[i].y * b_r[3].y;
|
|
ov.w += a_r[i].z * b_r[3].z;
|
|
ov.w += a_r[i].w * b_r[3].w;
|
|
|
|
c_r[i] = ov;
|
|
}
|
|
}
|
|
|
|
int2 c_samp = {a_off_thr, b_off_thr*4};
|
|
for (short i = 0; i < 4; i++) {
|
|
write_imagep(C, c_samp, c_r[i]);
|
|
++c_samp.y;
|
|
}
|
|
}
|
|
|
|
|