diff --git a/include/RenderGraph/RunnablePasses/ComputePass.hpp b/include/RenderGraph/RunnablePasses/ComputePass.hpp index 3da4c61..2aee291 100644 --- a/include/RenderGraph/RunnablePasses/ComputePass.hpp +++ b/include/RenderGraph/RunnablePasses/ComputePass.hpp @@ -10,6 +10,11 @@ namespace crg { namespace cp { + struct GroupCountT + { + }; + using GetGroupCountCallback = GetValueCallbackT< GroupCountT, uint32_t >; + template< template< typename ValueT > typename WrapperT > struct ConfigT { @@ -162,6 +167,33 @@ namespace crg } /** *\param[in] config + * The callback to retrieve the X dispatch groups count. + */ + auto & getGroupCountX( GetGroupCountCallback config ) + { + m_getGroupCountX = config; + return *this; + } + /** + *\param[in] config + * The callback to retrieve the Y dispatch groups count. + */ + auto & getGroupCountY( GetGroupCountCallback config ) + { + m_getGroupCountY = config; + return *this; + } + /** + *\param[in] config + * The callback to retrieve the Z dispatch groups count. + */ + auto & getGroupCountZ( GetGroupCountCallback config ) + { + m_getGroupCountZ = config; + return *this; + } + /** + *\param[in] config * The buffer used during indirect compute. */ auto & indirectBuffer( IndirectBuffer config ) @@ -198,6 +230,9 @@ namespace crg WrapperT< uint32_t > m_groupCountX{}; WrapperT< uint32_t > m_groupCountY{}; WrapperT< uint32_t > m_groupCountZ{}; + WrapperT< GetGroupCountCallback > m_getGroupCountX{}; + WrapperT< GetGroupCountCallback > m_getGroupCountY{}; + WrapperT< GetGroupCountCallback > m_getGroupCountZ{}; WrapperT< IndirectBuffer > m_indirectBuffer{}; }; @@ -213,6 +248,9 @@ namespace crg RawTypeT< uint32_t > groupCountX{ 1u }; RawTypeT< uint32_t > groupCountY{ 1u }; RawTypeT< uint32_t > groupCountZ{ 1u }; + std::optional< GetGroupCountCallback > getGroupCountX{}; + std::optional< GetGroupCountCallback > getGroupCountY{}; + std::optional< GetGroupCountCallback > getGroupCountZ{}; RawTypeT< IndirectBuffer > indirectBuffer{ defaultV< IndirectBuffer > }; }; @@ -220,6 +258,19 @@ namespace crg using ConfigData = ConfigT< RawTypeT >; } + template<> + struct DefaultValueGetterT < cp::GetGroupCountCallback > + { + static cp::GetGroupCountCallback get() + { + cp::GetGroupCountCallback const result{ []() + { + return 0u; + } }; + return result; + } + }; + template<> struct DefaultValueGetterT< cp::Config > { diff --git a/source/RenderGraph/RunnablePasses/ComputePass.cpp b/source/RenderGraph/RunnablePasses/ComputePass.cpp index 099cb71..1f3a280 100644 --- a/source/RenderGraph/RunnablePasses/ComputePass.cpp +++ b/source/RenderGraph/RunnablePasses/ComputePass.cpp @@ -35,6 +35,9 @@ namespace crg , cpConfig.m_groupCountX ? std::move( *cpConfig.m_groupCountX ) : 1u , cpConfig.m_groupCountY ? std::move( *cpConfig.m_groupCountY ) : 1u , cpConfig.m_groupCountZ ? std::move( *cpConfig.m_groupCountZ ) : 1u + , cpConfig.m_getGroupCountX ? std::optional< cp::GetGroupCountCallback >( std::move( *cpConfig.m_getGroupCountX ) ) : std::nullopt + , cpConfig.m_getGroupCountY ? std::optional< cp::GetGroupCountCallback >( std::move( *cpConfig.m_getGroupCountY ) ) : std::nullopt + , cpConfig.m_getGroupCountZ ? std::optional< cp::GetGroupCountCallback >( std::move( *cpConfig.m_getGroupCountZ ) ) : std::nullopt , cpConfig.m_indirectBuffer ? *cpConfig.m_indirectBuffer : getDefaultV < IndirectBuffer >() } , m_pipeline{ pass , context @@ -93,7 +96,10 @@ namespace crg } else { - m_context.vkCmdDispatch( commandBuffer, m_cpConfig.groupCountX, m_cpConfig.groupCountY, m_cpConfig.groupCountZ ); + m_context.vkCmdDispatch( commandBuffer + , ( m_cpConfig.getGroupCountX ? ( *m_cpConfig.getGroupCountX )() : m_cpConfig.groupCountX ) + , ( m_cpConfig.getGroupCountY ? ( *m_cpConfig.getGroupCountX )() : m_cpConfig.groupCountY ) + , ( m_cpConfig.getGroupCountZ ? ( *m_cpConfig.getGroupCountX )() : m_cpConfig.groupCountZ ) ); } m_cpConfig.end( context, commandBuffer, index );