Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 66 additions & 87 deletions include/alp/imf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,121 +244,100 @@ namespace alp {

};

namespace internal {

/**
* Ensures that the range of the right IMF matches the domain of the left.
* If the condition is not satisfied, throws an exception
*
* @tparam LeftImf The type of the left IMF
* @tparam RightImf The type of the right IMF
*
* @param[in] left_imf The left IMF
* @param[in] right_imf The right IMF
*
*/
template< typename LeftImf, typename RightImf >
static void ensure_imfs_match( const LeftImf &left_imf, const RightImf &right_imf ) {
if( !( right_imf.N == left_imf.n ) ) {
throw std::runtime_error( "Cannot compose two IMFs with non-matching range and domain" );
}
}

} // namespace internal

/**
* Exposes the type and creates the composed IMF from two provided input IMFs.
*
* For certain combinations of IMFs, the resulting composed IMF is
* one of the fundamental types. In these cases, the factory is
* specialized to produce the appropriate type and object.
*/
template< typename LeftImf, typename RightImf >
struct composed_type {
struct ComposedFactory {

typedef Composed< LeftImf, RightImf > type;
};

template<>
struct composed_type< Strided, Strided > {
typedef Strided type;
static type create( const LeftImf &f, const RightImf &g ) {
internal::ensure_imfs_match( f, g );
return type( f, g );
}
};

template< typename RightImf >
struct composed_type< Id, RightImf > {
struct ComposedFactory< Id, RightImf > {

typedef RightImf type;

static type create( const Id &f, const RightImf &g ) {
internal::ensure_imfs_match( f, g );
return RightImf( g );
}
};

template< typename LeftImf >
struct composed_type< LeftImf, Id > {
struct ComposedFactory< LeftImf, Id > {

typedef LeftImf type;
};

template<>
struct composed_type< Id, Id > {
typedef Id type;
static type create( const LeftImf &f, const Id &g ) {
internal::ensure_imfs_match( f, g );
return LeftImf( f );
}
};

template<>
struct composed_type< Zero, Id > {
typedef Zero type;
};
struct ComposedFactory< Id, Id > {

template<>
struct composed_type< Id, Constant > {
typedef Constant type;
};
typedef Id type;

template<>
struct composed_type< Strided, Constant > {
typedef Constant type;
static type create( const Id &f, const Id &g ) {
internal::ensure_imfs_match( f, g );
return type( f.n );
}
};

template<>
struct composed_type< Id, Zero > {
typedef Zero type;
};
struct ComposedFactory< Strided, Strided >{

/**
* Creates the composed IMF from two provided input IMFs.
* Depending on the input IMF types, the factory may
* specialize the returned IMF type.
*/

struct ComposedFactory {
typedef Strided type;

template< typename LeftImf, typename RightImf >
static typename composed_type< LeftImf, RightImf >::type create( const LeftImf &left_imf, const RightImf &right_imf ) {
return typename composed_type< LeftImf, RightImf >::type( left_imf, right_imf );
static type create( const Strided &f, const Strided &g ) {
internal::ensure_imfs_match( f, g );
return type( g.n, f.N, f.s * g.b + f.b, f.s * g.s );
}

};

template<>
Strided ComposedFactory::create( const Id &f, const Strided &g ) {
return Strided( g.n, f.N, g.b, g.s );
}

template<>
Strided ComposedFactory::create( const Strided &f, const Strided &g ) {
return Strided( g.n, f.N, f.s * g.b + f.b, f.s * g.s );
}

template<>
Strided ComposedFactory::create( const Strided &f, const Id &g ) {
return Strided( g.n, f.N, f.b, f.s );
}
struct ComposedFactory< Strided, Constant > {

/** Composition of two Id IMFs is an Id Imf */
template<>
Id ComposedFactory::create( const Id &f, const Id &g ) {
#ifdef NDEBUG
(void)f;
#endif
// The first function's co-domain must be equal to the second function's domain.
assert( g.N == f.n );
return Id( g.n );
}

template<>
Constant ComposedFactory::create( const Id &f, const Constant &g ) {
(void)f;
return Constant( g.n, f.N, g.b );
}

template<>
Constant ComposedFactory::create( const Strided &f, const Constant &g ) {
(void)f;
return Constant( g.n, f.N, f.b + f.s * g.b );
}

template<>
Zero ComposedFactory::create( const Id &f, const Zero &g ) {
(void)f;
return Zero( g.n );
}

template<>
Zero ComposedFactory::create( const Zero &f, const Id &g ) {
(void)f;
return Zero( g.n );
}
typedef Constant type;

template<>
typename composed_type< Id, Select >::type ComposedFactory::create( const Id &f1, const Select &f2 ) {
(void) f1;
return f2;
}
static type create( const Strided &f, const Constant &g ) {
internal::ensure_imfs_match( f, g );
return type( g.n, f.N, f.b + f.s * g.b );
}
};

}; // namespace imf

Expand Down
10 changes: 5 additions & 5 deletions include/alp/storage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,8 @@ namespace alp {
typedef typename SourceAMF::mapping_polynomial_type SourcePoly;

/** Compose row and column IMFs */
typedef typename imf::composed_type< SourceImfR, ViewImfR >::type composed_imf_r_type;
typedef typename imf::composed_type< SourceImfC, ViewImfC >::type composed_imf_c_type;
typedef typename imf::ComposedFactory< SourceImfR, ViewImfR >::type composed_imf_r_type;
typedef typename imf::ComposedFactory< SourceImfC, ViewImfC >::type composed_imf_c_type;

/** Fuse composed row IMF into the target polynomial */
typedef typename polynomials::fuse_on_i<
Expand All @@ -809,8 +809,8 @@ namespace alp {
typedef AMF< final_imf_r_type, final_imf_c_type, final_polynomial_type > amf_type;

static amf_type Create( ViewImfR imf_r, ViewImfC imf_c, const AMF< SourceImfR, SourceImfC, SourcePoly > &amf ) {
composed_imf_r_type composed_imf_r { imf::ComposedFactory::create( amf.imf_r, imf_r ) };
composed_imf_c_type composed_imf_c { imf::ComposedFactory::create( amf.imf_c, imf_c ) };
composed_imf_r_type composed_imf_r = imf::ComposedFactory< SourceImfR, ViewImfR >::create( amf.imf_r, imf_r );
composed_imf_c_type composed_imf_c = imf::ComposedFactory< SourceImfC, ViewImfC >::create( amf.imf_c, imf_c );
return amf_type(
fused_row::CreateImf( composed_imf_r ),
fused_row_col::CreateImf( composed_imf_c ),
Expand Down Expand Up @@ -1005,7 +1005,7 @@ namespace alp {
assert( amf.getLogicalDimensions().first == amf.getLogicalDimensions().second );
return amf_type(
imf::Id( amf.getLogicalDimensions().first ),
imf::Zero( amf.getLogicalDimensions().second ),
imf::Zero( 1 ),
new_poly_type(
orig_p::Ax2 * amf.map_poly.ax2 + orig_p::Ay2 * amf.map_poly.ay2 + orig_p::Axy * amf.map_poly.axy, 0, 0,
orig_p::Ax * amf.map_poly.ax + orig_p::Ay * amf.map_poly.ay, 0,
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/alp_dynamic_views.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ void alpProgram( const size_t &n, alp::RC &rc ) {
auto v_view = alp::get_view< alp::structures::General >( v, sel_r );
print_vector( "v_view", v_view );

// select view over select view
std::vector< size_t > sel2_v_data{ 2, 0, 1 };
alp::Vector< size_t > sel2_v( sel2_v_data.size() );
alp::buildVector( sel2_v, sel2_v_data.begin(), sel2_v_data.end() );
auto v_view_2 = alp::get_view< alp::structures::General >( v_view, sel2_v );
print_vector( "v_view_2", v_view_2 );

// matrix view over select x select view
auto v_mat = alp::get_view< alp::view::matrix >( v_view_2 );
print_matrix( "v_mat", v_mat );

rc = alp::SUCCESS;

}
Expand Down