diff --git a/include/alp/imf.hpp b/include/alp/imf.hpp index ad602c035..98f0ee911 100644 --- a/include/alp/imf.hpp +++ b/include/alp/imf.hpp @@ -244,121 +244,100 @@ namespace alp { }; + namespace internal { + + /** + * Ensures that the range of the right IMF matches the domain of the left. + * If the condition is not satisfied, throws an exception + * + * @tparam LeftImf The type of the left IMF + * @tparam RightImf The type of the right IMF + * + * @param[in] left_imf The left IMF + * @param[in] right_imf The right IMF + * + */ + template< typename LeftImf, typename RightImf > + static void ensure_imfs_match( const LeftImf &left_imf, const RightImf &right_imf ) { + if( !( right_imf.N == left_imf.n ) ) { + throw std::runtime_error( "Cannot compose two IMFs with non-matching range and domain" ); + } + } + + } // namespace internal + + /** + * Exposes the type and creates the composed IMF from two provided input IMFs. + * + * For certain combinations of IMFs, the resulting composed IMF is + * one of the fundamental types. In these cases, the factory is + * specialized to produce the appropriate type and object. + */ template< typename LeftImf, typename RightImf > - struct composed_type { + struct ComposedFactory { + typedef Composed< LeftImf, RightImf > type; - }; - template<> - struct composed_type< Strided, Strided > { - typedef Strided type; + static type create( const LeftImf &f, const RightImf &g ) { + internal::ensure_imfs_match( f, g ); + return type( f, g ); + } }; template< typename RightImf > - struct composed_type< Id, RightImf > { + struct ComposedFactory< Id, RightImf > { + typedef RightImf type; + + static type create( const Id &f, const RightImf &g ) { + internal::ensure_imfs_match( f, g ); + return RightImf( g ); + } }; template< typename LeftImf > - struct composed_type< LeftImf, Id > { + struct ComposedFactory< LeftImf, Id > { + typedef LeftImf type; - }; - template<> - struct composed_type< Id, Id > { - typedef Id type; + static type create( const LeftImf &f, const Id &g ) { + internal::ensure_imfs_match( f, g ); + return LeftImf( f ); + } }; template<> - struct composed_type< Zero, Id > { - typedef Zero type; - }; + struct ComposedFactory< Id, Id > { - template<> - struct composed_type< Id, Constant > { - typedef Constant type; - }; + typedef Id type; - template<> - struct composed_type< Strided, Constant > { - typedef Constant type; + static type create( const Id &f, const Id &g ) { + internal::ensure_imfs_match( f, g ); + return type( f.n ); + } }; template<> - struct composed_type< Id, Zero > { - typedef Zero type; - }; + struct ComposedFactory< Strided, Strided >{ - /** - * Creates the composed IMF from two provided input IMFs. - * Depending on the input IMF types, the factory may - * specialize the returned IMF type. - */ - - struct ComposedFactory { + typedef Strided type; - template< typename LeftImf, typename RightImf > - static typename composed_type< LeftImf, RightImf >::type create( const LeftImf &left_imf, const RightImf &right_imf ) { - return typename composed_type< LeftImf, RightImf >::type( left_imf, right_imf ); + static type create( const Strided &f, const Strided &g ) { + internal::ensure_imfs_match( f, g ); + return type( g.n, f.N, f.s * g.b + f.b, f.s * g.s ); } - }; template<> - Strided ComposedFactory::create( const Id &f, const Strided &g ) { - return Strided( g.n, f.N, g.b, g.s ); - } - - template<> - Strided ComposedFactory::create( const Strided &f, const Strided &g ) { - return Strided( g.n, f.N, f.s * g.b + f.b, f.s * g.s ); - } - - template<> - Strided ComposedFactory::create( const Strided &f, const Id &g ) { - return Strided( g.n, f.N, f.b, f.s ); - } + struct ComposedFactory< Strided, Constant > { - /** Composition of two Id IMFs is an Id Imf */ - template<> - Id ComposedFactory::create( const Id &f, const Id &g ) { -#ifdef NDEBUG - (void)f; -#endif - // The first function's co-domain must be equal to the second function's domain. - assert( g.N == f.n ); - return Id( g.n ); - } - - template<> - Constant ComposedFactory::create( const Id &f, const Constant &g ) { - (void)f; - return Constant( g.n, f.N, g.b ); - } - - template<> - Constant ComposedFactory::create( const Strided &f, const Constant &g ) { - (void)f; - return Constant( g.n, f.N, f.b + f.s * g.b ); - } - - template<> - Zero ComposedFactory::create( const Id &f, const Zero &g ) { - (void)f; - return Zero( g.n ); - } - - template<> - Zero ComposedFactory::create( const Zero &f, const Id &g ) { - (void)f; - return Zero( g.n ); - } + typedef Constant type; - template<> - typename composed_type< Id, Select >::type ComposedFactory::create( const Id &f1, const Select &f2 ) { - (void) f1; - return f2; - } + static type create( const Strided &f, const Constant &g ) { + internal::ensure_imfs_match( f, g ); + return type( g.n, f.N, f.b + f.s * g.b ); + } + }; }; // namespace imf diff --git a/include/alp/storage.hpp b/include/alp/storage.hpp index 218669536..c957d971b 100644 --- a/include/alp/storage.hpp +++ b/include/alp/storage.hpp @@ -785,8 +785,8 @@ namespace alp { typedef typename SourceAMF::mapping_polynomial_type SourcePoly; /** Compose row and column IMFs */ - typedef typename imf::composed_type< SourceImfR, ViewImfR >::type composed_imf_r_type; - typedef typename imf::composed_type< SourceImfC, ViewImfC >::type composed_imf_c_type; + typedef typename imf::ComposedFactory< SourceImfR, ViewImfR >::type composed_imf_r_type; + typedef typename imf::ComposedFactory< SourceImfC, ViewImfC >::type composed_imf_c_type; /** Fuse composed row IMF into the target polynomial */ typedef typename polynomials::fuse_on_i< @@ -809,8 +809,8 @@ namespace alp { typedef AMF< final_imf_r_type, final_imf_c_type, final_polynomial_type > amf_type; static amf_type Create( ViewImfR imf_r, ViewImfC imf_c, const AMF< SourceImfR, SourceImfC, SourcePoly > &amf ) { - composed_imf_r_type composed_imf_r { imf::ComposedFactory::create( amf.imf_r, imf_r ) }; - composed_imf_c_type composed_imf_c { imf::ComposedFactory::create( amf.imf_c, imf_c ) }; + composed_imf_r_type composed_imf_r = imf::ComposedFactory< SourceImfR, ViewImfR >::create( amf.imf_r, imf_r ); + composed_imf_c_type composed_imf_c = imf::ComposedFactory< SourceImfC, ViewImfC >::create( amf.imf_c, imf_c ); return amf_type( fused_row::CreateImf( composed_imf_r ), fused_row_col::CreateImf( composed_imf_c ), @@ -1005,7 +1005,7 @@ namespace alp { assert( amf.getLogicalDimensions().first == amf.getLogicalDimensions().second ); return amf_type( imf::Id( amf.getLogicalDimensions().first ), - imf::Zero( amf.getLogicalDimensions().second ), + imf::Zero( 1 ), new_poly_type( orig_p::Ax2 * amf.map_poly.ax2 + orig_p::Ay2 * amf.map_poly.ay2 + orig_p::Axy * amf.map_poly.axy, 0, 0, orig_p::Ax * amf.map_poly.ax + orig_p::Ay * amf.map_poly.ay, 0, diff --git a/tests/unit/alp_dynamic_views.cpp b/tests/unit/alp_dynamic_views.cpp index 90d2b6d16..e12abedfe 100644 --- a/tests/unit/alp_dynamic_views.cpp +++ b/tests/unit/alp_dynamic_views.cpp @@ -106,6 +106,17 @@ void alpProgram( const size_t &n, alp::RC &rc ) { auto v_view = alp::get_view< alp::structures::General >( v, sel_r ); print_vector( "v_view", v_view ); + // select view over select view + std::vector< size_t > sel2_v_data{ 2, 0, 1 }; + alp::Vector< size_t > sel2_v( sel2_v_data.size() ); + alp::buildVector( sel2_v, sel2_v_data.begin(), sel2_v_data.end() ); + auto v_view_2 = alp::get_view< alp::structures::General >( v_view, sel2_v ); + print_vector( "v_view_2", v_view_2 ); + + // matrix view over select x select view + auto v_mat = alp::get_view< alp::view::matrix >( v_view_2 ); + print_matrix( "v_mat", v_mat ); + rc = alp::SUCCESS; }