Initial commit
Co-authored-by: Zhuohan Li <zhuohan@openai.com> Co-authored-by: Maratyszcza <marat@openai.com> Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
This commit is contained in:
24
gpt_oss/metal/source/embeddings.metal
Normal file
24
gpt_oss/metal/source/embeddings.metal
Normal file
@@ -0,0 +1,24 @@
|
||||
#include <internal/kernel-args.h>
|
||||
|
||||
#pragma METAL fp math_mode(safe)
|
||||
#pragma METAL fp contract(off)
|
||||
|
||||
|
||||
kernel void gptoss_bf16_f32_embeddings(
|
||||
constant gptoss_embeddings_args& args [[ buffer(0) ]],
|
||||
const device uint* tokens [[ buffer(1) ]],
|
||||
const device bfloat4* weights [[ buffer(2) ]],
|
||||
device float4* output [[ buffer(3) ]],
|
||||
uint gid [[threadgroup_position_in_grid]],
|
||||
uint tid [[thread_position_in_threadgroup]],
|
||||
uint threadgroup_size [[ threads_per_threadgroup ]])
|
||||
{
|
||||
const uint t = tokens[gid];
|
||||
|
||||
weights += t * args.num_vecs;
|
||||
output += gid * args.num_vecs;
|
||||
for (uint i = tid; i < args.num_vecs; i += threadgroup_size) {
|
||||
const bfloat4 w = weights[i];
|
||||
output[i] = static_cast<float4>(w);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user