#version 450 #extension GL_EXT_control_flow_attributes : require #if USE_SUBGROUP_ADD #extension GL_KHR_shader_subgroup_arithmetic : enable #endif #include "types.glsl" layout(constant_id = 0) const uint D_STATE = 128; layout(constant_id = 1) const uint SUBGROUP_SIZE = 32; layout(constant_id = 2) const uint SPLIT_H = 16; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout(binding = 0) readonly buffer Src0 { float s0[]; }; layout(binding = 1) readonly buffer Src1 { float x[]; }; layout(binding = 2) readonly buffer Src2 { float dt[]; }; layout(binding = 3) readonly buffer Src3 { float A[]; }; layout(binding = 4) readonly buffer Src4 { float B[]; }; layout(binding = 5) readonly buffer Src5 { float C[]; }; layout(binding = 6) readonly buffer Src6 { int ids[]; }; layout(binding = 7) buffer Dst { float d[]; }; layout(push_constant) uniform PushConstants { uint nb02; uint nb03; uint nb12; uint nb13; uint nb21; uint nb22; uint nb31; uint nb42; uint nb43; uint nb52; uint nb53; uint s_off; uint n_head; uint d_head; uint n_group; uint n_tok; }; float softplus(float x) { if (x <= 20.0) { return log(1.0 + exp(x)); } else { return x; } } shared float stateC[SPLIT_H * D_STATE]; void main() { const uint tid = gl_LocalInvocationID.x; const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head; const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4; const uint seq_idx = gl_WorkGroupID.y; const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4; const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4; const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4; const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4; const uint A_base_idx = (head_idx * nb31) / 4; const uint B_base_idx = (seq_idx * nb43 + group_off) / 4; const uint C_base_idx = (seq_idx * nb53 + group_off) / 4; const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H; const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4; const uint stride_x = nb12 / 4; const uint stride_dt = nb21 / 4; const uint stride_B = nb42 / 4; const uint stride_C = nb52 / 4; const uint stride_y = n_head * d_head; float state[SPLIT_H]; [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { state[j] = s0[s0_base_idx + j * D_STATE + tid]; } for (uint i = 0; i < n_tok; i++) { const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]); const float dA = exp(dt_soft_plus * A[A_base_idx]); const float B_val = B[B_base_idx + i * stride_B + tid]; const float C_val = C[C_base_idx + i * stride_C + tid]; [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus; state[j] = (state[j] * dA) + (B_val * x_dt); stateC[j * D_STATE + tid] = state[j] * C_val; } barrier(); [[unroll]] for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) { [[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) { const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w); if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) { stateC[k] += stateC[k + w]; } } barrier(); } [[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) { const uint idx = (tid % SUBGROUP_SIZE) + D_STATE * (tid / SUBGROUP_SIZE) + j * D_STATE * (D_STATE / SUBGROUP_SIZE); const uint max_idx = SUBGROUP_SIZE - 1 + D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) + j * D_STATE * (D_STATE / SUBGROUP_SIZE); if (idx < SPLIT_H * D_STATE || max_idx < SPLIT_H * D_STATE) { float sc; #if USE_SUBGROUP_ADD sc = stateC[idx]; sc = subgroupAdd(sc); #else [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) { if (idx + offset < SPLIT_H * D_STATE) { stateC[idx] += stateC[idx + offset]; } barrier(); } if (tid % SUBGROUP_SIZE == 0) { sc = stateC[idx]; } #endif if (tid % SUBGROUP_SIZE == 0) { const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE); d[y_base_idx + i * stride_y + k] = sc; } } } barrier(); } [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { d[s_base_idx + j * D_STATE + tid] = state[j]; } }