Initial commit

Co-authored-by: Zhuohan Li <zhuohan@openai.com>
Co-authored-by: Maratyszcza <marat@openai.com>
Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
This commit is contained in:
Dominik Kundel
2025-08-05 08:19:49 -07:00
commit 243a1b0276
124 changed files with 20405 additions and 0 deletions

View File

@@ -0,0 +1,64 @@
#include <metal_integer>
#include <internal/kernel-args.h>
#pragma METAL fp math_mode(safe)
#pragma METAL fp contract(off)
kernel void gptoss_mf4_f32_convert(
constant gptoss_convert_args& args [[ buffer(0) ]],
const device uint4* blocks [[ buffer(1) ]],
const device uchar* scales [[ buffer(2) ]],
device float4* output [[ buffer(3) ]],
uint gid [[threadgroup_position_in_grid]],
uint tid [[thread_position_in_threadgroup]],
uint threadgroup_size [[ threads_per_threadgroup ]])
{
const ulong num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
const ulong threadgroup_start = gid * num_vecs_per_threadgroup;
const ulong threadgroup_end = metal::min(threadgroup_start + num_vecs_per_threadgroup, args.num_vecs);
const ulong thread_start = threadgroup_start + tid;
uint num_iter = static_cast<uint>((threadgroup_end - thread_start + (threadgroup_size - 1)) / threadgroup_size);
blocks += thread_start;
scales += thread_start;
output += 8 * thread_start;
for (; num_iter != 0; num_iter--) {
const uint4 block = *blocks;
const float scale = as_type<float>((static_cast<uint>(*scales) + 14) << 23);
uint4 block02468ACEGIKMOQSU = block + block;
uint4 block13579BDFHJLNPRTV = block >> 3;
block02468ACEGIKMOQSU &= 0x1E1E1E1Eu;
block13579BDFHJLNPRTV &= 0x1E1E1E1Eu;
block02468ACEGIKMOQSU += 0x70707070u;
block13579BDFHJLNPRTV += 0x70707070u;
block02468ACEGIKMOQSU &= 0x8E8E8E8Eu;
block13579BDFHJLNPRTV &= 0x8E8E8E8Eu;
const uint4 block26AEIMQU = block02468ACEGIKMOQSU & 0xFF00FF00u;
const uint4 block048CGKOS = (block02468ACEGIKMOQSU << 8) & 0xFF00FF00u;
const uint4 block37BFJNRV = block13579BDFHJLNPRTV & 0xFF00FF00u;
const uint4 block159DHLPT = (block13579BDFHJLNPRTV << 8) & 0xFF00FF00u;
const float4 block048C = static_cast<float4>(as_type<half4>(block048CGKOS.xy)) * scale;
const float4 blockGKOS = static_cast<float4>(as_type<half4>(block048CGKOS.zw)) * scale;
const float4 block26AE = static_cast<float4>(as_type<half4>(block26AEIMQU.xy)) * scale;
const float4 blockIMQU = static_cast<float4>(as_type<half4>(block26AEIMQU.zw)) * scale;
const float4 block159D = static_cast<float4>(as_type<half4>(block159DHLPT.xy)) * scale;
const float4 blockHLPT = static_cast<float4>(as_type<half4>(block159DHLPT.zw)) * scale;
const float4 block37BF = static_cast<float4>(as_type<half4>(block37BFJNRV.xy)) * scale;
const float4 blockJNRV = static_cast<float4>(as_type<half4>(block37BFJNRV.zw)) * scale;
output[0] = (float4) { block048C.x, block159D.x, block26AE.x, block37BF.x };
output[1] = (float4) { block048C.y, block159D.y, block26AE.y, block37BF.y };
output[2] = (float4) { block048C.z, block159D.z, block26AE.z, block37BF.z };
output[3] = (float4) { block048C.w, block159D.w, block26AE.w, block37BF.w };
output[4] = (float4) { blockGKOS.x, blockHLPT.x, blockIMQU.x, blockJNRV.x };
output[5] = (float4) { blockGKOS.y, blockHLPT.y, blockIMQU.y, blockJNRV.y };
output[6] = (float4) { blockGKOS.z, blockHLPT.z, blockIMQU.z, blockJNRV.z };
output[7] = (float4) { blockGKOS.w, blockHLPT.w, blockIMQU.w, blockJNRV.w };
blocks += threadgroup_size;
scales += threadgroup_size;
output += 8 * threadgroup_size;
}
}