fix 64 sized warps

This commit is contained in:
mcrcortex
2025-06-22 16:52:08 +10:00
parent 0dd730d8de
commit 8b5e2780c7

View File

@@ -12,50 +12,10 @@ layout(binding = IO_BUFFER, std430) buffer InputBuffer {
uvec4[] ioCount; uvec4[] ioCount;
}; };
shared uint warpPrefixSum[32];//Warps are 32, tricks require full warp shared uint warpPrefixSum[8];//Warps are 32, tricks require full warp
void main() { void main() {
/* if (gl_SubgroupSize == 32) {
uint subgroupId = gl_LocalInvocationID.x>>5;
warpPrefixSum[gl_SubgroupInvocationID] = 0;
memoryBarrierShared();
//todo
//assert(gl_SubgroupSize == 32);
//assert(gl_NumSubgroups == (WORK_SIZE>>5));
uint gid = gl_GlobalInvocationID.x;
uvec4 count = uvec4(0);
uint sum = 0;
{
uvec4 dat = ioCount[gid];
count.yzw = dat.xyz;
count.z += count.y;
count.w += count.z;
sum = count.w + dat.w;
}
barrier();
count += subgroupExclusiveAdd(sum);
if (gl_SubgroupInvocationID==31) {
warpPrefixSum[subgroupId] = count.x+sum;
}
memoryBarrierShared();
barrier();
uint val = warpPrefixSum[gl_SubgroupInvocationID];
barrier();
if (subgroupId == 0) {
//Use warp to do entire add in 1 reduction
warpPrefixSum[gl_SubgroupInvocationID] = subgroupExclusiveAdd(val);
}
memoryBarrierShared();
barrier();
count += warpPrefixSum[subgroupId];
ioCount[gid] = count;
*/
#ifdef IS_INTEL #ifdef IS_INTEL
uint subgroupId = gl_LocalInvocationID.x>>5; uint subgroupId = gl_LocalInvocationID.x>>5;
#else #else
@@ -87,7 +47,52 @@ void main() {
memoryBarrierShared(); memoryBarrierShared();
barrier(); barrier();
if (subgroupId == 0) { if (gl_LocalInvocationID.x<8) {
uint val = warpPrefixSum[gl_SubgroupInvocationID];
subgroupBarrier();
//Use warp to do entire add in 1 reduction
warpPrefixSum[gl_SubgroupInvocationID] = subgroupExclusiveAdd(val);
}
memoryBarrierShared();
barrier();
//Add the computed sum across all threads and warps
count += warpPrefixSum[subgroupId];
ioCount[gid] = count;
} else {
#ifdef IS_INTEL
uint subgroupId = gl_LocalInvocationID.x>>6;
#else
uint subgroupId = gl_SubgroupID;
#endif
//todo
//assert(gl_SubgroupSize == 32);
//assert(gl_NumSubgroups == (WORK_SIZE>>5));
uint gid = gl_GlobalInvocationID.x;
uvec4 count = uvec4(0);
uint sum = 0;
{
uvec4 dat = ioCount[gid];
count.yzw = dat.xyz;
count.z += count.y;
count.w += count.z;
sum = count.w + dat.w;
}
subgroupBarrier();//Wait for all threads in the subgroup to get the buffer
count += subgroupExclusiveAdd(sum);
if (gl_SubgroupInvocationID==63) {
warpPrefixSum[subgroupId] = count.x+sum;
}
memoryBarrierShared();
barrier();
if (gl_LocalInvocationID.x<4) {
uint val = warpPrefixSum[gl_SubgroupInvocationID]; uint val = warpPrefixSum[gl_SubgroupInvocationID];
subgroupBarrier(); subgroupBarrier();
//Use warp to do entire add in 1 reduction //Use warp to do entire add in 1 reduction
@@ -101,3 +106,4 @@ void main() {
count += warpPrefixSum[subgroupId]; count += warpPrefixSum[subgroupId];
ioCount[gid] = count; ioCount[gid] = count;
} }
}