fix 64 sized warps
This commit is contained in:
@@ -12,50 +12,10 @@ layout(binding = IO_BUFFER, std430) buffer InputBuffer {
|
|||||||
uvec4[] ioCount;
|
uvec4[] ioCount;
|
||||||
};
|
};
|
||||||
|
|
||||||
shared uint warpPrefixSum[32];//Warps are 32, tricks require full warp
|
shared uint warpPrefixSum[8];//Warps are 32, tricks require full warp
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
/*
|
if (gl_SubgroupSize == 32) {
|
||||||
uint subgroupId = gl_LocalInvocationID.x>>5;
|
|
||||||
warpPrefixSum[gl_SubgroupInvocationID] = 0;
|
|
||||||
memoryBarrierShared();
|
|
||||||
|
|
||||||
//todo
|
|
||||||
//assert(gl_SubgroupSize == 32);
|
|
||||||
//assert(gl_NumSubgroups == (WORK_SIZE>>5));
|
|
||||||
|
|
||||||
uint gid = gl_GlobalInvocationID.x;
|
|
||||||
uvec4 count = uvec4(0);
|
|
||||||
uint sum = 0;
|
|
||||||
{
|
|
||||||
uvec4 dat = ioCount[gid];
|
|
||||||
count.yzw = dat.xyz;
|
|
||||||
count.z += count.y;
|
|
||||||
count.w += count.z;
|
|
||||||
sum = count.w + dat.w;
|
|
||||||
}
|
|
||||||
|
|
||||||
barrier();
|
|
||||||
count += subgroupExclusiveAdd(sum);
|
|
||||||
|
|
||||||
if (gl_SubgroupInvocationID==31) {
|
|
||||||
warpPrefixSum[subgroupId] = count.x+sum;
|
|
||||||
}
|
|
||||||
memoryBarrierShared();
|
|
||||||
barrier();
|
|
||||||
uint val = warpPrefixSum[gl_SubgroupInvocationID];
|
|
||||||
barrier();
|
|
||||||
if (subgroupId == 0) {
|
|
||||||
//Use warp to do entire add in 1 reduction
|
|
||||||
warpPrefixSum[gl_SubgroupInvocationID] = subgroupExclusiveAdd(val);
|
|
||||||
}
|
|
||||||
memoryBarrierShared();
|
|
||||||
barrier();
|
|
||||||
count += warpPrefixSum[subgroupId];
|
|
||||||
ioCount[gid] = count;
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
#ifdef IS_INTEL
|
#ifdef IS_INTEL
|
||||||
uint subgroupId = gl_LocalInvocationID.x>>5;
|
uint subgroupId = gl_LocalInvocationID.x>>5;
|
||||||
#else
|
#else
|
||||||
@@ -87,7 +47,7 @@ void main() {
|
|||||||
memoryBarrierShared();
|
memoryBarrierShared();
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
if (subgroupId == 0) {
|
if (gl_LocalInvocationID.x<8) {
|
||||||
uint val = warpPrefixSum[gl_SubgroupInvocationID];
|
uint val = warpPrefixSum[gl_SubgroupInvocationID];
|
||||||
subgroupBarrier();
|
subgroupBarrier();
|
||||||
//Use warp to do entire add in 1 reduction
|
//Use warp to do entire add in 1 reduction
|
||||||
@@ -100,4 +60,50 @@ void main() {
|
|||||||
//Add the computed sum across all threads and warps
|
//Add the computed sum across all threads and warps
|
||||||
count += warpPrefixSum[subgroupId];
|
count += warpPrefixSum[subgroupId];
|
||||||
ioCount[gid] = count;
|
ioCount[gid] = count;
|
||||||
|
} else {
|
||||||
|
#ifdef IS_INTEL
|
||||||
|
uint subgroupId = gl_LocalInvocationID.x>>6;
|
||||||
|
#else
|
||||||
|
uint subgroupId = gl_SubgroupID;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
//todo
|
||||||
|
//assert(gl_SubgroupSize == 32);
|
||||||
|
//assert(gl_NumSubgroups == (WORK_SIZE>>5));
|
||||||
|
|
||||||
|
uint gid = gl_GlobalInvocationID.x;
|
||||||
|
uvec4 count = uvec4(0);
|
||||||
|
uint sum = 0;
|
||||||
|
{
|
||||||
|
uvec4 dat = ioCount[gid];
|
||||||
|
count.yzw = dat.xyz;
|
||||||
|
count.z += count.y;
|
||||||
|
count.w += count.z;
|
||||||
|
sum = count.w + dat.w;
|
||||||
|
}
|
||||||
|
subgroupBarrier();//Wait for all threads in the subgroup to get the buffer
|
||||||
|
|
||||||
|
count += subgroupExclusiveAdd(sum);
|
||||||
|
|
||||||
|
if (gl_SubgroupInvocationID==63) {
|
||||||
|
warpPrefixSum[subgroupId] = count.x+sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
memoryBarrierShared();
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
if (gl_LocalInvocationID.x<4) {
|
||||||
|
uint val = warpPrefixSum[gl_SubgroupInvocationID];
|
||||||
|
subgroupBarrier();
|
||||||
|
//Use warp to do entire add in 1 reduction
|
||||||
|
warpPrefixSum[gl_SubgroupInvocationID] = subgroupExclusiveAdd(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
memoryBarrierShared();
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
//Add the computed sum across all threads and warps
|
||||||
|
count += warpPrefixSum[subgroupId];
|
||||||
|
ioCount[gid] = count;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user