Expand Up @@ -5,9 +5,8 @@ #define _NBL_BUILTIN_HLSL_BLIT_INCLUDED_ #include <nbl/builtin/hlsl/ndarray_addressing .hlsl> #include <nbl/builtin/hlsl/glsl_compat/core .hlsl> #include <nbl/builtin/hlsl/blit/parameters.hlsl> #include <nbl/builtin/hlsl/blit/common.hlsl> namespace nbl Expand All @@ -17,177 +16,162 @@ namespace hlsl namespace blit { template <typename ConstevalParameters> struct compute_blit_t template< bool DoCoverage, uint16_t WorkGroupSize, int32_t Dims, typename InCombinedSamplerAccessor, typename OutImageAccessor, //typename KernelWeightsAccessor, //typename HistogramAccessor, typename SharedAccessor > void execute( NBL_CONST_REF_ARG(InCombinedSamplerAccessor) inCombinedSamplerAccessor, NBL_REF_ARG(OutImageAccessor) outImageAccessor, //NBL_CONST_REF_ARG(KernelWeightsAccessor) kernelWeightsAccessor, //NBL_REF_ARG(HistogramAccessor) histogramAccessor, NBL_REF_ARG(SharedAccessor) sharedAccessor, NBL_CONST_REF_ARG(SPerWorkgroup) params, const uint16_t layer, const vector<uint16_t,Dims> virtWorkGroupID ) { float32_t3 scale; float32_t3 negativeSupport; uint32_t kernelWeightsOffsetY; uint32_t kernelWeightsOffsetZ; uint32_t inPixelCount; uint32_t outPixelCount; uint16_t3 outputTexelsPerWG; uint16_t3 inDims; uint16_t3 outDims; uint16_t3 windowDims; uint16_t3 phaseCount; uint16_t3 preloadRegion; uint16_t3 iterationRegionXPrefixProducts; uint16_t3 iterationRegionYPrefixProducts; uint16_t3 iterationRegionZPrefixProducts; uint16_t secondScratchOffset; static compute_blit_t create(NBL_CONST_REF_ARG(parameters_t) params) const uint16_t lastChannel = params.lastChannel; const uint16_t coverageChannel = params.coverageChannel; using uint16_tN = vector<uint16_t,Dims>; // the dimensional truncation is desired const uint16_tN outputTexelsPerWG = params.template getPerWGOutputExtent<Dims>(); // its the min XYZ corner of the area the workgroup will sample from to produce its output const uint16_tN minOutputTexel = virtWorkGroupID*outputTexelsPerWG; using float32_tN = vector<float32_t,Dims>; const float32_tN scale = truncate<Dims>(params.scale); const float32_tN inputMaxCoord = params.template getInputMaxCoord<Dims>(); const uint16_t inLevel = _static_cast<uint16_t>(params.inLevel); const float32_tN inImageSizeRcp = inCombinedSamplerAccessor.template extentRcp<Dims>(inLevel); using int32_tN = vector<int32_t,Dims>; // can be negative, its the min XYZ corner of the area the workgroup will sample from to produce its output const float32_tN regionStartCoord = params.inputUpperBound<Dims>(minOutputTexel); const float32_tN regionNextStartCoord = params.inputUpperBound<Dims>(minOutputTexel+outputTexelsPerWG); const uint16_t localInvocationIndex = _static_cast<uint16_t>(glsl::gl_LocalInvocationIndex()); // workgroup::SubgroupContiguousIndex() // need to clear our atomic coverage counter to 0 const uint16_t coverageDWORD = _static_cast<uint16_t>(params.coverageDWORD); if (DoCoverage) { compute_blit_t compute_blit; compute_blit.scale = params.fScale; compute_blit.negativeSupport = params.negativeSupport; compute_blit.kernelWeightsOffsetY = params.kernelWeightsOffsetY; compute_blit.kernelWeightsOffsetZ = params.kernelWeightsOffsetZ; compute_blit.inPixelCount = params.inPixelCount; compute_blit.outPixelCount = params.outPixelCount; compute_blit.outputTexelsPerWG = params.getOutputTexelsPerWG(); compute_blit.inDims = params.inputDims; compute_blit.outDims = params.outputDims; compute_blit.windowDims = params.windowDims; compute_blit.phaseCount = params.phaseCount; compute_blit.preloadRegion = params.preloadRegion; compute_blit.iterationRegionXPrefixProducts = params.iterationRegionXPrefixProducts; compute_blit.iterationRegionYPrefixProducts = params.iterationRegionYPrefixProducts; compute_blit.iterationRegionZPrefixProducts = params.iterationRegionZPrefixProducts; compute_blit.secondScratchOffset = params.secondScratchOffset; return compute_blit; if (localInvocationIndex==0) sharedAccessor.set(coverageDWORD,0u); glsl::barrier(); } template < typename InCombinedSamplerAccessor, typename OutImageAccessor, typename KernelWeightsAccessor, typename HistogramAccessor, typename SharedAccessor> void execute( NBL_CONST_REF_ARG(InCombinedSamplerAccessor) inCombinedSamplerAccessor, NBL_REF_ARG(OutImageAccessor) outImageAccessor, NBL_CONST_REF_ARG(KernelWeightsAccessor) kernelWeightsAccessor, NBL_REF_ARG(HistogramAccessor) histogramAccessor, NBL_REF_ARG(SharedAccessor) sharedAccessor, uint16_t3 workGroupID, uint16_t localInvocationIndex) // const PatchLayout<Dims> preloadLayout = params.getPreloadMeta(); for (uint16_t virtualInvocation=localInvocationIndex; virtualInvocation<preloadLayout.getLinearEnd(); virtualInvocation+=WorkGroupSize) { const float3 halfScale = scale * float3(0.5f, 0.5f, 0.5f); // bottom of the input tile const uint32_t3 minOutputPixel = workGroupID * outputTexelsPerWG; const float3 minOutputPixelCenterOfWG = float3(minOutputPixel)*scale + halfScale; // this can be negative, in which case HW sampler takes care of wrapping for us const int32_t3 regionStartCoord = int32_t3(ceil(minOutputPixelCenterOfWG - float3(0.5f, 0.5f, 0.5f) + negativeSupport)); const uint32_t virtualInvocations = preloadRegion.x * preloadRegion.y * preloadRegion.z; for (uint32_t virtualInvocation = localInvocationIndex; virtualInvocation < virtualInvocations; virtualInvocation += ConstevalParameters::WorkGroupSize) // if we make all args in snakeCurveInverse 16bit maybe compiler will optimize the divisions into using float32_t const uint16_tN virtualInvocationID = preloadLayout.getID(virtualInvocation); const float32_tN inputTexCoordUnnorm = regionStartCoord + float32_tN(virtualInvocationID); const float32_tN inputTexCoord = (inputTexCoordUnnorm + promote<float32_tN>(0.5f)) * inImageSizeRcp; const float32_t4 loadedData = inCombinedSamplerAccessor.template get<float32_t,Dims>(inputTexCoord,layer,inLevel); if (DoCoverage) if (loadedData[coverageChannel]>=params.alphaRefValue && all(inputTexCoordUnnorm<regionNextStartCoord) && // not overlapping with the next tile all(inputTexCoordUnnorm>=promote<float32_tN>(0.f)) && // within the image from below all(inputTexCoordUnnorm<=inputMaxCoord) // within the image from above ) { const int32_t3 inputPixelCoord = regionStartCoord + int32_t3(ndarray_addressing::snakeCurveInverse(virtualInvocation, preloadRegion)); float32_t3 inputTexCoord = (inputPixelCoord + float32_t3(0.5f, 0.5f, 0.5f)) / inDims; const float4 loadedData = inCombinedSamplerAccessor.get(inputTexCoord, workGroupID.z); for (uint32_t ch = 0; ch < ConstevalParameters::BlitOutChannelCount; ++ch) sharedAccessor.set(ch * ConstevalParameters::SMemFloatsPerChannel + virtualInvocation, loadedData[ch]); // TODO: atomicIncr or a workgroup reduction of ballots? //sharedAccessor.template atomicIncr<uint32_t>(coverageDWORD); } GroupMemoryBarrierWithGroupSync(); const uint32_t3 iterationRegionPrefixProducts[3] = {iterationRegionXPrefixProducts, iterationRegionYPrefixProducts, iterationRegionZPrefixProducts}; uint32_t readScratchOffset = 0; uint32_t writeScratchOffset = secondScratchOffset; for (uint32_t axis = 0; axis < ConstevalParameters::BlitDimCount; ++axis) [unroll(4)] for (uint16_t ch=0; ch<4 && ch<=lastChannel; ch++) sharedAccessor.template set<float32_t>(preloadCount*ch+virtualInvocation,loadedData[ch]); } glsl::barrier(); uint16_t readScratchOffset = uint16_t(0); uint16_t writeScratchOffset = _static_cast<uint16_t>(params.secondScratchOffDWORD); const uint16_tN windowExtent = params.template getWindowExtent<Dims>(); uint16_t prevLayout = preloadLayout; uint32_t kernelWeightOffset = 0; [unroll(3)] for (int32_t axis=0; axis<Dims; axis++) { const PatchLayout<Dims> outputLayout = params.getPassMeta<Dims>(axis); const uint16_t invocationCount = outputLayout.getLinearEnd(); const uint16_t phaseCount = params.getPhaseCount(axis); const uint16_t windowLength = windowExtent[axis]; const uint16_t prevPassInvocationCount = prevLayout.getLinearEnd(); for (uint16_t virtualInvocation=localInvocationIndex; virtualInvocation<invocationCount; virtualInvocation+=WorkGroupSize) { for (uint32_t virtualInvocation = localInvocationIndex; virtualInvocation < iterationRegionPrefixProducts[axis].z; virtualInvocation += ConstevalParameters::WorkGroupSize) { const uint32_t3 virtualInvocationID = ndarray_addressing::snakeCurveInverse(virtualInvocation, iterationRegionPrefixProducts[axis].xy); uint32_t outputPixel = virtualInvocationID.x; if (axis == 2) outputPixel = virtualInvocationID.z; outputPixel += minOutputPixel[axis]; if (outputPixel >= outDims[axis]) break; const int32_t minKernelWindow = int32_t(ceil((outputPixel + 0.5f) * scale[axis] - 0.5f + negativeSupport[axis])); // Combined stride for the two non-blitting dimensions, tightly coupled and experimentally derived with/by `iterationRegionPrefixProducts` above and the general order of iteration we use to avoid // read bank conflicts. uint32_t combinedStride; { if (axis == 0) combinedStride = virtualInvocationID.z * preloadRegion.y + virtualInvocationID.y; else if (axis == 1) combinedStride = virtualInvocationID.z * outputTexelsPerWG.x + virtualInvocationID.y; else if (axis == 2) combinedStride = virtualInvocationID.y * outputTexelsPerWG.y + virtualInvocationID.x; } uint32_t offset = readScratchOffset + (minKernelWindow - regionStartCoord[axis]) + combinedStride*preloadRegion[axis]; const uint32_t windowPhase = outputPixel % phaseCount[axis]; uint32_t kernelWeightIndex; if (axis == 0) kernelWeightIndex = windowPhase * windowDims.x; else if (axis == 1) kernelWeightIndex = kernelWeightsOffsetY + windowPhase * windowDims.y; else if (axis == 2) kernelWeightIndex = kernelWeightsOffsetZ + windowPhase * windowDims.z; // this always maps to the index in the current pass output const uint16_tN virtualInvocationID = outputLayout.getID(virtualInvocation); float4 kernelWeight = kernelWeightsAccessor.get(kernelWeightIndex); // we sweep along a line at a time, `[0]` is not a typo, look at the definition of `params.getPassMeta` uint16_t localOutputCoord = virtualInvocationID[0]; // we can actually compute the output position of this line const uint16_t globalOutputCoord = localOutputCoord+minOutputTexel[axis]; // hopefull the compiler will see that float32_t may be possible here due to `sizeof(float32_t mantissa)>sizeof(uint16_t)` const uint32_t windowPhase = globalOutputCoord % phaseCount; float4 accum = float4(0.f, 0.f, 0.f, 0.f); for (uint32_t ch = 0; ch < ConstevalParameters::BlitOutChannelCount; ++ch) accum[ch] = sharedAccessor.get(ch * ConstevalParameters::SMemFloatsPerChannel + offset) * kernelWeight[ch]; //const int32_t windowStart = ceil(localOutputCoord+0.5f; for (uint32_t i = 1; i < windowDims[axis]; ++i) // let us sweep float32_t4 accum = promote<float32_t4>(0.f); { uint32_t kernelWeightIndex = windowPhase*windowLength+kernelWeightOffset; // Need to use global coordinate because of ceil(x*scale) involvement uint16_tN tmp; tmp[0] = params.inputUpperBound(globalOutputCoord,axis)-regionStartCoord; [unroll(2)] for (int32_t i=1; i<Dims; i++) tmp[i] = virtualInvocationID[i]; // initialize to the first gather texel in range of the window for the output uint16_t inputIndex = readScratchOffset+prevLayout.getIndex(tmp); for (uint16_t i=0; i<windowLength; i++,inputIndex++) { kernelWeightIndex++; offset++; kernelWeight = kernelWeightsAccessor.get(kernelWeightIndex); for (uint ch = 0; ch < ConstevalParameters::BlitOutChannelCount; ++ch) accum[ch] += sharedAccessor.get(ch * ConstevalParameters::SMemFloatsPerChannel + offset) * kernelWeight[ch]; const float32_t4 kernelWeight = kernelWeightsAccessor.get(kernelWeightIndex++); [unroll(4)] for (uint16_t ch=0; ch<4 && ch<=lastChannel; ch++) accum[ch] += sharedAccessor.template get<float32_t>(ch*prevPassInvocationCount+inputIndex)*kernelWeight[ch]; } } const bool lastPass = (axis == (ConstevalParameters::BlitDimCount - 1)); if (lastPass) { // Tightly coupled with iteration order (`iterationRegionPrefixProducts`) uint32_t3 outCoord = virtualInvocationID.yxz; if (axis == 0) outCoord = virtualInvocationID.xyz; outCoord += minOutputPixel; const uint32_t bucketIndex = uint32_t(round(clamp(accum.a, 0, 1) * float(ConstevalParameters::AlphaBinCount-1))); histogramAccessor.atomicAdd(workGroupID.z, bucketIndex, uint32_t(1)); outImageAccessor.set(outCoord, workGroupID.z, accum); } else // now write outputs if (axis!=Dims-1) // not last pass { const uint32_t scratchOffset = writeScratchOffset+params.template getStorageIndex<Dims>(axis,virtualInvocationID); [unroll(4)] for (uint16_t ch=0; ch<4 && ch<=lastChannel; ch++) sharedAccessor.template set(ch*invocationCount+scratchOffset,accum[ch]); } else { const uint16_tN coord = SPerWorkgroup::unswizzle<Dims>(virtualInvocationID)+minOutputTexel; outImageAccessor.template set<float32_t,Dims>(coord,layer,accum); if (DoCoverage) { uint32_t scratchOffset = writeScratchOffset; if (axis == 0) scratchOffset += ndarray_addressing::snakeCurve(virtualInvocationID.yxz, uint32_t3(preloadRegion.y, outputTexelsPerWG.x, preloadRegion.z)); else scratchOffset += writeScratchOffset + ndarray_addressing::snakeCurve(virtualInvocationID.zxy, uint32_t3(preloadRegion.z, outputTexelsPerWG.y, outputTexelsPerWG.x)); for (uint32_t ch = 0; ch < ConstevalParameters::BlitOutChannelCount; ++ch) sharedAccessor.set(ch * ConstevalParameters::SMemFloatsPerChannel + scratchOffset, accum[ch]); //const uint32_t bucketIndex = uint32_t(round(accum[coverageChannel] * float(ConstevalParameters::AlphaBinCount - 1))); //histogramAccessor.atomicAdd(workGroupID.z,bucketIndex,uint32_t(1)); //intermediateAlphaImageAccessor.template set<float32_t,Dims>(coord,layer,accum); } } const uint32_t tmp = readScratchOffset; readScratchOffset = writeScratchOffset; writeScratchOffset = tmp; GroupMemoryBarrierWithGroupSync(); } glsl::barrier(); kernelWeightOffset += phaseCount*windowExtent; prevLayout = outputLayout; // TODO: use Przemog's `nbl::hlsl::swap` method when the float64 stuff gets merged const uint32_t tmp = readScratchOffset; readScratchOffset = writeScratchOffset; writeScratchOffset = tmp; } }; } } } Expand Down