fix 64 sized warps

This commit is contained in:
mcrcortex
2025-06-22 16:52:08 +10:00
parent 0dd730d8de
commit 8b5e2780c7

View File

@@ -12,92 +12,98 @@ layout(binding = IO_BUFFER, std430) buffer InputBuffer {
uvec4[] ioCount; uvec4[] ioCount;
}; };
shared uint warpPrefixSum[32];//Warps are 32, tricks require full warp shared uint warpPrefixSum[8];//Warps are 32, tricks require full warp
void main() { void main() {
/* if (gl_SubgroupSize == 32) {
uint subgroupId = gl_LocalInvocationID.x>>5; #ifdef IS_INTEL
warpPrefixSum[gl_SubgroupInvocationID] = 0; uint subgroupId = gl_LocalInvocationID.x>>5;
memoryBarrierShared(); #else
uint subgroupId = gl_SubgroupID;
#endif
//todo //todo
//assert(gl_SubgroupSize == 32); //assert(gl_SubgroupSize == 32);
//assert(gl_NumSubgroups == (WORK_SIZE>>5)); //assert(gl_NumSubgroups == (WORK_SIZE>>5));
uint gid = gl_GlobalInvocationID.x; uint gid = gl_GlobalInvocationID.x;
uvec4 count = uvec4(0); uvec4 count = uvec4(0);
uint sum = 0; uint sum = 0;
{ {
uvec4 dat = ioCount[gid]; uvec4 dat = ioCount[gid];
count.yzw = dat.xyz; count.yzw = dat.xyz;
count.z += count.y; count.z += count.y;
count.w += count.z; count.w += count.z;
sum = count.w + dat.w; sum = count.w + dat.w;
}
subgroupBarrier();//Wait for all threads in the subgroup to get the buffer
count += subgroupExclusiveAdd(sum);
if (gl_SubgroupInvocationID==31) {
warpPrefixSum[subgroupId] = count.x+sum;
}
memoryBarrierShared();
barrier();
if (gl_LocalInvocationID.x<8) {
uint val = warpPrefixSum[gl_SubgroupInvocationID];
subgroupBarrier();
//Use warp to do entire add in 1 reduction
warpPrefixSum[gl_SubgroupInvocationID] = subgroupExclusiveAdd(val);
}
memoryBarrierShared();
barrier();
//Add the computed sum across all threads and warps
count += warpPrefixSum[subgroupId];
ioCount[gid] = count;
} else {
#ifdef IS_INTEL
uint subgroupId = gl_LocalInvocationID.x>>6;
#else
uint subgroupId = gl_SubgroupID;
#endif
//todo
//assert(gl_SubgroupSize == 32);
//assert(gl_NumSubgroups == (WORK_SIZE>>5));
uint gid = gl_GlobalInvocationID.x;
uvec4 count = uvec4(0);
uint sum = 0;
{
uvec4 dat = ioCount[gid];
count.yzw = dat.xyz;
count.z += count.y;
count.w += count.z;
sum = count.w + dat.w;
}
subgroupBarrier();//Wait for all threads in the subgroup to get the buffer
count += subgroupExclusiveAdd(sum);
if (gl_SubgroupInvocationID==63) {
warpPrefixSum[subgroupId] = count.x+sum;
}
memoryBarrierShared();
barrier();
if (gl_LocalInvocationID.x<4) {
uint val = warpPrefixSum[gl_SubgroupInvocationID];
subgroupBarrier();
//Use warp to do entire add in 1 reduction
warpPrefixSum[gl_SubgroupInvocationID] = subgroupExclusiveAdd(val);
}
memoryBarrierShared();
barrier();
//Add the computed sum across all threads and warps
count += warpPrefixSum[subgroupId];
ioCount[gid] = count;
} }
barrier();
count += subgroupExclusiveAdd(sum);
if (gl_SubgroupInvocationID==31) {
warpPrefixSum[subgroupId] = count.x+sum;
}
memoryBarrierShared();
barrier();
uint val = warpPrefixSum[gl_SubgroupInvocationID];
barrier();
if (subgroupId == 0) {
//Use warp to do entire add in 1 reduction
warpPrefixSum[gl_SubgroupInvocationID] = subgroupExclusiveAdd(val);
}
memoryBarrierShared();
barrier();
count += warpPrefixSum[subgroupId];
ioCount[gid] = count;
*/
#ifdef IS_INTEL
uint subgroupId = gl_LocalInvocationID.x>>5;
#else
uint subgroupId = gl_SubgroupID;
#endif
//todo
//assert(gl_SubgroupSize == 32);
//assert(gl_NumSubgroups == (WORK_SIZE>>5));
uint gid = gl_GlobalInvocationID.x;
uvec4 count = uvec4(0);
uint sum = 0;
{
uvec4 dat = ioCount[gid];
count.yzw = dat.xyz;
count.z += count.y;
count.w += count.z;
sum = count.w + dat.w;
}
subgroupBarrier();//Wait for all threads in the subgroup to get the buffer
count += subgroupExclusiveAdd(sum);
if (gl_SubgroupInvocationID==31) {
warpPrefixSum[subgroupId] = count.x+sum;
}
memoryBarrierShared();
barrier();
if (subgroupId == 0) {
uint val = warpPrefixSum[gl_SubgroupInvocationID];
subgroupBarrier();
//Use warp to do entire add in 1 reduction
warpPrefixSum[gl_SubgroupInvocationID] = subgroupExclusiveAdd(val);
}
memoryBarrierShared();
barrier();
//Add the computed sum across all threads and warps
count += warpPrefixSum[subgroupId];
ioCount[gid] = count;
} }