Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions config/template/kernels/3/bli_gemm_template_noopt_mxn.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

void bli_zgemm_template_noopt
(
dim_t m,
dim_t n,
dim_t k,
dcomplex* restrict alpha,
dcomplex* restrict a1,
Expand Down Expand Up @@ -88,8 +90,7 @@ void bli_zgemm_template_noopt

dim_t l, j, i;

dcomplex ab[ bli_zmr *
bli_znr ];
dcomplex ab[ mr * nr ];
dcomplex* abij;
dcomplex ai, bj;

Expand Down Expand Up @@ -137,16 +138,16 @@ void bli_zgemm_template_noopt
if ( bli_zeq0( *beta ) )
{
/* c11 := ab */
bli_zcopys_mxn( mr,
nr,
bli_zcopys_mxn( m,
n,
ab, rs_ab, cs_ab,
c11, rs_c, cs_c );
}
else
{
/* c11 := beta * c11 + ab */
bli_zxpbys_mxn( mr,
nr,
bli_zxpbys_mxn( m,
n,
ab, rs_ab, cs_ab,
beta,
c11, rs_c, cs_c );
Expand Down
4 changes: 4 additions & 0 deletions config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ void bli_zgemmtrsm_l_template_noopt
*/
const num_t dt = BLIS_DCOMPLEX;

const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx );
const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx );
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx );

const inc_t rs_b = packnr;
Expand All @@ -84,6 +86,8 @@ void bli_zgemmtrsm_l_template_noopt
/* b11 = alpha * b11 - a10 * b01; */
bli_zgemm_template_noopt
(
mr,
nr,
k,
minus_one,
a10,
Expand Down
8 changes: 6 additions & 2 deletions config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ void bli_zgemmtrsm_u_template_noopt
*/
const num_t dt = BLIS_DCOMPLEX;

const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx );
const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx );
const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx );

const inc_t rs_b = packnr;
Expand All @@ -84,10 +86,12 @@ void bli_zgemmtrsm_u_template_noopt
/* b11 = alpha * b11 - a12 * b21; */
bli_zgemm_template_noopt
(
mr,
nr,
k,
minus_one,
a12,
b21,
a10,
b01,
alpha,
b11, rs_b, cs_b,
data
Expand Down
58 changes: 39 additions & 19 deletions frame/1m/packm/bli_packm_alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,35 @@
#include "blis.h"

void* bli_packm_alloc
(
siz_t size_needed,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
)
(
siz_t size_needed,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
)
{
// Query the pack buffer type from the control tree node.
packbuf_t pack_buf_type = bli_cntl_packm_params_pack_buf_type( cntl );

return bli_packm_alloc_ex
(
size_needed,
pack_buf_type,
rntm,
cntl,
thread
);
}

void* bli_packm_alloc_ex
(
siz_t size_needed,
packbuf_t pack_buf_type,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
)
{
// Query the address of the mem_t entry within the control tree node.
mem_t* cntl_mem_p = bli_cntl_pack_mem( cntl );

Expand All @@ -55,7 +74,7 @@ void* bli_packm_alloc
siz_t cntl_mem_size = 0;

if ( bli_mem_is_alloc( cntl_mem_p ) )
cntl_mem_size = bli_mem_size( cntl_mem_p );
cntl_mem_size = bli_mem_size( cntl_mem_p );

if ( cntl_mem_size < size_needed )
{
Expand All @@ -64,14 +83,15 @@ void* bli_packm_alloc
// The chief thread releases the existing block associated with
// the mem_t entry in the control tree, and then re-acquires a
// new block, saving the associated mem_t entry to local_mem_s.
if ( bli_mem_is_alloc( cntl_mem_p ) )
{
bli_pba_release
(
rntm,
cntl_mem_p
);
}
if ( bli_mem_is_alloc( cntl_mem_p ) )
{
bli_pba_release
(
rntm,
cntl_mem_p
);
}

bli_pba_acquire_m
(
rntm,
Expand All @@ -89,11 +109,11 @@ void* bli_packm_alloc
// this thread's control tree node.
*cntl_mem_p = *local_mem_p;

// Barrier so that the master thread doesn't return from the function
// before we are done reading.
bli_thread_barrier( thread );
// Barrier so that the master thread doesn't return from the function
// before we are done reading.
bli_thread_barrier( thread );
}

return bli_mem_buffer( cntl_mem_p );
return bli_mem_buffer( cntl_mem_p );
}

23 changes: 16 additions & 7 deletions frame/1m/packm/bli_packm_alloc.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,20 @@

*/

BLIS_EXPORT_BLIS void* bli_packm_alloc
(
siz_t size_needed,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
);
BLIS_EXPORT_BLIS void* bli_packm_alloc
(
siz_t size_needed,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
);

BLIS_EXPORT_BLIS void* bli_packm_alloc_ex
(
siz_t size_needed,
packbuf_t pack_buf_type,
rntm_t* rntm,
cntl_t* cntl,
thrinfo_t* thread
);

18 changes: 16 additions & 2 deletions frame/3/bli_l3_cntl.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,14 @@ void bli_l3_cntl_create_if
family == BLIS_GEMMT ||
family == BLIS_TRMM )
{
*cntl_use = bli_gemm_cntl_create( rntm, family, schema_a, schema_b );
*cntl_use = bli_gemm_cntl_create
(
rntm,
family,
schema_a,
schema_b,
bli_obj_ker_fn( c )
);
}
else // if ( family == BLIS_TRSM )
{
Expand All @@ -66,7 +73,14 @@ void bli_l3_cntl_create_if
if ( bli_obj_is_triangular( a ) ) side = BLIS_LEFT;
else side = BLIS_RIGHT;

*cntl_use = bli_trsm_cntl_create( rntm, side, schema_a, schema_b );
*cntl_use = bli_trsm_cntl_create
(
rntm,
side,
schema_a,
schema_b,
bli_obj_ker_fn( c )
);
}
}
else
Expand Down
2 changes: 2 additions & 0 deletions frame/3/bli_l3_ft_ukr.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
\
typedef void (*PASTECH3(ch,opname,_ukr,tsuf)) \
( \
dim_t m, \
dim_t n, \
dim_t k, \
ctype* restrict alpha, \
ctype* restrict a, \
Expand Down
4 changes: 4 additions & 0 deletions frame/3/bli_l3_ukr_oapi.c
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ void PASTEMAC0(opname) \
\
num_t dt = bli_obj_dt( c ); \
\
dim_t m = bli_obj_length( c ); \
dim_t n = bli_obj_width( c ); \
dim_t k = bli_obj_width( a ); \
void* buf_a = bli_obj_buffer_at_off( a ); \
void* buf_b = bli_obj_buffer_at_off( b ); \
Expand All @@ -75,6 +77,8 @@ void PASTEMAC0(opname) \
\
f \
( \
m, \
n, \
k, \
buf_alpha, \
buf_a, \
Expand Down
2 changes: 2 additions & 0 deletions frame/3/bli_l3_ukr_prot.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
\
void PASTEMAC(ch,opname) \
( \
dim_t m, \
dim_t n, \
dim_t k, \
ctype_out* restrict alpha, \
ctype_in* restrict a, \
Expand Down
63 changes: 35 additions & 28 deletions frame/3/bli_l3_ukr_tapi.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
\
void PASTEMAC(ch,opname) \
( \
dim_t m, \
dim_t n, \
dim_t k, \
ctype* restrict alpha, \
ctype* restrict a, \
Expand All @@ -58,16 +60,19 @@ void PASTEMAC(ch,opname) \
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
\
/* Invoke the typed function for the given datatype. */ \
f( \
k, \
alpha, \
a, \
b, \
beta, \
c, rs_c, cs_c, \
data, \
cntx \
); \
f \
( \
m, \
n, \
k, \
alpha, \
a, \
b, \
beta, \
c, rs_c, cs_c, \
data, \
cntx \
); \
} \

INSERT_GENTFUNC_BASIC2( gemm_ukernel, gemm, BLIS_GEMM_UKR )
Expand Down Expand Up @@ -98,17 +103,18 @@ void PASTEMAC(ch,opname) \
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
\
/* Invoke the typed function for the given datatype. */ \
f( \
k, \
alpha, \
a1x, \
a11, \
bx1, \
b11, \
c11, rs_c, cs_c, \
data, \
cntx \
); \
f \
( \
k, \
alpha, \
a1x, \
a11, \
bx1, \
b11, \
c11, rs_c, cs_c, \
data, \
cntx \
); \
} \

INSERT_GENTFUNC_BASIC2( gemmtrsm_l_ukernel, gemmtrsm, BLIS_GEMMTRSM_L_UKR )
Expand Down Expand Up @@ -136,13 +142,14 @@ void PASTEMAC(ch,opname) \
PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \
\
/* Invoke the typed function for the given datatype. */ \
f( \
a, \
b, \
c, rs_c, cs_c, \
data, \
cntx \
); \
f \
( \
a, \
b, \
c, rs_c, cs_c, \
data, \
cntx \
); \
} \

INSERT_GENTFUNC_BASIC2( trsm_l_ukernel, trsm, BLIS_TRSM_L_UKR )
Expand Down
15 changes: 10 additions & 5 deletions frame/3/gemm/bli_gemm_cntl.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ cntl_t* bli_gemm_cntl_create
rntm_t* rntm,
opid_t family,
pack_t schema_a,
pack_t schema_b
pack_t schema_b,
void_fp ker
)
{
return bli_gemmbp_cntl_create( rntm, family, schema_a, schema_b );
return bli_gemmbp_cntl_create( rntm, family, schema_a, schema_b, ker );
}

// -----------------------------------------------------------------------------
Expand All @@ -53,18 +54,22 @@ cntl_t* bli_gemmbp_cntl_create
rntm_t* rntm,
opid_t family,
pack_t schema_a,
pack_t schema_b
pack_t schema_b,
void_fp ker
)
{
void_fp macro_kernel_fp;

// Use the function pointers to the macrokernels that use slab
// assignment of micropanels to threads in the jr and ir loops.
// Choose the default macrokernel based on the operation family...
if ( family == BLIS_GEMM ) macro_kernel_fp = bli_gemm_ker_var2;
else if ( family == BLIS_GEMMT ) macro_kernel_fp = bli_gemmt_x_ker_var2;
else if ( family == BLIS_TRMM ) macro_kernel_fp = bli_trmm_xx_ker_var2;
else /* should never execute */ macro_kernel_fp = NULL;

// ...unless a non-NULL kernel function pointer is passed in, in which
// case we use that instead.
if ( ker ) macro_kernel_fp = ker;

// Create two nodes for the macro-kernel.
cntl_t* gemm_cntl_bu_ke = bli_gemm_cntl_create_node
(
Expand Down
Loading