Initial commit
Co-authored-by: Zhuohan Li <zhuohan@openai.com> Co-authored-by: Maratyszcza <marat@openai.com> Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
This commit is contained in:
60
gpt_oss/triton/moe.py
Normal file
60
gpt_oss/triton/moe.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
from torch.profiler import record_function
|
||||
|
||||
import triton_kernels
|
||||
import triton_kernels.swiglu
|
||||
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
|
||||
from triton_kernels.matmul_ogs import PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
|
||||
from triton_kernels.matmul_ogs import matmul_ogs
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.routing import routing
|
||||
from triton_kernels.tensor import convert_layout
|
||||
from triton_kernels.tensor_details.layout import StridedLayout, HopperMXScaleLayout, HopperMXValueLayout
|
||||
from triton_kernels.tensor import wrap_torch_tensor, FP4
|
||||
|
||||
|
||||
def quantize_mx4(w):
|
||||
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
|
||||
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), HopperMXValueLayout, mx_axis=1)
|
||||
w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout)
|
||||
return w, w_scale
|
||||
|
||||
|
||||
def swiglu(x, alpha: float = 1.702, limit: float = 7.0, interleaved: bool = True):
|
||||
if interleaved:
|
||||
x_glu, x_linear = x[..., ::2], x[..., 1::2]
|
||||
else:
|
||||
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
|
||||
x_glu = x_glu.clamp(min=None, max=limit)
|
||||
x_linear = x_linear.clamp(min=-limit, max=limit)
|
||||
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
|
||||
return out_glu * (x_linear + 1)
|
||||
|
||||
|
||||
def moe(x, wg, w1, w1_mx, w2, w2_mx, bg, b1, b2, experts_per_token=4, num_experts=128, swiglu_limit=7.0, fused_act=True, interleaved=True):
|
||||
if x.numel() == 0:
|
||||
return x
|
||||
|
||||
pc1 = PrecisionConfig(weight_scale=w1_mx, flex_ctx=FlexCtx(rhs_data=InFlexData()))
|
||||
pc2 = PrecisionConfig(weight_scale=w2_mx, flex_ctx=FlexCtx(rhs_data=InFlexData()))
|
||||
pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=InFlexData()))
|
||||
|
||||
with record_function("wg"):
|
||||
logits = matmul_ogs(x, wg, bg, precision_config=pcg)
|
||||
with record_function("routing"):
|
||||
rdata, gather_indx, scatter_indx = routing(logits, experts_per_token, simulated_ep=1)
|
||||
|
||||
if fused_act:
|
||||
assert interleaved, "Fused activation requires interleaved weights"
|
||||
with record_function("w1+swiglu"):
|
||||
act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), (1.702, swiglu_limit), 2)
|
||||
x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act)
|
||||
else:
|
||||
with record_function("w1"):
|
||||
x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1)
|
||||
with record_function("swiglu"):
|
||||
x = swiglu(x, limit=swiglu_limit, interleaved=interleaved)
|
||||
|
||||
with record_function("w2"):
|
||||
x = matmul_ogs(x, w2, b2, rdata, scatter_indx=scatter_indx, precision_config=pc2, gammas=rdata.gate_scal)
|
||||
return x
|
||||
Reference in New Issue
Block a user