fix 64 sized warps
This commit is contained in:
@@ -12,50 +12,10 @@ layout(binding = IO_BUFFER, std430) buffer InputBuffer {
|
||||
uvec4[] ioCount;
|
||||
};
|
||||
|
||||
shared uint warpPrefixSum[32];//Warps are 32, tricks require full warp
|
||||
shared uint warpPrefixSum[8];//Warps are 32, tricks require full warp
|
||||
|
||||
void main() {
|
||||
/*
|
||||
uint subgroupId = gl_LocalInvocationID.x>>5;
|
||||
warpPrefixSum[gl_SubgroupInvocationID] = 0;
|
||||
memoryBarrierShared();
|
||||
|
||||
//todo
|
||||
//assert(gl_SubgroupSize == 32);
|
||||
//assert(gl_NumSubgroups == (WORK_SIZE>>5));
|
||||
|
||||
uint gid = gl_GlobalInvocationID.x;
|
||||
uvec4 count = uvec4(0);
|
||||
uint sum = 0;
|
||||
{
|
||||
uvec4 dat = ioCount[gid];
|
||||
count.yzw = dat.xyz;
|
||||
count.z += count.y;
|
||||
count.w += count.z;
|
||||
sum = count.w + dat.w;
|
||||
}
|
||||
|
||||
barrier();
|
||||
count += subgroupExclusiveAdd(sum);
|
||||
|
||||
if (gl_SubgroupInvocationID==31) {
|
||||
warpPrefixSum[subgroupId] = count.x+sum;
|
||||
}
|
||||
memoryBarrierShared();
|
||||
barrier();
|
||||
uint val = warpPrefixSum[gl_SubgroupInvocationID];
|
||||
barrier();
|
||||
if (subgroupId == 0) {
|
||||
//Use warp to do entire add in 1 reduction
|
||||
warpPrefixSum[gl_SubgroupInvocationID] = subgroupExclusiveAdd(val);
|
||||
}
|
||||
memoryBarrierShared();
|
||||
barrier();
|
||||
count += warpPrefixSum[subgroupId];
|
||||
ioCount[gid] = count;
|
||||
*/
|
||||
|
||||
|
||||
if (gl_SubgroupSize == 32) {
|
||||
#ifdef IS_INTEL
|
||||
uint subgroupId = gl_LocalInvocationID.x>>5;
|
||||
#else
|
||||
@@ -87,7 +47,52 @@ void main() {
|
||||
memoryBarrierShared();
|
||||
barrier();
|
||||
|
||||
if (subgroupId == 0) {
|
||||
if (gl_LocalInvocationID.x<8) {
|
||||
uint val = warpPrefixSum[gl_SubgroupInvocationID];
|
||||
subgroupBarrier();
|
||||
//Use warp to do entire add in 1 reduction
|
||||
warpPrefixSum[gl_SubgroupInvocationID] = subgroupExclusiveAdd(val);
|
||||
}
|
||||
|
||||
memoryBarrierShared();
|
||||
barrier();
|
||||
|
||||
//Add the computed sum across all threads and warps
|
||||
count += warpPrefixSum[subgroupId];
|
||||
ioCount[gid] = count;
|
||||
} else {
|
||||
#ifdef IS_INTEL
|
||||
uint subgroupId = gl_LocalInvocationID.x>>6;
|
||||
#else
|
||||
uint subgroupId = gl_SubgroupID;
|
||||
#endif
|
||||
|
||||
//todo
|
||||
//assert(gl_SubgroupSize == 32);
|
||||
//assert(gl_NumSubgroups == (WORK_SIZE>>5));
|
||||
|
||||
uint gid = gl_GlobalInvocationID.x;
|
||||
uvec4 count = uvec4(0);
|
||||
uint sum = 0;
|
||||
{
|
||||
uvec4 dat = ioCount[gid];
|
||||
count.yzw = dat.xyz;
|
||||
count.z += count.y;
|
||||
count.w += count.z;
|
||||
sum = count.w + dat.w;
|
||||
}
|
||||
subgroupBarrier();//Wait for all threads in the subgroup to get the buffer
|
||||
|
||||
count += subgroupExclusiveAdd(sum);
|
||||
|
||||
if (gl_SubgroupInvocationID==63) {
|
||||
warpPrefixSum[subgroupId] = count.x+sum;
|
||||
}
|
||||
|
||||
memoryBarrierShared();
|
||||
barrier();
|
||||
|
||||
if (gl_LocalInvocationID.x<4) {
|
||||
uint val = warpPrefixSum[gl_SubgroupInvocationID];
|
||||
subgroupBarrier();
|
||||
//Use warp to do entire add in 1 reduction
|
||||
@@ -101,3 +106,4 @@ void main() {
|
||||
count += warpPrefixSum[subgroupId];
|
||||
ioCount[gid] = count;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user