Files
openharmony-mlx/gpt_oss/metal/test/f32-bf16w-rmsnorm.cc
Dominik Kundel 243a1b0276 Initial commit
Co-authored-by: Zhuohan Li <zhuohan@openai.com>
Co-authored-by: Maratyszcza <marat@openai.com>
Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
2025-08-05 08:19:49 -07:00

37 lines
881 B
C++

#include <gtest/gtest.h>
#include <cstdint>
#include "rmsnorm-kernel-tester.hpp"
using gptoss::RMSNormKernelTester;
constexpr std::uint32_t kThreadgroupSize = 1024; // fixed in the kernel
constexpr std::uint32_t kVectorSize = 4; // fixed in the kernel
TEST(F32_BF16W_RMSNORM, single_iteration) {
RMSNormKernelTester()
.num_channels(kThreadgroupSize)
.TestF32_BF16W();
}
TEST(F32_BF16W_RMSNORM, multiple_iterations) {
RMSNormKernelTester()
.num_channels(kThreadgroupSize * 2)
.TestF32_BF16W();
}
TEST(F32_BF16W_RMSNORM, partial_iteration) {
RMSNormKernelTester()
.num_channels(kThreadgroupSize * 2 + kVectorSize)
.TestF32_BF16W();
}
TEST(F32_BF16W_RMSNORM, multiple_tokens) {
RMSNormKernelTester()
.num_tokens(3)
.num_channels(kThreadgroupSize * 2 + kVectorSize)
.TestF32_BF16W();
}