Fix intel

This commit is contained in:
mcrcortex
2025-05-23 00:21:57 +10:00
parent 0882b71a9f
commit 90a3a16cc2
3 changed files with 57 additions and 27 deletions

View File

@@ -8,15 +8,17 @@
//Does inital parralel prefix sum on batches of WORK_SIZE
layout(local_size_x=WORK_SIZE) in;
layout(binding = IO_BUFFER, std430) restrict buffer InputBuffer {
layout(binding = IO_BUFFER, std430) buffer InputBuffer {
uvec4[] ioCount;
};
shared uint warpPrefixSum[32];//Warps are 32, tricks require full warp
void main() {
/*
uint subgroupId = gl_LocalInvocationID.x>>5;
warpPrefixSum[gl_SubgroupInvocationID] = 0;
barrier();
memoryBarrierShared();
//todo
//assert(gl_SubgroupSize == 32);
@@ -33,33 +35,69 @@ void main() {
sum = count.w + dat.w;
}
barrier();
count += subgroupExclusiveAdd(sum);
if (gl_SubgroupInvocationID==31) {
warpPrefixSum[subgroupId] = count.x+sum;
}
memoryBarrierShared();
barrier();
uint val = warpPrefixSum[gl_SubgroupInvocationID];
barrier();
if (subgroupId == 0) {
//Use warp to do entire add in 1 reduction
warpPrefixSum[gl_SubgroupInvocationID] = subgroupExclusiveAdd(val);
}
memoryBarrierShared();
barrier();
count += warpPrefixSum[subgroupId];
ioCount[gid] = count;
*/
#ifdef IS_INTEL
uint subgroupId = gl_LocalInvocationID.x>>5;
#else
uint subgroupId = gl_SubgroupID;
#endif
//todo
//assert(gl_SubgroupSize == 32);
//assert(gl_NumSubgroups == (WORK_SIZE>>5));
uint gid = gl_GlobalInvocationID.x;
uvec4 count = uvec4(0);
uint sum = 0;
{
uvec4 dat = ioCount[gid];
count.yzw = dat.xyz;
count.z += count.y;
count.w += count.z;
sum = count.w + dat.w;
}
subgroupBarrier();//Wait for all threads in the subgroup to get the buffer
count += subgroupExclusiveAdd(sum);
if ((gl_LocalInvocationID.x&31u)==31) {
warpPrefixSum[gl_SubgroupID] = count.x+sum;
if (gl_SubgroupInvocationID==31) {
warpPrefixSum[subgroupId] = count.x+sum;
}
memoryBarrierShared();
barrier();
#ifdef IS_INTEL
uint val = subgroupExclusiveAdd(warpPrefixSum[gl_SubgroupInvocationID]);
barrier();
if (gl_SubgroupID == 0) {
warpPrefixSum[gl_SubgroupInvocationID] = val;
}
#else
if (gl_SubgroupID == 0) {
uint val = warpPrefixSum[gl_SubgroupInvocationID];
subgroupBarrier();
//Use warp to do entire add in 1 reduction
warpPrefixSum[gl_SubgroupInvocationID] = subgroupExclusiveAdd(val);
}
#endif
memoryBarrierShared();
barrier();
//Add the computed sum across all threads and warps
count += warpPrefixSum[gl_SubgroupID];
count += warpPrefixSum[subgroupId];
ioCount[gid] = count;
}