diff --git a/config/template/kernels/3/bli_gemm_template_noopt_mxn.c b/config/template/kernels/3/bli_gemm_template_noopt_mxn.c index b7a13f3b69..06f25a0e9e 100644 --- a/config/template/kernels/3/bli_gemm_template_noopt_mxn.c +++ b/config/template/kernels/3/bli_gemm_template_noopt_mxn.c @@ -37,6 +37,8 @@ void bli_zgemm_template_noopt ( + dim_t m, + dim_t n, dim_t k, dcomplex* restrict alpha, dcomplex* restrict a1, @@ -88,8 +90,7 @@ void bli_zgemm_template_noopt dim_t l, j, i; - dcomplex ab[ bli_zmr * - bli_znr ]; + dcomplex ab[ mr * nr ]; dcomplex* abij; dcomplex ai, bj; @@ -137,16 +138,16 @@ void bli_zgemm_template_noopt if ( bli_zeq0( *beta ) ) { /* c11 := ab */ - bli_zcopys_mxn( mr, - nr, + bli_zcopys_mxn( m, + n, ab, rs_ab, cs_ab, c11, rs_c, cs_c ); } else { /* c11 := beta * c11 + ab */ - bli_zxpbys_mxn( mr, - nr, + bli_zxpbys_mxn( m, + n, ab, rs_ab, cs_ab, beta, c11, rs_c, cs_c ); diff --git a/config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c b/config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c index da0cd3110f..87c21f7edf 100644 --- a/config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c +++ b/config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c @@ -74,6 +74,8 @@ void bli_zgemmtrsm_l_template_noopt */ const num_t dt = BLIS_DCOMPLEX; + const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); + const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); const inc_t rs_b = packnr; @@ -84,6 +86,8 @@ void bli_zgemmtrsm_l_template_noopt /* b11 = alpha * b11 - a10 * b01; */ bli_zgemm_template_noopt ( + mr, + nr, k, minus_one, a10, diff --git a/config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c b/config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c index 09b3af9cee..0b4544ae1d 100644 --- a/config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c +++ b/config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c @@ -74,6 +74,8 @@ void bli_zgemmtrsm_u_template_noopt */ const num_t dt = BLIS_DCOMPLEX; + const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); + const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); const inc_t rs_b = packnr; @@ -84,10 +86,12 @@ void bli_zgemmtrsm_u_template_noopt /* b11 = alpha * b11 - a12 * b21; */ bli_zgemm_template_noopt ( + mr, + nr, k, minus_one, - a12, - b21, + a10, + b01, alpha, b11, rs_b, cs_b, data diff --git a/frame/1m/packm/bli_packm_alloc.c b/frame/1m/packm/bli_packm_alloc.c index df6750d7ac..b12a93ddc0 100644 --- a/frame/1m/packm/bli_packm_alloc.c +++ b/frame/1m/packm/bli_packm_alloc.c @@ -36,16 +36,35 @@ #include "blis.h" void* bli_packm_alloc - ( - siz_t size_needed, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) + ( + siz_t size_needed, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) { // Query the pack buffer type from the control tree node. packbuf_t pack_buf_type = bli_cntl_packm_params_pack_buf_type( cntl ); + return bli_packm_alloc_ex + ( + size_needed, + pack_buf_type, + rntm, + cntl, + thread + ); +} + +void* bli_packm_alloc_ex + ( + siz_t size_needed, + packbuf_t pack_buf_type, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ // Query the address of the mem_t entry within the control tree node. mem_t* cntl_mem_p = bli_cntl_pack_mem( cntl ); @@ -55,7 +74,7 @@ void* bli_packm_alloc siz_t cntl_mem_size = 0; if ( bli_mem_is_alloc( cntl_mem_p ) ) - cntl_mem_size = bli_mem_size( cntl_mem_p ); + cntl_mem_size = bli_mem_size( cntl_mem_p ); if ( cntl_mem_size < size_needed ) { @@ -64,14 +83,15 @@ void* bli_packm_alloc // The chief thread releases the existing block associated with // the mem_t entry in the control tree, and then re-acquires a // new block, saving the associated mem_t entry to local_mem_s. - if ( bli_mem_is_alloc( cntl_mem_p ) ) - { - bli_pba_release - ( - rntm, - cntl_mem_p - ); - } + if ( bli_mem_is_alloc( cntl_mem_p ) ) + { + bli_pba_release + ( + rntm, + cntl_mem_p + ); + } + bli_pba_acquire_m ( rntm, @@ -89,11 +109,11 @@ void* bli_packm_alloc // this thread's control tree node. *cntl_mem_p = *local_mem_p; - // Barrier so that the master thread doesn't return from the function - // before we are done reading. - bli_thread_barrier( thread ); + // Barrier so that the master thread doesn't return from the function + // before we are done reading. + bli_thread_barrier( thread ); } - return bli_mem_buffer( cntl_mem_p ); + return bli_mem_buffer( cntl_mem_p ); } diff --git a/frame/1m/packm/bli_packm_alloc.h b/frame/1m/packm/bli_packm_alloc.h index b433be350a..5a5cf126b1 100644 --- a/frame/1m/packm/bli_packm_alloc.h +++ b/frame/1m/packm/bli_packm_alloc.h @@ -32,11 +32,20 @@ */ -BLIS_EXPORT_BLIS void* bli_packm_alloc - ( - siz_t size_needed, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ); +BLIS_EXPORT_BLIS void* bli_packm_alloc + ( + siz_t size_needed, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ); + +BLIS_EXPORT_BLIS void* bli_packm_alloc_ex + ( + siz_t size_needed, + packbuf_t pack_buf_type, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ); diff --git a/frame/3/bli_l3_cntl.c b/frame/3/bli_l3_cntl.c index 3cdecfbc26..83ff8e5af5 100644 --- a/frame/3/bli_l3_cntl.c +++ b/frame/3/bli_l3_cntl.c @@ -57,7 +57,14 @@ void bli_l3_cntl_create_if family == BLIS_GEMMT || family == BLIS_TRMM ) { - *cntl_use = bli_gemm_cntl_create( rntm, family, schema_a, schema_b ); + *cntl_use = bli_gemm_cntl_create + ( + rntm, + family, + schema_a, + schema_b, + bli_obj_ker_fn( c ) + ); } else // if ( family == BLIS_TRSM ) { @@ -66,7 +73,14 @@ void bli_l3_cntl_create_if if ( bli_obj_is_triangular( a ) ) side = BLIS_LEFT; else side = BLIS_RIGHT; - *cntl_use = bli_trsm_cntl_create( rntm, side, schema_a, schema_b ); + *cntl_use = bli_trsm_cntl_create + ( + rntm, + side, + schema_a, + schema_b, + bli_obj_ker_fn( c ) + ); } } else diff --git a/frame/3/bli_l3_ft_ukr.h b/frame/3/bli_l3_ft_ukr.h index 4249dcbd6b..561c8264fe 100644 --- a/frame/3/bli_l3_ft_ukr.h +++ b/frame/3/bli_l3_ft_ukr.h @@ -47,6 +47,8 @@ \ typedef void (*PASTECH3(ch,opname,_ukr,tsuf)) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ diff --git a/frame/3/bli_l3_ukr_oapi.c b/frame/3/bli_l3_ukr_oapi.c index 33262b0bba..b8f2e00e6a 100644 --- a/frame/3/bli_l3_ukr_oapi.c +++ b/frame/3/bli_l3_ukr_oapi.c @@ -51,6 +51,8 @@ void PASTEMAC0(opname) \ \ num_t dt = bli_obj_dt( c ); \ \ + dim_t m = bli_obj_length( c ); \ + dim_t n = bli_obj_width( c ); \ dim_t k = bli_obj_width( a ); \ void* buf_a = bli_obj_buffer_at_off( a ); \ void* buf_b = bli_obj_buffer_at_off( b ); \ @@ -75,6 +77,8 @@ void PASTEMAC0(opname) \ \ f \ ( \ + m, \ + n, \ k, \ buf_alpha, \ buf_a, \ diff --git a/frame/3/bli_l3_ukr_prot.h b/frame/3/bli_l3_ukr_prot.h index ca523b1d70..f68973ff59 100644 --- a/frame/3/bli_l3_ukr_prot.h +++ b/frame/3/bli_l3_ukr_prot.h @@ -42,6 +42,8 @@ \ void PASTEMAC(ch,opname) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype_out* restrict alpha, \ ctype_in* restrict a, \ diff --git a/frame/3/bli_l3_ukr_tapi.c b/frame/3/bli_l3_ukr_tapi.c index 67e33175b7..ab745d12b3 100644 --- a/frame/3/bli_l3_ukr_tapi.c +++ b/frame/3/bli_l3_ukr_tapi.c @@ -39,6 +39,8 @@ \ void PASTEMAC(ch,opname) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -58,16 +60,19 @@ void PASTEMAC(ch,opname) \ PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \ \ /* Invoke the typed function for the given datatype. */ \ - f( \ - k, \ - alpha, \ - a, \ - b, \ - beta, \ - c, rs_c, cs_c, \ - data, \ - cntx \ - ); \ + f \ + ( \ + m, \ + n, \ + k, \ + alpha, \ + a, \ + b, \ + beta, \ + c, rs_c, cs_c, \ + data, \ + cntx \ + ); \ } \ INSERT_GENTFUNC_BASIC2( gemm_ukernel, gemm, BLIS_GEMM_UKR ) @@ -98,17 +103,18 @@ void PASTEMAC(ch,opname) \ PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \ \ /* Invoke the typed function for the given datatype. */ \ - f( \ - k, \ - alpha, \ - a1x, \ - a11, \ - bx1, \ - b11, \ - c11, rs_c, cs_c, \ - data, \ - cntx \ - ); \ + f \ + ( \ + k, \ + alpha, \ + a1x, \ + a11, \ + bx1, \ + b11, \ + c11, rs_c, cs_c, \ + data, \ + cntx \ + ); \ } \ INSERT_GENTFUNC_BASIC2( gemmtrsm_l_ukernel, gemmtrsm, BLIS_GEMMTRSM_L_UKR ) @@ -136,13 +142,14 @@ void PASTEMAC(ch,opname) \ PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \ \ /* Invoke the typed function for the given datatype. */ \ - f( \ - a, \ - b, \ - c, rs_c, cs_c, \ - data, \ - cntx \ - ); \ + f \ + ( \ + a, \ + b, \ + c, rs_c, cs_c, \ + data, \ + cntx \ + ); \ } \ INSERT_GENTFUNC_BASIC2( trsm_l_ukernel, trsm, BLIS_TRSM_L_UKR ) diff --git a/frame/3/gemm/bli_gemm_cntl.c b/frame/3/gemm/bli_gemm_cntl.c index 72d78efe16..052c812a33 100644 --- a/frame/3/gemm/bli_gemm_cntl.c +++ b/frame/3/gemm/bli_gemm_cntl.c @@ -40,10 +40,11 @@ cntl_t* bli_gemm_cntl_create rntm_t* rntm, opid_t family, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ) { - return bli_gemmbp_cntl_create( rntm, family, schema_a, schema_b ); + return bli_gemmbp_cntl_create( rntm, family, schema_a, schema_b, ker ); } // ----------------------------------------------------------------------------- @@ -53,18 +54,22 @@ cntl_t* bli_gemmbp_cntl_create rntm_t* rntm, opid_t family, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ) { void_fp macro_kernel_fp; - // Use the function pointers to the macrokernels that use slab - // assignment of micropanels to threads in the jr and ir loops. + // Choose the default macrokernel based on the operation family... if ( family == BLIS_GEMM ) macro_kernel_fp = bli_gemm_ker_var2; else if ( family == BLIS_GEMMT ) macro_kernel_fp = bli_gemmt_x_ker_var2; else if ( family == BLIS_TRMM ) macro_kernel_fp = bli_trmm_xx_ker_var2; else /* should never execute */ macro_kernel_fp = NULL; + // ...unless a non-NULL kernel function pointer is passed in, in which + // case we use that instead. + if ( ker ) macro_kernel_fp = ker; + // Create two nodes for the macro-kernel. cntl_t* gemm_cntl_bu_ke = bli_gemm_cntl_create_node ( diff --git a/frame/3/gemm/bli_gemm_cntl.h b/frame/3/gemm/bli_gemm_cntl.h index bff91b58aa..5fa213ac41 100644 --- a/frame/3/gemm/bli_gemm_cntl.h +++ b/frame/3/gemm/bli_gemm_cntl.h @@ -38,7 +38,8 @@ cntl_t* bli_gemm_cntl_create rntm_t* rntm, opid_t family, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ); // ----------------------------------------------------------------------------- @@ -48,7 +49,8 @@ cntl_t* bli_gemmbp_cntl_create rntm_t* rntm, opid_t family, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ); #if 0 diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index a9ea21dc43..4ff45036fe 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -283,90 +283,3 @@ void bli_gemm_front #endif } -// ----------------------------------------------------------------------------- - -#if 0 - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - const bool a_is_real = bli_obj_is_real( a ); - const bool a_is_comp = bli_obj_is_complex( a ); - const bool b_is_real = bli_obj_is_real( b ); - const bool b_is_comp = bli_obj_is_complex( b ); - const bool c_is_real = bli_obj_is_real( c ); - const bool c_is_comp = bli_obj_is_complex( c ); - - const bool a_is_single = bli_obj_is_single_prec( a ); - const bool a_is_double = bli_obj_is_double_prec( a ); - const bool b_is_single = bli_obj_is_single_prec( b ); - const bool b_is_double = bli_obj_is_double_prec( b ); - const bool c_is_single = bli_obj_is_single_prec( c ); - const bool c_is_double = bli_obj_is_double_prec( c ); - - const bool comp_single = bli_obj_comp_prec( c ) == BLIS_SINGLE_PREC; - const bool comp_double = bli_obj_comp_prec( c ) == BLIS_DOUBLE_PREC; - - const bool mixeddomain = bli_obj_domain( c ) != bli_obj_domain( a ) || - bli_obj_domain( c ) != bli_obj_domain( b ); - - ( void )a_is_real; ( void )a_is_comp; - ( void )b_is_real; ( void )b_is_comp; - ( void )c_is_real; ( void )c_is_comp; - ( void )a_is_single; ( void )a_is_double; - ( void )b_is_single; ( void )b_is_double; - ( void )c_is_single; ( void )c_is_double; - ( void )comp_single; ( void )comp_double; - - if ( - //( c_is_comp && a_is_comp && b_is_real ) || - //( c_is_comp && a_is_real && b_is_comp ) || - //( c_is_real && a_is_comp && b_is_comp ) || - //( c_is_comp && a_is_real && b_is_real ) || - //( c_is_real && a_is_comp && b_is_real ) || - //( c_is_real && a_is_real && b_is_comp ) || - //FALSE - TRUE - ) - { - if ( - ( c_is_single && a_is_single && b_is_single && mixeddomain ) || - ( c_is_single && a_is_single && b_is_single && comp_single ) || - ( c_is_single && a_is_single && b_is_single && comp_double ) || - ( c_is_single && a_is_single && b_is_double ) || - ( c_is_single && a_is_double && b_is_single ) || - ( c_is_double && a_is_single && b_is_single ) || - ( c_is_single && a_is_double && b_is_double ) || - ( c_is_double && a_is_single && b_is_double ) || - ( c_is_double && a_is_double && b_is_single ) || - ( c_is_double && a_is_double && b_is_double && comp_single ) || - ( c_is_double && a_is_double && b_is_double && comp_double ) || - ( c_is_double && a_is_double && b_is_double && mixeddomain ) || - FALSE - ) - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - } - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#else -#if 0 - // If any of the storage datatypes differ, or if the execution precision - // differs from the storage precision of C, utilize the mixed datatype - // code path. - // NOTE: We could check the exec dt against the storage dt of C, but for - // now we don't support the caller setting the execution domain - // explicitly. - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#endif -#endif - diff --git a/frame/3/gemm/bli_gemm_ker_var2.c b/frame/3/gemm/bli_gemm_ker_var2.c index 0c90605524..6de361194d 100644 --- a/frame/3/gemm/bli_gemm_ker_var2.c +++ b/frame/3/gemm/bli_gemm_ker_var2.c @@ -35,28 +35,44 @@ #include "blis.h" -#define FUNCPTR_T gemm_fp +typedef void (*xpbys_mxn_vft) + ( + dim_t m, + dim_t n, + void* x, inc_t rs_x, inc_t cs_x, + void* b, + void* y, inc_t rs_y, inc_t cs_y + ); -typedef void (*FUNCPTR_T) - ( - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - rntm_t* rntm, - thrinfo_t* thread - ); +#undef GENTFUNC2 +#define GENTFUNC2(ctypex,ctypey,chx,chy,op) \ +\ +void PASTEMAC2(chx,chy,op) \ + ( \ + dim_t m, \ + dim_t n, \ + void* x, inc_t rs_x, inc_t cs_x, \ + void* b, \ + void* y, inc_t rs_y, inc_t cs_y \ + ) \ +{ \ + ctypex* restrict x_cast = x; \ + ctypey* restrict b_cast = b; \ + ctypey* restrict y_cast = y; \ +\ + PASTEMAC3(chx,chy,chy,xpbys_mxn) \ + ( \ + m, n, \ + x_cast, rs_x, cs_x, \ + b_cast, \ + y_cast, rs_y, cs_y \ + ); \ +} -static FUNCPTR_T GENARRAY(ftypes,gemm_ker_var2); +INSERT_GENTFUNC2_BASIC0(xbpys_mxn_fn); +INSERT_GENTFUNC2_MIXDP0(xbpys_mxn_fn); + +static xpbys_mxn_vft GENARRAY2_ALL(xbpys_mxn, xbpys_mxn_fn); void bli_gemm_ker_var2 @@ -70,23 +86,8 @@ void bli_gemm_ker_var2 thrinfo_t* thread ) { -#ifdef BLIS_ENABLE_GEMM_MD - // By now, A and B have been packed and cast to the execution precision. - // In most cases, such as when storage precision of C differs from the - // execution precision, we utilize the mixed datatype code path. However, - // a few cases still fall within this kernel, such as mixed domain with - // equal precision (ccr, crc, rcc), hence those expressions being disabled - // in the conditional below. - if ( //( bli_obj_domain( c ) != bli_obj_domain( a ) ) || - //( bli_obj_domain( c ) != bli_obj_domain( b ) ) || - ( bli_obj_dt( c ) != bli_obj_exec_dt( c ) ) ) - { - bli_gemm_ker_var2_md( a, b, c, cntx, rntm, cntl, thread ); - return; - } -#endif - num_t dt_exec = bli_obj_exec_dt( c ); + num_t dt_c = bli_obj_dt( c ); pack_t schema_a = bli_obj_pack_schema( a ); pack_t schema_b = bli_obj_pack_schema( b ); @@ -95,50 +96,55 @@ void bli_gemm_ker_var2 dim_t n = bli_obj_width( c ); dim_t k = bli_obj_width( a ); - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); + char* a_cast = bli_obj_buffer_at_off( a ); inc_t is_a = bli_obj_imag_stride( a ); dim_t pd_a = bli_obj_panel_dim( a ); inc_t ps_a = bli_obj_panel_stride( a ); - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); + char* b_cast = bli_obj_buffer_at_off( b ); inc_t is_b = bli_obj_imag_stride( b ); dim_t pd_b = bli_obj_panel_dim( b ); inc_t ps_b = bli_obj_panel_stride( b ); - void* buf_c = bli_obj_buffer_at_off( c ); + char* c_cast = bli_obj_buffer_at_off( c ); inc_t rs_c = bli_obj_row_stride( c ); inc_t cs_c = bli_obj_col_stride( c ); - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; + // If any dimension is zero, return immediately. + if ( bli_zero_dim3( m, n, k ) ) return; // Detach and multiply the scalars attached to A and B. + // NOTE: We know that the internal scalars of A and B are already of the + // target datatypes because the necessary typecasting would have already + // taken place during bli_packm_init(). + obj_t scalar_a; + obj_t scalar_b; bli_obj_scalar_detach( a, &scalar_a ); bli_obj_scalar_detach( b, &scalar_b ); bli_mulsc( &scalar_a, &scalar_b ); // Grab the addresses of the internal scalar buffers for the scalar // merged above and the scalar attached to C. - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); + // NOTE: We know that scalar_b is of type dt_exec due to the above code + // that casts the scalars of A and B to dt_exec via scalar_a and scalar_b, + // and we know that the internal scalar in C is already of the type dt_c + // due to the casting in the implementation of bli_obj_scalar_attach(). + char* alpha_cast = bli_obj_internal_scalar_buffer( &scalar_b ); + char* beta_cast = bli_obj_internal_scalar_buffer( c ); // If 1m is being employed on a column- or row-stored matrix with a // real-valued beta, we can use the real domain macro-kernel, which // eliminates a little overhead associated with the 1m virtual // micro-kernel. + // Only employ this optimization if the storage datatype of C is + // equal to the execution/computation datatype. #if 1 if ( bli_cntx_method( cntx ) == BLIS_1M ) { bli_gemm_ind_recast_1m_params ( &dt_exec, + &dt_c, schema_a, c, &m, &n, &k, @@ -151,273 +157,211 @@ void bli_gemm_ker_var2 #ifdef BLIS_ENABLE_GEMM_MD // Tweak parameters in select mixed domain cases (rcc, crc, ccr). - bli_gemm_md_ker_var2_recast - ( - &dt_exec, - bli_obj_dt( a ), - bli_obj_dt( b ), - bli_obj_dt( c ), - &m, &n, &k, - &pd_a, &ps_a, - &pd_b, &ps_b, - c, - &rs_c, &cs_c - ); + if ( bli_cntx_method( cntx ) == BLIS_NAT ) + { + bli_gemm_md_ker_var2_recast + ( + &dt_exec, + bli_obj_dt( a ), + bli_obj_dt( b ), + &dt_c, + &m, &n, &k, + &pd_a, &ps_a, + &pd_b, &ps_b, + c, + &rs_c, &cs_c + ); + } #endif - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_exec]; - - // Invoke the function. - f( schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - rntm, - thread ); -} + siz_t dt_size = bli_dt_size( dt_exec ); + siz_t dt_c_size = bli_dt_size( dt_c ); + // Alias some constants to simpler names. + const dim_t MR = pd_a; + const dim_t NR = pd_b; + //const dim_t PACKMR = cs_a; + //const dim_t PACKNR = rs_b; + + // Query the context for the micro-kernel address and cast it to its + // function pointer type. + gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt_exec, BLIS_GEMM_UKR, cntx ); + + // Query the params field from the obj_t. If it is non-NULL, grab the ukr + // field of the params struct. If that function pointer is non-NULL, use it + // as our microkernel instead of the default microkernel queried from the + // cntx above. + gemm_ker_params_t* params = bli_obj_ker_params( c ); + gemm_ukr_vft user_ukr = params ? params->ukr : NULL; + if ( user_ukr ) gemm_ukr = user_ukr; + + // Temporary C buffer for edge cases. Note that the strides of this + // temporary buffer are set so that they match the storage of the + // original C matrix. For example, if C is column-stored, ct will be + // column-stored as well. + char ct[ BLIS_STACK_BUF_MAX_SIZE ] + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt_exec, BLIS_GEMM_UKR, cntx ); + const inc_t rs_ct = ( col_pref ? 1 : NR ); + const inc_t cs_ct = ( col_pref ? MR : 1 ); + char* zero = bli_obj_buffer_for_const( dt_exec, &BLIS_ZERO ); + + // + // Assumptions/assertions: + // rs_a == 1 + // cs_a == PACKMR + // pd_a == MR + // ps_a == stride to next micro-panel of A + // rs_b == PACKNR + // cs_b == 1 + // pd_b == NR + // ps_b == stride to next micro-panel of B + // rs_c == (no assumptions) + // cs_c == (no assumptions) + // + + // Compute number of primary and leftover components of the m and n + // dimensions. + dim_t n_iter = n / NR; + dim_t n_left = n % NR; + + dim_t m_iter = m / MR; + dim_t m_left = m % MR; + + if ( n_left ) ++n_iter; + if ( m_left ) ++m_iter; + + // Determine some increments used to step through A, B, and C. + inc_t rstep_a = ps_a * dt_size; + + inc_t cstep_b = ps_b * dt_size; + + inc_t rstep_c = rs_c * MR * dt_c_size; + inc_t cstep_c = cs_c * NR * dt_c_size; + + auxinfo_t aux; + + // Save the pack schemas of A and B to the auxinfo_t object. + bli_auxinfo_set_schema_a( schema_a, &aux ); + bli_auxinfo_set_schema_b( schema_b, &aux ); + + // Save the imaginary stride of A and B to the auxinfo_t object. + bli_auxinfo_set_is_a( is_a, &aux ); + bli_auxinfo_set_is_b( is_b, &aux ); + + // Save the virtual microkernel address and the params. + bli_auxinfo_set_ukr( gemm_ukr, &aux ); + bli_auxinfo_set_params( params, &aux ); + + // The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + // loop around the microkernel. Here we query the thrinfo_t node for the + // 1st (ir) loop around the microkernel. + thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); + + // Query the number of threads and thread ids for each loop. + dim_t jr_nt = bli_thread_n_way( thread ); + dim_t jr_tid = bli_thread_work_id( thread ); + dim_t ir_nt = bli_thread_n_way( caucus ); + dim_t ir_tid = bli_thread_work_id( caucus ); + + dim_t jr_start, jr_end; + dim_t ir_start, ir_end; + dim_t jr_inc, ir_inc; + + // Determine the thread range and increment for the 2nd and 1st loops. + // NOTE: The definition of bli_thread_range_jrir() will depend on whether + // slab or round-robin partitioning was requested at configure-time. + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); + bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); + + // Loop over the n dimension (NR columns at a time). + for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) + { + char* b1 = b_cast + j * cstep_b; + char* c1 = c_cast + j * cstep_c; + + dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + + // Initialize our next panel of B to be the current panel of B. + char* b2 = b1; + + // Loop over the m dimension (MR rows at a time). + for ( dim_t i = ir_start; i < ir_end; i += ir_inc ) + { + char* a1 = a_cast + i * rstep_a; + char* c11 = c1 + i * rstep_c; + + dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + + // Compute the addresses of the next panels of A and B. + char* a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) + { + a2 = a_cast; + b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) + b2 = b_cast; + } + + // Save addresses of next panels of A and B to the auxinfo_t + // object. + bli_auxinfo_set_next_a( a2, &aux ); + bli_auxinfo_set_next_b( b2, &aux ); + + // Edge case handling now occurs within the microkernel itself, but + // we must still explicitly accumulate to a temporary microtile in + // situations where a virtual microkernel is being used, such as + // during the 1m method or some cases of mixed datatypes. + if ( dt_exec == dt_c ) + { + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k, + alpha_cast, + a1, + b1, + beta_cast, + c11, rs_c, cs_c, + &aux, + cntx + ); + } + else + { + // Invoke the gemm micro-kernel. + gemm_ukr + ( + MR, + NR, + k, + alpha_cast, + a1, + b1, + zero, + &ct, rs_ct, cs_ct, + &aux, + cntx + ); + + // Accumulate to C with type-casting. + xbpys_mxn[ dt_exec ][ dt_c ] + ( + m_cur, n_cur, + &ct, rs_ct, cs_ct, + beta_cast, + c11, rs_c, cs_c + ); + } + } + } -#undef GENTFUNC -#define GENTFUNC( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ - /* Query the context for the micro-kernel address and cast it to its - function pointer type. Note that the virtual gemm ukernel is queried - instead of the native gemm ukernel. This is needed for certain - situations for the 1m method that require an extra layer of logic - to allow for handling (for example) complex values of beta. Also - note that under certain circumstances, the real-domain version of - this macrokernel will be called for 1m (NOT the complex version) - as an optimization. In these cases, the corresponding real-domain - slots within the cntx_t's virtual gemm ukernel func_t will contain - pointers to the *native* gemm ukernel, thanks to logic in the - context initialization function for the induced method (defined - in bli_cntx_ref.c). */ \ - PASTECH(ch,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype* restrict zero = PASTEMAC(ch,0); \ - ctype* restrict a_cast = a; \ - ctype* restrict b_cast = b; \ - ctype* restrict c_cast = c; \ - ctype* restrict alpha_cast = alpha; \ - ctype* restrict beta_cast = beta; \ - ctype* restrict b1; \ - ctype* restrict c1; \ -\ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t i, j; \ - dim_t m_cur; \ - dim_t n_cur; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ - /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) - loop around the microkernel. Here we query the thrinfo_t node for the - 1st (ir) loop around the microkernel. */ \ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ -\ - /* Query the number of threads and thread ids for each loop. */ \ - dim_t jr_nt = bli_thread_n_way( thread ); \ - dim_t jr_tid = bli_thread_work_id( thread ); \ - dim_t ir_nt = bli_thread_n_way( caucus ); \ - dim_t ir_tid = bli_thread_work_id( caucus ); \ -\ - dim_t jr_start, jr_end; \ - dim_t ir_start, ir_end; \ - dim_t jr_inc, ir_inc; \ -\ - /* Determine the thread range and increment for the 2nd and 1st loops. - NOTE: The definition of bli_thread_range_jrir() will depend on whether - slab or round-robin partitioning was requested at configure-time. */ \ - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ - bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_start; j < jr_end; j += jr_inc ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_start; i < ir_end; i += ir_inc ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the bottom edge of C and add the result from above. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ -\ /* -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \ -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); \ -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); \ -*/ \ +PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); +PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); +PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); +*/ } -INSERT_GENTFUNC_BASIC0( gemm_ker_var2 ) - diff --git a/frame/3/gemm/bli_gemm_ker_var2_md.c b/frame/3/gemm/bli_gemm_ker_var2_md.c deleted file mode 100644 index 09c279d149..0000000000 --- a/frame/3/gemm/bli_gemm_ker_var2_md.c +++ /dev/null @@ -1,406 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#ifdef BLIS_ENABLE_GEMM_MD - -#define FUNCPTR_T gemm_fp - -typedef void (*FUNCPTR_T) - ( - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - rntm_t* rntm, - thrinfo_t* thread - ); - -static FUNCPTR_T GENARRAY2_ALL(ftypes,gemm_ker_var2_md); - - -void bli_gemm_ker_var2_md - ( - obj_t* a, - obj_t* b, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - num_t dt_exec = bli_obj_exec_dt( c ); - num_t dt_c = bli_obj_dt( c ); - - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - - dim_t m = bli_obj_length( c ); - dim_t n = bli_obj_width( c ); - dim_t k = bli_obj_width( a ); - - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); - inc_t is_a = bli_obj_imag_stride( a ); - dim_t pd_a = bli_obj_panel_dim( a ); - inc_t ps_a = bli_obj_panel_stride( a ); - - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); - inc_t is_b = bli_obj_imag_stride( b ); - dim_t pd_b = bli_obj_panel_dim( b ); - inc_t ps_b = bli_obj_panel_stride( b ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; - - // Detach and multiply the scalars attached to A and B. - // NOTE: We know that the internal scalars of A and B are already of the - // target datatypes because the necessary typecasting would have already - // taken place during bli_packm_init(). - bli_obj_scalar_detach( a, &scalar_a ); - bli_obj_scalar_detach( b, &scalar_b ); - bli_mulsc( &scalar_a, &scalar_b ); - - // Grab the addresses of the internal scalar buffers for the scalar - // merged above and the scalar attached to C. - // NOTE: We know that scalar_b is of type dt_exec due to the above code - // that casts the scalars of A and B to dt_exec via scalar_a and scalar_b, - // and we know that the internal scalar in C is already of the type dt_c - // due to the casting in the implementation of bli_obj_scalar_attach(). - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); - -#if 0 - // NOTE: Turns out that this optimization will never be employed since - // currently bli_gemm_ker_var2_md() is only called when the storage - // datatype of C differs from the execution/computation datatype, and - // this optimization would only make sense if they are equal. - - // If 1m is being employed on a column- or row-stored matrix with a - // real-valued beta, we can use the real domain macro-kernel, which - // eliminates a little overhead associated with the 1m virtual - // micro-kernel. - if ( bli_cntx_method( cntx ) == BLIS_1M ) - { - // Only employ this optimization if the storage datatype of C is - // equal to the execution/computation datatype. - if ( dt_c == dt_exec ) - { - bli_gemm_ind_recast_1m_params - ( - &dt_exec, - schema_a, - c, - &m, &n, &k, - &pd_a, &ps_a, - &pd_b, &ps_b, - &rs_c, &cs_c - ); - } - } -#endif - - // Tweak parameters in select mixed domain cases (rcc, crc, ccr). - bli_gemm_md_ker_var2_recast - ( - &dt_exec, - bli_obj_dt( a ), - bli_obj_dt( b ), - bli_obj_dt( c ), - &m, &n, &k, - &pd_a, &ps_a, - &pd_b, &ps_b, - c, - &rs_c, &cs_c - ); - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_c][dt_exec]; - - // Invoke the function. - f( schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - rntm, - thread ); -} - - -#undef GENTFUNC2 -#define GENTFUNC2( ctype_c, ctype_e, chc, che, varname ) \ -\ -void PASTEMAC2(chc,che,varname) \ - ( \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dte = PASTEMAC(che,type); \ - /*const num_t dtc = PASTEMAC(chc,type);*/ \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ - /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ \ - PASTECH(che,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dte, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype_e ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_e ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dte, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype_e* restrict zero = PASTEMAC(che,0); \ - ctype_e* restrict a_cast = a; \ - ctype_e* restrict b_cast = b; \ - ctype_c* restrict c_cast = c; \ - ctype_e* restrict alpha_cast = alpha; \ - ctype_c* restrict beta_cast = beta; \ - ctype_e* restrict b1; \ - ctype_c* restrict c1; \ -\ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t i, j; \ - dim_t m_cur; \ - dim_t n_cur; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(che,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ - /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) - loop around the microkernel. Here we query the thrinfo_t node for the - 1st (ir) loop around the microkernel. */ \ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ -\ - /* Query the number of threads and thread ids for each loop. */ \ - dim_t jr_nt = bli_thread_n_way( thread ); \ - dim_t jr_tid = bli_thread_work_id( thread ); \ - dim_t ir_nt = bli_thread_n_way( caucus ); \ - dim_t ir_tid = bli_thread_work_id( caucus ); \ -\ - dim_t jr_start, jr_end; \ - dim_t ir_start, ir_end; \ - dim_t jr_inc, ir_inc; \ -\ - /* Determine the thread range and increment for the 2nd and 1st loops. - NOTE: The definition of bli_thread_range_jrir() will depend on whether - slab or round-robin partitioning was requested at configure-time. */ \ - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ - bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_start; j < jr_end; j += jr_inc ) \ - { \ - ctype_e* restrict a1; \ - ctype_c* restrict c11; \ - ctype_e* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_start; i < ir_end; i += ir_inc ) \ - { \ - ctype_e* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Always save the micropanel product to the local microtile and - then accumulate it into C via the xpbys_mxn macro. */ \ - /*if ( 1 )*/ \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the microtile of C and add the result from above. */ \ - PASTEMAC3(che,chc,chc,xpbys_mxn) \ - ( \ - m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c \ - ); \ - } \ - } \ - } \ -\ -/* -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \ -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); \ -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); \ -*/ \ -} - -INSERT_GENTFUNC2_BASIC0( gemm_ker_var2_md ) -INSERT_GENTFUNC2_MIXDP0( gemm_ker_var2_md ) - -#endif diff --git a/frame/3/gemm/bli_gemm_md.h b/frame/3/gemm/bli_gemm_md.h index 8fcf6bd21d..751e271eaf 100644 --- a/frame/3/gemm/bli_gemm_md.h +++ b/frame/3/gemm/bli_gemm_md.h @@ -154,7 +154,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast num_t* dt_comp, num_t dt_a, num_t dt_b, - num_t dt_c, + num_t* dt_c, dim_t* m, dim_t* n, dim_t* k, @@ -164,7 +164,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast inc_t* rs_c, inc_t* cs_c ) { - if ( bli_is_real( dt_c ) && + if ( bli_is_real( *dt_c ) && bli_is_complex( dt_a ) && bli_is_complex( dt_b ) ) { @@ -177,7 +177,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast *ps_a *= 2; *ps_b *= 2; } - else if ( bli_is_complex( dt_c ) && + else if ( bli_is_complex( *dt_c ) && bli_is_real( dt_a ) && bli_is_complex( dt_b ) ) { @@ -197,6 +197,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast // to the real virtual microkernel slots of the context) instead of // the complex macrokernel and c2r virtual microkernel. *dt_comp = bli_dt_proj_to_real( *dt_comp ); + *dt_c = bli_dt_proj_to_real( *dt_c ); *n *= 2; *pd_b *= 2; *ps_b *= 2; *rs_c *= 2; @@ -211,7 +212,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast *ps_a /= 2; } } - else if ( bli_is_complex( dt_c ) && + else if ( bli_is_complex( *dt_c ) && bli_is_complex( dt_a ) && bli_is_real( dt_b ) ) { @@ -231,6 +232,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast // to the real virtual microkernel slots of the context) instead of // the complex macrokernel and c2r virtual microkernel. *dt_comp = bli_dt_proj_to_real( *dt_comp ); + *dt_c = bli_dt_proj_to_real( *dt_c ); *m *= 2; *pd_a *= 2; *ps_a *= 2; *cs_c *= 2; @@ -274,54 +276,3 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast #endif } -// ----------------------------------------------------------------------------- - -// -// Prototype object-based interfaces. -// - -#undef GENPROT -#define GENPROT( opname ) \ -\ -void PASTEMAC0(opname) \ - ( \ - obj_t* a, \ - obj_t* b, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - cntl_t* cntl, \ - thrinfo_t* thread \ - ); - -GENPROT( gemm_ker_var2_md ) - -// -// Prototype BLAS-like interfaces with void pointer operands. -// - -#undef GENTPROT2 -#define GENTPROT2( ctype_c, ctype_e, chc, che, varname ) \ -\ -void PASTEMAC2(chc,che,varname) \ - ( \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ); - -INSERT_GENTPROT2_BASIC0( gemm_ker_var2_md ) -INSERT_GENTPROT2_MIXDP0( gemm_ker_var2_md ) - diff --git a/frame/3/gemm/bli_gemm_md_c2r_ref.c b/frame/3/gemm/bli_gemm_md_c2r_ref.c index 0bfb596302..bbd9190a9a 100644 --- a/frame/3/gemm/bli_gemm_md_c2r_ref.c +++ b/frame/3/gemm/bli_gemm_md_c2r_ref.c @@ -41,6 +41,8 @@ \ void PASTEMAC2(ch,opname,suf) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -61,6 +63,9 @@ void PASTEMAC2(ch,opname,suf) \ \ const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ +\ + dim_t mr_r = mr; \ + dim_t nr_r = nr; \ \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype_r ) ] \ @@ -81,6 +86,9 @@ void PASTEMAC2(ch,opname,suf) \ \ ctype_r* restrict beta_r = &PASTEMAC(ch,real)( *beta ); \ ctype_r* restrict beta_i = &PASTEMAC(ch,imag)( *beta ); \ +\ + dim_t m_use; \ + dim_t n_use; \ \ ctype_r* c_use; \ inc_t rs_c_use; \ @@ -146,17 +154,16 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ rs_c_use = rs_ct; \ cs_c_use = cs_ct; \ \ - /* Convert the strides from being in units of complex elements to - be in units of real elements. Note that we don't need to check for - general storage here because that case corresponds to the scenario - where we are using the ct buffer and its rs_ct/cs_ct strides. */ \ - if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) cs_c_use *= 2; \ - else rs_c_use *= 2; \ -\ + /* Convert the strides and corresponding microtile dimension from being + in units of complex elements to be in units of real elements. */ \ + if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) { cs_c_use *= 2; mr_r *= 2; } \ + else { rs_c_use *= 2; nr_r *= 2; }\ \ /* c = beta * c + alpha_r * a * b; */ \ rgemm_ukr \ ( \ + mr_r, \ + nr_r, \ k, \ alpha_r, \ a_r, \ @@ -166,14 +173,12 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ data, \ cntx \ ); \ -\ - dim_t i, j; \ \ /* Accumulate the final result in ct back to c. */ \ if ( PASTEMAC(ch,eq1)( *beta ) ) \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,adds)( *(ct + i*rs_ct + j*cs_ct), \ *(c + i*rs_c + j*cs_c ) ); \ @@ -181,8 +186,8 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ } \ else if ( PASTEMAC(ch,eq0)( *beta ) ) \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,copys)( *(ct + i*rs_ct + j*cs_ct), \ *(c + i*rs_c + j*cs_c ) ); \ @@ -190,8 +195,8 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ } \ else \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,xpbys)( *(ct + i*rs_ct + j*cs_ct), \ *beta, \ @@ -207,17 +212,19 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ c_use = ( ctype_r* )c; \ rs_c_use = rs_c; \ cs_c_use = cs_c; \ + m_use = m; \ + n_use = n; \ \ - /* Convert the strides from being in units of complex elements to - be in units of real elements. Note that we don't need to check for - general storage here because that case corresponds to the scenario - where we are using the ct buffer and its rs_ct/cs_ct strides. */ \ - if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) cs_c_use *= 2; \ - else rs_c_use *= 2; \ + /* Convert the strides and corresponding microtile dimension from being + in units of complex elements to be in units of real elements. */ \ + if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) { cs_c_use *= 2; m_use *= 2; } \ + else { rs_c_use *= 2; n_use *= 2; } \ \ /* c = beta * c + alpha_r * a * b; */ \ rgemm_ukr \ ( \ + m_use, \ + n_use, \ k, \ alpha_r, \ a_r, \ diff --git a/frame/3/gemm/bli_gemm_var.h b/frame/3/gemm/bli_gemm_var.h index e7befc5b46..888181bad6 100644 --- a/frame/3/gemm/bli_gemm_var.h +++ b/frame/3/gemm/bli_gemm_var.h @@ -34,6 +34,16 @@ */ +// +// gemm kernel parameter struct. +// + +typedef struct +{ + gemm_ukr_vft ukr; +} gemm_ker_params_t; + + // // Prototype object-based interfaces. // @@ -59,32 +69,3 @@ GENPROT( gemm_blk_var3 ) GENPROT( gemm_ker_var1 ) GENPROT( gemm_ker_var2 ) - -// -// Prototype BLAS-like interfaces with void pointer operands. -// - -#undef GENTPROT -#define GENTPROT( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ); - -INSERT_GENTPROT_BASIC0( gemm_ker_var2 ) - diff --git a/frame/3/gemm/ind/bli_gemm_ind_opt.h b/frame/3/gemm/ind/bli_gemm_ind_opt.h index 7528c4f03e..52ea81a5e8 100644 --- a/frame/3/gemm/ind/bli_gemm_ind_opt.h +++ b/frame/3/gemm/ind/bli_gemm_ind_opt.h @@ -35,6 +35,7 @@ BLIS_INLINE void bli_gemm_ind_recast_1m_params ( num_t* dt_exec, + num_t* dt_c, pack_t schema_a, obj_t* c, dim_t* m, @@ -57,6 +58,7 @@ BLIS_INLINE void bli_gemm_ind_recast_1m_params !bli_is_gen_stored( *rs_c, *cs_c ) ) { *dt_exec = bli_dt_proj_to_real( *dt_exec ); + *dt_c = bli_dt_proj_to_real( *dt_c ); if ( bli_is_1e_packed( schema_a ) ) { diff --git a/frame/3/gemmt/bli_gemmt_l_ker_var2.c b/frame/3/gemmt/bli_gemmt_l_ker_var2.c index a995e6c521..fea4efec0a 100644 --- a/frame/3/gemmt/bli_gemmt_l_ker_var2.c +++ b/frame/3/gemmt/bli_gemmt_l_ker_var2.c @@ -279,6 +279,9 @@ void PASTEMAC(ch,varname) \ /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ bli_auxinfo_set_is_a( is_a, &aux ); \ bli_auxinfo_set_is_b( is_b, &aux ); \ +\ + /* Save the desired output datatype (indicating no typecasting). */ \ + /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) loop around the microkernel. Here we query the thrinfo_t node for the @@ -381,43 +384,20 @@ void PASTEMAC(ch,varname) \ And if we're strictly above the diagonal, we do nothing and continue. */ \ { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ @@ -490,6 +470,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ + MR, \ + NR, \ k, \ alpha_cast, \ a1, \ @@ -509,43 +491,20 @@ void PASTEMAC(ch,varname) \ } \ else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ diff --git a/frame/3/gemmt/bli_gemmt_u_ker_var2.c b/frame/3/gemmt/bli_gemmt_u_ker_var2.c index 3115fc67b5..4b849bbc6d 100644 --- a/frame/3/gemmt/bli_gemmt_u_ker_var2.c +++ b/frame/3/gemmt/bli_gemmt_u_ker_var2.c @@ -281,6 +281,9 @@ void PASTEMAC(ch,varname) \ /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ bli_auxinfo_set_is_a( is_a, &aux ); \ bli_auxinfo_set_is_b( is_b, &aux ); \ +\ + /* Save the desired output datatype (indicating no typecasting). */ \ + /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) loop around the microkernel. Here we query the thrinfo_t node for the @@ -385,6 +388,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ + MR, \ + NR, \ k, \ alpha_cast, \ a1, \ @@ -404,43 +409,20 @@ void PASTEMAC(ch,varname) \ } \ else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ @@ -512,43 +494,20 @@ void PASTEMAC(ch,varname) \ And if we're strictly below the diagonal, we do nothing and continue. */ \ { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ diff --git a/frame/3/trmm/bli_trmm_ll_ker_var2.c b/frame/3/trmm/bli_trmm_ll_ker_var2.c index 792281b530..646287f931 100644 --- a/frame/3/trmm/bli_trmm_ll_ker_var2.c +++ b/frame/3/trmm/bli_trmm_ll_ker_var2.c @@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \ function pointer type. */ \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ ctype* restrict one = PASTEMAC(ch,1); \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ ctype* restrict c_cast = c; \ @@ -254,10 +242,6 @@ void PASTEMAC(ch,varname) \ diagoffa = 0; \ c_cast = c_cast + (i )*rs_c; \ } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ \ /* Compute number of primary and leftover components of the m and n dimensions. */ \ @@ -307,8 +291,8 @@ void PASTEMAC(ch,varname) \ dim_t jr_inc; \ \ /* Determine the thread range and increment for the 2nd loop. - NOTE: The definition of bli_thread_range_jrir() will depend on whether - slab or round-robin partitioning was requested at configure-time. \ + NOTE: The definition of bli_thread_range_jrir() will depend on whether + slab or round-robin partitioning was requested at configure-time. \ NOTE: Parallelism in the 1st loop is disabled for now. */ \ bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ /*bli_thread_range_jrir_rr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc );*/ \ @@ -379,47 +363,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_a1011, \ - alpha_cast, \ - a1, \ - b1_i, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Copy edge elements of C to the temporary buffer. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - c11, rs_c, cs_c, \ - ct, rs_ct, cs_ct ); \ -\ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_a1011, \ - alpha_cast, \ - a1, \ - b1_i, \ - beta_cast, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Copy the result to the edge of C. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k_a1011, \ + alpha_cast, \ + a1, \ + b1_i, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ /*}*/ \ \ a1 += ps_a_cur; \ @@ -446,42 +403,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - one, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + one, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ /*}*/ \ \ a1 += rstep_a; \ diff --git a/frame/3/trmm/bli_trmm_lu_ker_var2.c b/frame/3/trmm/bli_trmm_lu_ker_var2.c index 69498540b7..9ef2a475de 100644 --- a/frame/3/trmm/bli_trmm_lu_ker_var2.c +++ b/frame/3/trmm/bli_trmm_lu_ker_var2.c @@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \ function pointer type. */ \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ ctype* restrict one = PASTEMAC(ch,1); \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ ctype* restrict c_cast = c; \ @@ -261,10 +249,6 @@ void PASTEMAC(ch,varname) \ { \ m = -diagoffa + k; \ } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ \ /* Compute number of primary and leftover components of the m and n dimensions. */ \ @@ -386,47 +370,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_a1112, \ - alpha_cast, \ - a1, \ - b1_i, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Copy edge elements of C to the temporary buffer. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - c11, rs_c, cs_c, \ - ct, rs_ct, cs_ct ); \ -\ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_a1112, \ - alpha_cast, \ - a1, \ - b1_i, \ - beta_cast, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Copy the result to the edge of C. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k_a1112, \ + alpha_cast, \ + a1, \ + b1_i, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ /*}*/ \ \ a1 += ps_a_cur; \ @@ -453,42 +410,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - one, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + one, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ /*}*/ \ \ a1 += rstep_a; \ diff --git a/frame/3/trmm/bli_trmm_rl_ker_var2.c b/frame/3/trmm/bli_trmm_rl_ker_var2.c index 03e3f1e531..f6b20af2e5 100644 --- a/frame/3/trmm/bli_trmm_rl_ker_var2.c +++ b/frame/3/trmm/bli_trmm_rl_ker_var2.c @@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \ function pointer type. */ \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ ctype* restrict one = PASTEMAC(ch,1); \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ ctype* restrict c_cast = c; \ @@ -261,10 +249,6 @@ void PASTEMAC(ch,varname) \ { \ n = diagoffb + k; \ } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ \ /* Compute number of primary and leftover components of the m and n dimensions. */ \ @@ -335,9 +319,9 @@ void PASTEMAC(ch,varname) \ \ /* Determine the thread range and increment for the 2nd and 1st loops for the initial rectangular region of B (if it exists). - NOTE: The definition of bli_thread_range_jrir() will depend on whether - slab or round-robin partitioning was requested at configure-time. \ - NOTE: Parallelism in the 1st loop is disabled for now. */ \ + NOTE: The definition of bli_thread_range_jrir() will depend on whether + slab or round-robin partitioning was requested at configure-time. \ + NOTE: Parallelism in the 1st loop is disabled for now. */ \ bli_thread_range_jrir( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ \ @@ -382,42 +366,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - one, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + one, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ @@ -501,47 +463,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_b1121, \ - alpha_cast, \ - a1_i, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Copy edge elements of C to the temporary buffer. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - c11, rs_c, cs_c, \ - ct, rs_ct, cs_ct ); \ -\ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_b1121, \ - alpha_cast, \ - a1_i, \ - b1, \ - beta_cast, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Copy the result to the edge of C. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k_b1121, \ + alpha_cast, \ + a1_i, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ \ a1 += rstep_a; \ diff --git a/frame/3/trmm/bli_trmm_ru_ker_var2.c b/frame/3/trmm/bli_trmm_ru_ker_var2.c index 5d63bd46df..f71fb3c4d8 100644 --- a/frame/3/trmm/bli_trmm_ru_ker_var2.c +++ b/frame/3/trmm/bli_trmm_ru_ker_var2.c @@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \ function pointer type. */ \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ ctype* restrict one = PASTEMAC(ch,1); \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ ctype* restrict c_cast = c; \ @@ -262,10 +250,6 @@ void PASTEMAC(ch,varname) \ { \ k = -diagoffb + n; \ } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ \ /* Compute number of primary and leftover components of the m and n dimensions. */ \ @@ -410,47 +394,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_b0111, \ - alpha_cast, \ - a1_i, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Copy edge elements of C to the temporary buffer. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - c11, rs_c, cs_c, \ - ct, rs_ct, cs_ct ); \ -\ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_b0111, \ - alpha_cast, \ - a1_i, \ - b1, \ - beta_cast, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Copy the result to the edge of C. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k_b0111, \ + alpha_cast, \ + a1_i, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ \ a1 += rstep_a; \ @@ -476,9 +433,9 @@ void PASTEMAC(ch,varname) \ bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ \ /* Advance the start and end iteration offsets for the rectangular region - by the number of iterations used for the triangular region. */ \ - jr_start += n_iter_tri; \ - jr_end += n_iter_tri; \ + by the number of iterations used for the triangular region. */ \ + jr_start += n_iter_tri; \ + jr_end += n_iter_tri; \ jb0 = n_iter_tri; \ \ /* Save the resulting value of b1 from the previous loop since it represents @@ -496,7 +453,7 @@ void PASTEMAC(ch,varname) \ the starting address of the rectangular region (which is already n_iter_tri logical iterations through B). */ \ b1 = b_cast + (j-jb0) * cstep_b; \ - c1 = c_cast + j * cstep_c; \ + c1 = c_cast + j * cstep_c; \ \ n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ \ @@ -533,42 +490,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - one, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + one, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ diff --git a/frame/3/trsm/bli_trsm_cntl.c b/frame/3/trsm/bli_trsm_cntl.c index a8196ebb93..0a3be87f74 100644 --- a/frame/3/trsm/bli_trsm_cntl.c +++ b/frame/3/trsm/bli_trsm_cntl.c @@ -40,27 +40,30 @@ cntl_t* bli_trsm_cntl_create rntm_t* rntm, side_t side, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ) { if ( bli_is_left( side ) ) - return bli_trsm_l_cntl_create( rntm, schema_a, schema_b ); + return bli_trsm_l_cntl_create( rntm, schema_a, schema_b, ker ); else - return bli_trsm_r_cntl_create( rntm, schema_a, schema_b ); + return bli_trsm_r_cntl_create( rntm, schema_a, schema_b, ker ); } cntl_t* bli_trsm_l_cntl_create ( rntm_t* rntm, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ) { void_fp macro_kernel_p; - // Use the function pointer to the macrokernels that use slab - // assignment of micropanels to threads in the jr and ir loops. + // Set the default macrokernel. If a non-NULL kernel function pointer is + // passed in, we use that instead. macro_kernel_p = bli_trsm_xx_ker_var2; + if ( ker ) macro_kernel_p = ker; const opid_t family = BLIS_TRSM; @@ -202,11 +205,15 @@ cntl_t* bli_trsm_r_cntl_create ( rntm_t* rntm, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ) { // NOTE: trsm macrokernels are presently disabled for right-side execution. + // Set the default macrokernel. If a non-NULL kernel function pointer is + // passed in, we use that instead. void_fp macro_kernel_p = bli_trsm_xx_ker_var2; + if ( ker ) macro_kernel_p = ker; const opid_t family = BLIS_TRSM; diff --git a/frame/3/trsm/bli_trsm_cntl.h b/frame/3/trsm/bli_trsm_cntl.h index 7fdb1fc4f6..86f4a29b2a 100644 --- a/frame/3/trsm/bli_trsm_cntl.h +++ b/frame/3/trsm/bli_trsm_cntl.h @@ -38,21 +38,24 @@ cntl_t* bli_trsm_cntl_create rntm_t* rntm, side_t side, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ); cntl_t* bli_trsm_l_cntl_create ( rntm_t* rntm, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ); cntl_t* bli_trsm_r_cntl_create ( rntm_t* rntm, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ); void bli_trsm_cntl_free diff --git a/frame/3/trsm/bli_trsm_ll_ker_var2.c b/frame/3/trsm/bli_trsm_ll_ker_var2.c index dec41301ac..b503efa5bf 100644 --- a/frame/3/trsm/bli_trsm_ll_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ll_ker_var2.c @@ -183,7 +183,6 @@ void PASTEMAC(ch,varname) \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict minus_one = PASTEMAC(ch,m1); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ @@ -470,43 +469,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - a1, \ - b1, \ - alpha2_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - alpha2_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + minus_one, \ + a1, \ + b1, \ + alpha2_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ \ a1 += rstep_a; \ } \ diff --git a/frame/3/trsm/bli_trsm_lu_ker_var2.c b/frame/3/trsm/bli_trsm_lu_ker_var2.c index 1627a12a39..55ceafb91d 100644 --- a/frame/3/trsm/bli_trsm_lu_ker_var2.c +++ b/frame/3/trsm/bli_trsm_lu_ker_var2.c @@ -183,7 +183,6 @@ void PASTEMAC(ch,varname) \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict minus_one = PASTEMAC(ch,m1); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ @@ -480,43 +479,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - a1, \ - b1, \ - alpha2_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - alpha2_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + minus_one, \ + a1, \ + b1, \ + alpha2_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ \ a1 += rstep_a; \ } \ diff --git a/frame/3/trsm/bli_trsm_rl_ker_var2.c b/frame/3/trsm/bli_trsm_rl_ker_var2.c index 8cbc26b36a..23d4dd7289 100644 --- a/frame/3/trsm/bli_trsm_rl_ker_var2.c +++ b/frame/3/trsm/bli_trsm_rl_ker_var2.c @@ -188,7 +188,6 @@ void PASTEMAC(ch,varname) \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict minus_one = PASTEMAC(ch,m1); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ @@ -499,43 +498,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( b2, &aux ); \ bli_auxinfo_set_next_b( a2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - b1, \ - a1, \ - alpha2_cast, \ - c11, cs_c, rs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - b1, \ - a1, \ - zero, \ - ct, cs_ct, rs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - alpha2_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + minus_one, \ + b1, \ + a1, \ + alpha2_cast, \ + c11, cs_c, rs_c, \ + &aux, \ + cntx \ + ); \ } \ \ a1 += rstep_a; \ diff --git a/frame/3/trsm/bli_trsm_ru_ker_var2.c b/frame/3/trsm/bli_trsm_ru_ker_var2.c index 97399d0ae0..71381707c4 100644 --- a/frame/3/trsm/bli_trsm_ru_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ru_ker_var2.c @@ -188,7 +188,6 @@ void PASTEMAC(ch,varname) \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict minus_one = PASTEMAC(ch,m1); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ @@ -492,43 +491,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( b2, &aux ); \ bli_auxinfo_set_next_b( a2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - b1, \ - a1, \ - alpha2_cast, \ - c11, cs_c, rs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - b1, \ - a1, \ - zero, \ - ct, cs_ct, rs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - alpha2_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + minus_one, \ + b1, \ + a1, \ + alpha2_cast, \ + c11, cs_c, rs_c, \ + &aux, \ + cntx \ + ); \ } \ \ a1 += rstep_a; \ diff --git a/frame/base/bli_auxinfo.h b/frame/base/bli_auxinfo.h index 68b6cc7cd6..d8c6cbb13f 100644 --- a/frame/base/bli_auxinfo.h +++ b/frame/base/bli_auxinfo.h @@ -74,6 +74,15 @@ BLIS_INLINE inc_t bli_auxinfo_ps_b( auxinfo_t* ai ) return ai->ps_b; } +BLIS_INLINE void_fp bli_auxinfo_ukr( auxinfo_t* ai ) +{ + return ai->ukr; +} +BLIS_INLINE void* bli_auxinfo_params( auxinfo_t* ai ) +{ + return ai->params; +} + // auxinfo_t field modification @@ -118,5 +127,14 @@ BLIS_INLINE void bli_auxinfo_set_ps_b( inc_t ps, auxinfo_t* ai ) ai->ps_b = ps; } -#endif +BLIS_INLINE void bli_auxinfo_set_ukr( void_fp ukr, auxinfo_t* ai ) +{ + ai->ukr = ukr; +} +BLIS_INLINE void bli_auxinfo_set_params( void* params, auxinfo_t* ai ) +{ + ai->params = params; +} + +#endif diff --git a/frame/include/bli_edge_case_macro_defs.h b/frame/include/bli_edge_case_macro_defs.h new file mode 100644 index 0000000000..242045a029 --- /dev/null +++ b/frame/include/bli_edge_case_macro_defs.h @@ -0,0 +1,109 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_EDGE_CASE_MACRO_DEFS_H +#define BLIS_EDGE_CASE_MACRO_DEFS_H + + +// Helper macros for edge-case handling within gemm microkernels. + +#define GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major) \ +\ + PASTEMAC(ch,ctype)* restrict _beta = beta; \ + PASTEMAC(ch,ctype)* restrict _c = c; \ + const inc_t _rs_c = rs_c; \ + const inc_t _cs_c = cs_c; \ + PASTEMAC(ch,ctype) _ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( PASTEMAC(ch,type) ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const inc_t _rs_ct = row_major ? nr : 1; \ + const inc_t _cs_ct = row_major ? 1 : mr; + +#define GEMM_UKR_SETUP_CT_POST(ch) \ +\ + PASTEMAC(ch,ctype) _zero; \ + PASTEMAC(ch,set0s)( _zero ); \ + \ + if ( _use_ct ) \ + { \ + c = _ct; \ + rs_c = _rs_ct; \ + cs_c = _cs_ct; \ + beta = &_zero; \ + } + +#define GEMM_UKR_SETUP_CT(ch,mr,nr,row_major) \ +\ + GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ + const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \ + m != mr || n != nr; \ + GEMM_UKR_SETUP_CT_POST(ch); + +#define GEMM_UKR_SETUP_CT_AMBI(ch,mr,nr,row_major) \ +\ + GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ + const bool _use_ct = ( cs_c != 1 && rs_c != 1 ) || \ + m != mr || n != nr; \ + GEMM_UKR_SETUP_CT_POST(ch); + +#define GEMM_UKR_SETUP_CT_ANY(ch,mr,nr,row_major) \ +\ + GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ + const bool _use_ct = m != mr || n != nr; \ + GEMM_UKR_SETUP_CT_POST(ch); + +#define GEMM_UKR_SETUP_CT_ALIGNED(ch,mr,nr,row_major,alignment) \ +\ + GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ + const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \ + m != mr || n != nr || \ + ( (uintptr_t)_c % alignment ) || \ + ( ( ( row_major ? _rs_c : _cs_c )*sizeof( PASTEMAC(ch,ctype) ) ) % alignment ); \ + GEMM_UKR_SETUP_CT_POST(ch); + +#define GEMM_UKR_FLUSH_CT(ch) \ +\ + if ( _use_ct ) \ + { \ + PASTEMAC(ch,xpbys_mxn) \ + ( \ + m, n, \ + _ct, _rs_ct, _cs_ct, \ + _beta, \ + _c, _rs_c, _cs_c \ + ); \ + } \ + + +#endif + diff --git a/frame/include/bli_macro_defs.h b/frame/include/bli_macro_defs.h index 03451d4407..be45a12e3f 100644 --- a/frame/include/bli_macro_defs.h +++ b/frame/include/bli_macro_defs.h @@ -98,6 +98,7 @@ #include "bli_gentprot_macro_defs.h" #include "bli_misc_macro_defs.h" +#include "bli_edge_case_macro_defs.h" #include "bli_param_macro_defs.h" #include "bli_obj_macro_defs.h" #include "bli_complex_macro_defs.h" diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index 5be0ceeb42..c66505bde8 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -1144,6 +1144,13 @@ typedef struct inc_t ps_a; inc_t ps_b; + // The type to convert to on output. + //num_t dt_on_output; + + // (Virtual) microkernel address and additional parameters. + void_fp ukr; + void* params; + } auxinfo_t; diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c index 66337e0b73..913abd1f6c 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c @@ -42,9 +42,13 @@ // 2vx10 microkernels. #include "armsve_asm_2vx10cmplx.h" +#include "arm_sve.h" + void bli_cgemm_armsve_asm_2vx10_unindexed ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -59,12 +63,15 @@ void bli_cgemm_armsve_asm_2vx10_unindexed // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_mker = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_mker = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; uint64_t info = 0; + uint64_t mr = svcntw(); + GEMM_UKR_SETUP_CT( c, mr, 10, false ); + __asm__ volatile ( // " ldr x0, %[a] \n\t" // " ldr x1, %[b] \n\t" @@ -310,5 +317,7 @@ GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16) "z24","z25","z26","z27", "z28","z29","z30","z31" ); + + GEMM_UKR_FLUSH_CT( c ); } diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c index e5b78a5921..9730fb8ce3 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c @@ -42,9 +42,13 @@ // 2vx10 microkernels. #include "armsve_asm_2vx10.h" +#include "arm_sve.h" + void bli_dgemm_armsve_asm_2vx10_unindexed ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -59,11 +63,14 @@ void bli_dgemm_armsve_asm_2vx10_unindexed // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_mker = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_mker = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t mr = 2*svcntd(); + GEMM_UKR_SETUP_CT( d, mr, 10, false ); + __asm__ volatile ( " ldr x0, %[a] \n\t" " ldr x1, %[b] \n\t" @@ -324,5 +331,7 @@ GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p0,x5,x7,x8,x "z24","z25","z26","z27", "z28","z29","z30","z31" ); + + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c index 00b3f20b44..74c4779d73 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c @@ -42,9 +42,13 @@ // 2vx10 microkernels. #include "armsve_asm_2vx10.h" +#include "arm_sve.h" + void bli_sgemm_armsve_asm_2vx10_unindexed ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -59,11 +63,14 @@ void bli_sgemm_armsve_asm_2vx10_unindexed // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_mker = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_mker = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t mr = 2*svcntw(); + GEMM_UKR_SETUP_CT( s, mr, 10, false ); + __asm__ volatile ( " ldr x0, %[a] \n\t" " ldr x1, %[b] \n\t" @@ -310,5 +317,7 @@ GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p0,x5,x7,x8,x "z24","z25","z26","z27", "z28","z29","z30","z31" ); + + GEMM_UKR_FLUSH_CT( s ); } diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c index 2fa37664ae..ee041b3c40 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c @@ -42,9 +42,13 @@ // 2vx10 microkernels. #include "armsve_asm_2vx10cmplx.h" +#include "arm_sve.h" + void bli_zgemm_armsve_asm_2vx10_unindexed ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -59,12 +63,15 @@ void bli_zgemm_armsve_asm_2vx10_unindexed // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_mker = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_mker = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; uint64_t info = 0; + uint64_t mr = svcntd(); + GEMM_UKR_SETUP_CT( z, mr, 10, false ); + __asm__ volatile ( // " ldr x0, %[a] \n\t" // " ldr x1, %[b] \n\t" @@ -309,5 +316,7 @@ GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16) "z24","z25","z26","z27", "z28","z29","z30","z31" ); + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c index 3d25719d92..641944ecd4 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c @@ -42,9 +42,13 @@ // 2vx7 microkernels. #include "armsve_asm_2vx7cmplx.h" +#include "arm_sve.h" + void bli_zgemm_armsve_asm_2vx7_unindexed ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -59,12 +63,15 @@ void bli_zgemm_armsve_asm_2vx7_unindexed // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_mker = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_mker = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; uint64_t info = 0; + uint64_t mr = svcntd(); + GEMM_UKR_SETUP_CT( z, mr, 7, false ); + __asm__ volatile ( // " ldr x0, %[a] \n\t" // " ldr x1, %[b] \n\t" @@ -261,6 +268,8 @@ GEMM_CCMPLX_STORE_COL7_G(z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27 "z24","z25","z26","z27", "z28","z29","z30","z31" ); + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c index d0eef4a8ca..4272f72c02 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c @@ -42,9 +42,13 @@ // 2vx8 microkernels. #include "armsve_asm_2vx8cmplx.h" +#include "arm_sve.h" + void bli_zgemm_armsve_asm_2vx8_unindexed ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -59,12 +63,15 @@ void bli_zgemm_armsve_asm_2vx8_unindexed // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_mker = k0 / 6; - uint64_t k_left = k0 % 6; + uint64_t k_mker = k / 6; + uint64_t k_left = k % 6; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; uint64_t info = 0; + uint64_t mr = svcntd(); + GEMM_UKR_SETUP_CT( z, mr, 8, false ); + __asm__ volatile ( // " ldr x0, %[a] \n\t" // " ldr x1, %[b] \n\t" @@ -286,5 +293,7 @@ GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z16,%2,%4,x16) "z24","z25","z26","z27", "z28","z29","z30","z31" ); + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c b/kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c index b526cd0951..c248285c38 100644 --- a/kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c +++ b/kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c @@ -48,23 +48,23 @@ void bli_sgemm_armv7a_ker_4x4 void bli_sgemm_armv7a_asm_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, float* restrict beta, - float* restrict c, inc_t rs_c0, inc_t cs_c0, + float* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint32_t k = k0; - uint32_t rs_c = rs_c0; - uint32_t cs_c = cs_c0; - + GEMM_UKR_SETUP_CT_ANY( s, 4, 4, false ); bli_sgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data ); + GEMM_UKR_FLUSH_CT( s ); } @@ -83,23 +83,23 @@ void bli_dgemm_armv7a_ker_4x4 void bli_dgemm_armv7a_asm_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint32_t k = k0; - uint32_t rs_c = rs_c0; - uint32_t cs_c = cs_c0; - + GEMM_UKR_SETUP_CT_ANY( d, 4, 4, false ); bli_dgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data ); + GEMM_UKR_FLUSH_CT( d ); } @@ -118,23 +118,23 @@ void bli_cgemm_armv7a_ker_2x2 void bli_cgemm_armv7a_asm_2x2 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, scomplex* restrict beta, - scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + scomplex* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint32_t k = k0; - uint32_t rs_c = rs_c0; - uint32_t cs_c = cs_c0; - + GEMM_UKR_SETUP_CT_ANY( c, 2, 2, false ); bli_cgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data ); + GEMM_UKR_FLUSH_CT( c ); } @@ -153,22 +153,22 @@ void bli_zgemm_armv7a_ker_2x2 void bli_zgemm_armv7a_asm_2x2 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, dcomplex* restrict beta, - dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + dcomplex* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint32_t k = k0; - uint32_t rs_c = rs_c0; - uint32_t cs_c = cs_c0; - + GEMM_UKR_SETUP_CT_ANY( z, 2, 2, false ); bli_zgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data ); + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/armv7a/3/bli_gemm_armv7a_int_d4x4.c b/kernels/armv7a/3/bli_gemm_armv7a_int_d4x4.c index b9db587266..06f36a3463 100644 --- a/kernels/armv7a/3/bli_gemm_armv7a_int_d4x4.c +++ b/kernels/armv7a/3/bli_gemm_armv7a_int_d4x4.c @@ -37,7 +37,9 @@ void bli_sgemm_armv7a_int_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -49,12 +51,14 @@ void bli_sgemm_armv7a_int_4x4 { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint32_t k_iter = k0 / 4; - uint32_t k_left = k0 % 4; + uint32_t k_iter = k / 4; + uint32_t k_left = k % 4; uint32_t rs_c = rs_c0; uint32_t cs_c = cs_c0; uint32_t i; + GEMM_UKR_SETUP_CT( s, 4, 4, false ); + void* a_next = bli_auxinfo_next_a( data ); void* b_next = bli_auxinfo_next_b( data ); @@ -82,47 +86,17 @@ void bli_sgemm_armv7a_int_4x4 if ( *beta != 0.0F ) { - if ( rs_c == 1 ) - { - // Load column 0 - cv0 = vld1q_f32( c + 0*rs_c + 0*cs_c ); - - // Load column 1 - cv1 = vld1q_f32( c + 0*rs_c + 1*cs_c ); - - // Load column 2 - cv2 = vld1q_f32( c + 0*rs_c + 2*cs_c ); - - // Load column 3 - cv3 = vld1q_f32( c + 0*rs_c + 3*cs_c ); - } - else - { - // Load column 0 - cv0 = vld1q_lane_f32( c + 0*rs_c + 0*cs_c, cv0, 0); - cv0 = vld1q_lane_f32( c + 1*rs_c + 0*cs_c, cv0, 1); - cv0 = vld1q_lane_f32( c + 2*rs_c + 0*cs_c, cv0, 2); - cv0 = vld1q_lane_f32( c + 3*rs_c + 0*cs_c, cv0, 3); - - // Load column 1 - cv1 = vld1q_lane_f32( c + 0*rs_c + 1*cs_c, cv1, 0); - cv1 = vld1q_lane_f32( c + 1*rs_c + 1*cs_c, cv1, 1); - cv1 = vld1q_lane_f32( c + 2*rs_c + 1*cs_c, cv1, 2); - cv1 = vld1q_lane_f32( c + 3*rs_c + 1*cs_c, cv1, 3); - - // Load column 2 - cv2 = vld1q_lane_f32( c + 0*rs_c + 2*cs_c, cv2, 0); - cv2 = vld1q_lane_f32( c + 1*rs_c + 2*cs_c, cv2, 1); - cv2 = vld1q_lane_f32( c + 2*rs_c + 2*cs_c, cv2, 2); - cv2 = vld1q_lane_f32( c + 3*rs_c + 2*cs_c, cv2, 3); - - // Load column 3 - cv3 = vld1q_lane_f32( c + 0*rs_c + 3*cs_c, cv3, 0); - cv3 = vld1q_lane_f32( c + 1*rs_c + 3*cs_c, cv3, 1); - cv3 = vld1q_lane_f32( c + 2*rs_c + 3*cs_c, cv3, 2); - cv3 = vld1q_lane_f32( c + 3*rs_c + 3*cs_c, cv3, 3); - - } + // Load column 0 + cv0 = vld1q_f32( c + 0*cs_c ); + + // Load column 1 + cv1 = vld1q_f32( c + 1*cs_c ); + + // Load column 2 + cv2 = vld1q_f32( c + 2*cs_c ); + + // Load column 3 + cv3 = vld1q_f32( c + 3*cs_c ); } else { @@ -255,47 +229,22 @@ void bli_sgemm_armv7a_int_4x4 cv3 = vmlaq_f32( cv3, abv3, alphav ); } - if ( rs_c == 1 ) - { - // Store column 0 - vst1q_f32( c + 0*rs_c + 0*cs_c, cv0 ); - // Store column 1 - vst1q_f32( c + 0*rs_c + 1*cs_c, cv1 ); - // Store column 2 - vst1q_f32( c + 0*rs_c + 2*cs_c, cv2 ); - // Store column 3 - vst1q_f32( c + 0*rs_c + 3*cs_c, cv3 ); - } - else - { - // Store column 0 - vst1q_lane_f32( c + 0*rs_c + 0*cs_c, cv0, 0); - vst1q_lane_f32( c + 1*rs_c + 0*cs_c, cv0, 1); - vst1q_lane_f32( c + 2*rs_c + 0*cs_c, cv0, 2); - vst1q_lane_f32( c + 3*rs_c + 0*cs_c, cv0, 3); - - // Store column 1 - vst1q_lane_f32( c + 0*rs_c + 1*cs_c, cv1, 0); - vst1q_lane_f32( c + 1*rs_c + 1*cs_c, cv1, 1); - vst1q_lane_f32( c + 2*rs_c + 1*cs_c, cv1, 2); - vst1q_lane_f32( c + 3*rs_c + 1*cs_c, cv1, 3); - - // Store column 2 - vst1q_lane_f32( c + 0*rs_c + 2*cs_c, cv2, 0); - vst1q_lane_f32( c + 1*rs_c + 2*cs_c, cv2, 1); - vst1q_lane_f32( c + 2*rs_c + 2*cs_c, cv2, 2); - vst1q_lane_f32( c + 3*rs_c + 2*cs_c, cv2, 3); - - // Store column 3 - vst1q_lane_f32( c + 0*rs_c + 3*cs_c, cv3, 0); - vst1q_lane_f32( c + 1*rs_c + 3*cs_c, cv3, 1); - vst1q_lane_f32( c + 2*rs_c + 3*cs_c, cv3, 2); - vst1q_lane_f32( c + 3*rs_c + 3*cs_c, cv3, 3); - } + // Store column 0 + vst1q_f32( c + 0*cs_c, cv0 ); + // Store column 1 + vst1q_f32( c + 1*cs_c, cv1 ); + // Store column 2 + vst1q_f32( c + 2*cs_c, cv2 ); + // Store column 3 + vst1q_f32( c + 3*cs_c, cv3 ); + + GEMM_UKR_FLUSH_CT( s ); } void bli_dgemm_armv7a_int_4x4 ( + dim_t m, + dim_t n, dim_t k, double* restrict alpha, double* restrict a, @@ -314,6 +263,8 @@ void bli_dgemm_armv7a_int_4x4 uint32_t cs_c = cs_c0; uint32_t i; + GEMM_UKR_SETUP_CT_ANY( d, 4, 4, false ); + //void* a_next = bli_auxinfo_next_a( data ); //void* b_next = bli_auxinfo_next_b( data ); @@ -568,5 +519,7 @@ void bli_dgemm_armv7a_int_4x4 *c23 += ab23 * *alpha; *c33 += ab33 * *alpha; } + + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c b/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c index dfdda863b1..7b420f202f 100644 --- a/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c +++ b/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c @@ -1,4 +1,4 @@ - /* + /* BLIS An object-based framework for developing high-performance BLAS-like @@ -40,20 +40,22 @@ o 4x4 Single precision micro-kernel fully functional. o Runnable on ARMv8, compiled with aarch64 GCC. o Use it together with the armv8 BLIS configuration. - o Tested on Juno board. Around 7.3 GFLOPS @ 1.1 GHz. + o Tested on Juno board. Around 7.3 GFLOPS @ 1.1 GHz. December 2014. - + * UPDATE NOVEMBER 2015 * Micro-kernel changed to 8x12 * Tested on Juno Board. Around 8.1 GFLOPS, 1 x A57 core @ 1.1 GHz. * Tested on Juno Board. Around 15.9 GFLOPS, 2 x A57 cores @ 1.1 GHz. - * Tested on Juno board. Around 3.1 GFLOPS, 1 x A53 core @ 850 MHz. + * Tested on Juno board. Around 3.1 GFLOPS, 1 x A53 core @ 850 MHz. * Tested on Juno board. Around 12 GFLOPS, 4 x A53 cores @ 850 MHz. */ void bli_sgemm_armv8a_asm_8x12 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -68,1020 +70,1023 @@ void bli_sgemm_armv8a_asm_8x12 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( s, 8, 12, false ); -__asm__ volatile -( -" \n\t" -" \n\t" -" ldr x0,%[aaddr] \n\t" // Load address of A. -" ldr x1,%[baddr] \n\t" // Load address of B. -" ldr x2,%[caddr] \n\t" // Load address of C. -" \n\t" -" ldr x5,%[k_iter] \n\t" // Number of unrolled iterations (k_iter). -" ldr x6,%[k_left] \n\t" // Number of remaining iterations (k_left). -" \n\t" -" ldr x10,%[cs_c] \n\t" // Load cs_c. -" lsl x10,x10,#2 \n\t" // cs_c * sizeof(float) -- AUX. -" \n\t" -" ldr x14,%[rs_c] \n\t" // Load rs_c. -" lsl x14,x14,#2 \n\t" // rs_c * sizeof(float). -" \n\t" -" add x16,x2,x10 \n\t" //Load address Column 1 of C -" add x17,x16,x10 \n\t" //Load address Column 2 of C -" add x19,x17,x10 \n\t" //Load address Column 3 of C -" add x20,x19,x10 \n\t" //Load address Column 4 of C -" add x21,x20,x10 \n\t" //Load address Column 5 of C -" add x22,x21,x10 \n\t" //Load address Column 6 of C -" add x23,x22,x10 \n\t" //Load address Column 7 of C -" add x24,x23,x10 \n\t" //Load address Column 8 of C -" add x25,x24,x10 \n\t" //Load address Column 9 of C -" add x26,x25,x10 \n\t" //Load address Column 10 of C -" add x27,x26,x10 \n\t" //Load address Column 11 of C -" \n\t" -" prfm pldl1keep,[x2] \n\t" // Prefetch c. -" prfm pldl1keep,[x16] \n\t" // Prefetch c. -" prfm pldl1keep,[x17] \n\t" // Prefetch c. -" prfm pldl1keep,[x19] \n\t" // Prefetch c. -" prfm pldl1keep,[x20] \n\t" // Prefetch c. -" prfm pldl1keep,[x21] \n\t" // Prefetch c. -" prfm pldl1keep,[x22] \n\t" // Prefetch c. -" prfm pldl1keep,[x23] \n\t" // Prefetch c. -" prfm pldl1keep,[x24] \n\t" // Prefetch c. -" prfm pldl1keep,[x25] \n\t" // Prefetch c. -" prfm pldl1keep,[x26] \n\t" // Prefetch c. -" prfm pldl1keep,[x27] \n\t" // Prefetch c. -" \n\t" -" dup v8.4s, wzr \n\t" // Vector for accummulating column 0 -" prfm PLDL1KEEP, [x1, #192] \n\t" -" dup v9.4s, wzr \n\t" // Vector for accummulating column 0 -" prfm PLDL1KEEP, [x1, #256] \n\t" -" dup v10.4s, wzr \n\t" // Vector for accummulating column 1 -" prfm PLDL1KEEP, [x1, #320] \n\t" -" dup v11.4s, wzr \n\t" // Vector for accummulating column 1 -" dup v12.4s, wzr \n\t" // Vector for accummulating column 2 -" dup v13.4s, wzr \n\t" // Vector for accummulating column 2 -" \n\t" -" dup v14.4s, wzr \n\t" // Vector for accummulating column 3 -" prfm PLDL1KEEP, [x0, #128] \n\t" -" dup v15.4s, wzr \n\t" // Vector for accummulating column 3 -" prfm PLDL1KEEP, [x0, #192] \n\t" -" dup v16.4s, wzr \n\t" // Vector for accummulating column 4 -" dup v17.4s, wzr \n\t" // Vector for accummulating column 4 -" dup v18.4s, wzr \n\t" // Vector for accummulating column 5 -" dup v19.4s, wzr \n\t" // Vector for accummulating column 5 -" \n\t" -" dup v20.4s, wzr \n\t" // Vector for accummulating column 6 -" dup v21.4s, wzr \n\t" // Vector for accummulating column 6 -" dup v22.4s, wzr \n\t" // Vector for accummulating column 7 -" dup v23.4s, wzr \n\t" // Vector for accummulating column 7 -" dup v24.4s, wzr \n\t" // Vector for accummulating column 8 -" dup v25.4s, wzr \n\t" // Vector for accummulating column 8 -" \n\t" -" dup v26.4s, wzr \n\t" // Vector for accummulating column 9 -" dup v27.4s, wzr \n\t" // Vector for accummulating column 9 -" dup v28.4s, wzr \n\t" // Vector for accummulating column 10 -" dup v29.4s, wzr \n\t" // Vector for accummulating column 10 -" dup v30.4s, wzr \n\t" // Vector for accummulating column 11 -" dup v31.4s, wzr \n\t" // Vector for accummulating column 11 -" \n\t" -" cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. -BEQ(SCONSIDERKLEFT) -" \n\t" -" ldr q0, [x0] \n\t" -" ldr q1, [x0, #16] \n\t" // Load a -" \n\t" -" ldr q2, [x1] \n\t" // Load b -" ldr q3, [x1, #16] \n\t" -" ldr q4, [x1, #32] \n\t" -" \n\t" -" add x0, x0, #32 \n\t" //update address of A -" add x1, x1, #48 \n\t" //update address of B -" \n\t" -" cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. -BEQ(SLASTITER) // (as loop is do-while-like). -" \n\t" -LABEL(SLOOPKITER) // Body of the k_iter loop. -" \n\t" -" ldr q5, [x0] \n\t" -" fmla v8.4s, v0.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s, v1.4s,v2.s[0] \n\t" // Accummulate. -" ldr q6, [x0, #16] \n\t" -" fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1] \n\t" -" \n\t" -" fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. -" prfm PLDL1KEEP, [x1, #336] \n\t" -" fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. -" prfm PLDL1KEEP, [x1, #400] \n\t" -" fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. -" prfm PLDL1KEEP, [x1, #464] \n\t" -" fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #16] \n\t" -" \n\t" -" fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #32] \n\t" -" \n\t" //End It 1 -" \n\t" -" ldr q0, [x0, #32] \n\t" -" fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. -" ldr q1, [x0, #48] \n\t" -" fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1, #48] \n\t" -" \n\t" -" fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. -" prfm PLDL1KEEP, [x0, #224] \n\t" -" fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. -" prfm PLDL1KEEP, [x0, #288] \n\t" -" fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #64] \n\t" -" \n\t" -" fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #80] \n\t" -" \n\t" //End It 2 -" \n\t" -" ldr q5, [x0, #64] \n\t" -" fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. -" ldr q6, [x0, #80] \n\t" -" fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1, #96] \n\t" -" \n\t" -" fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #112] \n\t" -" \n\t" -" fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #128] \n\t" -" \n\t" //End It 3 -" \n\t" -" ldr q0, [x0, #96] \n\t" -" fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. -" ldr q1, [x0, #112] \n\t" -" fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1, #144] \n\t" -" \n\t" -" fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #160] \n\t" -" \n\t" -" fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #176] \n\t" -" add x1, x1, #192 \n\t" -" add x0, x0, #128 \n\t" -" \n\t" //End It 4 -" sub x5,x5,1 \n\t" // i-=1. -" cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. -BNE(SLOOPKITER) -" \n\t" -LABEL(SLASTITER) // Last iteration of k_iter loop. -" \n\t" -" \n\t" -" ldr q5, [x0] \n\t" -" fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. -" ldr q6, [x0, #16] \n\t" -" fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1] \n\t" -" \n\t" -" fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #16] \n\t" -" \n\t" -" fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #32] \n\t" -" \n\t" //End It 1 -" \n\t" -" ldr q0, [x0, #32] \n\t" -" fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. -" ldr q1, [x0, #48] \n\t" -" fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1, #48] \n\t" -" \n\t" -" fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #64] \n\t" -" \n\t" -" fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #80] \n\t" -" \n\t" //End It 2 -" \n\t" -" ldr q5, [x0, #64] \n\t" -" fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. -" ldr q6, [x0, #80] \n\t" -" fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1, #96] \n\t" -" \n\t" -" fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #112] \n\t" -" \n\t" -" fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #128] \n\t" -" \n\t" //End It 3 -" \n\t" -" fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. -" fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. -" add x1, x1, #144 \n\t" -" add x0, x0, #96 \n\t" -" \n\t" //End It 4 -" \n\t" -LABEL(SCONSIDERKLEFT) -" cmp x6,0 \n\t" // If k_left == 0, we are done. -BEQ(SPOSTACCUM) // else, we enter the k_left loop. -" \n\t" -LABEL(SLOOPKLEFT) // Body of the left iterations -" \n\t" -" ldr q0, [x0],#16 \n\t" -" ldr q1, [x0],#16 \n\t" // Load a -" \n\t" -" ldr q2, [x1],#16 \n\t" // Load b -" ldr q3, [x1],#16 \n\t" -" ldr q4, [x1],#16 \n\t" -" \n\t" -" sub x6,x6,1 \n\t" // i = i-1. -" \n\t" -" fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. -" fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. -" fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. -" \n\t" -" cmp x6,0 \n\t" // Iterate again. -BNE(SLOOPKLEFT) // if i!=0. -" \n\t" -LABEL(SPOSTACCUM) -" \n\t" -" ldr x0,%[alpha] \n\t" // Alpha address. -" ldr x1,%[beta] \n\t" // Beta address. -" \n\t" -" ld1r {v6.4s},[x0] \n\t" // Load alpha. -" ld1r {v7.4s},[x1] \n\t" // Load beta -" \n\t" -" ldr x0,%[a_next] \n\t" // Pointer to next block of A. -" ldr x1,%[b_next] \n\t" // Pointer to next pointer of B. -" \n\t" -" cmp x14,#4 \n\t" // If rs_c != 1 (column-major) -BNE(SGENSTORED) -" \n\t" -LABEL(SCOLSTORED) // C is column-major. -" \n\t" -" dup v0.4s, wzr \n\t" -" dup v1.4s, wzr \n\t" -" dup v2.4s, wzr \n\t" -" dup v3.4s, wzr \n\t" -" dup v4.4s, wzr \n\t" -" dup v5.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -BEQ(SBETAZEROCOLSTOREDS1) // Taking care of the beta==0 case. -" \n\t" -" ldr q0, [x2] \n\t" //Load column 0 of C -" ldr q1, [x2, #16] \n\t" -" ldr q2, [x16] \n\t" //Load column 1 of C -" ldr q3, [x16, #16] \n\t" -" ldr q4, [x17] \n\t" //Load column 2 of C -" ldr q5, [x17, #16] \n\t" -" \n\t" -" fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta -" fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta -" fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta -" fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta -" fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta -" fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -LABEL(SBETAZEROCOLSTOREDS1) -" \n\t" -" fmla v0.4s,v8.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v1.4s,v9.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v2.4s,v10.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v3.4s,v11.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v4.4s,v12.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v5.4s,v13.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" str q0, [x2] \n\t" //Store column 0 of C -" str q1, [x2, #16] \n\t" -" str q2, [x16] \n\t" //Store column 1 of C -" str q3, [x16, #16] \n\t" -" str q4, [x17] \n\t" //Store column 2 of C -" str q5, [x17, #16] \n\t" -" \n\t" -" dup v8.4s, wzr \n\t" -" dup v9.4s, wzr \n\t" -" dup v10.4s, wzr \n\t" -" dup v11.4s, wzr \n\t" -" dup v12.4s, wzr \n\t" -" dup v13.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -BEQ(SBETAZEROCOLSTOREDS2) // Taking care of the beta==0 case. -" \n\t" -" ldr q8, [x19] \n\t" //Load column 3 of C -" ldr q9, [x19, #16] \n\t" -" ldr q10, [x20] \n\t" //Load column 4 of C -" ldr q11, [x20, #16] \n\t" -" ldr q12, [x21] \n\t" //Load column 5 of C -" ldr q13, [x21, #16] \n\t" -" \n\t" -" fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta -" fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta -" fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta -" fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta -" fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta -" fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -LABEL(SBETAZEROCOLSTOREDS2) -" \n\t" -" fmla v8.4s, v14.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v9.4s, v15.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v10.4s,v16.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v11.4s,v17.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v12.4s,v18.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v13.4s,v19.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" str q8, [x19] \n\t" //Store column 3 of C -" str q9, [x19, #16] \n\t" -" str q10, [x20] \n\t" //Store column 4 of C -" str q11, [x20, #16] \n\t" -" str q12, [x21] \n\t" //Store column 5 of C -" str q13, [x21, #16] \n\t" -" \n\t" -" dup v0.4s, wzr \n\t" -" dup v1.4s, wzr \n\t" -" dup v2.4s, wzr \n\t" -" dup v3.4s, wzr \n\t" -" dup v4.4s, wzr \n\t" -" dup v5.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -BEQ(SBETAZEROCOLSTOREDS3) // Taking care of the beta==0 case. -" \n\t" -" ldr q0, [x22] \n\t" //Load column 6 of C -" ldr q1, [x22, #16] \n\t" -" ldr q2, [x23] \n\t" //Load column 7 of C -" ldr q3, [x23, #16] \n\t" -" ldr q4, [x24] \n\t" //Load column 8 of C -" ldr q5, [x24, #16] \n\t" -" \n\t" -" fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta -" fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta -" fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta -" fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta -" fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta -" fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -LABEL(SBETAZEROCOLSTOREDS3) -" \n\t" -" fmla v0.4s,v20.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v1.4s,v21.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v2.4s,v22.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v3.4s,v23.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v4.4s,v24.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v5.4s,v25.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" str q0, [x22] \n\t" //Store column 6 of C -" str q1, [x22, #16] \n\t" -" str q2, [x23] \n\t" //Store column 7 of C -" str q3, [x23, #16] \n\t" -" str q4, [x24] \n\t" //Store column 8 of C -" str q5, [x24, #16] \n\t" -" \n\t" -" dup v8.4s, wzr \n\t" -" dup v9.4s, wzr \n\t" -" dup v10.4s, wzr \n\t" -" dup v11.4s, wzr \n\t" -" dup v12.4s, wzr \n\t" -" dup v13.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -BEQ(SBETAZEROCOLSTOREDS4) // Taking care of the beta==0 case. -" \n\t" -" ldr q8, [x25] \n\t" //Load column 9 of C -" ldr q9, [x25, #16] \n\t" -" ldr q10, [x26] \n\t" //Load column 10 of C -" ldr q11, [x26, #16] \n\t" -" ldr q12, [x27] \n\t" //Load column 11 of C -" ldr q13, [x27, #16] \n\t" -" \n\t" -" fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta -" fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta -" fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta -" fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta -" fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta -" fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -LABEL(SBETAZEROCOLSTOREDS4) -" \n\t" -" prfm pldl2keep,[x0] \n\t" -" prfm pldl2keep,[x1] \n\t" -" \n\t" -" fmla v8.4s, v26.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v9.4s, v27.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v10.4s,v28.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v11.4s,v29.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v12.4s,v30.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v13.4s,v31.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" str q8, [x25] \n\t" //Store column 9 of C -" str q9, [x25, #16] \n\t" -" str q10, [x26] \n\t" //Store column 10 of C -" str q11, [x26, #16] \n\t" -" str q12, [x27] \n\t" //Store column 11 of C -" str q13, [x27, #16] \n\t" -" \n\t" -" \n\t" -BRANCH(SEND) // Done. -" \n\t" -" \n\t" -LABEL(SGENSTORED) // C is general-stride stored. -" \n\t" -" \n\t" -" dup v0.4s, wzr \n\t" -" dup v1.4s, wzr \n\t" -" dup v2.4s, wzr \n\t" -" dup v3.4s, wzr \n\t" -" dup v4.4s, wzr \n\t" -" dup v5.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -BEQ(SBETAZEROGENSTOREDS1) // Taking care of the beta==0 case. -" \n\t" -" mov x5, x2 \n\t" -" \n\t" -" ld1 {v0.s}[0],[x5],x14 \n\t" // Load c00 into quad and increment by rs_c. -" ld1 {v0.s}[1],[x5],x14 \n\t" // Load c01 into quad and increment by rs_c. -" ld1 {v0.s}[2],[x5],x14 \n\t" // Load c02 into quad and increment by rs_c. -" ld1 {v0.s}[3],[x5],x14 \n\t" // Load c03 into quad and increment by rs_c. -" ld1 {v1.s}[0],[x5],x14 \n\t" // Load c04 into quad and increment by rs_c. -" ld1 {v1.s}[1],[x5],x14 \n\t" // Load c05 into quad and increment by rs_c. -" ld1 {v1.s}[2],[x5],x14 \n\t" // Load c06 into quad and increment by rs_c. -" ld1 {v1.s}[3],[x5],x14 \n\t" // Load c07 into quad and increment by rs_c. -" \n\t" -" mov x5, x16 \n\t" -" \n\t" -" ld1 {v2.s}[0],[x5],x14 \n\t" // Load c10 into quad and increment by rs_c. -" ld1 {v2.s}[1],[x5],x14 \n\t" // Load c11 into quad and increment by rs_c. -" ld1 {v2.s}[2],[x5],x14 \n\t" // Load c12 into quad and increment by rs_c. -" ld1 {v2.s}[3],[x5],x14 \n\t" // Load c13 into quad and increment by rs_c. -" ld1 {v3.s}[0],[x5],x14 \n\t" // Load c14 into quad and increment by rs_c. -" ld1 {v3.s}[1],[x5],x14 \n\t" // Load c15 into quad and increment by rs_c. -" ld1 {v3.s}[2],[x5],x14 \n\t" // Load c16 into quad and increment by rs_c. -" ld1 {v3.s}[3],[x5],x14 \n\t" // Load c17 into quad and increment by rs_c. -" \n\t" -" mov x5, x17 \n\t" -" \n\t" -" ld1 {v4.s}[0],[x5],x14 \n\t" // Load c20 into quad and increment by rs_c. -" ld1 {v4.s}[1],[x5],x14 \n\t" // Load c21 into quad and increment by rs_c. -" ld1 {v4.s}[2],[x5],x14 \n\t" // Load c22 into quad and increment by rs_c. -" ld1 {v4.s}[3],[x5],x14 \n\t" // Load c23 into quad and increment by rs_c. -" ld1 {v5.s}[0],[x5],x14 \n\t" // Load c24 into quad and increment by rs_c. -" ld1 {v5.s}[1],[x5],x14 \n\t" // Load c25 into quad and increment by rs_c. -" ld1 {v5.s}[2],[x5],x14 \n\t" // Load c26 into quad and increment by rs_c. -" ld1 {v5.s}[3],[x5],x14 \n\t" // Load c27 into quad and increment by rs_c. -" \n\t" -" fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta -" fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta -" fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta -" fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta -" fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta -" fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -LABEL(SBETAZEROGENSTOREDS1) -" \n\t" -" fmla v0.4s, v8.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v1.4s, v9.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v2.4s,v10.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v3.4s,v11.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v4.4s,v12.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v5.4s,v13.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" mov x5, x2 \n\t" -" \n\t" -" st1 {v0.s}[0],[x5],x14 \n\t" // Store c00 into quad and increment by rs_c. -" st1 {v0.s}[1],[x5],x14 \n\t" // Store c01 into quad and increment by rs_c. -" st1 {v0.s}[2],[x5],x14 \n\t" // Store c02 into quad and increment by rs_c. -" st1 {v0.s}[3],[x5],x14 \n\t" // Store c03 into quad and increment by rs_c. -" st1 {v1.s}[0],[x5],x14 \n\t" // Store c04 into quad and increment by rs_c. -" st1 {v1.s}[1],[x5],x14 \n\t" // Store c05 into quad and increment by rs_c. -" st1 {v1.s}[2],[x5],x14 \n\t" // Store c06 into quad and increment by rs_c. -" st1 {v1.s}[3],[x5],x14 \n\t" // Store c07 into quad and increment by rs_c. -" \n\t" -" mov x5, x16 \n\t" -" \n\t" -" st1 {v2.s}[0],[x5],x14 \n\t" // Store c10 into quad and increment by rs_c. -" st1 {v2.s}[1],[x5],x14 \n\t" // Store c11 into quad and increment by rs_c. -" st1 {v2.s}[2],[x5],x14 \n\t" // Store c12 into quad and increment by rs_c. -" st1 {v2.s}[3],[x5],x14 \n\t" // Store c13 into quad and increment by rs_c. -" st1 {v3.s}[0],[x5],x14 \n\t" // Store c14 into quad and increment by rs_c. -" st1 {v3.s}[1],[x5],x14 \n\t" // Store c15 into quad and increment by rs_c. -" st1 {v3.s}[2],[x5],x14 \n\t" // Store c16 into quad and increment by rs_c. -" st1 {v3.s}[3],[x5],x14 \n\t" // Store c17 into quad and increment by rs_c. -" \n\t" -" mov x5, x17 \n\t" -" \n\t" -" st1 {v4.s}[0],[x5],x14 \n\t" // Store c20 into quad and increment by rs_c. -" st1 {v4.s}[1],[x5],x14 \n\t" // Store c21 into quad and increment by rs_c. -" st1 {v4.s}[2],[x5],x14 \n\t" // Store c22 into quad and increment by rs_c. -" st1 {v4.s}[3],[x5],x14 \n\t" // Store c23 into quad and increment by rs_c. -" st1 {v5.s}[0],[x5],x14 \n\t" // Store c24 into quad and increment by rs_c. -" st1 {v5.s}[1],[x5],x14 \n\t" // Store c25 into quad and increment by rs_c. -" st1 {v5.s}[2],[x5],x14 \n\t" // Store c26 into quad and increment by rs_c. -" st1 {v5.s}[3],[x5],x14 \n\t" // Store c27 into quad and increment by rs_c. -" \n\t" -" dup v8.4s, wzr \n\t" -" dup v9.4s, wzr \n\t" -" dup v10.4s, wzr \n\t" -" dup v11.4s, wzr \n\t" -" dup v12.4s, wzr \n\t" -" dup v13.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -BEQ(SBETAZEROGENSTOREDS2) // Taking care of the beta==0 case. -" \n\t" -" mov x5, x19 \n\t" -" \n\t" -" ld1 {v8.s}[0],[x5],x14 \n\t" // Load c30 into quad and increment by rs_c. -" ld1 {v8.s}[1],[x5],x14 \n\t" // Load c31 into quad and increment by rs_c. -" ld1 {v8.s}[2],[x5],x14 \n\t" // Load c32 into quad and increment by rs_c. -" ld1 {v8.s}[3],[x5],x14 \n\t" // Load c33 into quad and increment by rs_c. -" ld1 {v9.s}[0],[x5],x14 \n\t" // Load c34 into quad and increment by rs_c. -" ld1 {v9.s}[1],[x5],x14 \n\t" // Load c35 into quad and increment by rs_c. -" ld1 {v9.s}[2],[x5],x14 \n\t" // Load c36 into quad and increment by rs_c. -" ld1 {v9.s}[3],[x5],x14 \n\t" // Load c37 into quad and increment by rs_c. -" \n\t" -" mov x5, x20 \n\t" -" \n\t" -" ld1 {v10.s}[0],[x5],x14 \n\t" // Load c40 into quad and increment by rs_c. -" ld1 {v10.s}[1],[x5],x14 \n\t" // Load c41 into quad and increment by rs_c. -" ld1 {v10.s}[2],[x5],x14 \n\t" // Load c42 into quad and increment by rs_c. -" ld1 {v10.s}[3],[x5],x14 \n\t" // Load c43 into quad and increment by rs_c. -" ld1 {v11.s}[0],[x5],x14 \n\t" // Load c44 into quad and increment by rs_c. -" ld1 {v11.s}[1],[x5],x14 \n\t" // Load c45 into quad and increment by rs_c. -" ld1 {v11.s}[2],[x5],x14 \n\t" // Load c46 into quad and increment by rs_c. -" ld1 {v11.s}[3],[x5],x14 \n\t" // Load c47 into quad and increment by rs_c. -" \n\t" -" mov x5, x21 \n\t" -" \n\t" -" ld1 {v12.s}[0],[x5],x14 \n\t" // Load c50 into quad and increment by rs_c. -" ld1 {v12.s}[1],[x5],x14 \n\t" // Load c51 into quad and increment by rs_c. -" ld1 {v12.s}[2],[x5],x14 \n\t" // Load c52 into quad and increment by rs_c. -" ld1 {v12.s}[3],[x5],x14 \n\t" // Load c53 into quad and increment by rs_c. -" ld1 {v13.s}[0],[x5],x14 \n\t" // Load c54 into quad and increment by rs_c. -" ld1 {v13.s}[1],[x5],x14 \n\t" // Load c55 into quad and increment by rs_c. -" ld1 {v13.s}[2],[x5],x14 \n\t" // Load c56 into quad and increment by rs_c. -" ld1 {v13.s}[3],[x5],x14 \n\t" // Load c57 into quad and increment by rs_c. -" \n\t" -" fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta -" fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta -" fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta -" fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta -" fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta -" fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -LABEL(SBETAZEROGENSTOREDS2) -" \n\t" -" fmla v8.4s, v14.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v9.4s, v15.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v10.4s,v16.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v11.4s,v17.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v12.4s,v18.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v13.4s,v19.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" mov x5, x19 \n\t" -" \n\t" -" st1 {v8.s}[0],[x5],x14 \n\t" // Store c30 into quad and increment by rs_c. -" st1 {v8.s}[1],[x5],x14 \n\t" // Store c31 into quad and increment by rs_c. -" st1 {v8.s}[2],[x5],x14 \n\t" // Store c32 into quad and increment by rs_c. -" st1 {v8.s}[3],[x5],x14 \n\t" // Store c33 into quad and increment by rs_c. -" st1 {v9.s}[0],[x5],x14 \n\t" // Store c34 into quad and increment by rs_c. -" st1 {v9.s}[1],[x5],x14 \n\t" // Store c35 into quad and increment by rs_c. -" st1 {v9.s}[2],[x5],x14 \n\t" // Store c36 into quad and increment by rs_c. -" st1 {v9.s}[3],[x5],x14 \n\t" // Store c37 into quad and increment by rs_c. -" \n\t" -" mov x5, x20 \n\t" -" \n\t" -" st1 {v10.s}[0],[x5],x14 \n\t" // Store c40 into quad and increment by rs_c. -" st1 {v10.s}[1],[x5],x14 \n\t" // Store c41 into quad and increment by rs_c. -" st1 {v10.s}[2],[x5],x14 \n\t" // Store c42 into quad and increment by rs_c. -" st1 {v10.s}[3],[x5],x14 \n\t" // Store c43 into quad and increment by rs_c. -" st1 {v11.s}[0],[x5],x14 \n\t" // Store c44 into quad and increment by rs_c. -" st1 {v11.s}[1],[x5],x14 \n\t" // Store c45 into quad and increment by rs_c. -" st1 {v11.s}[2],[x5],x14 \n\t" // Store c46 into quad and increment by rs_c. -" st1 {v11.s}[3],[x5],x14 \n\t" // Store c47 into quad and increment by rs_c. -" \n\t" -" mov x5, x21 \n\t" -" \n\t" -" st1 {v12.s}[0],[x5],x14 \n\t" // Store c50 into quad and increment by rs_c. -" st1 {v12.s}[1],[x5],x14 \n\t" // Store c51 into quad and increment by rs_c. -" st1 {v12.s}[2],[x5],x14 \n\t" // Store c52 into quad and increment by rs_c. -" st1 {v12.s}[3],[x5],x14 \n\t" // Store c53 into quad and increment by rs_c. -" st1 {v13.s}[0],[x5],x14 \n\t" // Store c54 into quad and increment by rs_c. -" st1 {v13.s}[1],[x5],x14 \n\t" // Store c55 into quad and increment by rs_c. -" st1 {v13.s}[2],[x5],x14 \n\t" // Store c56 into quad and increment by rs_c. -" st1 {v13.s}[3],[x5],x14 \n\t" // Store c57 into quad and increment by rs_c. -" \n\t" -" dup v0.4s, wzr \n\t" -" dup v1.4s, wzr \n\t" -" dup v2.4s, wzr \n\t" -" dup v3.4s, wzr \n\t" -" dup v4.4s, wzr \n\t" -" dup v5.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -BEQ(SBETAZEROGENSTOREDS3) // Taking care of the beta==0 case. -" \n\t" -" mov x5, x22 \n\t" -" \n\t" -" ld1 {v0.s}[0],[x5],x14 \n\t" // Load c60 into quad and increment by rs_c. -" ld1 {v0.s}[1],[x5],x14 \n\t" // Load c61 into quad and increment by rs_c. -" ld1 {v0.s}[2],[x5],x14 \n\t" // Load c62 into quad and increment by rs_c. -" ld1 {v0.s}[3],[x5],x14 \n\t" // Load c63 into quad and increment by rs_c. -" ld1 {v1.s}[0],[x5],x14 \n\t" // Load c64 into quad and increment by rs_c. -" ld1 {v1.s}[1],[x5],x14 \n\t" // Load c65 into quad and increment by rs_c. -" ld1 {v1.s}[2],[x5],x14 \n\t" // Load c66 into quad and increment by rs_c. -" ld1 {v1.s}[3],[x5],x14 \n\t" // Load c67 into quad and increment by rs_c. -" \n\t" -" mov x5, x23 \n\t" -" \n\t" -" ld1 {v2.s}[0],[x5],x14 \n\t" // Load c70 into quad and increment by rs_c. -" ld1 {v2.s}[1],[x5],x14 \n\t" // Load c71 into quad and increment by rs_c. -" ld1 {v2.s}[2],[x5],x14 \n\t" // Load c72 into quad and increment by rs_c. -" ld1 {v2.s}[3],[x5],x14 \n\t" // Load c73 into quad and increment by rs_c. -" ld1 {v3.s}[0],[x5],x14 \n\t" // Load c74 into quad and increment by rs_c. -" ld1 {v3.s}[1],[x5],x14 \n\t" // Load c75 into quad and increment by rs_c. -" ld1 {v3.s}[2],[x5],x14 \n\t" // Load c76 into quad and increment by rs_c. -" ld1 {v3.s}[3],[x5],x14 \n\t" // Load c77 into quad and increment by rs_c. -" \n\t" -" mov x5, x24 \n\t" -" \n\t" -" ld1 {v4.s}[0],[x5],x14 \n\t" // Load c80 into quad and increment by rs_c. -" ld1 {v4.s}[1],[x5],x14 \n\t" // Load c81 into quad and increment by rs_c. -" ld1 {v4.s}[2],[x5],x14 \n\t" // Load c82 into quad and increment by rs_c. -" ld1 {v4.s}[3],[x5],x14 \n\t" // Load c83 into quad and increment by rs_c. -" ld1 {v5.s}[0],[x5],x14 \n\t" // Load c84 into quad and increment by rs_c. -" ld1 {v5.s}[1],[x5],x14 \n\t" // Load c85 into quad and increment by rs_c. -" ld1 {v5.s}[2],[x5],x14 \n\t" // Load c86 into quad and increment by rs_c. -" ld1 {v5.s}[3],[x5],x14 \n\t" // Load c87 into quad and increment by rs_c. -" \n\t" -" fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta -" fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta -" fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta -" fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta -" fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta -" fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -LABEL(SBETAZEROGENSTOREDS3) -" \n\t" -" fmla v0.4s,v20.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v1.4s,v21.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v2.4s,v22.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v3.4s,v23.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v4.4s,v24.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v5.4s,v25.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" mov x5, x22 \n\t" -" \n\t" -" st1 {v0.s}[0],[x5],x14 \n\t" // Store c60 into quad and increment by rs_c. -" st1 {v0.s}[1],[x5],x14 \n\t" // Store c61 into quad and increment by rs_c. -" st1 {v0.s}[2],[x5],x14 \n\t" // Store c62 into quad and increment by rs_c. -" st1 {v0.s}[3],[x5],x14 \n\t" // Store c63 into quad and increment by rs_c. -" st1 {v1.s}[0],[x5],x14 \n\t" // Store c64 into quad and increment by rs_c. -" st1 {v1.s}[1],[x5],x14 \n\t" // Store c65 into quad and increment by rs_c. -" st1 {v1.s}[2],[x5],x14 \n\t" // Store c66 into quad and increment by rs_c. -" st1 {v1.s}[3],[x5],x14 \n\t" // Store c67 into quad and increment by rs_c. -" \n\t" -" mov x5, x23 \n\t" -" \n\t" -" st1 {v2.s}[0],[x5],x14 \n\t" // Store c70 into quad and increment by rs_c. -" st1 {v2.s}[1],[x5],x14 \n\t" // Store c71 into quad and increment by rs_c. -" st1 {v2.s}[2],[x5],x14 \n\t" // Store c72 into quad and increment by rs_c. -" st1 {v2.s}[3],[x5],x14 \n\t" // Store c73 into quad and increment by rs_c. -" st1 {v3.s}[0],[x5],x14 \n\t" // Store c74 into quad and increment by rs_c. -" st1 {v3.s}[1],[x5],x14 \n\t" // Store c75 into quad and increment by rs_c. -" st1 {v3.s}[2],[x5],x14 \n\t" // Store c76 into quad and increment by rs_c. -" st1 {v3.s}[3],[x5],x14 \n\t" // Store c77 into quad and increment by rs_c. -" \n\t" -" mov x5, x24 \n\t" -" \n\t" -" st1 {v4.s}[0],[x5],x14 \n\t" // Store c80 into quad and increment by rs_c. -" st1 {v4.s}[1],[x5],x14 \n\t" // Store c81 into quad and increment by rs_c. -" st1 {v4.s}[2],[x5],x14 \n\t" // Store c82 into quad and increment by rs_c. -" st1 {v4.s}[3],[x5],x14 \n\t" // Store c83 into quad and increment by rs_c. -" st1 {v5.s}[0],[x5],x14 \n\t" // Store c84 into quad and increment by rs_c. -" st1 {v5.s}[1],[x5],x14 \n\t" // Store c85 into quad and increment by rs_c. -" st1 {v5.s}[2],[x5],x14 \n\t" // Store c86 into quad and increment by rs_c. -" st1 {v5.s}[3],[x5],x14 \n\t" // Store c87 into quad and increment by rs_c. -" \n\t" -" dup v8.4s, wzr \n\t" -" dup v9.4s, wzr \n\t" -" dup v10.4s, wzr \n\t" -" dup v11.4s, wzr \n\t" -" dup v12.4s, wzr \n\t" -" dup v13.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -BEQ(SBETAZEROGENSTOREDS4) // Taking care of the beta==0 case. -" \n\t" -" mov x5, x25 \n\t" -" \n\t" -" ld1 {v8.s}[0],[x5],x14 \n\t" // Load c90 into quad and increment by rs_c. -" ld1 {v8.s}[1],[x5],x14 \n\t" // Load c91 into quad and increment by rs_c. -" ld1 {v8.s}[2],[x5],x14 \n\t" // Load c92 into quad and increment by rs_c. -" ld1 {v8.s}[3],[x5],x14 \n\t" // Load c93 into quad and increment by rs_c. -" ld1 {v9.s}[0],[x5],x14 \n\t" // Load c94 into quad and increment by rs_c. -" ld1 {v9.s}[1],[x5],x14 \n\t" // Load c95 into quad and increment by rs_c. -" ld1 {v9.s}[2],[x5],x14 \n\t" // Load c96 into quad and increment by rs_c. -" ld1 {v9.s}[3],[x5],x14 \n\t" // Load c97 into quad and increment by rs_c. -" \n\t" -" mov x5, x26 \n\t" -" \n\t" -" ld1 {v10.s}[0],[x5],x14 \n\t" // Load c100 into quad and increment by rs_c. -" ld1 {v10.s}[1],[x5],x14 \n\t" // Load c101 into quad and increment by rs_c. -" ld1 {v10.s}[2],[x5],x14 \n\t" // Load c102 into quad and increment by rs_c. -" ld1 {v10.s}[3],[x5],x14 \n\t" // Load c103 into quad and increment by rs_c. -" ld1 {v11.s}[0],[x5],x14 \n\t" // Load c104 into quad and increment by rs_c. -" ld1 {v11.s}[1],[x5],x14 \n\t" // Load c105 into quad and increment by rs_c. -" ld1 {v11.s}[2],[x5],x14 \n\t" // Load c106 into quad and increment by rs_c. -" ld1 {v11.s}[3],[x5],x14 \n\t" // Load c107 into quad and increment by rs_c. -" \n\t" -" mov x5, x27 \n\t" -" \n\t" -" ld1 {v12.s}[0],[x5],x14 \n\t" // Load c110 into quad and increment by rs_c. -" ld1 {v12.s}[1],[x5],x14 \n\t" // Load c111 into quad and increment by rs_c. -" ld1 {v12.s}[2],[x5],x14 \n\t" // Load c112 into quad and increment by rs_c. -" ld1 {v12.s}[3],[x5],x14 \n\t" // Load c113 into quad and increment by rs_c. -" ld1 {v13.s}[0],[x5],x14 \n\t" // Load c114 into quad and increment by rs_c. -" ld1 {v13.s}[1],[x5],x14 \n\t" // Load c115 into quad and increment by rs_c. -" ld1 {v13.s}[2],[x5],x14 \n\t" // Load c116 into quad and increment by rs_c. -" ld1 {v13.s}[3],[x5],x14 \n\t" // Load c117 into quad and increment by rs_c. -" \n\t" -" fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta -" fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta -" fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta -" fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta -" fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta -" fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -LABEL(SBETAZEROGENSTOREDS4) -" \n\t" -" prfm pldl2keep,[x0] \n\t" -" prfm pldl2keep,[x1] \n\t" -" \n\t" -" fmla v8.4s, v26.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v9.4s, v27.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v10.4s,v28.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v11.4s,v29.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v12.4s,v30.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v13.4s,v31.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" mov x5, x25 \n\t" -" \n\t" -" st1 {v8.s}[0],[x5],x14 \n\t" // Store c90 into quad and increment by rs_c. -" st1 {v8.s}[1],[x5],x14 \n\t" // Store c91 into quad and increment by rs_c. -" st1 {v8.s}[2],[x5],x14 \n\t" // Store c92 into quad and increment by rs_c. -" st1 {v8.s}[3],[x5],x14 \n\t" // Store c93 into quad and increment by rs_c. -" st1 {v9.s}[0],[x5],x14 \n\t" // Store c94 into quad and increment by rs_c. -" st1 {v9.s}[1],[x5],x14 \n\t" // Store c95 into quad and increment by rs_c. -" st1 {v9.s}[2],[x5],x14 \n\t" // Store c96 into quad and increment by rs_c. -" st1 {v9.s}[3],[x5],x14 \n\t" // Store c97 into quad and increment by rs_c. -" \n\t" -" mov x5, x26 \n\t" -" \n\t" -" st1 {v10.s}[0],[x5],x14 \n\t" // Store c100 into quad and increment by rs_c. -" st1 {v10.s}[1],[x5],x14 \n\t" // Store c101 into quad and increment by rs_c. -" st1 {v10.s}[2],[x5],x14 \n\t" // Store c102 into quad and increment by rs_c. -" st1 {v10.s}[3],[x5],x14 \n\t" // Store c103 into quad and increment by rs_c. -" st1 {v11.s}[0],[x5],x14 \n\t" // Store c104 into quad and increment by rs_c. -" st1 {v11.s}[1],[x5],x14 \n\t" // Store c105 into quad and increment by rs_c. -" st1 {v11.s}[2],[x5],x14 \n\t" // Store c106 into quad and increment by rs_c. -" st1 {v11.s}[3],[x5],x14 \n\t" // Store c107 into quad and increment by rs_c. -" \n\t" -" mov x5, x27 \n\t" -" \n\t" -" st1 {v12.s}[0],[x5],x14 \n\t" // Store c110 into quad and increment by rs_c. -" st1 {v12.s}[1],[x5],x14 \n\t" // Store c111 into quad and increment by rs_c. -" st1 {v12.s}[2],[x5],x14 \n\t" // Store c112 into quad and increment by rs_c. -" st1 {v12.s}[3],[x5],x14 \n\t" // Store c113 into quad and increment by rs_c. -" st1 {v13.s}[0],[x5],x14 \n\t" // Store c114 into quad and increment by rs_c. -" st1 {v13.s}[1],[x5],x14 \n\t" // Store c115 into quad and increment by rs_c. -" st1 {v13.s}[2],[x5],x14 \n\t" // Store c116 into quad and increment by rs_c. -" st1 {v13.s}[3],[x5],x14 \n\t" // Store c147 into quad and increment by rs_c. -" \n\t" -LABEL(SEND) // Done! -" \n\t" -:// output operands (none) -:// input operands - [aaddr] "m" (a), // 0 - [baddr] "m" (b), // 1 - [caddr] "m" (c), // 2 - [k_iter] "m" (k_iter), // 3 - [k_left] "m" (k_left), // 4 - [alpha] "m" (alpha), // 5 - [beta] "m" (beta), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [a_next] "m" (a_next), // 9 - [b_next] "m" (b_next) // 10 -:// Register clobber list - "x0", "x1", "x2", - "x5", "x6", "x10","x14", - "x16","x17","x19","x20", - "x21","x22","x23","x24", - "x25","x26","x27", - "v0", "v1", "v2", "v3", - "v4", "v5", "v6", "v7", - "v8", "v9", "v10","v11", - "v12","v13","v14","v15", - "v16","v17","v18","v19", - "v20","v21","v22","v23", - "v24","v25","v26","v27", - "v28","v29","v30","v31" -); + __asm__ volatile + ( + " \n\t" + " \n\t" + " ldr x0,%[aaddr] \n\t" // Load address of A. + " ldr x1,%[baddr] \n\t" // Load address of B. + " ldr x2,%[caddr] \n\t" // Load address of C. + " \n\t" + " ldr x5,%[k_iter] \n\t" // Number of unrolled iterations (k_iter). + " ldr x6,%[k_left] \n\t" // Number of remaining iterations (k_left). + " \n\t" + " ldr x10,%[cs_c] \n\t" // Load cs_c. + " lsl x10,x10,#2 \n\t" // cs_c * sizeof(float) -- AUX. + " \n\t" + " ldr x14,%[rs_c] \n\t" // Load rs_c. + " lsl x14,x14,#2 \n\t" // rs_c * sizeof(float). + " \n\t" + " add x16,x2,x10 \n\t" //Load address Column 1 of C + " add x17,x16,x10 \n\t" //Load address Column 2 of C + " add x19,x17,x10 \n\t" //Load address Column 3 of C + " add x20,x19,x10 \n\t" //Load address Column 4 of C + " add x21,x20,x10 \n\t" //Load address Column 5 of C + " add x22,x21,x10 \n\t" //Load address Column 6 of C + " add x23,x22,x10 \n\t" //Load address Column 7 of C + " add x24,x23,x10 \n\t" //Load address Column 8 of C + " add x25,x24,x10 \n\t" //Load address Column 9 of C + " add x26,x25,x10 \n\t" //Load address Column 10 of C + " add x27,x26,x10 \n\t" //Load address Column 11 of C + " \n\t" + " prfm pldl1keep,[x2] \n\t" // Prefetch c. + " prfm pldl1keep,[x16] \n\t" // Prefetch c. + " prfm pldl1keep,[x17] \n\t" // Prefetch c. + " prfm pldl1keep,[x19] \n\t" // Prefetch c. + " prfm pldl1keep,[x20] \n\t" // Prefetch c. + " prfm pldl1keep,[x21] \n\t" // Prefetch c. + " prfm pldl1keep,[x22] \n\t" // Prefetch c. + " prfm pldl1keep,[x23] \n\t" // Prefetch c. + " prfm pldl1keep,[x24] \n\t" // Prefetch c. + " prfm pldl1keep,[x25] \n\t" // Prefetch c. + " prfm pldl1keep,[x26] \n\t" // Prefetch c. + " prfm pldl1keep,[x27] \n\t" // Prefetch c. + " \n\t" + " dup v8.4s, wzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #192] \n\t" + " dup v9.4s, wzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #256] \n\t" + " dup v10.4s, wzr \n\t" // Vector for accummulating column 1 + " prfm PLDL1KEEP, [x1, #320] \n\t" + " dup v11.4s, wzr \n\t" // Vector for accummulating column 1 + " dup v12.4s, wzr \n\t" // Vector for accummulating column 2 + " dup v13.4s, wzr \n\t" // Vector for accummulating column 2 + " \n\t" + " dup v14.4s, wzr \n\t" // Vector for accummulating column 3 + " prfm PLDL1KEEP, [x0, #128] \n\t" + " dup v15.4s, wzr \n\t" // Vector for accummulating column 3 + " prfm PLDL1KEEP, [x0, #192] \n\t" + " dup v16.4s, wzr \n\t" // Vector for accummulating column 4 + " dup v17.4s, wzr \n\t" // Vector for accummulating column 4 + " dup v18.4s, wzr \n\t" // Vector for accummulating column 5 + " dup v19.4s, wzr \n\t" // Vector for accummulating column 5 + " \n\t" + " dup v20.4s, wzr \n\t" // Vector for accummulating column 6 + " dup v21.4s, wzr \n\t" // Vector for accummulating column 6 + " dup v22.4s, wzr \n\t" // Vector for accummulating column 7 + " dup v23.4s, wzr \n\t" // Vector for accummulating column 7 + " dup v24.4s, wzr \n\t" // Vector for accummulating column 8 + " dup v25.4s, wzr \n\t" // Vector for accummulating column 8 + " \n\t" + " dup v26.4s, wzr \n\t" // Vector for accummulating column 9 + " dup v27.4s, wzr \n\t" // Vector for accummulating column 9 + " dup v28.4s, wzr \n\t" // Vector for accummulating column 10 + " dup v29.4s, wzr \n\t" // Vector for accummulating column 10 + " dup v30.4s, wzr \n\t" // Vector for accummulating column 11 + " dup v31.4s, wzr \n\t" // Vector for accummulating column 11 + " \n\t" + " cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. + BEQ(SCONSIDERKLEFT) + " \n\t" + " ldr q0, [x0] \n\t" + " ldr q1, [x0, #16] \n\t" // Load a + " \n\t" + " ldr q2, [x1] \n\t" // Load b + " ldr q3, [x1, #16] \n\t" + " ldr q4, [x1, #32] \n\t" + " \n\t" + " add x0, x0, #32 \n\t" //update address of A + " add x1, x1, #48 \n\t" //update address of B + " \n\t" + " cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. + BEQ(SLASTITER) // (as loop is do-while-like). + " \n\t" + LABEL(SLOOPKITER) // Body of the k_iter loop. + " \n\t" + " ldr q5, [x0] \n\t" + " fmla v8.4s, v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s, v1.4s,v2.s[0] \n\t" // Accummulate. + " ldr q6, [x0, #16] \n\t" + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1] \n\t" + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x1, #336] \n\t" + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x1, #400] \n\t" + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x1, #464] \n\t" + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #16] \n\t" + " \n\t" + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #32] \n\t" + " \n\t" //End It 1 + " \n\t" + " ldr q0, [x0, #32] \n\t" + " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. + " ldr q1, [x0, #48] \n\t" + " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #48] \n\t" + " \n\t" + " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x0, #224] \n\t" + " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x0, #288] \n\t" + " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #64] \n\t" + " \n\t" + " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #80] \n\t" + " \n\t" //End It 2 + " \n\t" + " ldr q5, [x0, #64] \n\t" + " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. + " ldr q6, [x0, #80] \n\t" + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #96] \n\t" + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #112] \n\t" + " \n\t" + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #128] \n\t" + " \n\t" //End It 3 + " \n\t" + " ldr q0, [x0, #96] \n\t" + " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. + " ldr q1, [x0, #112] \n\t" + " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #144] \n\t" + " \n\t" + " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #160] \n\t" + " \n\t" + " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #176] \n\t" + " add x1, x1, #192 \n\t" + " add x0, x0, #128 \n\t" + " \n\t" //End It 4 + " sub x5,x5,1 \n\t" // i-=1. + " cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. + BNE(SLOOPKITER) + " \n\t" + LABEL(SLASTITER) // Last iteration of k_iter loop. + " \n\t" + " \n\t" + " ldr q5, [x0] \n\t" + " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. + " ldr q6, [x0, #16] \n\t" + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1] \n\t" + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #16] \n\t" + " \n\t" + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #32] \n\t" + " \n\t" //End It 1 + " \n\t" + " ldr q0, [x0, #32] \n\t" + " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. + " ldr q1, [x0, #48] \n\t" + " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #48] \n\t" + " \n\t" + " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #64] \n\t" + " \n\t" + " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #80] \n\t" + " \n\t" //End It 2 + " \n\t" + " ldr q5, [x0, #64] \n\t" + " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. + " ldr q6, [x0, #80] \n\t" + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #96] \n\t" + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #112] \n\t" + " \n\t" + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #128] \n\t" + " \n\t" //End It 3 + " \n\t" + " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. + " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. + " add x1, x1, #144 \n\t" + " add x0, x0, #96 \n\t" + " \n\t" //End It 4 + " \n\t" + LABEL(SCONSIDERKLEFT) + " cmp x6,0 \n\t" // If k_left == 0, we are done. + BEQ(SPOSTACCUM) // else, we enter the k_left loop. + " \n\t" + LABEL(SLOOPKLEFT) // Body of the left iterations + " \n\t" + " ldr q0, [x0],#16 \n\t" + " ldr q1, [x0],#16 \n\t" // Load a + " \n\t" + " ldr q2, [x1],#16 \n\t" // Load b + " ldr q3, [x1],#16 \n\t" + " ldr q4, [x1],#16 \n\t" + " \n\t" + " sub x6,x6,1 \n\t" // i = i-1. + " \n\t" + " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " \n\t" + " cmp x6,0 \n\t" // Iterate again. + BNE(SLOOPKLEFT) // if i!=0. + " \n\t" + LABEL(SPOSTACCUM) + " \n\t" + " ldr x0,%[alpha] \n\t" // Alpha address. + " ldr x1,%[beta] \n\t" // Beta address. + " \n\t" + " ld1r {v6.4s},[x0] \n\t" // Load alpha. + " ld1r {v7.4s},[x1] \n\t" // Load beta + " \n\t" + " ldr x0,%[a_next] \n\t" // Pointer to next block of A. + " ldr x1,%[b_next] \n\t" // Pointer to next pointer of B. + " \n\t" + " cmp x14,#4 \n\t" // If rs_c != 1 (column-major) + BNE(SGENSTORED) + " \n\t" + LABEL(SCOLSTORED) // C is column-major. + " \n\t" + " dup v0.4s, wzr \n\t" + " dup v1.4s, wzr \n\t" + " dup v2.4s, wzr \n\t" + " dup v3.4s, wzr \n\t" + " dup v4.4s, wzr \n\t" + " dup v5.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROCOLSTOREDS1) // Taking care of the beta==0 case. + " \n\t" + " ldr q0, [x2] \n\t" //Load column 0 of C + " ldr q1, [x2, #16] \n\t" + " ldr q2, [x16] \n\t" //Load column 1 of C + " ldr q3, [x16, #16] \n\t" + " ldr q4, [x17] \n\t" //Load column 2 of C + " ldr q5, [x17, #16] \n\t" + " \n\t" + " fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta + " fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta + " fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta + " fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta + " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta + " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROCOLSTOREDS1) + " \n\t" + " fmla v0.4s,v8.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v1.4s,v9.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v2.4s,v10.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v3.4s,v11.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v4.4s,v12.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v5.4s,v13.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " str q0, [x2] \n\t" //Store column 0 of C + " str q1, [x2, #16] \n\t" + " str q2, [x16] \n\t" //Store column 1 of C + " str q3, [x16, #16] \n\t" + " str q4, [x17] \n\t" //Store column 2 of C + " str q5, [x17, #16] \n\t" + " \n\t" + " dup v8.4s, wzr \n\t" + " dup v9.4s, wzr \n\t" + " dup v10.4s, wzr \n\t" + " dup v11.4s, wzr \n\t" + " dup v12.4s, wzr \n\t" + " dup v13.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROCOLSTOREDS2) // Taking care of the beta==0 case. + " \n\t" + " ldr q8, [x19] \n\t" //Load column 3 of C + " ldr q9, [x19, #16] \n\t" + " ldr q10, [x20] \n\t" //Load column 4 of C + " ldr q11, [x20, #16] \n\t" + " ldr q12, [x21] \n\t" //Load column 5 of C + " ldr q13, [x21, #16] \n\t" + " \n\t" + " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta + " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta + " fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta + " fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta + " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta + " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROCOLSTOREDS2) + " \n\t" + " fmla v8.4s, v14.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v9.4s, v15.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v10.4s,v16.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v11.4s,v17.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v12.4s,v18.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v13.4s,v19.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " str q8, [x19] \n\t" //Store column 3 of C + " str q9, [x19, #16] \n\t" + " str q10, [x20] \n\t" //Store column 4 of C + " str q11, [x20, #16] \n\t" + " str q12, [x21] \n\t" //Store column 5 of C + " str q13, [x21, #16] \n\t" + " \n\t" + " dup v0.4s, wzr \n\t" + " dup v1.4s, wzr \n\t" + " dup v2.4s, wzr \n\t" + " dup v3.4s, wzr \n\t" + " dup v4.4s, wzr \n\t" + " dup v5.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROCOLSTOREDS3) // Taking care of the beta==0 case. + " \n\t" + " ldr q0, [x22] \n\t" //Load column 6 of C + " ldr q1, [x22, #16] \n\t" + " ldr q2, [x23] \n\t" //Load column 7 of C + " ldr q3, [x23, #16] \n\t" + " ldr q4, [x24] \n\t" //Load column 8 of C + " ldr q5, [x24, #16] \n\t" + " \n\t" + " fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta + " fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta + " fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta + " fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta + " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta + " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROCOLSTOREDS3) + " \n\t" + " fmla v0.4s,v20.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v1.4s,v21.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v2.4s,v22.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v3.4s,v23.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v4.4s,v24.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v5.4s,v25.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " str q0, [x22] \n\t" //Store column 6 of C + " str q1, [x22, #16] \n\t" + " str q2, [x23] \n\t" //Store column 7 of C + " str q3, [x23, #16] \n\t" + " str q4, [x24] \n\t" //Store column 8 of C + " str q5, [x24, #16] \n\t" + " \n\t" + " dup v8.4s, wzr \n\t" + " dup v9.4s, wzr \n\t" + " dup v10.4s, wzr \n\t" + " dup v11.4s, wzr \n\t" + " dup v12.4s, wzr \n\t" + " dup v13.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROCOLSTOREDS4) // Taking care of the beta==0 case. + " \n\t" + " ldr q8, [x25] \n\t" //Load column 9 of C + " ldr q9, [x25, #16] \n\t" + " ldr q10, [x26] \n\t" //Load column 10 of C + " ldr q11, [x26, #16] \n\t" + " ldr q12, [x27] \n\t" //Load column 11 of C + " ldr q13, [x27, #16] \n\t" + " \n\t" + " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta + " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta + " fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta + " fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta + " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta + " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROCOLSTOREDS4) + " \n\t" + " prfm pldl2keep,[x0] \n\t" + " prfm pldl2keep,[x1] \n\t" + " \n\t" + " fmla v8.4s, v26.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v9.4s, v27.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v10.4s,v28.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v11.4s,v29.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v12.4s,v30.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v13.4s,v31.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " str q8, [x25] \n\t" //Store column 9 of C + " str q9, [x25, #16] \n\t" + " str q10, [x26] \n\t" //Store column 10 of C + " str q11, [x26, #16] \n\t" + " str q12, [x27] \n\t" //Store column 11 of C + " str q13, [x27, #16] \n\t" + " \n\t" + " \n\t" + BRANCH(SEND) // Done. + " \n\t" + " \n\t" + LABEL(SGENSTORED) // C is general-stride stored. + " \n\t" + " \n\t" + " dup v0.4s, wzr \n\t" + " dup v1.4s, wzr \n\t" + " dup v2.4s, wzr \n\t" + " dup v3.4s, wzr \n\t" + " dup v4.4s, wzr \n\t" + " dup v5.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROGENSTOREDS1) // Taking care of the beta==0 case. + " \n\t" + " mov x5, x2 \n\t" + " \n\t" + " ld1 {v0.s}[0],[x5],x14 \n\t" // Load c00 into quad and increment by rs_c. + " ld1 {v0.s}[1],[x5],x14 \n\t" // Load c01 into quad and increment by rs_c. + " ld1 {v0.s}[2],[x5],x14 \n\t" // Load c02 into quad and increment by rs_c. + " ld1 {v0.s}[3],[x5],x14 \n\t" // Load c03 into quad and increment by rs_c. + " ld1 {v1.s}[0],[x5],x14 \n\t" // Load c04 into quad and increment by rs_c. + " ld1 {v1.s}[1],[x5],x14 \n\t" // Load c05 into quad and increment by rs_c. + " ld1 {v1.s}[2],[x5],x14 \n\t" // Load c06 into quad and increment by rs_c. + " ld1 {v1.s}[3],[x5],x14 \n\t" // Load c07 into quad and increment by rs_c. + " \n\t" + " mov x5, x16 \n\t" + " \n\t" + " ld1 {v2.s}[0],[x5],x14 \n\t" // Load c10 into quad and increment by rs_c. + " ld1 {v2.s}[1],[x5],x14 \n\t" // Load c11 into quad and increment by rs_c. + " ld1 {v2.s}[2],[x5],x14 \n\t" // Load c12 into quad and increment by rs_c. + " ld1 {v2.s}[3],[x5],x14 \n\t" // Load c13 into quad and increment by rs_c. + " ld1 {v3.s}[0],[x5],x14 \n\t" // Load c14 into quad and increment by rs_c. + " ld1 {v3.s}[1],[x5],x14 \n\t" // Load c15 into quad and increment by rs_c. + " ld1 {v3.s}[2],[x5],x14 \n\t" // Load c16 into quad and increment by rs_c. + " ld1 {v3.s}[3],[x5],x14 \n\t" // Load c17 into quad and increment by rs_c. + " \n\t" + " mov x5, x17 \n\t" + " \n\t" + " ld1 {v4.s}[0],[x5],x14 \n\t" // Load c20 into quad and increment by rs_c. + " ld1 {v4.s}[1],[x5],x14 \n\t" // Load c21 into quad and increment by rs_c. + " ld1 {v4.s}[2],[x5],x14 \n\t" // Load c22 into quad and increment by rs_c. + " ld1 {v4.s}[3],[x5],x14 \n\t" // Load c23 into quad and increment by rs_c. + " ld1 {v5.s}[0],[x5],x14 \n\t" // Load c24 into quad and increment by rs_c. + " ld1 {v5.s}[1],[x5],x14 \n\t" // Load c25 into quad and increment by rs_c. + " ld1 {v5.s}[2],[x5],x14 \n\t" // Load c26 into quad and increment by rs_c. + " ld1 {v5.s}[3],[x5],x14 \n\t" // Load c27 into quad and increment by rs_c. + " \n\t" + " fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta + " fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta + " fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta + " fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta + " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta + " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROGENSTOREDS1) + " \n\t" + " fmla v0.4s, v8.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v1.4s, v9.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v2.4s,v10.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v3.4s,v11.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v4.4s,v12.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v5.4s,v13.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " mov x5, x2 \n\t" + " \n\t" + " st1 {v0.s}[0],[x5],x14 \n\t" // Store c00 into quad and increment by rs_c. + " st1 {v0.s}[1],[x5],x14 \n\t" // Store c01 into quad and increment by rs_c. + " st1 {v0.s}[2],[x5],x14 \n\t" // Store c02 into quad and increment by rs_c. + " st1 {v0.s}[3],[x5],x14 \n\t" // Store c03 into quad and increment by rs_c. + " st1 {v1.s}[0],[x5],x14 \n\t" // Store c04 into quad and increment by rs_c. + " st1 {v1.s}[1],[x5],x14 \n\t" // Store c05 into quad and increment by rs_c. + " st1 {v1.s}[2],[x5],x14 \n\t" // Store c06 into quad and increment by rs_c. + " st1 {v1.s}[3],[x5],x14 \n\t" // Store c07 into quad and increment by rs_c. + " \n\t" + " mov x5, x16 \n\t" + " \n\t" + " st1 {v2.s}[0],[x5],x14 \n\t" // Store c10 into quad and increment by rs_c. + " st1 {v2.s}[1],[x5],x14 \n\t" // Store c11 into quad and increment by rs_c. + " st1 {v2.s}[2],[x5],x14 \n\t" // Store c12 into quad and increment by rs_c. + " st1 {v2.s}[3],[x5],x14 \n\t" // Store c13 into quad and increment by rs_c. + " st1 {v3.s}[0],[x5],x14 \n\t" // Store c14 into quad and increment by rs_c. + " st1 {v3.s}[1],[x5],x14 \n\t" // Store c15 into quad and increment by rs_c. + " st1 {v3.s}[2],[x5],x14 \n\t" // Store c16 into quad and increment by rs_c. + " st1 {v3.s}[3],[x5],x14 \n\t" // Store c17 into quad and increment by rs_c. + " \n\t" + " mov x5, x17 \n\t" + " \n\t" + " st1 {v4.s}[0],[x5],x14 \n\t" // Store c20 into quad and increment by rs_c. + " st1 {v4.s}[1],[x5],x14 \n\t" // Store c21 into quad and increment by rs_c. + " st1 {v4.s}[2],[x5],x14 \n\t" // Store c22 into quad and increment by rs_c. + " st1 {v4.s}[3],[x5],x14 \n\t" // Store c23 into quad and increment by rs_c. + " st1 {v5.s}[0],[x5],x14 \n\t" // Store c24 into quad and increment by rs_c. + " st1 {v5.s}[1],[x5],x14 \n\t" // Store c25 into quad and increment by rs_c. + " st1 {v5.s}[2],[x5],x14 \n\t" // Store c26 into quad and increment by rs_c. + " st1 {v5.s}[3],[x5],x14 \n\t" // Store c27 into quad and increment by rs_c. + " \n\t" + " dup v8.4s, wzr \n\t" + " dup v9.4s, wzr \n\t" + " dup v10.4s, wzr \n\t" + " dup v11.4s, wzr \n\t" + " dup v12.4s, wzr \n\t" + " dup v13.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROGENSTOREDS2) // Taking care of the beta==0 case. + " \n\t" + " mov x5, x19 \n\t" + " \n\t" + " ld1 {v8.s}[0],[x5],x14 \n\t" // Load c30 into quad and increment by rs_c. + " ld1 {v8.s}[1],[x5],x14 \n\t" // Load c31 into quad and increment by rs_c. + " ld1 {v8.s}[2],[x5],x14 \n\t" // Load c32 into quad and increment by rs_c. + " ld1 {v8.s}[3],[x5],x14 \n\t" // Load c33 into quad and increment by rs_c. + " ld1 {v9.s}[0],[x5],x14 \n\t" // Load c34 into quad and increment by rs_c. + " ld1 {v9.s}[1],[x5],x14 \n\t" // Load c35 into quad and increment by rs_c. + " ld1 {v9.s}[2],[x5],x14 \n\t" // Load c36 into quad and increment by rs_c. + " ld1 {v9.s}[3],[x5],x14 \n\t" // Load c37 into quad and increment by rs_c. + " \n\t" + " mov x5, x20 \n\t" + " \n\t" + " ld1 {v10.s}[0],[x5],x14 \n\t" // Load c40 into quad and increment by rs_c. + " ld1 {v10.s}[1],[x5],x14 \n\t" // Load c41 into quad and increment by rs_c. + " ld1 {v10.s}[2],[x5],x14 \n\t" // Load c42 into quad and increment by rs_c. + " ld1 {v10.s}[3],[x5],x14 \n\t" // Load c43 into quad and increment by rs_c. + " ld1 {v11.s}[0],[x5],x14 \n\t" // Load c44 into quad and increment by rs_c. + " ld1 {v11.s}[1],[x5],x14 \n\t" // Load c45 into quad and increment by rs_c. + " ld1 {v11.s}[2],[x5],x14 \n\t" // Load c46 into quad and increment by rs_c. + " ld1 {v11.s}[3],[x5],x14 \n\t" // Load c47 into quad and increment by rs_c. + " \n\t" + " mov x5, x21 \n\t" + " \n\t" + " ld1 {v12.s}[0],[x5],x14 \n\t" // Load c50 into quad and increment by rs_c. + " ld1 {v12.s}[1],[x5],x14 \n\t" // Load c51 into quad and increment by rs_c. + " ld1 {v12.s}[2],[x5],x14 \n\t" // Load c52 into quad and increment by rs_c. + " ld1 {v12.s}[3],[x5],x14 \n\t" // Load c53 into quad and increment by rs_c. + " ld1 {v13.s}[0],[x5],x14 \n\t" // Load c54 into quad and increment by rs_c. + " ld1 {v13.s}[1],[x5],x14 \n\t" // Load c55 into quad and increment by rs_c. + " ld1 {v13.s}[2],[x5],x14 \n\t" // Load c56 into quad and increment by rs_c. + " ld1 {v13.s}[3],[x5],x14 \n\t" // Load c57 into quad and increment by rs_c. + " \n\t" + " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta + " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta + " fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta + " fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta + " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta + " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROGENSTOREDS2) + " \n\t" + " fmla v8.4s, v14.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v9.4s, v15.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v10.4s,v16.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v11.4s,v17.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v12.4s,v18.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v13.4s,v19.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " mov x5, x19 \n\t" + " \n\t" + " st1 {v8.s}[0],[x5],x14 \n\t" // Store c30 into quad and increment by rs_c. + " st1 {v8.s}[1],[x5],x14 \n\t" // Store c31 into quad and increment by rs_c. + " st1 {v8.s}[2],[x5],x14 \n\t" // Store c32 into quad and increment by rs_c. + " st1 {v8.s}[3],[x5],x14 \n\t" // Store c33 into quad and increment by rs_c. + " st1 {v9.s}[0],[x5],x14 \n\t" // Store c34 into quad and increment by rs_c. + " st1 {v9.s}[1],[x5],x14 \n\t" // Store c35 into quad and increment by rs_c. + " st1 {v9.s}[2],[x5],x14 \n\t" // Store c36 into quad and increment by rs_c. + " st1 {v9.s}[3],[x5],x14 \n\t" // Store c37 into quad and increment by rs_c. + " \n\t" + " mov x5, x20 \n\t" + " \n\t" + " st1 {v10.s}[0],[x5],x14 \n\t" // Store c40 into quad and increment by rs_c. + " st1 {v10.s}[1],[x5],x14 \n\t" // Store c41 into quad and increment by rs_c. + " st1 {v10.s}[2],[x5],x14 \n\t" // Store c42 into quad and increment by rs_c. + " st1 {v10.s}[3],[x5],x14 \n\t" // Store c43 into quad and increment by rs_c. + " st1 {v11.s}[0],[x5],x14 \n\t" // Store c44 into quad and increment by rs_c. + " st1 {v11.s}[1],[x5],x14 \n\t" // Store c45 into quad and increment by rs_c. + " st1 {v11.s}[2],[x5],x14 \n\t" // Store c46 into quad and increment by rs_c. + " st1 {v11.s}[3],[x5],x14 \n\t" // Store c47 into quad and increment by rs_c. + " \n\t" + " mov x5, x21 \n\t" + " \n\t" + " st1 {v12.s}[0],[x5],x14 \n\t" // Store c50 into quad and increment by rs_c. + " st1 {v12.s}[1],[x5],x14 \n\t" // Store c51 into quad and increment by rs_c. + " st1 {v12.s}[2],[x5],x14 \n\t" // Store c52 into quad and increment by rs_c. + " st1 {v12.s}[3],[x5],x14 \n\t" // Store c53 into quad and increment by rs_c. + " st1 {v13.s}[0],[x5],x14 \n\t" // Store c54 into quad and increment by rs_c. + " st1 {v13.s}[1],[x5],x14 \n\t" // Store c55 into quad and increment by rs_c. + " st1 {v13.s}[2],[x5],x14 \n\t" // Store c56 into quad and increment by rs_c. + " st1 {v13.s}[3],[x5],x14 \n\t" // Store c57 into quad and increment by rs_c. + " \n\t" + " dup v0.4s, wzr \n\t" + " dup v1.4s, wzr \n\t" + " dup v2.4s, wzr \n\t" + " dup v3.4s, wzr \n\t" + " dup v4.4s, wzr \n\t" + " dup v5.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROGENSTOREDS3) // Taking care of the beta==0 case. + " \n\t" + " mov x5, x22 \n\t" + " \n\t" + " ld1 {v0.s}[0],[x5],x14 \n\t" // Load c60 into quad and increment by rs_c. + " ld1 {v0.s}[1],[x5],x14 \n\t" // Load c61 into quad and increment by rs_c. + " ld1 {v0.s}[2],[x5],x14 \n\t" // Load c62 into quad and increment by rs_c. + " ld1 {v0.s}[3],[x5],x14 \n\t" // Load c63 into quad and increment by rs_c. + " ld1 {v1.s}[0],[x5],x14 \n\t" // Load c64 into quad and increment by rs_c. + " ld1 {v1.s}[1],[x5],x14 \n\t" // Load c65 into quad and increment by rs_c. + " ld1 {v1.s}[2],[x5],x14 \n\t" // Load c66 into quad and increment by rs_c. + " ld1 {v1.s}[3],[x5],x14 \n\t" // Load c67 into quad and increment by rs_c. + " \n\t" + " mov x5, x23 \n\t" + " \n\t" + " ld1 {v2.s}[0],[x5],x14 \n\t" // Load c70 into quad and increment by rs_c. + " ld1 {v2.s}[1],[x5],x14 \n\t" // Load c71 into quad and increment by rs_c. + " ld1 {v2.s}[2],[x5],x14 \n\t" // Load c72 into quad and increment by rs_c. + " ld1 {v2.s}[3],[x5],x14 \n\t" // Load c73 into quad and increment by rs_c. + " ld1 {v3.s}[0],[x5],x14 \n\t" // Load c74 into quad and increment by rs_c. + " ld1 {v3.s}[1],[x5],x14 \n\t" // Load c75 into quad and increment by rs_c. + " ld1 {v3.s}[2],[x5],x14 \n\t" // Load c76 into quad and increment by rs_c. + " ld1 {v3.s}[3],[x5],x14 \n\t" // Load c77 into quad and increment by rs_c. + " \n\t" + " mov x5, x24 \n\t" + " \n\t" + " ld1 {v4.s}[0],[x5],x14 \n\t" // Load c80 into quad and increment by rs_c. + " ld1 {v4.s}[1],[x5],x14 \n\t" // Load c81 into quad and increment by rs_c. + " ld1 {v4.s}[2],[x5],x14 \n\t" // Load c82 into quad and increment by rs_c. + " ld1 {v4.s}[3],[x5],x14 \n\t" // Load c83 into quad and increment by rs_c. + " ld1 {v5.s}[0],[x5],x14 \n\t" // Load c84 into quad and increment by rs_c. + " ld1 {v5.s}[1],[x5],x14 \n\t" // Load c85 into quad and increment by rs_c. + " ld1 {v5.s}[2],[x5],x14 \n\t" // Load c86 into quad and increment by rs_c. + " ld1 {v5.s}[3],[x5],x14 \n\t" // Load c87 into quad and increment by rs_c. + " \n\t" + " fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta + " fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta + " fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta + " fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta + " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta + " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROGENSTOREDS3) + " \n\t" + " fmla v0.4s,v20.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v1.4s,v21.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v2.4s,v22.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v3.4s,v23.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v4.4s,v24.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v5.4s,v25.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " mov x5, x22 \n\t" + " \n\t" + " st1 {v0.s}[0],[x5],x14 \n\t" // Store c60 into quad and increment by rs_c. + " st1 {v0.s}[1],[x5],x14 \n\t" // Store c61 into quad and increment by rs_c. + " st1 {v0.s}[2],[x5],x14 \n\t" // Store c62 into quad and increment by rs_c. + " st1 {v0.s}[3],[x5],x14 \n\t" // Store c63 into quad and increment by rs_c. + " st1 {v1.s}[0],[x5],x14 \n\t" // Store c64 into quad and increment by rs_c. + " st1 {v1.s}[1],[x5],x14 \n\t" // Store c65 into quad and increment by rs_c. + " st1 {v1.s}[2],[x5],x14 \n\t" // Store c66 into quad and increment by rs_c. + " st1 {v1.s}[3],[x5],x14 \n\t" // Store c67 into quad and increment by rs_c. + " \n\t" + " mov x5, x23 \n\t" + " \n\t" + " st1 {v2.s}[0],[x5],x14 \n\t" // Store c70 into quad and increment by rs_c. + " st1 {v2.s}[1],[x5],x14 \n\t" // Store c71 into quad and increment by rs_c. + " st1 {v2.s}[2],[x5],x14 \n\t" // Store c72 into quad and increment by rs_c. + " st1 {v2.s}[3],[x5],x14 \n\t" // Store c73 into quad and increment by rs_c. + " st1 {v3.s}[0],[x5],x14 \n\t" // Store c74 into quad and increment by rs_c. + " st1 {v3.s}[1],[x5],x14 \n\t" // Store c75 into quad and increment by rs_c. + " st1 {v3.s}[2],[x5],x14 \n\t" // Store c76 into quad and increment by rs_c. + " st1 {v3.s}[3],[x5],x14 \n\t" // Store c77 into quad and increment by rs_c. + " \n\t" + " mov x5, x24 \n\t" + " \n\t" + " st1 {v4.s}[0],[x5],x14 \n\t" // Store c80 into quad and increment by rs_c. + " st1 {v4.s}[1],[x5],x14 \n\t" // Store c81 into quad and increment by rs_c. + " st1 {v4.s}[2],[x5],x14 \n\t" // Store c82 into quad and increment by rs_c. + " st1 {v4.s}[3],[x5],x14 \n\t" // Store c83 into quad and increment by rs_c. + " st1 {v5.s}[0],[x5],x14 \n\t" // Store c84 into quad and increment by rs_c. + " st1 {v5.s}[1],[x5],x14 \n\t" // Store c85 into quad and increment by rs_c. + " st1 {v5.s}[2],[x5],x14 \n\t" // Store c86 into quad and increment by rs_c. + " st1 {v5.s}[3],[x5],x14 \n\t" // Store c87 into quad and increment by rs_c. + " \n\t" + " dup v8.4s, wzr \n\t" + " dup v9.4s, wzr \n\t" + " dup v10.4s, wzr \n\t" + " dup v11.4s, wzr \n\t" + " dup v12.4s, wzr \n\t" + " dup v13.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROGENSTOREDS4) // Taking care of the beta==0 case. + " \n\t" + " mov x5, x25 \n\t" + " \n\t" + " ld1 {v8.s}[0],[x5],x14 \n\t" // Load c90 into quad and increment by rs_c. + " ld1 {v8.s}[1],[x5],x14 \n\t" // Load c91 into quad and increment by rs_c. + " ld1 {v8.s}[2],[x5],x14 \n\t" // Load c92 into quad and increment by rs_c. + " ld1 {v8.s}[3],[x5],x14 \n\t" // Load c93 into quad and increment by rs_c. + " ld1 {v9.s}[0],[x5],x14 \n\t" // Load c94 into quad and increment by rs_c. + " ld1 {v9.s}[1],[x5],x14 \n\t" // Load c95 into quad and increment by rs_c. + " ld1 {v9.s}[2],[x5],x14 \n\t" // Load c96 into quad and increment by rs_c. + " ld1 {v9.s}[3],[x5],x14 \n\t" // Load c97 into quad and increment by rs_c. + " \n\t" + " mov x5, x26 \n\t" + " \n\t" + " ld1 {v10.s}[0],[x5],x14 \n\t" // Load c100 into quad and increment by rs_c. + " ld1 {v10.s}[1],[x5],x14 \n\t" // Load c101 into quad and increment by rs_c. + " ld1 {v10.s}[2],[x5],x14 \n\t" // Load c102 into quad and increment by rs_c. + " ld1 {v10.s}[3],[x5],x14 \n\t" // Load c103 into quad and increment by rs_c. + " ld1 {v11.s}[0],[x5],x14 \n\t" // Load c104 into quad and increment by rs_c. + " ld1 {v11.s}[1],[x5],x14 \n\t" // Load c105 into quad and increment by rs_c. + " ld1 {v11.s}[2],[x5],x14 \n\t" // Load c106 into quad and increment by rs_c. + " ld1 {v11.s}[3],[x5],x14 \n\t" // Load c107 into quad and increment by rs_c. + " \n\t" + " mov x5, x27 \n\t" + " \n\t" + " ld1 {v12.s}[0],[x5],x14 \n\t" // Load c110 into quad and increment by rs_c. + " ld1 {v12.s}[1],[x5],x14 \n\t" // Load c111 into quad and increment by rs_c. + " ld1 {v12.s}[2],[x5],x14 \n\t" // Load c112 into quad and increment by rs_c. + " ld1 {v12.s}[3],[x5],x14 \n\t" // Load c113 into quad and increment by rs_c. + " ld1 {v13.s}[0],[x5],x14 \n\t" // Load c114 into quad and increment by rs_c. + " ld1 {v13.s}[1],[x5],x14 \n\t" // Load c115 into quad and increment by rs_c. + " ld1 {v13.s}[2],[x5],x14 \n\t" // Load c116 into quad and increment by rs_c. + " ld1 {v13.s}[3],[x5],x14 \n\t" // Load c117 into quad and increment by rs_c. + " \n\t" + " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta + " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta + " fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta + " fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta + " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta + " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROGENSTOREDS4) + " \n\t" + " prfm pldl2keep,[x0] \n\t" + " prfm pldl2keep,[x1] \n\t" + " \n\t" + " fmla v8.4s, v26.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v9.4s, v27.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v10.4s,v28.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v11.4s,v29.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v12.4s,v30.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v13.4s,v31.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " mov x5, x25 \n\t" + " \n\t" + " st1 {v8.s}[0],[x5],x14 \n\t" // Store c90 into quad and increment by rs_c. + " st1 {v8.s}[1],[x5],x14 \n\t" // Store c91 into quad and increment by rs_c. + " st1 {v8.s}[2],[x5],x14 \n\t" // Store c92 into quad and increment by rs_c. + " st1 {v8.s}[3],[x5],x14 \n\t" // Store c93 into quad and increment by rs_c. + " st1 {v9.s}[0],[x5],x14 \n\t" // Store c94 into quad and increment by rs_c. + " st1 {v9.s}[1],[x5],x14 \n\t" // Store c95 into quad and increment by rs_c. + " st1 {v9.s}[2],[x5],x14 \n\t" // Store c96 into quad and increment by rs_c. + " st1 {v9.s}[3],[x5],x14 \n\t" // Store c97 into quad and increment by rs_c. + " \n\t" + " mov x5, x26 \n\t" + " \n\t" + " st1 {v10.s}[0],[x5],x14 \n\t" // Store c100 into quad and increment by rs_c. + " st1 {v10.s}[1],[x5],x14 \n\t" // Store c101 into quad and increment by rs_c. + " st1 {v10.s}[2],[x5],x14 \n\t" // Store c102 into quad and increment by rs_c. + " st1 {v10.s}[3],[x5],x14 \n\t" // Store c103 into quad and increment by rs_c. + " st1 {v11.s}[0],[x5],x14 \n\t" // Store c104 into quad and increment by rs_c. + " st1 {v11.s}[1],[x5],x14 \n\t" // Store c105 into quad and increment by rs_c. + " st1 {v11.s}[2],[x5],x14 \n\t" // Store c106 into quad and increment by rs_c. + " st1 {v11.s}[3],[x5],x14 \n\t" // Store c107 into quad and increment by rs_c. + " \n\t" + " mov x5, x27 \n\t" + " \n\t" + " st1 {v12.s}[0],[x5],x14 \n\t" // Store c110 into quad and increment by rs_c. + " st1 {v12.s}[1],[x5],x14 \n\t" // Store c111 into quad and increment by rs_c. + " st1 {v12.s}[2],[x5],x14 \n\t" // Store c112 into quad and increment by rs_c. + " st1 {v12.s}[3],[x5],x14 \n\t" // Store c113 into quad and increment by rs_c. + " st1 {v13.s}[0],[x5],x14 \n\t" // Store c114 into quad and increment by rs_c. + " st1 {v13.s}[1],[x5],x14 \n\t" // Store c115 into quad and increment by rs_c. + " st1 {v13.s}[2],[x5],x14 \n\t" // Store c116 into quad and increment by rs_c. + " st1 {v13.s}[3],[x5],x14 \n\t" // Store c147 into quad and increment by rs_c. + " \n\t" + LABEL(SEND) // Done! + " \n\t" + :// output operands (none) + :// input operands + [aaddr] "m" (a), // 0 + [baddr] "m" (b), // 1 + [caddr] "m" (c), // 2 + [k_iter] "m" (k_iter), // 3 + [k_left] "m" (k_left), // 4 + [alpha] "m" (alpha), // 5 + [beta] "m" (beta), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [a_next] "m" (a_next), // 9 + [b_next] "m" (b_next) // 10 + :// Register clobber list + "x0", "x1", "x2", + "x5", "x6", "x10","x14", + "x16","x17","x19","x20", + "x21","x22","x23","x24", + "x25","x26","x27", + "v0", "v1", "v2", "v3", + "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11", + "v12","v13","v14","v15", + "v16","v17","v18","v19", + "v20","v21","v22","v23", + "v24","v25","v26","v27", + "v28","v29","v30","v31" + ); + + GEMM_UKR_FLUSH_CT( s ); } @@ -1089,24 +1094,26 @@ LABEL(SEND) // Done! o 4x4 Double precision micro-kernel NOT fully functional yet. o Runnable on ARMv8, compiled with aarch64 GCC. o Use it together with the armv8 BLIS configuration. - o Tested on Juno board. Around 3 GFLOPS @ 1.1 GHz. + o Tested on Juno board. Around 3 GFLOPS @ 1.1 GHz. December 2014. - + * UPDATE OCTOBER 2015: Now is fully functional. * Tested on Juno board. Around 5.6 GFLOPS, 2 A57 cores @ 1.1 GHz. * Tested on Juno board. Around 4 GFLOPS, 4 A53 cores @ 850 MHz. - + * UPDATE NOVEMBER 2015 * Micro-kernel changed to 6x8 * Tested on Juno Board. Around 4 GFLOPS, 1 x A57 core @ 1.1 GHz. * Tested on Juno Board. Around 7.6 GFLOPS, 2 x A57 cores @ 1.1 GHz. - * Tested on Juno board. Around 1.5 GFLOPS, 1 x A53 core @ 850 MHz. + * Tested on Juno board. Around 1.5 GFLOPS, 1 x A53 core @ 850 MHz. * Tested on Juno board. Around 5.5 GFLOPS, 4 x A53 cores @ 850 MHz. */ void bli_dgemm_armv8a_asm_6x8 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -1121,966 +1128,969 @@ void bli_dgemm_armv8a_asm_6x8 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; -__asm__ volatile -( -" \n\t" -" ldr x0,%[aaddr] \n\t" // Load address of A -" ldr x1,%[baddr] \n\t" // Load address of B -" ldr x2,%[caddr] \n\t" // Load address of C -" \n\t" -" ldr x5,%[k_iter] \n\t" // Init guard (k_iter) -" ldr x6,%[k_left] \n\t" // Init guard (k_iter) -" \n\t" -" ldr x10,%[cs_c] \n\t" // Load cs_c -" lsl x10,x10,#3 \n\t" // cs_c * sizeof(double) -" \n\t" -" ldr x14,%[rs_c] \n\t" // Load rs_c. -" lsl x14,x14,#3 \n\t" // rs_c * sizeof(double). -" \n\t" -" add x20,x2,x10 \n\t" //Load address Column 1 of C -" add x21,x20,x10 \n\t" //Load address Column 2 of C -" add x22,x21,x10 \n\t" //Load address Column 3 of C -" add x23,x22,x10 \n\t" //Load address Column 4 of C -" add x24,x23,x10 \n\t" //Load address Column 5 of C -" add x25,x24,x10 \n\t" //Load address Column 6 of C -" add x26,x25,x10 \n\t" //Load address Column 7 of C -" \n\t" -" prfm pldl1keep,[x2] \n\t" // Prefetch c. -" prfm pldl1keep,[x20] \n\t" // Prefetch c. -" prfm pldl1keep,[x21] \n\t" // Prefetch c. -" prfm pldl1keep,[x22] \n\t" // Prefetch c. -" prfm pldl1keep,[x23] \n\t" // Prefetch c. -" prfm pldl1keep,[x24] \n\t" // Prefetch c. -" prfm pldl1keep,[x25] \n\t" // Prefetch c. -" prfm pldl1keep,[x26] \n\t" // Prefetch c. -" \n\t" -" dup v8.2d, xzr \n\t" // Vector for accummulating column 0 -" prfm PLDL1KEEP, [x1, #256] \n\t" -" dup v9.2d, xzr \n\t" // Vector for accummulating column 0 -" prfm PLDL1KEEP, [x1, #320] \n\t" -" dup v10.2d, xzr \n\t" // Vector for accummulating column 0 -" prfm PLDL1KEEP, [x1, #384] \n\t" -" dup v11.2d, xzr \n\t" // Vector for accummulating column 1 -" prfm PLDL1KEEP, [x1, #448] \n\t" -" dup v12.2d, xzr \n\t" // Vector for accummulating column 1 -" dup v13.2d, xzr \n\t" // Vector for accummulating column 1 -" \n\t" -" dup v14.2d, xzr \n\t" // Vector for accummulating column 2 -" prfm PLDL1KEEP, [x0, #192] \n\t" -" dup v15.2d, xzr \n\t" // Vector for accummulating column 2 -" prfm PLDL1KEEP, [x0, #256] \n\t" -" dup v16.2d, xzr \n\t" // Vector for accummulating column 2 -" prfm PLDL1KEEP, [x0, #320] \n\t" -" dup v17.2d, xzr \n\t" // Vector for accummulating column 3 -" dup v18.2d, xzr \n\t" // Vector for accummulating column 3 -" dup v19.2d, xzr \n\t" // Vector for accummulating column 3 -" \n\t" -" dup v20.2d, xzr \n\t" // Vector for accummulating column 4 -" dup v21.2d, xzr \n\t" // Vector for accummulating column 4 -" dup v22.2d, xzr \n\t" // Vector for accummulating column 4 -" dup v23.2d, xzr \n\t" // Vector for accummulating column 5 -" dup v24.2d, xzr \n\t" // Vector for accummulating column 5 -" dup v25.2d, xzr \n\t" // Vector for accummulating column 5 -" \n\t" -" dup v26.2d, xzr \n\t" // Vector for accummulating column 6 -" dup v27.2d, xzr \n\t" // Vector for accummulating column 6 -" dup v28.2d, xzr \n\t" // Vector for accummulating column 6 -" dup v29.2d, xzr \n\t" // Vector for accummulating column 7 -" dup v30.2d, xzr \n\t" // Vector for accummulating column 7 -" dup v31.2d, xzr \n\t" // Vector for accummulating column 7 -" \n\t" -" \n\t" -" cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. -BEQ(DCONSIDERKLEFT) -" \n\t" -" ldr q0, [x0] \n\t" // Load a -" ldr q1, [x0, #16] \n\t" -" ldr q2, [x0, #32] \n\t" -" \n\t" -" ldr q3, [x1] \n\t" // Load b -" ldr q4, [x1, #16] \n\t" -" ldr q5, [x1, #32] \n\t" -" ldr q6, [x1, #48] \n\t" -" \n\t" -" add x0, x0, #48 \n\t" //update address of A -" add x1, x1, #64 \n\t" //update address of B -" \n\t" -" cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. -BEQ(DLASTITER) // (as loop is do-while-like). -" \n\t" -LABEL(DLOOP) // Body -" \n\t" -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x1, #448] \n\t" //512-64=448 -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x1, #512] \n\t" -" fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x1, #576] \n\t" -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate -" ldr q3, [x1] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate -" ldr q7, [x0, #32] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate -" ldr q4, [x1, #16] \n\t" -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #32] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #16] \n\t" -" \n\t" -" fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #48] \n\t" -" \n\t" // End it 1 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x1, #640] \n\t" -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x0, #336] \n\t" -" fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x0, #400] \n\t" -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate -" ldr q3, [x1, #64] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate -" ldr q2, [x0, #80] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate -" ldr q4, [x1, #80] \n\t" -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #96] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0, #48] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #64] \n\t" -" \n\t" -" fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #112] \n\t" -" \n\t" //End it 2 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x0, #464] \n\t" -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate -" ldr q3, [x1, #128] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate -" ldr q7, [x0, #128] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate -" ldr q4, [x1, #144] \n\t" -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #160] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0, #96] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #112] \n\t" -" \n\t" -" fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #176] \n\t" -" \n\t" // End it 3 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate -" ldr q3, [x1, #192] \n\t" -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate -" ldr q2, [x0, #176] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate -" ldr q4, [x1, #208] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #224] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0, #144] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #160] \n\t" -" \n\t" -" fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #240] \n\t" -" \n\t" //End it 4 -" add x0, x0, #192 \n\t" -" add x1, x1, #256 \n\t" -" \n\t" -" sub x5,x5,1 \n\t" // i-=1 -" cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. -BNE(DLOOP) -" \n\t" -LABEL(DLASTITER) -" \n\t" -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate -" ldr q3, [x1] \n\t" -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate -" ldr q7, [x0, #32] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate -" ldr q4, [x1, #16] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #32] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #16] \n\t" -" \n\t" -" fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #48] \n\t" -" \n\t" // End it 1 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate -" ldr q3, [x1, #64] \n\t" -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate -" ldr q2, [x0, #80] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate -" ldr q4, [x1, #80] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #96] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0, #48] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #64] \n\t" -" \n\t" -" fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #112] \n\t" -" \n\t" //End it 2 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate -" ldr q3, [x1, #128] \n\t" -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate -" ldr q7, [x0, #128] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate -" ldr q4, [x1, #144] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #160] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0, #96] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #112] \n\t" -" \n\t" -" fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #176] \n\t" -" \n\t" // End it 3 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" add x1, x1, #192 \n\t" -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" \n\t" -" fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate -" \n\t" //End it 4 -" add x0, x0, #144 \n\t" -" \n\t" -LABEL(DCONSIDERKLEFT) -" cmp x6,0 \n\t" // If k_left == 0, we are done. -BEQ(DPOSTACCUM) // else, we enter the k_left loop. -" \n\t" -LABEL(DLOOPKLEFT) -" \n\t" -" ldr q0, [x0],#16 \n\t" -" ldr q1, [x0],#16 \n\t" // Load a -" ldr q2, [x0],#16 \n\t" -" \n\t" -" ldr q3, [x1],#16 \n\t" // Load b -" ldr q4, [x1],#16 \n\t" -" ldr q5, [x1],#16 \n\t" -" ldr q6, [x1],#16 \n\t" -" \n\t" -" sub x6,x6,1 \n\t" -" \n\t" -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" \n\t" -" fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate -" \n\t" -" cmp x6,0 \n\t" // Iterate again. -BNE(DLOOPKLEFT) // if i!=0. -" \n\t" -LABEL(DPOSTACCUM) -" \n\t" -" ldr x0,%[alpha] \n\t" // Alpha address -" ldr x1,%[beta] \n\t" // Beta address -" \n\t" -" ld1r {v6.2d},[x0] \n\t" // Load alpha. -" ld1r {v7.2d},[x1] \n\t" // Load beta -" \n\t" -" ldr x0,%[a_next] \n\t" // Next A address for later use. -" ldr x1,%[b_next] \n\t" // Next B address for later use. -" \n\t" -" cmp x14,#8 \n\t" // If rs_c != 1 (column-major) -BNE(DGENSTORED) -" \n\t" -LABEL(DCOLSTORED) // C is column-major. -" \n\t" -" dup v0.2d, xzr \n\t" -" dup v1.2d, xzr \n\t" -" dup v2.2d, xzr \n\t" -" dup v3.2d, xzr \n\t" -" dup v4.2d, xzr \n\t" -" dup v5.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -BEQ(DBETAZEROCOLSTOREDS1) // Taking care of the beta==0 case. -" \n\t" -" ldr q0, [x2] \n\t" //Load column 0 of C -" ldr q1, [x2, #16] \n\t" -" ldr q2, [x2, #32] \n\t" -" \n\t" -" ldr q3, [x20] \n\t" //Load column 1 of C -" ldr q4, [x20, #16] \n\t" -" ldr q5, [x20, #32] \n\t" -" \n\t" -" fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta -" fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta -" fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta -" fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta -" fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta -" fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -LABEL(DBETAZEROCOLSTOREDS1) -" \n\t" -" fmla v0.2d,v8.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v1.2d,v9.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v2.2d,v10.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v3.2d,v11.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v4.2d,v12.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v5.2d,v13.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" str q0, [x2] \n\t" //Store column 0 of C -" str q1, [x2, #16] \n\t" -" str q2, [x2, #32] \n\t" -" \n\t" -" str q3, [x20] \n\t" //Store column 1 of C -" str q4, [x20, #16] \n\t" -" str q5, [x20, #32] \n\t" -" \n\t" -" dup v8.2d, xzr \n\t" -" dup v9.2d, xzr \n\t" -" dup v10.2d, xzr \n\t" -" dup v11.2d, xzr \n\t" -" dup v12.2d, xzr \n\t" -" dup v13.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -BEQ(DBETAZEROCOLSTOREDS2) // Taking care of the beta==0 case. -" \n\t" -" ldr q8, [x21] \n\t" //Load column 2 of C -" ldr q9, [x21, #16] \n\t" -" ldr q10, [x21, #32] \n\t" -" \n\t" -" ldr q11, [x22] \n\t" //Load column 3 of C -" ldr q12, [x22, #16] \n\t" -" ldr q13, [x22, #32] \n\t" -" \n\t" -" fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta -" fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta -" fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta -" fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta -" fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta -" fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -LABEL(DBETAZEROCOLSTOREDS2) -" \n\t" -" fmla v8.2d, v14.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v9.2d, v15.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v10.2d,v16.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v11.2d,v17.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v12.2d,v18.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v13.2d,v19.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" str q8, [x21] \n\t" //Store column 2 of C -" str q9, [x21, #16] \n\t" -" str q10, [x21, #32] \n\t" -" \n\t" -" str q11, [x22] \n\t" //Store column 3 of C -" str q12, [x22, #16] \n\t" -" str q13, [x22, #32] \n\t" -" \n\t" -" dup v0.2d, xzr \n\t" -" dup v1.2d, xzr \n\t" -" dup v2.2d, xzr \n\t" -" dup v3.2d, xzr \n\t" -" dup v4.2d, xzr \n\t" -" dup v5.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -BEQ(DBETAZEROCOLSTOREDS3) // Taking care of the beta==0 case. -" \n\t" -" ldr q0, [x23] \n\t" //Load column 4 of C -" ldr q1, [x23, #16] \n\t" -" ldr q2, [x23, #32] \n\t" -" \n\t" -" ldr q3, [x24] \n\t" //Load column 5 of C -" ldr q4, [x24, #16] \n\t" -" ldr q5, [x24, #32] \n\t" -" \n\t" -" fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta -" fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta -" fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta -" fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta -" fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta -" fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -LABEL(DBETAZEROCOLSTOREDS3) -" \n\t" -" fmla v0.2d,v20.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v1.2d,v21.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v2.2d,v22.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v3.2d,v23.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v4.2d,v24.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v5.2d,v25.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" str q0, [x23] \n\t" //Store column 4 of C -" str q1, [x23, #16] \n\t" -" str q2, [x23, #32] \n\t" -" \n\t" -" str q3, [x24] \n\t" //Store column 5 of C -" str q4, [x24, #16] \n\t" -" str q5, [x24, #32] \n\t" -" \n\t" -" dup v8.2d, xzr \n\t" -" dup v9.2d, xzr \n\t" -" dup v10.2d, xzr \n\t" -" dup v11.2d, xzr \n\t" -" dup v12.2d, xzr \n\t" -" dup v13.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -BEQ(DBETAZEROCOLSTOREDS4) // Taking care of the beta==0 case. -" \n\t" -" ldr q8, [x25] \n\t" //Load column 6 of C -" ldr q9, [x25, #16] \n\t" -" ldr q10, [x25, #32] \n\t" -" \n\t" -" ldr q11, [x26] \n\t" //Load column 7 of C -" ldr q12, [x26, #16] \n\t" -" ldr q13, [x26, #32] \n\t" -" \n\t" -" fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta -" fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta -" fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta -" fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta -" fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta -" fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -LABEL(DBETAZEROCOLSTOREDS4) -" \n\t" -" prfm pldl2keep,[x0] \n\t" -" prfm pldl2keep,[x1] \n\t" -" \n\t" -" fmla v8.2d, v26.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v9.2d, v27.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v10.2d,v28.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v11.2d,v29.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v12.2d,v30.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v13.2d,v31.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" str q8, [x25] \n\t" //Store column 6 of C -" str q9, [x25, #16] \n\t" -" str q10, [x25, #32] \n\t" -" \n\t" -" str q11, [x26] \n\t" //Store column 7 of C -" str q12, [x26, #16] \n\t" -" str q13, [x26, #32] \n\t" -" \n\t" -BRANCH(DEND) -" \n\t" -LABEL(DGENSTORED) // C is general-stride stored. -" \n\t" -" dup v0.2d, xzr \n\t" -" dup v1.2d, xzr \n\t" -" dup v2.2d, xzr \n\t" -" dup v3.2d, xzr \n\t" -" dup v4.2d, xzr \n\t" -" dup v5.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -BEQ(DBETAZEROGENSTOREDS1) // Taking care of the beta==0 case. -" \n\t" -" mov x27, x2 \n\t" -" \n\t" // Load address of C. -" ld1 {v0.d}[0],[x27],x14 \n\t" // Load c00 into quad and increment by rs_c. -" ld1 {v0.d}[1],[x27],x14 \n\t" // Load c01 into quad and increment by rs_c. -" ld1 {v1.d}[0],[x27],x14 \n\t" // Load c02 into quad and increment by rs_c. -" ld1 {v1.d}[1],[x27],x14 \n\t" // Load c03 into quad and increment by rs_c. -" ld1 {v2.d}[0],[x27],x14 \n\t" // Load c04 into quad and increment by rs_c. -" ld1 {v2.d}[1],[x27],x14 \n\t" // Load c05 into quad and increment by rs_c. -" \n\t" -" mov x27, x20 \n\t" // Load address of C. -" \n\t" -" ld1 {v3.d}[0],[x27],x14 \n\t" // Load c10 into quad and increment by rs_c. -" ld1 {v3.d}[1],[x27],x14 \n\t" // Load c11 into quad and increment by rs_c. -" ld1 {v4.d}[0],[x27],x14 \n\t" // Load c12 into quad and increment by rs_c. -" ld1 {v4.d}[1],[x27],x14 \n\t" // Load c13 into quad and increment by rs_c. -" ld1 {v5.d}[0],[x27],x14 \n\t" // Load c14 into quad and increment by rs_c. -" ld1 {v5.d}[1],[x27],x14 \n\t" // Load c15 into quad and increment by rs_c. -" \n\t" -" fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta -" fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta -" fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta -" fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta -" fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta -" fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -LABEL(DBETAZEROGENSTOREDS1) -" \n\t" -" fmla v0.2d,v8.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v1.2d,v9.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v2.2d,v10.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v3.2d,v11.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v4.2d,v12.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v5.2d,v13.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" mov x27, x2 \n\t" // Load address of C. -" \n\t" -" st1 {v0.d}[0],[x27],x14 \n\t" // Store c00 into quad and increment by rs_c. -" st1 {v0.d}[1],[x27],x14 \n\t" // Store c01 into quad and increment by rs_c. -" st1 {v1.d}[0],[x27],x14 \n\t" // Store c02 into quad and increment by rs_c. -" st1 {v1.d}[1],[x27],x14 \n\t" // Store c03 into quad and increment by rs_c. -" st1 {v2.d}[0],[x27],x14 \n\t" // Store c04 into quad and increment by rs_c. -" st1 {v2.d}[1],[x27],x14 \n\t" // Store c05 into quad and increment by rs_c. -" \n\t" -" mov x27, x20 \n\t" // Load address of C. -" \n\t" -" st1 {v3.d}[0],[x27],x14 \n\t" // Store c10 into quad and increment by rs_c. -" st1 {v3.d}[1],[x27],x14 \n\t" // Store c11 into quad and increment by rs_c. -" st1 {v4.d}[0],[x27],x14 \n\t" // Store c12 into quad and increment by rs_c. -" st1 {v4.d}[1],[x27],x14 \n\t" // Store c13 into quad and increment by rs_c. -" st1 {v5.d}[0],[x27],x14 \n\t" // Store c14 into quad and increment by rs_c. -" st1 {v5.d}[1],[x27],x14 \n\t" // Store c15 into quad and increment by rs_c. -" \n\t" -" dup v8.2d, xzr \n\t" -" dup v9.2d, xzr \n\t" -" dup v10.2d, xzr \n\t" -" dup v11.2d, xzr \n\t" -" dup v12.2d, xzr \n\t" -" dup v13.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -BEQ(DBETAZEROGENSTOREDS2) // Taking care of the beta==0 case. -" \n\t" -" mov x27, x21 \n\t" // Load address of C. -" \n\t" -" ld1 {v8.d}[0], [x27],x14 \n\t" // Load c20 into quad and increment by rs_c. -" ld1 {v8.d}[1], [x27],x14 \n\t" // Load c21 into quad and increment by rs_c. -" ld1 {v9.d}[0], [x27],x14 \n\t" // Load c22 into quad and increment by rs_c. -" ld1 {v9.d}[1], [x27],x14 \n\t" // Load c23 into quad and increment by rs_c. -" ld1 {v10.d}[0],[x27],x14 \n\t" // Load c24 into quad and increment by rs_c. -" ld1 {v10.d}[1],[x27],x14 \n\t" // Load c25 into quad and increment by rs_c. -" \n\t" -" mov x27, x22 \n\t" // Load address of C. -" \n\t" -" ld1 {v11.d}[0],[x27],x14 \n\t" // Load c30 into quad and increment by rs_c. -" ld1 {v11.d}[1],[x27],x14 \n\t" // Load c31 into quad and increment by rs_c. -" ld1 {v12.d}[0],[x27],x14 \n\t" // Load c32 into quad and increment by rs_c. -" ld1 {v12.d}[1],[x27],x14 \n\t" // Load c33 into quad and increment by rs_c. -" ld1 {v13.d}[0],[x27],x14 \n\t" // Load c34 into quad and increment by rs_c. -" ld1 {v13.d}[1],[x27],x14 \n\t" // Load c35 into quad and increment by rs_c. -" \n\t" -" fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta -" fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta -" fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta -" fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta -" fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta -" fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -LABEL(DBETAZEROGENSTOREDS2) -" \n\t" -" fmla v8.2d, v14.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v9.2d, v15.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v10.2d,v16.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v11.2d,v17.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v12.2d,v18.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v13.2d,v19.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" mov x27, x21 \n\t" // Load address of C. -" \n\t" -" st1 {v8.d}[0], [x27],x14 \n\t" // Store c20 into quad and increment by rs_c. -" st1 {v8.d}[1], [x27],x14 \n\t" // Store c21 into quad and increment by rs_c. -" st1 {v9.d}[0], [x27],x14 \n\t" // Store c22 into quad and increment by rs_c. -" st1 {v9.d}[1], [x27],x14 \n\t" // Store c23 into quad and increment by rs_c. -" st1 {v10.d}[0],[x27],x14 \n\t" // Store c24 into quad and increment by rs_c. -" st1 {v10.d}[1],[x27],x14 \n\t" // Store c25 into quad and increment by rs_c. -" \n\t" -" mov x27, x22 \n\t" // Load address of C. -" \n\t" -" st1 {v11.d}[0],[x27],x14 \n\t" // Store c30 into quad and increment by rs_c. -" st1 {v11.d}[1],[x27],x14 \n\t" // Store c31 into quad and increment by rs_c. -" st1 {v12.d}[0],[x27],x14 \n\t" // Store c32 into quad and increment by rs_c. -" st1 {v12.d}[1],[x27],x14 \n\t" // Store c33 into quad and increment by rs_c. -" st1 {v13.d}[0],[x27],x14 \n\t" // Store c34 into quad and increment by rs_c. -" st1 {v13.d}[1],[x27],x14 \n\t" // Store c35 into quad and increment by rs_c. -" \n\t" -" dup v0.2d, xzr \n\t" -" dup v1.2d, xzr \n\t" -" dup v2.2d, xzr \n\t" -" dup v3.2d, xzr \n\t" -" dup v4.2d, xzr \n\t" -" dup v5.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -BEQ(DBETAZEROGENSTOREDS3) // Taking care of the beta==0 case. -" \n\t" -" mov x27, x23 \n\t" // Load address of C. -" \n\t" -" ld1 {v0.d}[0],[x27],x14 \n\t" // Load c40 into quad and increment by rs_c. -" ld1 {v0.d}[1],[x27],x14 \n\t" // Load c41 into quad and increment by rs_c. -" ld1 {v1.d}[0],[x27],x14 \n\t" // Load c42 into quad and increment by rs_c. -" ld1 {v1.d}[1],[x27],x14 \n\t" // Load c43 into quad and increment by rs_c. -" ld1 {v2.d}[0],[x27],x14 \n\t" // Load c44 into quad and increment by rs_c. -" ld1 {v2.d}[1],[x27],x14 \n\t" // Load c45 into quad and increment by rs_c. -" \n\t" -" mov x27, x24 \n\t" // Load address of C. -" \n\t" -" ld1 {v3.d}[0],[x27],x14 \n\t" // Load c50 into quad and increment by rs_c. -" ld1 {v3.d}[1],[x27],x14 \n\t" // Load c51 into quad and increment by rs_c. -" ld1 {v4.d}[0],[x27],x14 \n\t" // Load c52 into quad and increment by rs_c. -" ld1 {v4.d}[1],[x27],x14 \n\t" // Load c53 into quad and increment by rs_c. -" ld1 {v5.d}[0],[x27],x14 \n\t" // Load c54 into quad and increment by rs_c. -" ld1 {v5.d}[1],[x27],x14 \n\t" // Load c55 into quad and increment by rs_c. -" \n\t" -" fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta -" fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta -" fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta -" fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta -" fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta -" fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -LABEL(DBETAZEROGENSTOREDS3) -" \n\t" -" fmla v0.2d,v20.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v1.2d,v21.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v2.2d,v22.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v3.2d,v23.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v4.2d,v24.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v5.2d,v25.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" mov x27, x23 \n\t" // Load address of C. -" \n\t" -" st1 {v0.d}[0],[x27],x14 \n\t" // Store c40 into quad and increment by rs_c. -" st1 {v0.d}[1],[x27],x14 \n\t" // Store c41 into quad and increment by rs_c. -" st1 {v1.d}[0],[x27],x14 \n\t" // Store c42 into quad and increment by rs_c. -" st1 {v1.d}[1],[x27],x14 \n\t" // Store c43 into quad and increment by rs_c. -" st1 {v2.d}[0],[x27],x14 \n\t" // Store c44 into quad and increment by rs_c. -" st1 {v2.d}[1],[x27],x14 \n\t" // Store c45 into quad and increment by rs_c. -" \n\t" -" mov x27, x24 \n\t" // Load address of C. -" \n\t" -" st1 {v3.d}[0],[x27],x14 \n\t" // Store c50 into quad and increment by rs_c. -" st1 {v3.d}[1],[x27],x14 \n\t" // Store c51 into quad and increment by rs_c. -" st1 {v4.d}[0],[x27],x14 \n\t" // Store c52 into quad and increment by rs_c. -" st1 {v4.d}[1],[x27],x14 \n\t" // Store c53 into quad and increment by rs_c. -" st1 {v5.d}[0],[x27],x14 \n\t" // Store c54 into quad and increment by rs_c. -" st1 {v5.d}[1],[x27],x14 \n\t" // Store c55 into quad and increment by rs_c. -" \n\t" -" dup v8.2d, xzr \n\t" -" dup v9.2d, xzr \n\t" -" dup v10.2d, xzr \n\t" -" dup v11.2d, xzr \n\t" -" dup v12.2d, xzr \n\t" -" dup v13.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -BEQ(DBETAZEROGENSTOREDS4) // Taking care of the beta==0 case. -" \n\t" -" mov x27, x25 \n\t" -" \n\t" -" ld1 {v8.d}[0], [x27],x14 \n\t" // Load c60 into quad and increment by rs_c. -" ld1 {v8.d}[1], [x27],x14 \n\t" // Load c61 into quad and increment by rs_c. -" ld1 {v9.d}[0], [x27],x14 \n\t" // Load c62 into quad and increment by rs_c. -" ld1 {v9.d}[1], [x27],x14 \n\t" // Load c63 into quad and increment by rs_c. -" ld1 {v10.d}[0],[x27],x14 \n\t" // Load c64 into quad and increment by rs_c. -" ld1 {v10.d}[1],[x27],x14 \n\t" // Load c65 into quad and increment by rs_c. -" \n\t" -" mov x27, x26 \n\t" // Load address of C. -" \n\t" -" ld1 {v11.d}[0],[x27],x14 \n\t" // Load c70 into quad and increment by rs_c. -" ld1 {v11.d}[1],[x27],x14 \n\t" // Load c71 into quad and increment by rs_c. -" ld1 {v12.d}[0],[x27],x14 \n\t" // Load c72 into quad and increment by rs_c. -" ld1 {v12.d}[1],[x27],x14 \n\t" // Load c73 into quad and increment by rs_c. -" ld1 {v13.d}[0],[x27],x14 \n\t" // Load c74 into quad and increment by rs_c. -" ld1 {v13.d}[1],[x27],x14 \n\t" // Load c75 into quad and increment by rs_c. -" \n\t" -" fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta -" fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta -" fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta -" fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta -" fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta -" fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -LABEL(DBETAZEROGENSTOREDS4) -" \n\t" -" prfm pldl2keep,[x0] \n\t" -" prfm pldl2keep,[x1] \n\t" -" \n\t" -" fmla v8.2d, v26.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v9.2d, v27.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v10.2d,v28.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v11.2d,v29.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v12.2d,v30.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v13.2d,v31.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" mov x27, x25 \n\t" // Load address of C. -" \n\t" -" st1 {v8.d}[0], [x27],x14 \n\t" // Store c60 into quad and increment by rs_c. -" st1 {v8.d}[1], [x27],x14 \n\t" // Store c61 into quad and increment by rs_c. -" st1 {v9.d}[0], [x27],x14 \n\t" // Store c62 into quad and increment by rs_c. -" st1 {v9.d}[1], [x27],x14 \n\t" // Store c63 into quad and increment by rs_c. -" st1 {v10.d}[0],[x27],x14 \n\t" // Store c64 into quad and increment by rs_c. -" st1 {v10.d}[1],[x27],x14 \n\t" // Store c65 into quad and increment by rs_c. -" \n\t" -" mov x27, x26 \n\t" // Load address of C. -" \n\t" -" st1 {v11.d}[0],[x27],x14 \n\t" // Store c70 into quad and increment by rs_c. -" st1 {v11.d}[1],[x27],x14 \n\t" // Store c71 into quad and increment by rs_c. -" st1 {v12.d}[0],[x27],x14 \n\t" // Store c72 into quad and increment by rs_c. -" st1 {v12.d}[1],[x27],x14 \n\t" // Store c73 into quad and increment by rs_c. -" st1 {v13.d}[0],[x27],x14 \n\t" // Store c74 into quad and increment by rs_c. -" st1 {v13.d}[1],[x27],x14 \n\t" // Store c75 into quad and increment by rs_c. -" \n\t" -LABEL(DEND) // Done! -" \n\t" -:// output operands (none) -:// input operands - [aaddr] "m" (a), // 0 - [baddr] "m" (b), // 1 - [caddr] "m" (c), // 2 - [k_iter] "m" (k_iter), // 3 - [k_left] "m" (k_left), // 4 - [alpha] "m" (alpha), // 5 - [beta] "m" (beta), // 6 - [rs_c] "m" (rs_c), // 6 - [cs_c] "m" (cs_c), // 7 - [a_next] "m" (a_next), // 8 - [b_next] "m" (b_next) // 9 -:// Register clobber list - "x0","x1","x2", - "x5","x6","x10", - "x14","x16","x17", - "x20","x21","x22","x23","x24","x25","x26","x27", - "v0","v1","v2", - "v3","v4","v5", - "v6","v7","v8", - "v9","v10","v11", - "v12","v13","v14", - "v15","v16","v17","v18","v19", - "v20","v21","v22","v23", - "v24","v25","v26","v27", - "v28","v29","v30","v31" -); - + GEMM_UKR_SETUP_CT( d, 6, 8, false ); + __asm__ volatile + ( + " \n\t" + " ldr x0,%[aaddr] \n\t" // Load address of A + " ldr x1,%[baddr] \n\t" // Load address of B + " ldr x2,%[caddr] \n\t" // Load address of C + " \n\t" + " ldr x5,%[k_iter] \n\t" // Init guard (k_iter) + " ldr x6,%[k_left] \n\t" // Init guard (k_iter) + " \n\t" + " ldr x10,%[cs_c] \n\t" // Load cs_c + " lsl x10,x10,#3 \n\t" // cs_c * sizeof(double) + " \n\t" + " ldr x14,%[rs_c] \n\t" // Load rs_c. + " lsl x14,x14,#3 \n\t" // rs_c * sizeof(double). + " \n\t" + " add x20,x2,x10 \n\t" //Load address Column 1 of C + " add x21,x20,x10 \n\t" //Load address Column 2 of C + " add x22,x21,x10 \n\t" //Load address Column 3 of C + " add x23,x22,x10 \n\t" //Load address Column 4 of C + " add x24,x23,x10 \n\t" //Load address Column 5 of C + " add x25,x24,x10 \n\t" //Load address Column 6 of C + " add x26,x25,x10 \n\t" //Load address Column 7 of C + " \n\t" + " prfm pldl1keep,[x2] \n\t" // Prefetch c. + " prfm pldl1keep,[x20] \n\t" // Prefetch c. + " prfm pldl1keep,[x21] \n\t" // Prefetch c. + " prfm pldl1keep,[x22] \n\t" // Prefetch c. + " prfm pldl1keep,[x23] \n\t" // Prefetch c. + " prfm pldl1keep,[x24] \n\t" // Prefetch c. + " prfm pldl1keep,[x25] \n\t" // Prefetch c. + " prfm pldl1keep,[x26] \n\t" // Prefetch c. + " \n\t" + " dup v8.2d, xzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #256] \n\t" + " dup v9.2d, xzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #320] \n\t" + " dup v10.2d, xzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #384] \n\t" + " dup v11.2d, xzr \n\t" // Vector for accummulating column 1 + " prfm PLDL1KEEP, [x1, #448] \n\t" + " dup v12.2d, xzr \n\t" // Vector for accummulating column 1 + " dup v13.2d, xzr \n\t" // Vector for accummulating column 1 + " \n\t" + " dup v14.2d, xzr \n\t" // Vector for accummulating column 2 + " prfm PLDL1KEEP, [x0, #192] \n\t" + " dup v15.2d, xzr \n\t" // Vector for accummulating column 2 + " prfm PLDL1KEEP, [x0, #256] \n\t" + " dup v16.2d, xzr \n\t" // Vector for accummulating column 2 + " prfm PLDL1KEEP, [x0, #320] \n\t" + " dup v17.2d, xzr \n\t" // Vector for accummulating column 3 + " dup v18.2d, xzr \n\t" // Vector for accummulating column 3 + " dup v19.2d, xzr \n\t" // Vector for accummulating column 3 + " \n\t" + " dup v20.2d, xzr \n\t" // Vector for accummulating column 4 + " dup v21.2d, xzr \n\t" // Vector for accummulating column 4 + " dup v22.2d, xzr \n\t" // Vector for accummulating column 4 + " dup v23.2d, xzr \n\t" // Vector for accummulating column 5 + " dup v24.2d, xzr \n\t" // Vector for accummulating column 5 + " dup v25.2d, xzr \n\t" // Vector for accummulating column 5 + " \n\t" + " dup v26.2d, xzr \n\t" // Vector for accummulating column 6 + " dup v27.2d, xzr \n\t" // Vector for accummulating column 6 + " dup v28.2d, xzr \n\t" // Vector for accummulating column 6 + " dup v29.2d, xzr \n\t" // Vector for accummulating column 7 + " dup v30.2d, xzr \n\t" // Vector for accummulating column 7 + " dup v31.2d, xzr \n\t" // Vector for accummulating column 7 + " \n\t" + " \n\t" + " cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. + BEQ(DCONSIDERKLEFT) + " \n\t" + " ldr q0, [x0] \n\t" // Load a + " ldr q1, [x0, #16] \n\t" + " ldr q2, [x0, #32] \n\t" + " \n\t" + " ldr q3, [x1] \n\t" // Load b + " ldr q4, [x1, #16] \n\t" + " ldr q5, [x1, #32] \n\t" + " ldr q6, [x1, #48] \n\t" + " \n\t" + " add x0, x0, #48 \n\t" //update address of A + " add x1, x1, #64 \n\t" //update address of B + " \n\t" + " cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. + BEQ(DLASTITER) // (as loop is do-while-like). + " \n\t" + LABEL(DLOOP) // Body + " \n\t" + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x1, #448] \n\t" //512-64=448 + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x1, #512] \n\t" + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x1, #576] \n\t" + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " ldr q3, [x1] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " ldr q7, [x0, #32] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " ldr q4, [x1, #16] \n\t" + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #32] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #16] \n\t" + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #48] \n\t" + " \n\t" // End it 1 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x1, #640] \n\t" + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x0, #336] \n\t" + " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x0, #400] \n\t" + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate + " ldr q3, [x1, #64] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate + " ldr q2, [x0, #80] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate + " ldr q4, [x1, #80] \n\t" + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #96] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #48] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #64] \n\t" + " \n\t" + " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #112] \n\t" + " \n\t" //End it 2 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x0, #464] \n\t" + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " ldr q3, [x1, #128] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " ldr q7, [x0, #128] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " ldr q4, [x1, #144] \n\t" + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #160] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #96] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #112] \n\t" + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #176] \n\t" + " \n\t" // End it 3 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate + " ldr q3, [x1, #192] \n\t" + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate + " ldr q2, [x0, #176] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate + " ldr q4, [x1, #208] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #224] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #144] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #160] \n\t" + " \n\t" + " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #240] \n\t" + " \n\t" //End it 4 + " add x0, x0, #192 \n\t" + " add x1, x1, #256 \n\t" + " \n\t" + " sub x5,x5,1 \n\t" // i-=1 + " cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. + BNE(DLOOP) + " \n\t" + LABEL(DLASTITER) + " \n\t" + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " ldr q3, [x1] \n\t" + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " ldr q7, [x0, #32] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " ldr q4, [x1, #16] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #32] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #16] \n\t" + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #48] \n\t" + " \n\t" // End it 1 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate + " ldr q3, [x1, #64] \n\t" + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate + " ldr q2, [x0, #80] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate + " ldr q4, [x1, #80] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #96] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #48] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #64] \n\t" + " \n\t" + " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #112] \n\t" + " \n\t" //End it 2 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " ldr q3, [x1, #128] \n\t" + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " ldr q7, [x0, #128] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " ldr q4, [x1, #144] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #160] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #96] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #112] \n\t" + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #176] \n\t" + " \n\t" // End it 3 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " add x1, x1, #192 \n\t" + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate + " \n\t" //End it 4 + " add x0, x0, #144 \n\t" + " \n\t" + LABEL(DCONSIDERKLEFT) + " cmp x6,0 \n\t" // If k_left == 0, we are done. + BEQ(DPOSTACCUM) // else, we enter the k_left loop. + " \n\t" + LABEL(DLOOPKLEFT) + " \n\t" + " ldr q0, [x0],#16 \n\t" + " ldr q1, [x0],#16 \n\t" // Load a + " ldr q2, [x0],#16 \n\t" + " \n\t" + " ldr q3, [x1],#16 \n\t" // Load b + " ldr q4, [x1],#16 \n\t" + " ldr q5, [x1],#16 \n\t" + " ldr q6, [x1],#16 \n\t" + " \n\t" + " sub x6,x6,1 \n\t" + " \n\t" + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " cmp x6,0 \n\t" // Iterate again. + BNE(DLOOPKLEFT) // if i!=0. + " \n\t" + LABEL(DPOSTACCUM) + " \n\t" + " ldr x0,%[alpha] \n\t" // Alpha address + " ldr x1,%[beta] \n\t" // Beta address + " \n\t" + " ld1r {v6.2d},[x0] \n\t" // Load alpha. + " ld1r {v7.2d},[x1] \n\t" // Load beta + " \n\t" + " ldr x0,%[a_next] \n\t" // Next A address for later use. + " ldr x1,%[b_next] \n\t" // Next B address for later use. + " \n\t" + " cmp x14,#8 \n\t" // If rs_c != 1 (column-major) + BNE(DGENSTORED) + " \n\t" + LABEL(DCOLSTORED) // C is column-major. + " \n\t" + " dup v0.2d, xzr \n\t" + " dup v1.2d, xzr \n\t" + " dup v2.2d, xzr \n\t" + " dup v3.2d, xzr \n\t" + " dup v4.2d, xzr \n\t" + " dup v5.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROCOLSTOREDS1) // Taking care of the beta==0 case. + " \n\t" + " ldr q0, [x2] \n\t" //Load column 0 of C + " ldr q1, [x2, #16] \n\t" + " ldr q2, [x2, #32] \n\t" + " \n\t" + " ldr q3, [x20] \n\t" //Load column 1 of C + " ldr q4, [x20, #16] \n\t" + " ldr q5, [x20, #32] \n\t" + " \n\t" + " fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta + " fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta + " fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta + " fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta + " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta + " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROCOLSTOREDS1) + " \n\t" + " fmla v0.2d,v8.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v1.2d,v9.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v2.2d,v10.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v3.2d,v11.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v4.2d,v12.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v5.2d,v13.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " str q0, [x2] \n\t" //Store column 0 of C + " str q1, [x2, #16] \n\t" + " str q2, [x2, #32] \n\t" + " \n\t" + " str q3, [x20] \n\t" //Store column 1 of C + " str q4, [x20, #16] \n\t" + " str q5, [x20, #32] \n\t" + " \n\t" + " dup v8.2d, xzr \n\t" + " dup v9.2d, xzr \n\t" + " dup v10.2d, xzr \n\t" + " dup v11.2d, xzr \n\t" + " dup v12.2d, xzr \n\t" + " dup v13.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROCOLSTOREDS2) // Taking care of the beta==0 case. + " \n\t" + " ldr q8, [x21] \n\t" //Load column 2 of C + " ldr q9, [x21, #16] \n\t" + " ldr q10, [x21, #32] \n\t" + " \n\t" + " ldr q11, [x22] \n\t" //Load column 3 of C + " ldr q12, [x22, #16] \n\t" + " ldr q13, [x22, #32] \n\t" + " \n\t" + " fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta + " fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta + " fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta + " fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta + " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta + " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROCOLSTOREDS2) + " \n\t" + " fmla v8.2d, v14.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v9.2d, v15.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v10.2d,v16.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v11.2d,v17.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v12.2d,v18.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v13.2d,v19.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " str q8, [x21] \n\t" //Store column 2 of C + " str q9, [x21, #16] \n\t" + " str q10, [x21, #32] \n\t" + " \n\t" + " str q11, [x22] \n\t" //Store column 3 of C + " str q12, [x22, #16] \n\t" + " str q13, [x22, #32] \n\t" + " \n\t" + " dup v0.2d, xzr \n\t" + " dup v1.2d, xzr \n\t" + " dup v2.2d, xzr \n\t" + " dup v3.2d, xzr \n\t" + " dup v4.2d, xzr \n\t" + " dup v5.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROCOLSTOREDS3) // Taking care of the beta==0 case. + " \n\t" + " ldr q0, [x23] \n\t" //Load column 4 of C + " ldr q1, [x23, #16] \n\t" + " ldr q2, [x23, #32] \n\t" + " \n\t" + " ldr q3, [x24] \n\t" //Load column 5 of C + " ldr q4, [x24, #16] \n\t" + " ldr q5, [x24, #32] \n\t" + " \n\t" + " fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta + " fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta + " fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta + " fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta + " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta + " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROCOLSTOREDS3) + " \n\t" + " fmla v0.2d,v20.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v1.2d,v21.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v2.2d,v22.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v3.2d,v23.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v4.2d,v24.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v5.2d,v25.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " str q0, [x23] \n\t" //Store column 4 of C + " str q1, [x23, #16] \n\t" + " str q2, [x23, #32] \n\t" + " \n\t" + " str q3, [x24] \n\t" //Store column 5 of C + " str q4, [x24, #16] \n\t" + " str q5, [x24, #32] \n\t" + " \n\t" + " dup v8.2d, xzr \n\t" + " dup v9.2d, xzr \n\t" + " dup v10.2d, xzr \n\t" + " dup v11.2d, xzr \n\t" + " dup v12.2d, xzr \n\t" + " dup v13.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROCOLSTOREDS4) // Taking care of the beta==0 case. + " \n\t" + " ldr q8, [x25] \n\t" //Load column 6 of C + " ldr q9, [x25, #16] \n\t" + " ldr q10, [x25, #32] \n\t" + " \n\t" + " ldr q11, [x26] \n\t" //Load column 7 of C + " ldr q12, [x26, #16] \n\t" + " ldr q13, [x26, #32] \n\t" + " \n\t" + " fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta + " fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta + " fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta + " fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta + " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta + " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROCOLSTOREDS4) + " \n\t" + " prfm pldl2keep,[x0] \n\t" + " prfm pldl2keep,[x1] \n\t" + " \n\t" + " fmla v8.2d, v26.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v9.2d, v27.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v10.2d,v28.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v11.2d,v29.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v12.2d,v30.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v13.2d,v31.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " str q8, [x25] \n\t" //Store column 6 of C + " str q9, [x25, #16] \n\t" + " str q10, [x25, #32] \n\t" + " \n\t" + " str q11, [x26] \n\t" //Store column 7 of C + " str q12, [x26, #16] \n\t" + " str q13, [x26, #32] \n\t" + " \n\t" + BRANCH(DEND) + " \n\t" + LABEL(DGENSTORED) // C is general-stride stored. + " \n\t" + " dup v0.2d, xzr \n\t" + " dup v1.2d, xzr \n\t" + " dup v2.2d, xzr \n\t" + " dup v3.2d, xzr \n\t" + " dup v4.2d, xzr \n\t" + " dup v5.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROGENSTOREDS1) // Taking care of the beta==0 case. + " \n\t" + " mov x27, x2 \n\t" + " \n\t" // Load address of C. + " ld1 {v0.d}[0],[x27],x14 \n\t" // Load c00 into quad and increment by rs_c. + " ld1 {v0.d}[1],[x27],x14 \n\t" // Load c01 into quad and increment by rs_c. + " ld1 {v1.d}[0],[x27],x14 \n\t" // Load c02 into quad and increment by rs_c. + " ld1 {v1.d}[1],[x27],x14 \n\t" // Load c03 into quad and increment by rs_c. + " ld1 {v2.d}[0],[x27],x14 \n\t" // Load c04 into quad and increment by rs_c. + " ld1 {v2.d}[1],[x27],x14 \n\t" // Load c05 into quad and increment by rs_c. + " \n\t" + " mov x27, x20 \n\t" // Load address of C. + " \n\t" + " ld1 {v3.d}[0],[x27],x14 \n\t" // Load c10 into quad and increment by rs_c. + " ld1 {v3.d}[1],[x27],x14 \n\t" // Load c11 into quad and increment by rs_c. + " ld1 {v4.d}[0],[x27],x14 \n\t" // Load c12 into quad and increment by rs_c. + " ld1 {v4.d}[1],[x27],x14 \n\t" // Load c13 into quad and increment by rs_c. + " ld1 {v5.d}[0],[x27],x14 \n\t" // Load c14 into quad and increment by rs_c. + " ld1 {v5.d}[1],[x27],x14 \n\t" // Load c15 into quad and increment by rs_c. + " \n\t" + " fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta + " fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta + " fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta + " fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta + " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta + " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROGENSTOREDS1) + " \n\t" + " fmla v0.2d,v8.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v1.2d,v9.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v2.2d,v10.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v3.2d,v11.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v4.2d,v12.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v5.2d,v13.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " mov x27, x2 \n\t" // Load address of C. + " \n\t" + " st1 {v0.d}[0],[x27],x14 \n\t" // Store c00 into quad and increment by rs_c. + " st1 {v0.d}[1],[x27],x14 \n\t" // Store c01 into quad and increment by rs_c. + " st1 {v1.d}[0],[x27],x14 \n\t" // Store c02 into quad and increment by rs_c. + " st1 {v1.d}[1],[x27],x14 \n\t" // Store c03 into quad and increment by rs_c. + " st1 {v2.d}[0],[x27],x14 \n\t" // Store c04 into quad and increment by rs_c. + " st1 {v2.d}[1],[x27],x14 \n\t" // Store c05 into quad and increment by rs_c. + " \n\t" + " mov x27, x20 \n\t" // Load address of C. + " \n\t" + " st1 {v3.d}[0],[x27],x14 \n\t" // Store c10 into quad and increment by rs_c. + " st1 {v3.d}[1],[x27],x14 \n\t" // Store c11 into quad and increment by rs_c. + " st1 {v4.d}[0],[x27],x14 \n\t" // Store c12 into quad and increment by rs_c. + " st1 {v4.d}[1],[x27],x14 \n\t" // Store c13 into quad and increment by rs_c. + " st1 {v5.d}[0],[x27],x14 \n\t" // Store c14 into quad and increment by rs_c. + " st1 {v5.d}[1],[x27],x14 \n\t" // Store c15 into quad and increment by rs_c. + " \n\t" + " dup v8.2d, xzr \n\t" + " dup v9.2d, xzr \n\t" + " dup v10.2d, xzr \n\t" + " dup v11.2d, xzr \n\t" + " dup v12.2d, xzr \n\t" + " dup v13.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROGENSTOREDS2) // Taking care of the beta==0 case. + " \n\t" + " mov x27, x21 \n\t" // Load address of C. + " \n\t" + " ld1 {v8.d}[0], [x27],x14 \n\t" // Load c20 into quad and increment by rs_c. + " ld1 {v8.d}[1], [x27],x14 \n\t" // Load c21 into quad and increment by rs_c. + " ld1 {v9.d}[0], [x27],x14 \n\t" // Load c22 into quad and increment by rs_c. + " ld1 {v9.d}[1], [x27],x14 \n\t" // Load c23 into quad and increment by rs_c. + " ld1 {v10.d}[0],[x27],x14 \n\t" // Load c24 into quad and increment by rs_c. + " ld1 {v10.d}[1],[x27],x14 \n\t" // Load c25 into quad and increment by rs_c. + " \n\t" + " mov x27, x22 \n\t" // Load address of C. + " \n\t" + " ld1 {v11.d}[0],[x27],x14 \n\t" // Load c30 into quad and increment by rs_c. + " ld1 {v11.d}[1],[x27],x14 \n\t" // Load c31 into quad and increment by rs_c. + " ld1 {v12.d}[0],[x27],x14 \n\t" // Load c32 into quad and increment by rs_c. + " ld1 {v12.d}[1],[x27],x14 \n\t" // Load c33 into quad and increment by rs_c. + " ld1 {v13.d}[0],[x27],x14 \n\t" // Load c34 into quad and increment by rs_c. + " ld1 {v13.d}[1],[x27],x14 \n\t" // Load c35 into quad and increment by rs_c. + " \n\t" + " fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta + " fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta + " fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta + " fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta + " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta + " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROGENSTOREDS2) + " \n\t" + " fmla v8.2d, v14.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v9.2d, v15.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v10.2d,v16.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v11.2d,v17.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v12.2d,v18.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v13.2d,v19.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " mov x27, x21 \n\t" // Load address of C. + " \n\t" + " st1 {v8.d}[0], [x27],x14 \n\t" // Store c20 into quad and increment by rs_c. + " st1 {v8.d}[1], [x27],x14 \n\t" // Store c21 into quad and increment by rs_c. + " st1 {v9.d}[0], [x27],x14 \n\t" // Store c22 into quad and increment by rs_c. + " st1 {v9.d}[1], [x27],x14 \n\t" // Store c23 into quad and increment by rs_c. + " st1 {v10.d}[0],[x27],x14 \n\t" // Store c24 into quad and increment by rs_c. + " st1 {v10.d}[1],[x27],x14 \n\t" // Store c25 into quad and increment by rs_c. + " \n\t" + " mov x27, x22 \n\t" // Load address of C. + " \n\t" + " st1 {v11.d}[0],[x27],x14 \n\t" // Store c30 into quad and increment by rs_c. + " st1 {v11.d}[1],[x27],x14 \n\t" // Store c31 into quad and increment by rs_c. + " st1 {v12.d}[0],[x27],x14 \n\t" // Store c32 into quad and increment by rs_c. + " st1 {v12.d}[1],[x27],x14 \n\t" // Store c33 into quad and increment by rs_c. + " st1 {v13.d}[0],[x27],x14 \n\t" // Store c34 into quad and increment by rs_c. + " st1 {v13.d}[1],[x27],x14 \n\t" // Store c35 into quad and increment by rs_c. + " \n\t" + " dup v0.2d, xzr \n\t" + " dup v1.2d, xzr \n\t" + " dup v2.2d, xzr \n\t" + " dup v3.2d, xzr \n\t" + " dup v4.2d, xzr \n\t" + " dup v5.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROGENSTOREDS3) // Taking care of the beta==0 case. + " \n\t" + " mov x27, x23 \n\t" // Load address of C. + " \n\t" + " ld1 {v0.d}[0],[x27],x14 \n\t" // Load c40 into quad and increment by rs_c. + " ld1 {v0.d}[1],[x27],x14 \n\t" // Load c41 into quad and increment by rs_c. + " ld1 {v1.d}[0],[x27],x14 \n\t" // Load c42 into quad and increment by rs_c. + " ld1 {v1.d}[1],[x27],x14 \n\t" // Load c43 into quad and increment by rs_c. + " ld1 {v2.d}[0],[x27],x14 \n\t" // Load c44 into quad and increment by rs_c. + " ld1 {v2.d}[1],[x27],x14 \n\t" // Load c45 into quad and increment by rs_c. + " \n\t" + " mov x27, x24 \n\t" // Load address of C. + " \n\t" + " ld1 {v3.d}[0],[x27],x14 \n\t" // Load c50 into quad and increment by rs_c. + " ld1 {v3.d}[1],[x27],x14 \n\t" // Load c51 into quad and increment by rs_c. + " ld1 {v4.d}[0],[x27],x14 \n\t" // Load c52 into quad and increment by rs_c. + " ld1 {v4.d}[1],[x27],x14 \n\t" // Load c53 into quad and increment by rs_c. + " ld1 {v5.d}[0],[x27],x14 \n\t" // Load c54 into quad and increment by rs_c. + " ld1 {v5.d}[1],[x27],x14 \n\t" // Load c55 into quad and increment by rs_c. + " \n\t" + " fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta + " fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta + " fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta + " fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta + " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta + " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROGENSTOREDS3) + " \n\t" + " fmla v0.2d,v20.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v1.2d,v21.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v2.2d,v22.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v3.2d,v23.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v4.2d,v24.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v5.2d,v25.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " mov x27, x23 \n\t" // Load address of C. + " \n\t" + " st1 {v0.d}[0],[x27],x14 \n\t" // Store c40 into quad and increment by rs_c. + " st1 {v0.d}[1],[x27],x14 \n\t" // Store c41 into quad and increment by rs_c. + " st1 {v1.d}[0],[x27],x14 \n\t" // Store c42 into quad and increment by rs_c. + " st1 {v1.d}[1],[x27],x14 \n\t" // Store c43 into quad and increment by rs_c. + " st1 {v2.d}[0],[x27],x14 \n\t" // Store c44 into quad and increment by rs_c. + " st1 {v2.d}[1],[x27],x14 \n\t" // Store c45 into quad and increment by rs_c. + " \n\t" + " mov x27, x24 \n\t" // Load address of C. + " \n\t" + " st1 {v3.d}[0],[x27],x14 \n\t" // Store c50 into quad and increment by rs_c. + " st1 {v3.d}[1],[x27],x14 \n\t" // Store c51 into quad and increment by rs_c. + " st1 {v4.d}[0],[x27],x14 \n\t" // Store c52 into quad and increment by rs_c. + " st1 {v4.d}[1],[x27],x14 \n\t" // Store c53 into quad and increment by rs_c. + " st1 {v5.d}[0],[x27],x14 \n\t" // Store c54 into quad and increment by rs_c. + " st1 {v5.d}[1],[x27],x14 \n\t" // Store c55 into quad and increment by rs_c. + " \n\t" + " dup v8.2d, xzr \n\t" + " dup v9.2d, xzr \n\t" + " dup v10.2d, xzr \n\t" + " dup v11.2d, xzr \n\t" + " dup v12.2d, xzr \n\t" + " dup v13.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROGENSTOREDS4) // Taking care of the beta==0 case. + " \n\t" + " mov x27, x25 \n\t" + " \n\t" + " ld1 {v8.d}[0], [x27],x14 \n\t" // Load c60 into quad and increment by rs_c. + " ld1 {v8.d}[1], [x27],x14 \n\t" // Load c61 into quad and increment by rs_c. + " ld1 {v9.d}[0], [x27],x14 \n\t" // Load c62 into quad and increment by rs_c. + " ld1 {v9.d}[1], [x27],x14 \n\t" // Load c63 into quad and increment by rs_c. + " ld1 {v10.d}[0],[x27],x14 \n\t" // Load c64 into quad and increment by rs_c. + " ld1 {v10.d}[1],[x27],x14 \n\t" // Load c65 into quad and increment by rs_c. + " \n\t" + " mov x27, x26 \n\t" // Load address of C. + " \n\t" + " ld1 {v11.d}[0],[x27],x14 \n\t" // Load c70 into quad and increment by rs_c. + " ld1 {v11.d}[1],[x27],x14 \n\t" // Load c71 into quad and increment by rs_c. + " ld1 {v12.d}[0],[x27],x14 \n\t" // Load c72 into quad and increment by rs_c. + " ld1 {v12.d}[1],[x27],x14 \n\t" // Load c73 into quad and increment by rs_c. + " ld1 {v13.d}[0],[x27],x14 \n\t" // Load c74 into quad and increment by rs_c. + " ld1 {v13.d}[1],[x27],x14 \n\t" // Load c75 into quad and increment by rs_c. + " \n\t" + " fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta + " fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta + " fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta + " fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta + " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta + " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROGENSTOREDS4) + " \n\t" + " prfm pldl2keep,[x0] \n\t" + " prfm pldl2keep,[x1] \n\t" + " \n\t" + " fmla v8.2d, v26.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v9.2d, v27.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v10.2d,v28.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v11.2d,v29.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v12.2d,v30.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v13.2d,v31.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " mov x27, x25 \n\t" // Load address of C. + " \n\t" + " st1 {v8.d}[0], [x27],x14 \n\t" // Store c60 into quad and increment by rs_c. + " st1 {v8.d}[1], [x27],x14 \n\t" // Store c61 into quad and increment by rs_c. + " st1 {v9.d}[0], [x27],x14 \n\t" // Store c62 into quad and increment by rs_c. + " st1 {v9.d}[1], [x27],x14 \n\t" // Store c63 into quad and increment by rs_c. + " st1 {v10.d}[0],[x27],x14 \n\t" // Store c64 into quad and increment by rs_c. + " st1 {v10.d}[1],[x27],x14 \n\t" // Store c65 into quad and increment by rs_c. + " \n\t" + " mov x27, x26 \n\t" // Load address of C. + " \n\t" + " st1 {v11.d}[0],[x27],x14 \n\t" // Store c70 into quad and increment by rs_c. + " st1 {v11.d}[1],[x27],x14 \n\t" // Store c71 into quad and increment by rs_c. + " st1 {v12.d}[0],[x27],x14 \n\t" // Store c72 into quad and increment by rs_c. + " st1 {v12.d}[1],[x27],x14 \n\t" // Store c73 into quad and increment by rs_c. + " st1 {v13.d}[0],[x27],x14 \n\t" // Store c74 into quad and increment by rs_c. + " st1 {v13.d}[1],[x27],x14 \n\t" // Store c75 into quad and increment by rs_c. + " \n\t" + LABEL(DEND) // Done! + " \n\t" + :// output operands (none) + :// input operands + [aaddr] "m" (a), // 0 + [baddr] "m" (b), // 1 + [caddr] "m" (c), // 2 + [k_iter] "m" (k_iter), // 3 + [k_left] "m" (k_left), // 4 + [alpha] "m" (alpha), // 5 + [beta] "m" (beta), // 6 + [rs_c] "m" (rs_c), // 6 + [cs_c] "m" (cs_c), // 7 + [a_next] "m" (a_next), // 8 + [b_next] "m" (b_next) // 9 + :// Register clobber list + "x0","x1","x2", + "x5","x6","x10", + "x14","x16","x17", + "x20","x21","x22","x23","x24","x25","x26","x27", + "v0","v1","v2", + "v3","v4","v5", + "v6","v7","v8", + "v9","v10","v11", + "v12","v13","v14", + "v15","v16","v17","v18","v19", + "v20","v21","v22","v23", + "v24","v25","v26","v27", + "v28","v29","v30","v31" + ); + GEMM_UKR_FLUSH_CT( d ); } #if 0 void bli_cgemm_armv8a_opt_4x4 ( + dim_t m, + dim_t n, dim_t k, scomplex* restrict alpha, scomplex* restrict a, @@ -2095,6 +2105,8 @@ void bli_cgemm_armv8a_opt_4x4 void bli_zgemm_armv8a_opt_4x4 ( + dim_t m, + dim_t n, dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, diff --git a/kernels/bgq/3/bli_gemm_bgq_int_8x8.c b/kernels/bgq/3/bli_gemm_bgq_int_8x8.c index 1612e69b0d..15e3e072f3 100644 --- a/kernels/bgq/3/bli_gemm_bgq_int_8x8.c +++ b/kernels/bgq/3/bli_gemm_bgq_int_8x8.c @@ -56,6 +56,8 @@ void bli_dgemm_bgq_int_8x8 ( + dim_t m, + dim_t n, dim_t k, double* restrict alpha, double* restrict a, @@ -66,6 +68,8 @@ void bli_dgemm_bgq_int_8x8 cntx_t* restrict cntx ) { + GEMM_UKR_SETUP_CT_ANY( d, 8, 8, false ); + //Registers for storing C. //4 4x4 subblocks of C, c00, c01, c10, c11 //4 registers per subblock: a, b, c, d @@ -201,6 +205,8 @@ void bli_dgemm_bgq_int_8x8 UPDATE( AB, c, 0 ); AB = vec_perm( c11d, c11d, pattern ); UPDATE( AB, c, 4 ); + + GEMM_UKR_FLUSH_CT( d ); } void printvec(vector4double v) @@ -214,6 +220,8 @@ void printvec(vector4double v) void bli_zgemm_bgq_int_4x4 ( + dim_t m, + dim_t n, dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, @@ -224,6 +232,8 @@ void bli_zgemm_bgq_int_4x4 cntx_t* restrict cntx ) { + GEMM_UKR_SETUP_CT_ANY( z, 4, 4, false ); + double* a_d = ( double* )a; double* b_d = ( double* )b; double* c_d = ( double* )c; @@ -368,4 +378,6 @@ void bli_zgemm_bgq_int_4x4 c_d += 2*cs_c; ZUPDATE( c03a, c03b, c_d, 0 ); ZUPDATE( c13a, c13b, c_d, 4 ); + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c b/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c index 403aaaaeef..3a75d61d73 100644 --- a/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c +++ b/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c @@ -90,7 +90,9 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -102,25 +104,27 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ALIGNED( s, 8, 8, false, 32 ); + begin_asm() - + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. - + vmovaps(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovsldup(mem(rbx, 0*32), ymm2) // elements of a and b. vpermilps(imm(0x4e), ymm2, ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float) lea(mem(rcx, rdi, 4), r10) // load address of c + 4*cs_c; - + lea(mem(rdi, rdi, 2), r14) // r14 = 3*cs_c; prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*cs_c @@ -130,7 +134,7 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 prefetch(0, mem(r10, rdi, 1, 7*8)) // prefetch c + 5*cs_c prefetch(0, mem(r10, rdi, 2, 7*8)) // prefetch c + 6*cs_c prefetch(0, mem(r10, r14, 1, 7*8)) // prefetch c + 7*cs_c - + vxorps(ymm8, ymm8, ymm8) vxorps(ymm9, ymm9, ymm9) vxorps(ymm10, ymm10, ymm10) @@ -139,15 +143,15 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 vxorps(ymm13, ymm13, ymm13) vxorps(ymm14, ymm14, ymm14) vxorps(ymm15, ymm15, ymm15) - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - + label(.SLOOPKITER) // MAIN LOOP - + // iteration 0 prefetch(0, mem(rax, 16*32)) vfmaddps(ymm15, ymm0, ymm2, ymm15) @@ -155,44 +159,44 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 vmovshdup(mem(rbx, 0*32), ymm2) vfmaddps(ymm13, ymm0, ymm3, ymm13) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vmovaps(mem(rax, 1*32), ymm1) vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm11, ymm0, ymm4, ymm11) vfmaddps(ymm9, ymm0, ymm5, ymm9) - + vfmaddps(ymm14, ymm0, ymm2, ymm14) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 1*32), ymm2) vfmaddps(ymm12, ymm0, ymm3, ymm12) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm10, ymm0, ymm4, ymm10) vfmaddps(ymm8, ymm0, ymm5, ymm8) - + // iteration 1 vfmaddps(ymm15, ymm1, ymm2, ymm15) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovshdup(mem(rbx, 1*32), ymm2) vfmaddps(ymm13, ymm1, ymm3, ymm13) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vmovaps(mem(rax, 2*32), ymm0) vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm11, ymm1, ymm4, ymm11) vfmaddps(ymm9, ymm1, ymm5, ymm9) - + vfmaddps(ymm14, ymm1, ymm2, ymm14) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 2*32), ymm2) vfmaddps(ymm12, ymm1, ymm3, ymm12) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm10, ymm1, ymm4, ymm10) vfmaddps(ymm8, ymm1, ymm5, ymm8) - + // iteration 2 prefetch(0, mem(rax, 18*32)) vfmaddps(ymm15, ymm0, ymm2, ymm15) @@ -200,23 +204,23 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 vmovshdup(mem(rbx, 2*32), ymm2) vfmaddps(ymm13, ymm0, ymm3, ymm13) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vmovaps(mem(rax, 3*32), ymm1) add(imm(4*8*4), rax) // a += 4*8 (unroll x mr) vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm11, ymm0, ymm4, ymm11) vfmaddps(ymm9, ymm0, ymm5, ymm9) - + vfmaddps(ymm14, ymm0, ymm2, ymm14) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 3*32), ymm2) vfmaddps(ymm12, ymm0, ymm3, ymm12) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm10, ymm0, ymm4, ymm10) vfmaddps(ymm8, ymm0, ymm5, ymm8) - + // iteration 3 vfmaddps(ymm15, ymm1, ymm2, ymm15) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) @@ -224,134 +228,134 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 add(imm(4*8*4), rbx) // b += 4*8 (unroll x nr) vfmaddps(ymm13, ymm1, ymm3, ymm13) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vmovaps(mem(rax, 0*32), ymm0) vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm11, ymm1, ymm4, ymm11) vfmaddps(ymm9, ymm1, ymm5, ymm9) - + vfmaddps(ymm14, ymm1, ymm2, ymm14) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 0*32), ymm2) vfmaddps(ymm12, ymm1, ymm3, ymm12) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm10, ymm1, ymm4, ymm10) vfmaddps(ymm8, ymm1, ymm5, ymm8) - - - + + + dec(rsi) // i -= 1; jne(.SLOOPKITER) // iterate again if i != 0. - - - - + + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - + label(.SLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 16*32)) vfmaddps(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmovshdup(mem(rbx, 0*32), ymm2) vfmaddps(ymm13, ymm0, ymm3, ymm13) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + vmovaps(mem(rax, 1*32), ymm1) add(imm(8*1*4), rax) // a += 8 (1 x mr) vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm11, ymm0, ymm4, ymm11) vfmaddps(ymm9, ymm0, ymm5, ymm9) - - vfmaddps(ymm14, ymm0, ymm2, ymm14) + + vfmaddps(ymm14, ymm0, ymm2, ymm14) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 1*32), ymm2) add(imm(8*1*4), rbx) // b += 8 (1 x nr) - vfmaddps(ymm12, ymm0, ymm3, ymm12) + vfmaddps(ymm12, ymm0, ymm3, ymm12) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + vpermilps(imm(0x4e), ymm2, ymm3) - vfmaddps(ymm10, ymm0, ymm4, ymm10) - vfmaddps(ymm8, ymm0, ymm5, ymm8) + vfmaddps(ymm10, ymm0, ymm4, ymm10) + vfmaddps(ymm8, ymm0, ymm5, ymm8) vmovaps(ymm1, ymm0) - - + + dec(rsi) // i -= 1; jne(.SLOOPKLEFT) // iterate again if i != 0. - - + + label(.SPOSTACCUM) // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab22 ab20 ab26 ab24 // ab32 ab30 ab36 ab34 // ab44 ab46 ab40 ab42 - // ab54 ab56 ab50 ab52 + // ab54 ab56 ab50 ab52 // ab66 ab64 ab62 ab60 // ab76 ) ab74 ) ab72 ) ab70 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab23 ab21 ab27 ab25 // ab33 ab31 ab37 ab35 // ab45 ab47 ab41 ab43 - // ab55 ab57 ab51 ab53 + // ab55 ab57 ab51 ab53 // ab67 ab65 ab63 ab61 // ab77 ) ab75 ) ab73 ) ab71 ) GROUP_YMM_BY_4 // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab20 ab22 ab24 ab26 // ab30 ab32 ab34 ab36 // ab44 ab46 ab40 ab42 - // ab54 ab56 ab50 ab52 + // ab54 ab56 ab50 ab52 // ab64 ab66 ab60 ab62 // ab74 ) ab76 ) ab70 ) ab72 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab21 ab23 ab25 ab27 // ab31 ab33 ab35 ab37 // ab45 ab47 ab41 ab43 - // ab55 ab57 ab51 ab53 + // ab55 ab57 ab51 ab53 // ab65 ab67 ab61 ab63 // ab75 ) ab77 ) ab71 ) ab73 ) // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab20 ab22 ab24 ab26 // ab30 ab32 ab34 ab36 // ab40 ab42 ab44 ab46 - // ab50 ab52 ab54 ab56 + // ab50 ab52 ab54 ab56 // ab60 ab62 ab64 ab66 // ab70 ) ab72 ) ab74 ) ab76 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab21 ab23 ab25 ab27 // ab31 ab33 ab35 ab37 // ab41 ab43 ab45 ab47 - // ab51 ab53 ab55 ab57 + // ab51 ab53 ab55 ab57 // ab61 ab63 ab65 ab67 // ab71 ) ab73 ) ab75 ) ab77 ) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rax), ymm0) // load alpha and duplicate vbroadcastss(mem(rbx), ymm4) // load beta and duplicate - + vmulps(ymm0, ymm8, ymm8) // scale by alpha vmulps(ymm0, ymm9, ymm9) vmulps(ymm0, ymm10, ymm10) @@ -360,401 +364,115 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 vmulps(ymm0, ymm13, ymm13) vmulps(ymm0, ymm14, ymm14) vmulps(ymm0, ymm15, ymm15) - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float) - - lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - - lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; - lea(mem(r12, rsi, 1), r13) // r13 = 3*rs_c; - - - // determine if - // c % 32 == 0, AND - // 4*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(4), rsi) // set ZF if (4*rs_c) == 4. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (4*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + + // now avoid loading C if beta == 0 - + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm4) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORED) // jump to column storage case - - - label(.SGENSTORED) - // update c00:c70 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vfmaddps(ymm15, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm14, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm14, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c02:c72 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm13, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm13, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c03:c73 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm12, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm12, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c04:c74 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm11, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm11, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c05:c75 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm10, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm10, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c06:c76 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm9, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm9, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c07:c77 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm8, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm8, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - - vmovaps(mem(rcx), ymm0) // load c00:c70, -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm15, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm15, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - vmovaps(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm1) // load c01:c71, -// vmulps(ymm4, ymm1, ymm1) // scale by beta, -// vaddps(ymm14, ymm1, ymm1) // add the gemm result, - vfmaddps(ymm14, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, - vmovaps(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm0) // load c02:c72, -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm13, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm13, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - vmovaps(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm1) // load c03:c73, -// vmulps(ymm4, ymm1, ymm1) // scale by beta, -// vaddps(ymm12, ymm1, ymm1) // add the gemm result, - vfmaddps(ymm12, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, - vmovaps(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm0) // load c04:c74, -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm11, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm11, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - vmovaps(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm1) // load c05:c75, -// vmulps(ymm4, ymm1, ymm1) // scale by beta, -// vaddps(ymm10, ymm1, ymm1) // add the gemm result, - vfmaddps(ymm10, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, - vmovaps(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm0) // load c06:c76, -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm9, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm9, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - vmovaps(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm1) // load c07:c77, -// vmulps(ymm4, ymm1, ymm1) // scale by beta, -// vaddps(ymm8, ymm1, ymm1) // add the gemm result, - vfmaddps(ymm8, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, - vmovaps(ymm1, mem(rcx)) // and store back to memory. - - jmp(.SDONE) // jump to end. - - + + vmovaps(mem(rcx), ymm0) // load c00:c70, + //vmulps(ymm4, ymm0, ymm0) // scale by beta, + //vaddps(ymm15, ymm0, ymm0) // add the gemm result, + vfmaddps(ymm15, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, + vmovaps(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm1) // load c01:c71, + //vmulps(ymm4, ymm1, ymm1) // scale by beta, + //vaddps(ymm14, ymm1, ymm1) // add the gemm result, + vfmaddps(ymm14, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, + vmovaps(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm0) // load c02:c72, + //vmulps(ymm4, ymm0, ymm0) // scale by beta, + //vaddps(ymm13, ymm0, ymm0) // add the gemm result, + vfmaddps(ymm13, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, + vmovaps(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm1) // load c03:c73, + //vmulps(ymm4, ymm1, ymm1) // scale by beta, + //vaddps(ymm12, ymm1, ymm1) // add the gemm result, + vfmaddps(ymm12, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, + vmovaps(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm0) // load c04:c74, + //vmulps(ymm4, ymm0, ymm0) // scale by beta, + //vaddps(ymm11, ymm0, ymm0) // add the gemm result, + vfmaddps(ymm11, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, + vmovaps(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm1) // load c05:c75, + //vmulps(ymm4, ymm1, ymm1) // scale by beta, + //vaddps(ymm10, ymm1, ymm1) // add the gemm result, + vfmaddps(ymm10, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, + vmovaps(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm0) // load c06:c76, + //vmulps(ymm4, ymm0, ymm0) // scale by beta, + //vaddps(ymm9, ymm0, ymm0) // add the gemm result, + vfmaddps(ymm9, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, + vmovaps(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm1) // load c07:c77, + //vmulps(ymm4, ymm1, ymm1) // scale by beta, + //vaddps(ymm8, ymm1, ymm1) // add the gemm result, + vfmaddps(ymm8, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, + vmovaps(ymm1, mem(rcx)) // and store back to memory. + + jmp(.SDONE) // jump to end. + label(.SBETAZERO) - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORBZ) // jump to column storage case - - - label(.SGENSTORBZ) - // update c00:c70 - vmovapd(ymm15, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - vmovapd(ymm14, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - vmovapd(ymm13, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - vmovapd(ymm12, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c04:c74 - vmovapd(ymm11, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c05:c75 - vmovapd(ymm10, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c06:c76 - vmovapd(ymm9, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c07:c77 - vmovapd(ymm8, ymm0) - STORE_SS - - jmp(.SDONE) // jump to end. - - - label(.SCOLSTORBZ) - - vmovaps(ymm15, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm14, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm13, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm12, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm11, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm10, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm9, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm8, mem(rcx)) // and store back to memory. - + + vmovaps(ymm15, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm14, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm13, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm12, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm11, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm10, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm9, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm8, mem(rcx)) // and store back to memory. + label(.SDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -762,6 +480,8 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } #undef KERNEL4x6_1 @@ -862,7 +582,9 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 void bli_dgemm_bulldozer_asm_4x6_fma4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -874,66 +596,68 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 12; - uint64_t k_left = k0 % 12; + uint64_t k_iter = k / 12; + uint64_t k_left = k % 12; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ANY( d, 4, 6, false ); + begin_asm() - - + + vzeroall() mov(var(b), rbx) // load address of b. mov(var(a), rax) // load address of a. prefetch(0, mem(rax, 64)) - - + + vmovaps(mem(rbx, 0*8), xmm1) vmovaps(mem(rbx, 2*8), xmm2) vmovaps(mem(rbx, 4*8), xmm3) add(imm(12*8), rbx) add(imm(8*8), rax) - + mov(var(k_iter), rsi) // i = k_iter; notice var(k_iter) not $0 test(rsi, rsi) je(.CONSIDERKLEFT) - + ALIGN32 label(.LOOPKITER) // MAIN LOOP - - KERNEL4x6_1(xx) - KERNEL4x6_2(xx) - KERNEL4x6_3(xx) - KERNEL4x6_4(xx) - KERNEL4x6_1(xx) - KERNEL4x6_2(xx) - KERNEL4x6_3(xx) - KERNEL4x6_4(xx) - KERNEL4x6_1(xx) - KERNEL4x6_2(xx) - KERNEL4x6_3(xx) - KERNEL4x6_4(xx) - + + KERNEL4x6_1(xx) + KERNEL4x6_2(xx) + KERNEL4x6_3(xx) + KERNEL4x6_4(xx) + KERNEL4x6_1(xx) + KERNEL4x6_2(xx) + KERNEL4x6_3(xx) + KERNEL4x6_4(xx) + KERNEL4x6_1(xx) + KERNEL4x6_2(xx) + KERNEL4x6_3(xx) + KERNEL4x6_4(xx) + dec(rsi) jne(.LOOPKITER) - + label(.CONSIDERKLEFT) - + mov(var(k_left), rsi) - test(rsi, rsi) + test(rsi, rsi) label(.LOOPKLEFT) je(.POSTACCUM) - - KERNEL4x6_1(xx) + + KERNEL4x6_1(xx) add(imm(6*8), rbx) add(imm(4*8), rax) - + dec(rsi) jmp(.LOOPKLEFT) // iterate again if i != 0. - + label(.POSTACCUM) - - + + mov(var(rs_c), rsi) // load cs_c mov(var(cs_c), rdi) // load rs_c vmovddup(mem(var(alpha)), xmm2) //load alpha @@ -942,32 +666,32 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 sal(imm(3), rsi) // cs_c *= sizeof(double) sal(imm(3), rdi) // rs_c *= sizeof(double) lea(mem(rcx, rdi, 2), rdx) - - vmovlpd(mem(rcx), xmm0, xmm0) - vmovlpd(mem(rdx), xmm1, xmm1) + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovlpd(mem(rdx), xmm1, xmm1) vmovhpd(mem(rcx, rdi, 1), xmm0, xmm0) vmovhpd(mem(rdx, rdi, 1), xmm1, xmm1) lea(mem(rdx, rdi, 2), r8) vmulpd(xmm2, xmm4, xmm4) // scale by alpha, vmulpd(xmm2, xmm5, xmm5) // scale by alpha, vfmaddpd(xmm4, xmm0, xmm3, xmm4) // scale by beta, and add the gemm result - vmovlpd(mem(r8), xmm0, xmm0) + vmovlpd(mem(r8), xmm0, xmm0) vfmaddpd(xmm5, xmm1, xmm3, xmm5) // scale by beta, and add the gemm result vmovhpd(mem(r8, rdi, 1), xmm0, xmm0) vmovlpd(xmm4, mem(rcx)) // and store back to memory. vmovlpd(xmm5, mem(rdx)) // and store back to memory. vmovhpd(xmm4, mem(rcx, rdi, 1)) - add(rsi, rcx) + add(rsi, rcx) vmovhpd(xmm5, mem(rdx, rdi, 1)) - add(rsi, rdx) - + add(rsi, rdx) + vmulpd(xmm2, xmm6, xmm6) // scale by alpha, vfmaddpd(xmm6, xmm0, xmm3, xmm6) // scale by beta, and add the gemm result vmovlpd(xmm6, mem(r8)) // and store back to memory. vmovhpd(xmm6, mem(r8, rdi, 1)) - add(rsi, r8) - - + add(rsi, r8) + + vmovlpd(mem(rcx), xmm0, xmm0) vmovlpd(mem(rdx), xmm1, xmm1) vmovlpd(mem(r8), xmm4, xmm4) @@ -984,13 +708,13 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 vmovlpd(xmm8, mem(rdx)) // and store back to memory. vmovlpd(xmm9, mem(r8)) // and store back to memory. vmovhpd(xmm7, mem(rcx, rdi, 1)) - add(rsi, rcx) + add(rsi, rcx) vmovhpd(xmm8, mem(rdx, rdi, 1)) - add(rsi, rdx) + add(rsi, rdx) vmovhpd(xmm9, mem(r8, rdi, 1)) - add(rsi, r8) - - + add(rsi, r8) + + vmovlpd(mem(rcx), xmm0, xmm0) vmovlpd(mem(rdx), xmm1, xmm1) vmovlpd(mem(r8), xmm4, xmm4) @@ -1007,13 +731,13 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 vmovlpd(xmm11, mem(rdx)) // and store back to memory. vmovlpd(xmm12, mem(r8)) // and store back to memory. vmovhpd(xmm10, mem(rcx, rdi, 1)) - add(rsi, rcx) + add(rsi, rcx) vmovhpd(xmm11, mem(rdx, rdi, 1)) - add(rsi, rdx) + add(rsi, rdx) vmovhpd(xmm12, mem(r8, rdi, 1)) - add(rsi, r8) - - + add(rsi, r8) + + vmovlpd(mem(rcx), xmm0, xmm0) vmovlpd(mem(rdx), xmm1, xmm1) vmovlpd(mem(r8), xmm4, xmm4) @@ -1031,30 +755,32 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 vmovlpd(xmm15, mem(r8)) // and store back to memory. vmovhpd(xmm13, mem(rcx, rdi, 1)) vmovhpd(xmm14, mem(rdx, rdi, 1)) - vmovhpd(xmm15, mem(r8, rdi, 1)) - - end_asm( - : // output operands (none) - : // input operands - [k_iter] "r" (k_iter), // 0 - [k_left] "r" (k_left), // 1 - [a] "r" (a), // 2 - [b] "r" (b), // 3 - [alpha] "r" (alpha), // 4 - [beta] "r" (beta), // 5 - [c] "r" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 - : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", - "xmm0", "xmm1", "xmm2", "xmm3", - "xmm4", "xmm5", "xmm6", "xmm7", - "xmm8", "xmm9", "xmm10", "xmm11", - "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + vmovhpd(xmm15, mem(r8, rdi, 1)) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "r" (k_iter), // 0 + [k_left] "r" (k_left), // 1 + [a] "r" (a), // 2 + [b] "r" (b), // 3 + [alpha] "r" (alpha), // 4 + [beta] "r" (beta), // 5 + [c] "r" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } //The parameter "i" is the iteration number, i.e. the B values to read #define MADD_TO_YMM(i) \ @@ -1076,7 +802,9 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 void bli_cgemm_bulldozer_asm_8x4_fma4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -1091,33 +819,35 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ALIGNED( c, 8, 4, false, 32 ); + begin_asm() - + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. //mov(var(a_next), r14) // load address of a_next. sub(imm(4*64), r15) - + vmovaps(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovsldup(mem(rbx, 0*32), ymm2) vpermilps(imm(0x4e), ymm2, ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(scomplex) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*cs_c prefetch(0, mem(r10, 3*8)) // prefetch c + 2*cs_c prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c - + vxorps(ymm8, ymm8, ymm8) vxorps(ymm9, ymm9, ymm9) vxorps(ymm10, ymm10, ymm10) @@ -1126,343 +856,312 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 vxorps(ymm13, ymm13, ymm13) vxorps(ymm14, ymm14, ymm14) vxorps(ymm15, ymm15, ymm15) - + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.CCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - + label(.CLOOPKITER) // MAIN LOOP - + add(imm(4*4*8), r15) // b_next += 4*4 (unroll x nr) - + // iteration 0 prefetch(0, mem(rax, 8*32)) vmovaps(mem(rax, 1*32), ymm1) MADD_TO_YMM(0) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vaddsubps(ymm6, ymm15, ymm15) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 1*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 2*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + // iteration 1 prefetch(0, mem(rax, 10*32)) vmovaps(mem(rax, 3*32), ymm1) MADD_TO_YMM(1) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 2*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 4*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - + // iteration 2 prefetch(0, mem(rax, 12*32)) vmovaps(mem(rax, 5*32), ymm1) MADD_TO_YMM(2) prefetch(0, mem(r15, 2*32)) // prefetch b_next[2*4] - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 3*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 6*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + // iteration 3 prefetch(0, mem(rax, 14*32)) vmovaps(mem(rax, 7*32), ymm1) MADD_TO_YMM(3) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 4*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 8*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + add(imm(8*4*8), rax) // a += 8*4 (unroll x mr) add(imm(4*4*8), rbx) // b += 4*4 (unroll x nr) - - + + dec(rsi) // i -= 1; jne(.CLOOPKITER) // iterate again if i != 0. - - - + + + label(.CCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.CPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.CLOOPKLEFT) // EDGE LOOP - + // iteration 0 prefetch(0, mem(rax, 8*32)) vmovaps(mem(rax, 1*32), ymm1) MADD_TO_YMM(0) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 1*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 2*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + add(imm(8*1*8), rax) // a += 8 (1 x mr) add(imm(4*1*8), rbx) // b += 4 (1 x nr) - - + + dec(rsi) // i -= 1; jne(.CLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.CPOSTACCUM) - + // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab21 ab20 ab23 ab22 - // ab31 ab30 ab33 ab32 - // ab42 ab43 ab40 ab41 - // ab52 ab53 ab50 ab51 - // ab63 ab62 ab61 ab60 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab21 ab20 ab23 ab22 + // ab31 ab30 ab33 ab32 + // ab42 ab43 ab40 ab41 + // ab52 ab53 ab50 ab51 + // ab63 ab62 ab61 ab60 // ab73 ) ab72 ) ab71 ) ab70 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba1 aba0 aba3 aba2 - // abb1 abb0 abb3 abb2 - // abc2 abc3 abc0 abc1 - // abd2 abd3 abd0 abd1 - // abe3 abe2 abe1 abe0 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba1 aba0 aba3 aba2 + // abb1 abb0 abb3 abb2 + // abc2 abc3 abc0 abc1 + // abd2 abd3 abd0 abd1 + // abe3 abe2 abe1 abe0 // abf3 abf2 abf1 abf0 ) GROUP_YMM_BY_4 // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab20 ab21 ab22 ab23 - // ab30 ab31 ab32 ab33 - // ab42 ab43 ab40 ab41 - // ab52 ab53 ab50 ab51 - // ab62 ab63 ab60 ab61 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab20 ab21 ab22 ab23 + // ab30 ab31 ab32 ab33 + // ab42 ab43 ab40 ab41 + // ab52 ab53 ab50 ab51 + // ab62 ab63 ab60 ab61 // ab72 ) ab73 ) ab70 ) ab71 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba0 aba1 aba2 aba3 - // abb0 abb1 abb2 abb3 - // abc2 abc3 abc0 abc1 - // abd2 abd3 abd0 abd1 - // abe2 abe3 abe0 abe1 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba0 aba1 aba2 aba3 + // abb0 abb1 abb2 abb3 + // abc2 abc3 abc0 abc1 + // abd2 abd3 abd0 abd1 + // abe2 abe3 abe0 abe1 // abf2 ) abf3 ) abf0 ) abf1 ) - + // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab20 ab21 ab22 ab23 - // ab30 ab31 ab32 ab33 - // ab40 ab41 ab42 ab43 - // ab50 ab51 ab52 ab53 - // ab60 ab61 ab62 ab63 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab20 ab21 ab22 ab23 + // ab30 ab31 ab32 ab33 + // ab40 ab41 ab42 ab43 + // ab50 ab51 ab52 ab53 + // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba0 aba1 aba2 aba3 - // abb0 abb1 abb2 abb3 - // abc0 abc1 abc2 abc3 - // abd0 abd1 abd2 abd3 - // abe0 abe1 abe2 abe3 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba0 aba1 aba2 aba3 + // abb0 abb1 abb2 abb3 + // abc0 abc1 abc2 abc3 + // abd0 abd1 abd2 abd3 + // abe0 abe1 abe2 abe3 // abf0 ) abf1 ) abf2 ) abf3 ) - + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vbroadcastss(mem(rax), ymm7) // load alpha_r and duplicate vbroadcastss(mem(rax, 4), ymm6) // load alpha_i and duplicate - + vpermilps(imm(0xb1), ymm15, ymm3) vmulps(ymm7, ymm15, ymm15) vmulps(ymm6, ymm3, ymm3) vaddsubps(ymm3, ymm15, ymm15) - + vpermilps(imm(0xb1), ymm14, ymm2) vmulps(ymm7, ymm14, ymm14) vmulps(ymm6, ymm2, ymm2) vaddsubps(ymm2, ymm14, ymm14) - + vpermilps(imm(0xb1), ymm13, ymm1) vmulps(ymm7, ymm13, ymm13) vmulps(ymm6, ymm1, ymm1) vaddsubps(ymm1, ymm13, ymm13) - + vpermilps(imm(0xb1), ymm12, ymm0) vmulps(ymm7, ymm12, ymm12) vmulps(ymm6, ymm0, ymm0) vaddsubps(ymm0, ymm12, ymm12) - + vpermilps(imm(0xb1), ymm11, ymm3) vmulps(ymm7, ymm11, ymm11) vmulps(ymm6, ymm3, ymm3) vaddsubps(ymm3, ymm11, ymm11) - + vpermilps(imm(0xb1), ymm10, ymm2) vmulps(ymm7, ymm10, ymm10) vmulps(ymm6, ymm2, ymm2) vaddsubps(ymm2, ymm10, ymm10) - + vpermilps(imm(0xb1), ymm9, ymm1) vmulps(ymm7, ymm9, ymm9) vmulps(ymm6, ymm1, ymm1) vaddsubps(ymm1, ymm9, ymm9) - + vpermilps(imm(0xb1), ymm8, ymm0) vmulps(ymm7, ymm8, ymm8) vmulps(ymm6, ymm0, ymm0) vaddsubps(ymm0, ymm8, ymm8) - - - - + + + + mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rbx), ymm7) // load beta_r and duplicate vbroadcastss(mem(rbx, 4), ymm6) // load beta_i and duplicate - - - - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(scomplex) - - lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - - lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; - lea(mem(r12, rsi, 1), r13) // r13 = 3*rs_c; - - - - // determine if - // c % 32 == 0, AND - // 8*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (8*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + + // now avoid loading C if beta == 0 - + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm7) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -1470,388 +1169,126 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 0, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.CCOLSTORED) // jump to column storage case - - - - label(.CGENSTORED) - - // update c00:c70 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c00,10) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c20,30) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c40,50) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c60,70) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c00,c10) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c20,c30) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c40,c50) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c60,c70) - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c80,90) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca0,b0) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc0,d0) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce0,f0) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c80,c90) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca0,cb0) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc0,cd0) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce0,cf0) - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c01,11) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c21,31) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c41,51) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c61,71) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c01,c11) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c21,c31) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c41,c51) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c61,c71) - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c81,91) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca1,b1) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc1,d1) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce1,f1) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c81,c91) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca1,cb1) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc1,cd1) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce1,cf1) - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c02,12) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c22,32) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c42,52) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c62,72) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c02,c12) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c22,c32) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c42,c52) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c62,c72) - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c82,92) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca2,b2) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc2,d2) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce2,f2) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c82,c92) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca2,cb2) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc2,cd2) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce2,cf2) - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c03,13) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c23,33) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c43,53) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c63,73) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c03,c13) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c23,c33) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c43,c53) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c63,c73) - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c83,93) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca3,b3) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc3,d3) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce3,f3) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c83,c93) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca3,cb3) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc3,cd3) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce3,cf3) - add(rdi, rdx) // c += cs_c; - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORED) - - // update c00:c70 - - vmovaps(mem(rcx), ymm0) // load c00:c70 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx)) // store c00:c70 - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - - vmovaps(mem(rdx), ymm0) // load c80:f0 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rdx)) // store c80:cf0 - add(rdi, rdx) // c += cs_c; - - // update c00:c70 - - vmovaps(mem(rcx), ymm0) // load c01:c71 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx)) // store c01:c71 - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - - vmovaps(mem(rdx), ymm0) // load c81:f1 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rdx)) // store c81:cf1 - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - vmovaps(mem(rcx), ymm0) // load c02:c72 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx)) // store c02:c72 - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - vmovaps(mem(rdx), ymm0) // load c82:f2 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rdx)) // store c82:cf2 - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - vmovaps(mem(rcx), ymm0) // load c03:c73 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx)) // store c03:c73 - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - vmovaps(mem(rdx), ymm0) // load c83:f3 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rdx)) // store c83:cf3 - add(rdi, rdx) // c += cs_c; - - jmp(.CDONE) // jump to end. - - + + // update c00:c70 + + vmovaps(mem(rcx), ymm0) // load c00:c70 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx)) // store c00:c70 + + // update c80:cf0 + + vmovaps(mem(rcx,32), ymm0) // load c80:f0 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx,32)) // store c80:cf0 + add(rdi, rcx) // c += cs_c; + + // update c00:c70 + + vmovaps(mem(rcx), ymm0) // load c01:c71 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx)) // store c01:c71 + + // update c81:cf1 + + vmovaps(mem(rcx,32), ymm0) // load c81:f1 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx,32)) // store c81:cf1 + add(rdi, rcx) // c += cs_c; + + // update c02:c72 + vmovaps(mem(rcx), ymm0) // load c02:c72 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx)) // store c02:c72 + + // update c82:cf2 + vmovaps(mem(rcx,32), ymm0) // load c82:f2 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx,32)) // store c82:cf2 + add(rdi, rcx) // c += cs_c; + + // update c03:c73 + vmovaps(mem(rcx), ymm0) // load c03:c73 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx)) // store c03:c73 + + // update c83:cf3 + vmovaps(mem(rcx,32), ymm0) // load c83:f3 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx,32)) // store c83:cf3 + //add(rdi, rcx) // c += cs_c; + + jmp(.CDONE) // jump to end. + label(.CBETAZERO) - // check if aligned/column-stored - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.CCOLSTORBZ) // jump to column storage case - - - label(.CGENSTORBZ) - // update c00:c70 - vextractf128(imm(1), ymm15, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm15, mem(rcx)) // store (c00,c10) - vmovhpd(xmm15, mem(rcx, rsi, 1)) // store (c20,c30) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c40,c50) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c60,c70) - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - vextractf128(imm(1), ymm14, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm14, mem(rdx)) // store (c80,c90) - vmovhpd(xmm14, mem(rdx, rsi, 1)) // store (ca0,cb0) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc0,cd0) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce0,cf0) - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - vextractf128(imm(1), ymm13, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm13, mem(rcx)) // store (c01,c11) - vmovhpd(xmm13, mem(rcx, rsi, 1)) // store (c21,c31) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c41,c51) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c61,c71) - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - vextractf128(imm(1), ymm12, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm12, mem(rdx)) // store (c81,c91) - vmovhpd(xmm12, mem(rdx, rsi, 1)) // store (ca1,cb1) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc1,cd1) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce1,cf1) - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - vextractf128(imm(1), ymm11, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm11, mem(rcx)) // store (c02,c12) - vmovhpd(xmm11, mem(rcx, rsi, 1)) // store (c22,c32) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c42,c52) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c62,c72) - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - vextractf128(imm(1), ymm10, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm10, mem(rdx)) // store (c82,c92) - vmovhpd(xmm10, mem(rdx, rsi, 1)) // store (ca2,cb2) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc2,cd2) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce2,cf2) - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - vextractf128(imm(1), ymm9, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm9, mem(rcx)) // store (c03,c13) - vmovhpd(xmm9, mem(rcx, rsi, 1)) // store (c23,c33) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c43,c53) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c63,c73) - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - vextractf128(imm(1), ymm8, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm8, mem(rdx)) // store (c83,c93) - vmovhpd(xmm8, mem(rdx, rsi, 1)) // store (ca3,cb3) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc3,cd3) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce3,cf3) - add(rdi, rdx) // c += cs_c; - - - jmp(.CDONE) // jump to end. - - - label(.CCOLSTORBZ) - - vmovaps(ymm15, mem(rcx)) // store c00:c70 - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm14, mem(rdx)) // store c80:cf0 - add(rdi, rdx) // c += cs_c; - - vmovaps(ymm13, mem(rcx)) // store c01:c71 - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm12, mem(rdx)) // store c81:cf1 - add(rdi, rdx) // c += cs_c; - - vmovaps(ymm11, mem(rcx)) // store c02:c72 - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm10, mem(rdx)) // store c82:cf2 - add(rdi, rdx) // c += cs_c; - - vmovaps(ymm9, mem(rcx)) // store c03:c73 - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm8, mem(rdx)) // store c83:cf3 - add(rdi, rdx) // c += cs_c; - - - + + vmovaps(ymm15, mem(rcx)) // store c00:c70 + vmovaps(ymm14, mem(rcx,32)) // store c80:cf0 + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm13, mem(rcx)) // store c01:c71 + vmovaps(ymm12, mem(rcx,32)) // store c81:cf1 + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm11, mem(rcx)) // store c02:c72 + vmovaps(ymm10, mem(rcx,32)) // store c82:cf2 + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm9, mem(rcx)) // store c03:c73 + vmovaps(ymm8, mem(rcx,32)) // store c83:cf3 + add(rdi, rcx) // c += cs_c; + label(.CDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next)/*, // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next)/*, // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", @@ -1859,6 +1296,8 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) + + GEMM_UKR_FLUSH_CT( c ); } #define MADDSUBPD_TO_YMM \ @@ -1883,11 +1322,13 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 vmulpd(ymm7, ymm(i), ymm(i))\ vmulpd(ymm6, ymm(j), ymm(j))\ vaddsubpd(ymm(j), ymm(i), ymm(i))\ - + void bli_zgemm_bulldozer_asm_4x4_fma4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -1902,34 +1343,36 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ALIGNED( z, 4, 4, false, 32 ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(var(b_next), r15) // load address of b_next. //mov(var(a_next), r14) // load address of a_next. - + vmovapd(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovddup(mem(rbx, 0+0*32), ymm2) vmovddup(mem(rbx, 0+1*32), ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(dcomplex) lea(mem(, rdi, 2), rdi) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*cs_c prefetch(0, mem(r10, 3*8)) // prefetch c + 2*cs_c prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c - + vxorpd(ymm8, ymm8, ymm8) vxorpd(ymm9, ymm9, ymm9) vxorpd(ymm10, ymm10, ymm10) @@ -1938,28 +1381,28 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vxorpd(ymm13, ymm13, ymm13) vxorpd(ymm14, ymm14, ymm14) vxorpd(ymm15, ymm15, ymm15) - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.ZCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - + label(.ZLOOPKITER) // MAIN LOOP - + // iteration 0 vmovapd(mem(rax, 1*32), ymm1) vfmaddpd(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vfmaddpd(ymm11, ymm0, ymm3, ymm11) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + prefetch(0, mem(rax, 16*32)) vfmaddpd(ymm14, ymm1, ymm2, ymm14) vmovddup(mem(rbx, 8+0*32), ymm2) vfmaddpd(ymm10, ymm1, ymm3, ymm10) vmovddup(mem(rbx, 8+1*32), ymm3) - + MADDSUBPD_TO_YMM vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+2*32), ymm2) @@ -1967,31 +1410,31 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vmovddup(mem(rbx, 0+3*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 2*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - + // iteration 1 vmovapd(mem(rax, 3*32), ymm1) vfmaddpd(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vfmaddpd(ymm11, ymm0, ymm3, ymm11) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + prefetch(0, mem(rax, 18*32)) vfmaddpd(ymm14, ymm1, ymm2, ymm14) vmovddup(mem(rbx, 8+2*32), ymm2) vfmaddpd(ymm10, ymm1, ymm3, ymm10) vmovddup(mem(rbx, 8+3*32), ymm3) - + MADDSUBPD_TO_YMM vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+4*32), ymm2) @@ -1999,31 +1442,31 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vmovddup(mem(rbx, 0+5*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 4*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - + // iteration 2 vmovapd(mem(rax, 5*32), ymm1) vfmaddpd(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vfmaddpd(ymm11, ymm0, ymm3, ymm11) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + prefetch(0, mem(rax, 20*32)) vfmaddpd(ymm14, ymm1, ymm2, ymm14) vmovddup(mem(rbx, 8+4*32), ymm2) vfmaddpd(ymm10, ymm1, ymm3, ymm10) vmovddup(mem(rbx, 8+5*32), ymm3) - + MADDSUBPD_TO_YMM vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+6*32), ymm2) @@ -2031,31 +1474,31 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vmovddup(mem(rbx, 0+7*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 6*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - + // iteration 3 vmovapd(mem(rax, 7*32), ymm1) vfmaddpd(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vfmaddpd(ymm11, ymm0, ymm3, ymm11) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + prefetch(0, mem(rax, 22*32)) vfmaddpd(ymm14, ymm1, ymm2, ymm14) vmovddup(mem(rbx, 8+6*32), ymm2) vfmaddpd(ymm10, ymm1, ymm3, ymm10) vmovddup(mem(rbx, 8+7*32), ymm3) - + MADDSUBPD_TO_YMM vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+8*32), ymm2) @@ -2063,48 +1506,48 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vmovddup(mem(rbx, 0+9*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 8*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - + add(imm(4*4*16), rbx) // b += 4*4 (unroll x nr) add(imm(4*4*16), rax) // a += 4*4 (unroll x mr) - + dec(rsi) // i -= 1; jne(.ZLOOPKITER) // iterate again if i != 0. - - + + label(.ZCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.ZLOOPKLEFT) // EDGE LOOP - + // iteration 0 vmovapd(mem(rax, 1*32), ymm1) vfmaddpd(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vfmaddpd(ymm11, ymm0, ymm3, ymm11) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + prefetch(0, mem(rax, 16*32)) vfmaddpd(ymm14, ymm1, ymm2, ymm14) vmovddup(mem(rbx, 8+0*32), ymm2) vfmaddpd(ymm10, ymm1, ymm3, ymm10) vmovddup(mem(rbx, 8+1*32), ymm3) - + MADDSUBPD_TO_YMM vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+2*32), ymm2) @@ -2112,75 +1555,75 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vmovddup(mem(rbx, 0+3*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 2*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + add(imm(4*1*16), rax) // a += 4 (1 x mr) add(imm(4*1*16), rbx) // b += 4 (1 x nr) - + dec(rsi) // i -= 1; jne(.ZLOOPKLEFT) // iterate again if i != 0. - - + + label(.ZPOSTACCUM) // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 + // ab10 ab11 ab12 ab13 // ab21 ab20 ab23 ab22 // ab31 ) ab30 ) ab33 ) ab32 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab50 ab51 ab52 ab53 + // ab50 ab51 ab52 ab53 // ab61 ab60 ab63 ab62 // ab71 ) ab70 ) ab73 ) ab72 ) - + vmovapd(ymm15, ymm7) vperm2f128(imm(0x12), ymm15, ymm13, ymm15) vperm2f128(imm(0x30), ymm7, ymm13, ymm13) - + vmovapd(ymm11, ymm7) vperm2f128(imm(0x12), ymm11, ymm9, ymm11) vperm2f128(imm(0x30), ymm7, ymm9, ymm9) - + vmovapd(ymm14, ymm7) vperm2f128(imm(0x12), ymm14, ymm12, ymm14) vperm2f128(imm(0x30), ymm7, ymm12, ymm12) - + vmovapd(ymm10, ymm7) vperm2f128(imm(0x12), ymm10, ymm8, ymm10) vperm2f128(imm(0x30), ymm7, ymm8, ymm8) - - + + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 + // ab10 ab11 ab12 ab13 // ab20 ab21 ab22 ab23 // ab30 ) ab31 ) ab32 ) ab33 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab50 ab51 ab52 ab53 + // ab50 ab51 ab52 ab53 // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - - + + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vbroadcastsd(mem(rax), ymm7) // load alpha_r and duplicate vbroadcastsd(mem(rax, 8), ymm6) // load alpha_i and duplicate - + Z_ALPHA(15, 3) Z_ALPHA(14, 2) Z_ALPHA(13, 1) @@ -2190,38 +1633,14 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 Z_ALPHA(10, 2) Z_ALPHA(9, 1) Z_ALPHA(8, 0) - + mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rbx), ymm7) // load beta_r and duplicate vbroadcastsd(mem(rbx, 8), ymm6) // load beta_i and duplicate - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(dcomplex) - lea(mem(, rsi, 2), rsi) - lea(mem(rcx, rsi, 2), rdx) // load address of c + 2*rs_c; - - - - // determine if - // c % 32 == 0, AND - // 16*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(16), rsi) // set ZF if (16*rs_c) == 16. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (16*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm7) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -2229,287 +1648,91 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 0, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.ZCOLSTORED) // jump to column storage case - - - - label(.ZGENSTORED) - // update c00:c30 - - vmovupd(mem(rcx), xmm0) // load (c00,c10) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c20,c30) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c00,c10) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c20,c30) - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vmovupd(mem(rdx), xmm0) // load (c40,c50) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c60,c70) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c40,c50) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c60,c70) - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vmovupd(mem(rcx), xmm0) // load (c01,c11) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c21,c31) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c01,c11) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c21,c31) - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vmovupd(mem(rdx), xmm0) // load (c41,c51) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c61,c71) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c41,c51) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c61,c71) - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vmovupd(mem(rcx), xmm0) // load (c02,c12) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c22,c32) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c02,c12) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c22,c32) - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vmovupd(mem(rdx), xmm0) // load (c42,c52) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c62,c72) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c42,c52) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c62,c72) - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vmovupd(mem(rcx), xmm0) // load (c03,c13) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c23,c33) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c03,c13) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c23,c33) - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vmovupd(mem(rdx), xmm0) // load (c43,c53) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c63,c73) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c43,c53) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c63,c73) - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORED) - // update c00:c30 - - vmovapd(mem(rcx), ymm0) // load c00:c30 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx)) // store c00:c30 - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vmovapd(mem(rdx), ymm0) // load c40:c70 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rdx)) // store c40:c70 - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vmovapd(mem(rcx), ymm0) // load c01:c31 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx)) // store c01:c31 - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vmovapd(mem(rdx), ymm0) // load c41:c71 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rdx)) // store c41:c71 - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vmovapd(mem(rcx), ymm0) // load c02:c32 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx)) // store c02:c32 - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vmovapd(mem(rdx), ymm0) // load c42:c72 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rdx)) // store c42:c72 - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vmovapd(mem(rcx), ymm0) // load c03:c33 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx)) // store c03:c33 - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vmovapd(mem(rdx), ymm0) // load c43:c73 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rdx)) // store c43:c73 - - - - jmp(.ZDONE) // jump to end. - - - + + // update c00:c30 + + vmovapd(mem(rcx), ymm0) // load c00:c30 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx)) // store c00:c30 + + // update c40:c70 + + vmovapd(mem(rcx,32), ymm0) // load c40:c70 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx,32)) // store c40:c70 + add(rdi, rcx) // c += cs_c; + + // update c01:c31 + + vmovapd(mem(rcx), ymm0) // load c01:c31 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx)) // store c01:c31 + + // update c41:c71 + + vmovapd(mem(rcx,32), ymm0) // load c41:c71 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx,32)) // store c41:c71 + add(rdi, rcx) // c += cs_c; + + // update c02:c32 + + vmovapd(mem(rcx), ymm0) // load c02:c32 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx)) // store c02:c32 + + // update c42:c72 + + vmovapd(mem(rcx,32), ymm0) // load c42:c72 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx,32)) // store c42:c72 + add(rdi, rcx) // c += cs_c; + + // update c03:c33 + + vmovapd(mem(rcx), ymm0) // load c03:c33 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx)) // store c03:c33 + + // update c43:c73 + + vmovapd(mem(rcx,32), ymm0) // load c43:c73 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx,32)) // store c43:c73 + add(rdi, rcx) // c += cs_c; + + jmp(.ZDONE) // jump to end. + label(.ZBETAZERO) - // check if aligned/column-stored - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.ZCOLSTORBZ) // jump to column storage case - - - - label(.ZGENSTORBZ) - // update c00:c30 - - vextractf128(imm(1), ymm15, xmm2) - vmovupd(xmm15, mem(rcx)) // store (c00,c10) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c20,c30) - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vextractf128(imm(1), ymm14, xmm2) - vmovupd(xmm14, mem(rdx)) // store (c40,c50) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c60,c70) - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vextractf128(imm(1), ymm13, xmm2) - vmovupd(xmm13, mem(rcx)) // store (c01,c11) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c21,c31) - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vextractf128(imm(1), ymm12, xmm2) - vmovupd(xmm12, mem(rdx)) // store (c41,c51) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c61,c71) - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vextractf128(imm(1), ymm11, xmm2) - vmovupd(xmm11, mem(rcx)) // store (c02,c12) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c22,c32) - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vextractf128(imm(1), ymm10, xmm2) - vmovupd(xmm10, mem(rdx)) // store (c42,c52) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c62,c72) - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vextractf128(imm(1), ymm9, xmm2) - vmovupd(xmm9, mem(rcx)) // store (c03,c13) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c23,c33) - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vextractf128(imm(1), ymm8, xmm2) - vmovupd(xmm8, mem(rdx)) // store (c43,c53) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c63,c73) - - - jmp(.ZDONE) // jump to end. - - - label(.ZCOLSTORBZ) - - - vmovapd(ymm15, mem(rcx)) // store c00:c30 - add(rdi, rcx) // c += cs_c; - - vmovapd(ymm14, mem(rdx)) // store c40:c70 - add(rdi, rdx) // c += cs_c; - - vmovapd(ymm13, mem(rcx)) // store c01:c31 - add(rdi, rcx) // c += cs_c; - - vmovapd(ymm12, mem(rdx)) // store c41:c71 - add(rdi, rdx) // c += cs_c; - - vmovapd(ymm11, mem(rcx)) // store c02:c32 - add(rdi, rcx) // c += cs_c; - - vmovapd(ymm10, mem(rdx)) // store c42:c72 - add(rdi, rdx) // c += cs_c; - - vmovapd(ymm9, mem(rcx)) // store c03:c33 - add(rdi, rcx) // c += cs_c; - - vmovapd(ymm8, mem(rdx)) // store c43:c73 - - + + vmovapd(ymm15, mem(rcx)) // store c00:c30 + vmovapd(ymm14, mem(rcx,32)) // store c40:c70 + add(rdi, rcx) // c += cs_c; + + vmovapd(ymm13, mem(rcx)) // store c01:c31 + vmovapd(ymm12, mem(rcx,32)) // store c41:c71 + add(rdi, rcx) // c += cs_c; + + vmovapd(ymm11, mem(rcx)) // store c02:c32 + vmovapd(ymm10, mem(rcx,32)) // store c42:c72 + add(rdi, rcx) // c += cs_c; + + vmovapd(ymm9, mem(rcx)) // store c03:c33 + vmovapd(ymm8, mem(rcx,32)) // store c43:c73 + //add(rdi, rcx) // c += cs_c; + label(.ZDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands [k_iter] "m" (k_iter), // 0 @@ -2524,7 +1747,7 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 [b_next] "m" (b_next), // 9 [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", @@ -2532,5 +1755,7 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c index 7907bd9018..d0e7938678 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c @@ -79,7 +79,9 @@ void bli_sgemm_haswell_asm_6x16 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -94,11 +96,13 @@ void bli_sgemm_haswell_asm_6x16 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_AMBI( s, 6, 16, true ); + begin_asm() vzeroall() // zero all xmm/ymm registers. @@ -109,36 +113,65 @@ void bli_sgemm_haswell_asm_6x16 //mov(%9, r15) // load address of b_next. add(imm(32*4), rbx) - // initialize loop by pre-loading + // initialize loop by pre-loading vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) - lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; - lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c - prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c - prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c - prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c - prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c - - - + cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. + jz(.SCOLPREFETCH) // jump to column prefetch case + + lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; + lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.SPREFETCHDONE) + + label(.SCOLPREFETCH) + + lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; + lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 7*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 7*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 7*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 7*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, r13, 1, 7*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 4, 7*8)) // prefetch c + 7*cs_c + lea(mem(rcx, rsi, 8), r14) // r14 = c + 8*cs_c; + lea(mem(r14, r13, 1), rdx) // rdx = c + 11*cs_c; + prefetch(0, mem(r14, 7*8)) // prefetch c + 8*cs_c + prefetch(0, mem(r14, rsi, 1, 7*8)) // prefetch c + 9*cs_c + prefetch(0, mem(r14, rsi, 2, 7*8)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 1, 7*8)) // prefetch c + 12*cs_c + prefetch(0, mem(rdx, rsi, 2, 7*8)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, r13, 1, 7*8)) // prefetch c + 14*cs_c + prefetch(0, mem(rdx, rsi, 4, 7*8)) // prefetch c + 15*cs_c + + label(.SPREFETCHDONE) mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. + // contains the k_left loop. label(.SLOOPKITER) // MAIN LOOP - // iteration 0 + // iteration 0 prefetch(0, mem(rax, 64*4)) vbroadcastss(mem(rax, 0*4), ymm2) @@ -165,7 +198,7 @@ void bli_sgemm_haswell_asm_6x16 vmovaps(mem(rbx, -2*32), ymm0) vmovaps(mem(rbx, -1*32), ymm1) - // iteration 1 + // iteration 1 vbroadcastss(mem(rax, 6*4), ymm2) vbroadcastss(mem(rax, 7*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) @@ -190,7 +223,7 @@ void bli_sgemm_haswell_asm_6x16 vmovaps(mem(rbx, 0*32), ymm0) vmovaps(mem(rbx, 1*32), ymm1) - // iteration 2 + // iteration 2 prefetch(0, mem(rax, 76*4)) vbroadcastss(mem(rax, 12*4), ymm2) @@ -217,7 +250,7 @@ void bli_sgemm_haswell_asm_6x16 vmovaps(mem(rbx, 2*32), ymm0) vmovaps(mem(rbx, 3*32), ymm1) - // iteration 3 + // iteration 3 vbroadcastss(mem(rax, 18*4), ymm2) vbroadcastss(mem(rax, 19*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) @@ -259,7 +292,7 @@ void bli_sgemm_haswell_asm_6x16 mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. + // else, we prepare to enter k_left loop. label(.SLOOPKLEFT) // EDGE LOOP @@ -338,533 +371,330 @@ void bli_sgemm_haswell_asm_6x16 lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c; - // now avoid loading C if beta == 0 + // now avoid loading C if beta == 0 vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm3) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. - jz(.SROWSTORED) // jump to row storage case - - - cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. - jz(.SCOLSTORED) // jump to column storage case - - - - label(.SGENSTORED) - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm4, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm6, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm8, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm10, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm12, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm14, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += rs_c; - - - mov(rdx, rcx) // rcx = c + 8*cs_c - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm5, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm7, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm9, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm11, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm13, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm15, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += rs_c; - - - - jmp(.SDONE) // jump to end. - - - - label(.SROWSTORED) - - - vfmadd231ps(mem(rcx), ymm3, ymm4) - vmovups(ymm4, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm5) - vmovups(ymm5, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm6) - vmovups(ymm6, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm7) - vmovups(ymm7, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm8) - vmovups(ymm8, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm9) - vmovups(ymm9, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm10) - vmovups(ymm10, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm11) - vmovups(ymm11, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm12) - vmovups(ymm12, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm13) - vmovups(ymm13, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm14) - vmovups(ymm14, mem(rcx)) - //add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm15) - vmovups(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - - vbroadcastss(mem(rbx), ymm3) - - vunpcklps(ymm6, ymm4, ymm0) - vunpcklps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm0) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) - vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) - vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - - vunpckhps(ymm6, ymm4, ymm0) - vunpckhps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) - vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) - vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) - vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) - vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - - lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - - vunpcklps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(r14), xmm1, xmm1) - vmovhpd(mem(r14, rsi, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) - vmovlpd(mem(r14, rsi, 4), xmm1, xmm1) - vmovhpd(mem(r14, r15, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - - vunpckhps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) - vmovhpd(mem(r14, r13, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) - vmovlpd(mem(r14, r13, 2), xmm1, xmm1) - vmovhpd(mem(r14, r10, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - - lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - - vunpcklps(ymm7, ymm5, ymm0) - vunpcklps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm0) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) - vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) - vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - - vunpckhps(ymm7, ymm5, ymm0) - vunpckhps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) - vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) - vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) - vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) - vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - - //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - - vunpcklps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(r14), xmm1, xmm1) - vmovhpd(mem(r14, rsi, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) - vmovlpd(mem(r14, rsi, 4), xmm1, xmm1) - vmovhpd(mem(r14, r15, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - - vunpckhps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) - vmovhpd(mem(r14, r13, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) - vmovlpd(mem(r14, r13, 2), xmm1, xmm1) - vmovhpd(mem(r14, r10, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - - //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - - jmp(.SDONE) // jump to end. - - + cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm5) + vmovups(ymm5, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm7) + vmovups(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm9) + vmovups(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm11) + vmovups(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm12) + vmovups(ymm12, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm13) + vmovups(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm14) + vmovups(ymm14, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm15) + vmovups(ymm15, mem(rcx,32)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(r14), xmm1, xmm1) + vmovhpd(mem(r14, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(r14, rsi, 4), xmm1, xmm1) + vmovhpd(mem(r14, r15, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) + vmovhpd(mem(r14, r13, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(r14, r13, 2), xmm1, xmm1) + vmovhpd(mem(r14, r10, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) + + lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c + + + + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(r14), xmm1, xmm1) + vmovhpd(mem(r14, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(r14, rsi, 4), xmm1, xmm1) + vmovhpd(mem(r14, r15, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) + vmovhpd(mem(r14, r13, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(r14, r13, 2), xmm1, xmm1) + vmovhpd(mem(r14, r10, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) + + //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c + + jmp(.SDONE) // jump to end. label(.SBETAZERO) - cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. - jz(.SROWSTORBZ) // jump to row storage case - - cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. - jz(.SCOLSTORBZ) // jump to column storage case - - - - label(.SGENSTORBZ) - - - vmovaps(ymm4, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm6, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm8, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm10, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm12, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm14, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += rs_c; - - - mov(rdx, rcx) // rcx = c + 8*cs_c - - - vmovaps(ymm5, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm7, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - + cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case - vmovaps(ymm9, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx,32)) + add(rdi, rcx) + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx,32)) + add(rdi, rcx) - vmovaps(ymm11, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx,32)) + add(rdi, rcx) - vmovaps(ymm13, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; + vmovups(ymm10, mem(rcx)) + vmovups(ymm11, mem(rcx,32)) + add(rdi, rcx) - vmovaps(ymm15, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += rs_c; + vmovups(ymm12, mem(rcx)) + vmovups(ymm13, mem(rcx,32)) + add(rdi, rcx) - jmp(.SDONE) // jump to end. + vmovups(ymm14, mem(rcx)) + vmovups(ymm15, mem(rcx,32)) + //add(rdi, rcx) + jmp(.SDONE) // jump to end. + label(.SCOLSTORBZ) - label(.SROWSTORBZ) + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - vmovups(ymm4, mem(rcx)) - add(rdi, rcx) - vmovups(ymm5, mem(rdx)) - add(rdi, rdx) + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - vmovups(ymm6, mem(rcx)) - add(rdi, rcx) - vmovups(ymm7, mem(rdx)) - add(rdi, rdx) + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) - vmovups(ymm8, mem(rcx)) - add(rdi, rcx) - vmovups(ymm9, mem(rdx)) - add(rdi, rdx) + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - vmovups(ymm10, mem(rcx)) - add(rdi, rcx) - vmovups(ymm11, mem(rdx)) - add(rdi, rdx) + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - vmovups(ymm12, mem(rcx)) - add(rdi, rcx) - vmovups(ymm13, mem(rdx)) - add(rdi, rdx) + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) + lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - vmovups(ymm14, mem(rcx)) - //add(rdi, rcx) - vmovups(ymm15, mem(rdx)) - //add(rdi, rdx) + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) - jmp(.SDONE) // jump to end. + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - label(.SCOLSTORBZ) + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - vunpcklps(ymm6, ymm4, ymm0) - vunpcklps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - - vunpckhps(ymm6, ymm4, ymm0) - vunpckhps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - - lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - - vunpcklps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) - vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - - vunpckhps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) - vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - - lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - - vunpcklps(ymm7, ymm5, ymm0) - vunpcklps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - - vunpckhps(ymm7, ymm5, ymm0) - vunpckhps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - - //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - - vunpcklps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) - vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - - vunpckhps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) - vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - - //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) + //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_ label(.SDONE) @@ -896,6 +726,8 @@ void bli_sgemm_haswell_asm_6x16 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } @@ -927,7 +759,9 @@ void bli_sgemm_haswell_asm_6x16 void bli_dgemm_haswell_asm_6x8 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -942,11 +776,13 @@ void bli_dgemm_haswell_asm_6x8 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_AMBI( d, 6, 8, true ); + begin_asm() vzeroall() // zero all xmm/ymm registers. @@ -957,36 +793,56 @@ void bli_dgemm_haswell_asm_6x8 //mov(%9, r15) // load address of b_next. add(imm(32*4), rbx) - // initialize loop by pre-loading + // initialize loop by pre-loading vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) - lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; - lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c - prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c - prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c - prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c - prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. + jz(.SCOLPREFETCH) // jump to column prefetch case + + lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; + lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.SPREFETCHDONE) + + label(.SCOLPREFETCH) + lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; + lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 7*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 7*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 7*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 7*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, r13, 1, 7*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 4, 7*8)) // prefetch c + 7*cs_c + label(.SPREFETCHDONE) mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. + // contains the k_left loop. label(.DLOOPKITER) // MAIN LOOP - // iteration 0 + // iteration 0 prefetch(0, mem(rax, 64*8)) vbroadcastsd(mem(rax, 0*8), ymm2) @@ -1013,7 +869,7 @@ void bli_dgemm_haswell_asm_6x8 vmovapd(mem(rbx, -2*32), ymm0) vmovapd(mem(rbx, -1*32), ymm1) - // iteration 1 + // iteration 1 prefetch(0, mem(rax, 72*8)) vbroadcastsd(mem(rax, 6*8), ymm2) @@ -1040,7 +896,7 @@ void bli_dgemm_haswell_asm_6x8 vmovapd(mem(rbx, 0*32), ymm0) vmovapd(mem(rbx, 1*32), ymm1) - // iteration 2 + // iteration 2 prefetch(0, mem(rax, 80*8)) vbroadcastsd(mem(rax, 12*8), ymm2) @@ -1067,7 +923,7 @@ void bli_dgemm_haswell_asm_6x8 vmovapd(mem(rbx, 2*32), ymm0) vmovapd(mem(rbx, 3*32), ymm1) - // iteration 3 + // iteration 3 vbroadcastsd(mem(rax, 18*8), ymm2) vbroadcastsd(mem(rax, 19*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) @@ -1109,7 +965,7 @@ void bli_dgemm_haswell_asm_6x8 mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. + // else, we prepare to enter k_left loop. label(.DLOOPKLEFT) // EDGE LOOP @@ -1188,428 +1044,232 @@ void bli_dgemm_haswell_asm_6x8 //lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c; - // now avoid loading C if beta == 0 + // now avoid loading C if beta == 0 vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx,32)) + add(rdi, rcx) - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.DROWSTORED) // jump to row storage case - - - cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. - jz(.DCOLSTORED) // jump to column storage case - - - - label(.DGENSTORED) - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm4, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm6, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm8, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm10, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm12, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm14, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - - - mov(rdx, rcx) // rcx = c + 4*cs_c - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm5, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm7, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm9, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm11, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm13, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm15, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - - - - jmp(.DDONE) // jump to end. - - - - label(.DROWSTORED) - - - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm5) - vmovupd(ymm5, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm7) - vmovupd(ymm7, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm8) - vmovupd(ymm8, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm9) - vmovupd(ymm9, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm10) - vmovupd(ymm10, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm11) - vmovupd(ymm11, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm12) - vmovupd(ymm12, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm13) - vmovupd(ymm13, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm14) - vmovupd(ymm14, mem(rcx)) - //add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm15) - vmovupd(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORED) - - - vunpcklpd(ymm6, ymm4, ymm0) - vunpckhpd(ymm6, ymm4, ymm1) - vunpcklpd(ymm10, ymm8, ymm2) - vunpckhpd(ymm10, ymm8, ymm3) - vinsertf128(imm(0x1), xmm2, ymm0, ymm4) - vinsertf128(imm(0x1), xmm3, ymm1, ymm6) - vperm2f128(imm(0x31), ymm2, ymm0, ymm8) - vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - - vbroadcastsd(mem(rbx), ymm3) - - vfmadd231pd(mem(rcx), ymm3, ymm4) - vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) - vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) - vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm10) - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm6, mem(rcx, rsi, 1)) - vmovupd(ymm8, mem(rcx, rsi, 2)) - vmovupd(ymm10, mem(rcx, r13, 1)) - - lea(mem(rcx, rsi, 4), rcx) - - vunpcklpd(ymm14, ymm12, ymm0) - vunpckhpd(ymm14, ymm12, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vfmadd231pd(mem(r14), xmm3, xmm0) - vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) - vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) - vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) - - lea(mem(r14, rsi, 4), r14) - - - vunpcklpd(ymm7, ymm5, ymm0) - vunpckhpd(ymm7, ymm5, ymm1) - vunpcklpd(ymm11, ymm9, ymm2) - vunpckhpd(ymm11, ymm9, ymm3) - vinsertf128(imm(0x1), xmm2, ymm0, ymm5) - vinsertf128(imm(0x1), xmm3, ymm1, ymm7) - vperm2f128(imm(0x31), ymm2, ymm0, ymm9) - vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - - vbroadcastsd(mem(rbx), ymm3) - - vfmadd231pd(mem(rcx), ymm3, ymm5) - vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) - vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) - vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm11) - vmovupd(ymm5, mem(rcx)) - vmovupd(ymm7, mem(rcx, rsi, 1)) - vmovupd(ymm9, mem(rcx, rsi, 2)) - vmovupd(ymm11, mem(rcx, r13, 1)) - - //lea(mem(rcx, rsi, 4), rcx) - - vunpcklpd(ymm15, ymm13, ymm0) - vunpckhpd(ymm15, ymm13, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vfmadd231pd(mem(r14), xmm3, xmm0) - vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) - vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) - vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) - - //lea(mem(r14, rsi, 4), r14) - - - - jmp(.DDONE) // jump to end. + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx,32)) + //add(rdi, rcx) + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORED) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, r13, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(r14), xmm3, xmm0) + vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) + vmovupd(xmm0, mem(r14)) + vmovupd(xmm1, mem(r14, rsi, 1)) + vmovupd(xmm2, mem(r14, rsi, 2)) + vmovupd(xmm4, mem(r14, r13, 1)) + + lea(mem(r14, rsi, 4), r14) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, r13, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(r14), xmm3, xmm0) + vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) + vmovupd(xmm0, mem(r14)) + vmovupd(xmm1, mem(r14, rsi, 1)) + vmovupd(xmm2, mem(r14, rsi, 2)) + vmovupd(xmm4, mem(r14, r13, 1)) + + //lea(mem(r14, rsi, 4), r14) + + jmp(.DDONE) // jump to end. label(.DBETAZERO) - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.DROWSTORBZ) // jump to row storage case - - cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. - jz(.DCOLSTORBZ) // jump to column storage case - - - - label(.DGENSTORBZ) - - - vmovapd(ymm4, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm6, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm8, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm10, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm12, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm14, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - - - mov(rdx, rcx) // rcx = c + 4*cs_c - - - vmovapd(ymm5, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case - vmovapd(ymm7, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx,32)) + add(rdi, rcx) + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx,32)) + add(rdi, rcx) - vmovapd(ymm9, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx,32)) + add(rdi, rcx) - vmovapd(ymm11, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx,32)) + add(rdi, rcx) - vmovapd(ymm13, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx,32)) + add(rdi, rcx) - vmovapd(ymm15, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ + vmovupd(ymm14, mem(rcx)) + vmovupd(ymm15, mem(rcx,32)) + //add(rdi, rcx) + jmp(.DDONE) // jump to end. - jmp(.DDONE) // jump to end. + label(.DCOLSTORBZ) + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, r13, 1)) - label(.DROWSTORBZ) + lea(mem(rcx, rsi, 4), rcx) + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) - vmovupd(ymm4, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm5, mem(rdx)) - add(rdi, rdx) + vmovupd(xmm0, mem(r14)) + vmovupd(xmm1, mem(r14, rsi, 1)) + vmovupd(xmm2, mem(r14, rsi, 2)) + vmovupd(xmm4, mem(r14, r13, 1)) - vmovupd(ymm6, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm7, mem(rdx)) - add(rdi, rdx) + lea(mem(r14, rsi, 4), r14) - vmovupd(ymm8, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm9, mem(rdx)) - add(rdi, rdx) + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, r13, 1)) - vmovupd(ymm10, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm11, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm12, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm13, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm14, mem(rcx)) - //add(rdi, rcx) - vmovupd(ymm15, mem(rdx)) - //add(rdi, rdx) - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORBZ) - - - vunpcklpd(ymm6, ymm4, ymm0) - vunpckhpd(ymm6, ymm4, ymm1) - vunpcklpd(ymm10, ymm8, ymm2) - vunpckhpd(ymm10, ymm8, ymm3) - vinsertf128(imm(0x1), xmm2, ymm0, ymm4) - vinsertf128(imm(0x1), xmm3, ymm1, ymm6) - vperm2f128(imm(0x31), ymm2, ymm0, ymm8) - vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm6, mem(rcx, rsi, 1)) - vmovupd(ymm8, mem(rcx, rsi, 2)) - vmovupd(ymm10, mem(rcx, r13, 1)) - - lea(mem(rcx, rsi, 4), rcx) - - vunpcklpd(ymm14, ymm12, ymm0) - vunpckhpd(ymm14, ymm12, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) - - lea(mem(r14, rsi, 4), r14) - - - vunpcklpd(ymm7, ymm5, ymm0) - vunpckhpd(ymm7, ymm5, ymm1) - vunpcklpd(ymm11, ymm9, ymm2) - vunpckhpd(ymm11, ymm9, ymm3) - vinsertf128(imm(0x1), xmm2, ymm0, ymm5) - vinsertf128(imm(0x1), xmm3, ymm1, ymm7) - vperm2f128(imm(0x31), ymm2, ymm0, ymm9) - vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - - vmovupd(ymm5, mem(rcx)) - vmovupd(ymm7, mem(rcx, rsi, 1)) - vmovupd(ymm9, mem(rcx, rsi, 2)) - vmovupd(ymm11, mem(rcx, r13, 1)) - - //lea(mem(rcx, rsi, 4), rcx) - - vunpcklpd(ymm15, ymm13, ymm0) - vunpckhpd(ymm15, ymm13, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) - - //lea(mem(r14, rsi, 4), r14) + //lea(mem(rcx, rsi, 4), rcx) + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + vmovupd(xmm0, mem(r14)) + vmovupd(xmm1, mem(r14, rsi, 1)) + vmovupd(xmm2, mem(r14, rsi, 2)) + vmovupd(xmm4, mem(r14, r13, 1)) + //lea(mem(r14, rsi, 4), r14) label(.DDONE) @@ -1641,45 +1301,26 @@ void bli_dgemm_haswell_asm_6x8 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } -// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. -// outputs to ymm0 -#define CGEMM_INPUT_SCALE_GS_BETA_NZ \ - vmovlpd(mem(rcx), xmm0, xmm0) \ - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) \ - vmovlpd(mem(rcx, rsi, 2), xmm3, xmm3) \ - vmovhpd(mem(rcx, r13, 1), xmm3, xmm3) \ - vinsertf128(imm(1), xmm3, ymm0, ymm0) \ - vpermilps(imm(0xb1), ymm0, ymm3) \ - vmulps(ymm1, ymm0, ymm0) \ - vmulps(ymm2, ymm3, ymm3) \ - vaddsubps(ymm3, ymm0, ymm0) -// assumes values to output are in ymm0 -#define CGEMM_OUTPUT_GS \ - vextractf128(imm(1), ymm0, xmm3) \ - vmovlpd(xmm0, mem(rcx)) \ - vmovhpd(xmm0, mem(rcx, rsi, 1)) \ - vmovlpd(xmm3, mem(rcx, rsi, 2)) \ - vmovhpd(xmm3, mem(rcx, r13, 1)) - -#define CGEMM_INPUT_SCALE_RS_BETA_NZ \ - vmovups(mem(rcx), ymm0) \ +#define CGEMM_INPUT_SCALE_RS_BETA_NZ(where) \ + vmovups(where, ymm0) \ vpermilps(imm(0xb1), ymm0, ymm3) \ vmulps(ymm1, ymm0, ymm0) \ vmulps(ymm2, ymm3, ymm3) \ vaddsubps(ymm3, ymm0, ymm0) -#define CGEMM_OUTPUT_RS \ - vmovups(ymm0, mem(rcx)) \ - void bli_cgemm_haswell_asm_3x8 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -1694,11 +1335,13 @@ void bli_cgemm_haswell_asm_3x8 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( c, 3, 8, true ); + begin_asm() vzeroall() // zero all xmm/ymm registers. @@ -1709,7 +1352,7 @@ void bli_cgemm_haswell_asm_3x8 //mov(%9, r15) // load address of b_next. add(imm(32*4), rbx) - // initialize loop by pre-loading + // initialize loop by pre-loading vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) @@ -1730,13 +1373,13 @@ void bli_cgemm_haswell_asm_3x8 mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.CCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. + // contains the k_left loop. label(.CLOOPKITER) // MAIN LOOP - // iteration 0 + // iteration 0 prefetch(0, mem(rax, 32*8)) vbroadcastss(mem(rax, 0*4), ymm2) @@ -1763,7 +1406,7 @@ void bli_cgemm_haswell_asm_3x8 vmovaps(mem(rbx, -2*32), ymm0) vmovaps(mem(rbx, -1*32), ymm1) - // iteration 1 + // iteration 1 vbroadcastss(mem(rax, 6*4), ymm2) vbroadcastss(mem(rax, 7*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) @@ -1788,7 +1431,7 @@ void bli_cgemm_haswell_asm_3x8 vmovaps(mem(rbx, 0*32), ymm0) vmovaps(mem(rbx, 1*32), ymm1) - // iteration 2 + // iteration 2 prefetch(0, mem(rax, 38*8)) vbroadcastss(mem(rax, 12*4), ymm2) @@ -1815,7 +1458,7 @@ void bli_cgemm_haswell_asm_3x8 vmovaps(mem(rbx, 2*32), ymm0) vmovaps(mem(rbx, 3*32), ymm1) - // iteration 3 + // iteration 3 vbroadcastss(mem(rax, 18*4), ymm2) vbroadcastss(mem(rax, 19*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) @@ -1857,7 +1500,7 @@ void bli_cgemm_haswell_asm_3x8 mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.CPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. + // else, we prepare to enter k_left loop. label(.CLOOPKLEFT) // EDGE LOOP @@ -1900,8 +1543,8 @@ void bli_cgemm_haswell_asm_3x8 label(.CPOSTACCUM) - // permute even and odd elements - // of ymm6/7, ymm10/11, ymm/14/15 + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 vpermilps(imm(0xb1), ymm6, ymm6) vpermilps(imm(0xb1), ymm7, ymm7) vpermilps(imm(0xb1), ymm10, ymm10) @@ -1910,7 +1553,7 @@ void bli_cgemm_haswell_asm_3x8 vpermilps(imm(0xb1), ymm15, ymm15) - // subtract/add even/odd elements + // subtract/add even/odd elements vaddsubps(ymm6, ymm4, ymm4) vaddsubps(ymm7, ymm5, ymm5) @@ -1969,16 +1612,7 @@ void bli_cgemm_haswell_asm_3x8 vbroadcastss(mem(rbx, 4), ymm2) // load beta_i and duplicate - - - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(scomplex) - lea(mem(, rsi, 4), rdx) // rdx = 4*cs_c; - lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; - - - - // now avoid loading C if beta == 0 + // now avoid loading C if beta == 0 vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm1) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -1987,162 +1621,49 @@ void bli_cgemm_haswell_asm_3x8 and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 1, jump to beta == 0 case + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx)) + vaddps(ymm4, ymm0, ymm0) + vmovups(ymm0, mem(rcx)) - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.CROWSTORED) // jump to row storage case - - - - label(.CGENSTORED) - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm4, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*cs_c; - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm5, ymm0, ymm0) - CGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*rs_c + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx,32)) + vaddps(ymm5, ymm0, ymm0) + vmovups(ymm0, mem(rcx,32)) - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm8, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*cs_c; + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11)) + vaddps(ymm8, ymm0, ymm0) + vmovups(ymm0, mem(r11)) - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm9, ymm0, ymm0) - CGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*rs_c + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11,32)) + vaddps(ymm9, ymm0, ymm0) + vmovups(ymm0, mem(r11,32)) - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm12, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*cs_c; + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12)) + vaddps(ymm12, ymm0, ymm0) + vmovups(ymm0, mem(r12)) - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm13, ymm0, ymm0) - CGEMM_OUTPUT_GS - - - - jmp(.CDONE) // jump to end. - - - - label(.CROWSTORED) - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm4, ymm0, ymm0) - CGEMM_OUTPUT_RS - add(rdx, rcx) // c += 4*cs_c; - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm5, ymm0, ymm0) - CGEMM_OUTPUT_RS - mov(r11, rcx) // rcx = c + 1*rs_c - - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm8, ymm0, ymm0) - CGEMM_OUTPUT_RS - add(rdx, rcx) // c += 4*cs_c; - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm9, ymm0, ymm0) - CGEMM_OUTPUT_RS - mov(r12, rcx) // rcx = c + 2*rs_c - - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm12, ymm0, ymm0) - CGEMM_OUTPUT_RS - add(rdx, rcx) // c += 4*cs_c; - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm13, ymm0, ymm0) - CGEMM_OUTPUT_RS - - - - jmp(.CDONE) // jump to end. + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12,32)) + vaddps(ymm13, ymm0, ymm0) + vmovups(ymm0, mem(r12,32)) + jmp(.CDONE) // jump to end. label(.CBETAZERO) - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.CROWSTORBZ) // jump to row storage case - - - - label(.CGENSTORBZ) - - - vmovaps(ymm4, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovaps(ymm5, ymm0) - CGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*rs_c - - - - vmovaps(ymm8, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovaps(ymm9, ymm0) - CGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*rs_c - - - - vmovaps(ymm12, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovaps(ymm13, ymm0) - CGEMM_OUTPUT_GS - - - - jmp(.CDONE) // jump to end. - - - - label(.CROWSTORBZ) - - - vmovups(ymm4, mem(rcx)) - vmovups(ymm5, mem(rcx, rdx, 1)) - - vmovups(ymm8, mem(r11)) - vmovups(ymm9, mem(r11, rdx, 1)) - - vmovups(ymm12, mem(r12)) - vmovups(ymm13, mem(r12, rdx, 1)) - + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx,32)) + vmovups(ymm8, mem(r11)) + vmovups(ymm9, mem(r11,32)) + vmovups(ymm12, mem(r12)) + vmovups(ymm13, mem(r12,32)) label(.CDONE) @@ -2174,41 +1695,25 @@ void bli_cgemm_haswell_asm_3x8 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( c ); } -// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. -// outputs to ymm0 -#define ZGEMM_INPUT_SCALE_GS_BETA_NZ \ - vmovupd(mem(rcx), xmm0) \ - vmovupd(mem(rcx, rsi, 1), xmm3) \ - vinsertf128(imm(1), xmm3, ymm0, ymm0) \ +#define ZGEMM_INPUT_SCALE_RS_BETA_NZ(where) \ + vmovupd(where, ymm0) \ vpermilpd(imm(0x5), ymm0, ymm3) \ vmulpd(ymm1, ymm0, ymm0) \ vmulpd(ymm2, ymm3, ymm3) \ vaddsubpd(ymm3, ymm0, ymm0) -// assumes values to output are in ymm0 -#define ZGEMM_OUTPUT_GS \ - vextractf128(imm(1), ymm0, xmm3) \ - vmovupd(xmm0, mem(rcx)) \ - vmovupd(xmm3, mem(rcx, rsi, 1)) \ - -#define ZGEMM_INPUT_SCALE_RS_BETA_NZ \ - vmovupd(mem(rcx), ymm0) \ - vpermilpd(imm(0x5), ymm0, ymm3) \ - vmulpd(ymm1, ymm0, ymm0) \ - vmulpd(ymm2, ymm3, ymm3) \ - vaddsubpd(ymm3, ymm0, ymm0) - -#define ZGEMM_OUTPUT_RS \ - vmovupd(ymm0, mem(rcx)) \ - void bli_zgemm_haswell_asm_3x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -2223,11 +1728,13 @@ void bli_zgemm_haswell_asm_3x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( z, 3, 4, true ); + begin_asm() vzeroall() // zero all xmm/ymm registers. @@ -2238,7 +1745,7 @@ void bli_zgemm_haswell_asm_3x4 //mov(%9, r15) // load address of b_next. add(imm(32*4), rbx) - // initialize loop by pre-loading + // initialize loop by pre-loading vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) @@ -2260,13 +1767,13 @@ void bli_zgemm_haswell_asm_3x4 mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.ZCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. + // contains the k_left loop. label(.ZLOOPKITER) // MAIN LOOP - // iteration 0 + // iteration 0 prefetch(0, mem(rax, 32*16)) vbroadcastsd(mem(rax, 0*8), ymm2) @@ -2293,7 +1800,7 @@ void bli_zgemm_haswell_asm_3x4 vmovapd(mem(rbx, -2*32), ymm0) vmovapd(mem(rbx, -1*32), ymm1) - // iteration 1 + // iteration 1 prefetch(0, mem(rax, 36*16)) vbroadcastsd(mem(rax, 6*8), ymm2) @@ -2320,7 +1827,7 @@ void bli_zgemm_haswell_asm_3x4 vmovapd(mem(rbx, 0*32), ymm0) vmovapd(mem(rbx, 1*32), ymm1) - // iteration 2 + // iteration 2 prefetch(0, mem(rax, 40*16)) vbroadcastsd(mem(rax, 12*8), ymm2) @@ -2347,7 +1854,7 @@ void bli_zgemm_haswell_asm_3x4 vmovapd(mem(rbx, 2*32), ymm0) vmovapd(mem(rbx, 3*32), ymm1) - // iteration 3 + // iteration 3 vbroadcastsd(mem(rax, 18*8), ymm2) vbroadcastsd(mem(rax, 19*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) @@ -2389,7 +1896,7 @@ void bli_zgemm_haswell_asm_3x4 mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. + // else, we prepare to enter k_left loop. label(.ZLOOPKLEFT) // EDGE LOOP @@ -2431,8 +1938,8 @@ void bli_zgemm_haswell_asm_3x4 label(.ZPOSTACCUM) - // permute even and odd elements - // of ymm6/7, ymm10/11, ymm/14/15 + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 vpermilpd(imm(0x5), ymm6, ymm6) vpermilpd(imm(0x5), ymm7, ymm7) vpermilpd(imm(0x5), ymm10, ymm10) @@ -2441,7 +1948,7 @@ void bli_zgemm_haswell_asm_3x4 vpermilpd(imm(0x5), ymm15, ymm15) - // subtract/add even/odd elements + // subtract/add even/odd elements vaddsubpd(ymm6, ymm4, ymm4) vaddsubpd(ymm7, ymm5, ymm5) @@ -2501,15 +2008,7 @@ void bli_zgemm_haswell_asm_3x4 - - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dcomplex) - lea(mem(, rsi, 2), rsi) - lea(mem(, rsi, 2), rdx) // rdx = 2*cs_c; - - - - // now avoid loading C if beta == 0 + // now avoid loading C if beta == 0 vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -2518,162 +2017,49 @@ void bli_zgemm_haswell_asm_3x4 and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case - - cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16. - jz(.ZROWSTORED) // jump to row storage case - - - - label(.ZGENSTORED) - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm5, ymm0, ymm0) - ZGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*rs_c - + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx)) + vaddpd(ymm4, ymm0, ymm0) + vmovupd(ymm0, mem(rcx)) - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx,32)) + vaddpd(ymm5, ymm0, ymm0) + vmovupd(ymm0, mem(rcx,32)) - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm9, ymm0, ymm0) - ZGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*rs_c + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11)) + vaddpd(ymm8, ymm0, ymm0) + vmovupd(ymm0, mem(r11)) - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11,32)) + vaddpd(ymm9, ymm0, ymm0) + vmovupd(ymm0, mem(r11,32)) - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm13, ymm0, ymm0) - ZGEMM_OUTPUT_GS + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12)) + vaddpd(ymm12, ymm0, ymm0) + vmovupd(ymm0, mem(r12)) - jmp(.ZDONE) // jump to end. - - - - label(.ZROWSTORED) - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_RS - add(rdx, rcx) // c += 2*cs_c; - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm5, ymm0, ymm0) - ZGEMM_OUTPUT_RS - mov(r11, rcx) // rcx = c + 1*rs_c - - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_RS - add(rdx, rcx) // c += 2*cs_c; - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm9, ymm0, ymm0) - ZGEMM_OUTPUT_RS - mov(r12, rcx) // rcx = c + 2*rs_c - - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_RS - add(rdx, rcx) // c += 2*cs_c; - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm13, ymm0, ymm0) - ZGEMM_OUTPUT_RS - - - - jmp(.ZDONE) // jump to end. - + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12,32)) + vaddpd(ymm13, ymm0, ymm0) + vmovupd(ymm0, mem(r12,32)) + jmp(.ZDONE) // jump to end. label(.ZBETAZERO) - cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16. - jz(.ZROWSTORBZ) // jump to row storage case - - - - label(.ZGENSTORBZ) - - - vmovapd(ymm4, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovapd(ymm5, ymm0) - ZGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*rs_c - - - - vmovapd(ymm8, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovapd(ymm9, ymm0) - ZGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*rs_c - - - - vmovapd(ymm12, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovapd(ymm13, ymm0) - ZGEMM_OUTPUT_GS - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZROWSTORBZ) - - - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rdx, 1)) - - vmovupd(ymm8, mem(r11)) - vmovupd(ymm9, mem(r11, rdx, 1)) - - vmovupd(ymm12, mem(r12)) - vmovupd(ymm13, mem(r12, rdx, 1)) - + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx,32)) + vmovupd(ymm8, mem(r11)) + vmovupd(ymm9, mem(r11,32)) + vmovupd(ymm12, mem(r12)) + vmovupd(ymm13, mem(r12,32)) label(.ZDONE) @@ -2705,6 +2091,8 @@ void bli_zgemm_haswell_asm_3x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c b/kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c index b074da965c..a3a8b0b09f 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c @@ -78,7 +78,9 @@ void bli_sgemm_haswell_asm_16x6 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -93,29 +95,31 @@ void bli_sgemm_haswell_asm_16x6 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( s, 16, 6, true ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rax) // initialize loop by pre-loading vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float) - + lea(mem(rdi, rdi, 2), r13) // r13 = 3*cs_c; lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c; prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c @@ -124,46 +128,46 @@ void bli_sgemm_haswell_asm_16x6 prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*cs_c prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*cs_c - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.SLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 128*4)) - + vbroadcastss(mem(rbx, 0*4), ymm2) vbroadcastss(mem(rbx, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 2*4), ymm2) vbroadcastss(mem(rbx, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 4*4), ymm2) vbroadcastss(mem(rbx, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, -2*32), ymm0) vmovaps(mem(rax, -1*32), ymm1) - + // iteration 1 vbroadcastss(mem(rbx, 6*4), ymm2) vbroadcastss(mem(rbx, 7*4), ymm3) @@ -171,51 +175,51 @@ void bli_sgemm_haswell_asm_16x6 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 8*4), ymm2) vbroadcastss(mem(rbx, 9*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 10*4), ymm2) vbroadcastss(mem(rbx, 11*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, 0*32), ymm0) vmovaps(mem(rax, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 152*4)) - + vbroadcastss(mem(rbx, 12*4), ymm2) vbroadcastss(mem(rbx, 13*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 14*4), ymm2) vbroadcastss(mem(rbx, 15*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 16*4), ymm2) vbroadcastss(mem(rbx, 17*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, 2*32), ymm0) vmovaps(mem(rax, 3*32), ymm1) - + // iteration 3 vbroadcastss(mem(rbx, 18*4), ymm2) vbroadcastss(mem(rbx, 19*4), ymm3) @@ -223,91 +227,91 @@ void bli_sgemm_haswell_asm_16x6 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 20*4), ymm2) vbroadcastss(mem(rbx, 21*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 22*4), ymm2) vbroadcastss(mem(rbx, 23*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(4*16*4), rax) // a += 4*16 (unroll x mr) add(imm(4*6*4), rbx) // b += 4*6 (unroll x nr) - + vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.SLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.SLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 128*4)) - + vbroadcastss(mem(rbx, 0*4), ymm2) vbroadcastss(mem(rbx, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 2*4), ymm2) vbroadcastss(mem(rbx, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 4*4), ymm2) vbroadcastss(mem(rbx, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(1*16*4), rax) // a += 1*16 (unroll x mr) add(imm(1*6*4), rbx) // b += 1*6 (unroll x nr) - + vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.SLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.SPOSTACCUM) - - - - + + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rax), ymm0) // load alpha and duplicate vbroadcastss(mem(rbx), ymm3) // load beta and duplicate - + vmulps(ymm0, ymm4, ymm4) // scale by alpha vmulps(ymm0, ymm5, ymm5) vmulps(ymm0, ymm6, ymm6) @@ -320,315 +324,107 @@ void bli_sgemm_haswell_asm_16x6 vmulps(ymm0, ymm13, ymm13) vmulps(ymm0, ymm14, ymm14) vmulps(ymm0, ymm15, ymm15) - - - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float) - - lea(mem(rcx, rsi, 8), rdx) // load address of c + 8*rs_c; - - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - lea(mem(rsi, rsi, 4), r15) // r15 = 5*rs_c; - lea(mem(r13, rsi, 4), r10) // r10 = 7*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm3) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(4), rsi) // set ZF if (4*rs_c) == 4. - jz(.SCOLSTORED) // jump to column storage case - - - - label(.SGENSTORED) - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm4, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm6, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm8, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm10, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm12, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm14, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - mov(rdx, rcx) // rcx = c + 8*rs_c - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm5, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm7, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm9, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm11, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm13, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm15, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - - vfmadd231ps(mem(rcx), ymm3, ymm4) - vmovups(ymm4, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm5) - vmovups(ymm5, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm6) - vmovups(ymm6, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm7) - vmovups(ymm7, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm8) - vmovups(ymm8, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm9) - vmovups(ymm9, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm10) - vmovups(ymm10, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm11) - vmovups(ymm11, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm12) - vmovups(ymm12, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm13) - vmovups(ymm13, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm14) - vmovups(ymm14, mem(rcx)) - //add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm15) - vmovups(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm5) + vmovups(ymm5, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm7) + vmovups(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm9) + vmovups(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm11) + vmovups(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm12) + vmovups(ymm12, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm13) + vmovups(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm14) + vmovups(ymm14, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm15) + vmovups(ymm15, mem(rcx,32)) + //add(rdi, rcx) + jmp(.SDONE) // jump to end. - - - + label(.SBETAZERO) - - cmp(imm(4), rsi) // set ZF if (4*rs_c) == 4. - jz(.SCOLSTORBZ) // jump to column storage case - - - - label(.SGENSTORBZ) - - - vmovaps(ymm4, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm6, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm8, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm10, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm12, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm14, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - mov(rdx, rcx) // rcx = c + 8*rs_c - - - vmovaps(ymm5, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm7, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm9, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm11, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm13, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm15, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORBZ) - - - vmovups(ymm4, mem(rcx)) - add(rdi, rcx) - vmovups(ymm5, mem(rdx)) - add(rdi, rdx) - - vmovups(ymm6, mem(rcx)) - add(rdi, rcx) - vmovups(ymm7, mem(rdx)) - add(rdi, rdx) - - - vmovups(ymm8, mem(rcx)) - add(rdi, rcx) - vmovups(ymm9, mem(rdx)) - add(rdi, rdx) - - - vmovups(ymm10, mem(rcx)) - add(rdi, rcx) - vmovups(ymm11, mem(rdx)) - add(rdi, rdx) - - - vmovups(ymm12, mem(rcx)) - add(rdi, rcx) - vmovups(ymm13, mem(rdx)) - add(rdi, rdx) - - - vmovups(ymm14, mem(rcx)) - //add(rdi, rcx) - vmovups(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - - - - + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx,32)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx)) + vmovups(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vmovups(ymm12, mem(rcx)) + vmovups(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vmovups(ymm14, mem(rcx)) + vmovups(ymm15, mem(rcx,32)) + //add(rdi, rcx) + label(.SDONE) - - + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -636,6 +432,8 @@ void bli_sgemm_haswell_asm_16x6 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } #define DGEMM_INPUT_GS_BETA_NZ \ @@ -664,7 +462,9 @@ void bli_sgemm_haswell_asm_16x6 void bli_dgemm_haswell_asm_8x6 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -679,29 +479,31 @@ void bli_dgemm_haswell_asm_8x6 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( d, 8, 6, false ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rax) // initialize loop by pre-loading vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) - + lea(mem(rdi, rdi, 2), r13) // r13 = 3*cs_c; lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c; prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c @@ -710,46 +512,46 @@ void bli_dgemm_haswell_asm_8x6 prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*cs_c prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*cs_c - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 64*8)) - + vbroadcastsd(mem(rbx, 0*8), ymm2) vbroadcastsd(mem(rbx, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 2*8), ymm2) vbroadcastsd(mem(rbx, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 4*8), ymm2) vbroadcastsd(mem(rbx, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, -2*32), ymm0) vmovapd(mem(rax, -1*32), ymm1) - + // iteration 1 vbroadcastsd(mem(rbx, 6*8), ymm2) vbroadcastsd(mem(rbx, 7*8), ymm3) @@ -757,51 +559,51 @@ void bli_dgemm_haswell_asm_8x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 8*8), ymm2) vbroadcastsd(mem(rbx, 9*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 10*8), ymm2) vbroadcastsd(mem(rbx, 11*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, 0*32), ymm0) vmovapd(mem(rax, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 76*8)) - + vbroadcastsd(mem(rbx, 12*8), ymm2) vbroadcastsd(mem(rbx, 13*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 14*8), ymm2) vbroadcastsd(mem(rbx, 15*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 16*8), ymm2) vbroadcastsd(mem(rbx, 17*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, 2*32), ymm0) vmovapd(mem(rax, 3*32), ymm1) - + // iteration 3 vbroadcastsd(mem(rbx, 18*8), ymm2) vbroadcastsd(mem(rbx, 19*8), ymm3) @@ -809,91 +611,91 @@ void bli_dgemm_haswell_asm_8x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 20*8), ymm2) vbroadcastsd(mem(rbx, 21*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 22*8), ymm2) vbroadcastsd(mem(rbx, 23*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(4*8*8), rax) // a += 4*8 (unroll x mr) add(imm(4*6*8), rbx) // b += 4*6 (unroll x nr) - + vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 64*8)) - + vbroadcastsd(mem(rbx, 0*8), ymm2) vbroadcastsd(mem(rbx, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 2*8), ymm2) vbroadcastsd(mem(rbx, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 4*8), ymm2) vbroadcastsd(mem(rbx, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(1*8*8), rax) // a += 1*8 (unroll x mr) add(imm(1*6*8), rbx) // b += 1*6 (unroll x nr) - + vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - - - + + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) @@ -906,314 +708,107 @@ void bli_dgemm_haswell_asm_8x6 vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm0, ymm14, ymm14) vmulpd(ymm0, ymm15, ymm15) - - - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) - - lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - //lea(mem(rsi, rsi, 4), r15) // r15 = 5*rs_c; - //lea(mem(r13, rsi, 4), r10) // r10 = 7*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - jz(.DCOLSTORED) // jump to column storage case - - - - label(.DGENSTORED) - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm4, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm6, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm8, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm10, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm12, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm14, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - mov(rdx, rcx) // rcx = c + 4*rs_c - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm5, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm7, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm9, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm11, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm13, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm15, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORED) - - - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm5) - vmovupd(ymm5, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm7) - vmovupd(ymm7, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm8) - vmovupd(ymm8, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm9) - vmovupd(ymm9, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm10) - vmovupd(ymm10, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm11) - vmovupd(ymm11, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm12) - vmovupd(ymm12, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm13) - vmovupd(ymm13, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm14) - vmovupd(ymm14, mem(rcx)) - //add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm15) - vmovupd(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - jmp(.DDONE) // jump to end. - - - + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx,32)) + //add(rdi, rcx) + + jmp(.DDONE) // jump to end. + label(.DBETAZERO) - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - jz(.DCOLSTORBZ) // jump to column storage case - - - - label(.DGENSTORBZ) - - - vmovapd(ymm4, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm6, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm8, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm10, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm12, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm14, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - mov(rdx, rcx) // rcx = c + 4*rs_c - - - vmovapd(ymm5, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm7, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm9, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm11, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm13, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm15, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORBZ) - - - vmovupd(ymm4, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm5, mem(rdx)) - add(rdi, rdx) - - vmovupd(ymm6, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm7, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm8, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm9, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm10, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm11, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm12, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm13, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm14, mem(rcx)) - //add(rdi, rcx) - vmovupd(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - - - - + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx,32)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx)) + vmovupd(ymm15, mem(rcx,32)) + //add(rdi, rcx) + label(.DDONE) - - - end_asm( + + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1221,45 +816,25 @@ void bli_dgemm_haswell_asm_8x6 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } -// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. -// outputs to ymm0 -#define CGEMM_INPUT_SCALE_GS_BETA_NZ \ - vmovlpd(mem(rcx), xmm0, xmm0) \ - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) \ - vmovlpd(mem(rcx, rsi, 2), xmm3, xmm3) \ - vmovhpd(mem(rcx, r13, 1), xmm3, xmm3) \ - vinsertf128(imm(1), xmm3, ymm0, ymm0) \ +#define CGEMM_INPUT_SCALE_CS_BETA_NZ(where) \ + vmovups(where, ymm0) \ vpermilps(imm(0xb1), ymm0, ymm3) \ vmulps(ymm1, ymm0, ymm0) \ vmulps(ymm2, ymm3, ymm3) \ vaddsubps(ymm3, ymm0, ymm0) -// assumes values to output are in ymm0 -#define CGEMM_OUTPUT_GS \ - vextractf128(imm(1), ymm0, xmm3) \ - vmovlpd(xmm0, mem(rcx)) \ - vmovhpd(xmm0, mem(rcx, rsi, 1)) \ - vmovlpd(xmm3, mem(rcx, rsi, 2)) \ - vmovhpd(xmm3, mem(rcx, r13, 1)) - -#define CGEMM_INPUT_SCALE_CS_BETA_NZ \ - vmovups(mem(rcx), ymm0) \ - vpermilps(imm(0xb1), ymm0, ymm3) \ - vmulps(ymm1, ymm0, ymm0) \ - vmulps(ymm2, ymm3, ymm3) \ - vaddsubps(ymm3, ymm0, ymm0) - -#define CGEMM_OUTPUT_CS \ - vmovups(ymm0, mem(rcx)) \ - void bli_cgemm_haswell_asm_8x3 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -1274,75 +849,77 @@ void bli_cgemm_haswell_asm_8x3 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( c, 8, 3, false ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rax) // initialize loop by pre-loading vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(scomplex) - + lea(mem(rcx, rdi, 1), r11) // r11 = c + 1*cs_c; lea(mem(rcx, rdi, 2), r12) // r12 = c + 2*cs_c; - + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c prefetch(0, mem(r11, 7*8)) // prefetch c + 1*cs_c prefetch(0, mem(r12, 7*8)) // prefetch c + 2*cs_c - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.CCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.CLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 32*8)) - + vbroadcastss(mem(rbx, 0*4), ymm2) vbroadcastss(mem(rbx, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 2*4), ymm2) vbroadcastss(mem(rbx, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 4*4), ymm2) vbroadcastss(mem(rbx, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, -2*32), ymm0) vmovaps(mem(rax, -1*32), ymm1) - + // iteration 1 vbroadcastss(mem(rbx, 6*4), ymm2) vbroadcastss(mem(rbx, 7*4), ymm3) @@ -1350,51 +927,51 @@ void bli_cgemm_haswell_asm_8x3 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 8*4), ymm2) vbroadcastss(mem(rbx, 9*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 10*4), ymm2) vbroadcastss(mem(rbx, 11*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, 0*32), ymm0) vmovaps(mem(rax, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 38*8)) - + vbroadcastss(mem(rbx, 12*4), ymm2) vbroadcastss(mem(rbx, 13*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 14*4), ymm2) vbroadcastss(mem(rbx, 15*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 16*4), ymm2) vbroadcastss(mem(rbx, 17*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, 2*32), ymm0) vmovaps(mem(rax, 3*32), ymm1) - + // iteration 3 vbroadcastss(mem(rbx, 18*4), ymm2) vbroadcastss(mem(rbx, 19*4), ymm3) @@ -1402,84 +979,84 @@ void bli_cgemm_haswell_asm_8x3 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 20*4), ymm2) vbroadcastss(mem(rbx, 21*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 22*4), ymm2) vbroadcastss(mem(rbx, 23*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(4*8*8), rax) // a += 4*8 (unroll x mr) add(imm(4*3*8), rbx) // b += 4*3 (unroll x nr) - + vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.CLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.CCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.CPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.CLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 32*8)) - + vbroadcastss(mem(rbx, 0*4), ymm2) vbroadcastss(mem(rbx, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 2*4), ymm2) vbroadcastss(mem(rbx, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 4*4), ymm2) vbroadcastss(mem(rbx, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(1*8*8), rax) // a += 1*8 (unroll x mr) add(imm(1*3*8), rbx) // b += 1*3 (unroll x nr) - + vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.CLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.CPOSTACCUM) - - + + // permute even and odd elements // of ymm6/7, ymm10/11, ymm/14/15 vpermilps(imm(0xb1), ymm6, ymm6) @@ -1488,76 +1065,68 @@ void bli_cgemm_haswell_asm_8x3 vpermilps(imm(0xb1), ymm11, ymm11) vpermilps(imm(0xb1), ymm14, ymm14) vpermilps(imm(0xb1), ymm15, ymm15) - - + + // subtract/add even/odd elements vaddsubps(ymm6, ymm4, ymm4) vaddsubps(ymm7, ymm5, ymm5) - + vaddsubps(ymm10, ymm8, ymm8) vaddsubps(ymm11, ymm9, ymm9) - + vaddsubps(ymm14, ymm12, ymm12) vaddsubps(ymm15, ymm13, ymm13) - - - - + + + + mov(var(alpha), rax) // load address of alpha vbroadcastss(mem(rax), ymm0) // load alpha_r and duplicate vbroadcastss(mem(rax, 4), ymm1) // load alpha_i and duplicate - - + + vpermilps(imm(0xb1), ymm4, ymm3) vmulps(ymm0, ymm4, ymm4) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm4, ymm4) - + vpermilps(imm(0xb1), ymm5, ymm3) vmulps(ymm0, ymm5, ymm5) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm5, ymm5) - - + + vpermilps(imm(0xb1), ymm8, ymm3) vmulps(ymm0, ymm8, ymm8) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm8, ymm8) - + vpermilps(imm(0xb1), ymm9, ymm3) vmulps(ymm0, ymm9, ymm9) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm9, ymm9) - - + + vpermilps(imm(0xb1), ymm12, ymm3) vmulps(ymm0, ymm12, ymm12) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm12, ymm12) - + vpermilps(imm(0xb1), ymm13, ymm3) vmulps(ymm0, ymm13, ymm13) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm13, ymm13) - - - - - + + + + + mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rbx), ymm1) // load beta_r and duplicate vbroadcastss(mem(rbx, 4), ymm2) // load beta_i and duplicate - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(scomplex) - lea(mem(, rsi, 4), rdx) // rdx = 4*rs_c; - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - - - + + + // now avoid loading C if beta == 0 vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm1) // set ZF if beta_r == 0. @@ -1566,186 +1135,71 @@ void bli_cgemm_haswell_asm_8x3 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.CCOLSTORED) // jump to row storage case - - - - label(.CGENSTORED) - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm4, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm5, ymm0, ymm0) - CGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm8, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm9, ymm0, ymm0) - CGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm12, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm13, ymm0, ymm0) - CGEMM_OUTPUT_GS - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORED) - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm4, ymm0, ymm0) - CGEMM_OUTPUT_CS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm5, ymm0, ymm0) - CGEMM_OUTPUT_CS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm8, ymm0, ymm0) - CGEMM_OUTPUT_CS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm9, ymm0, ymm0) - CGEMM_OUTPUT_CS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm12, ymm0, ymm0) - CGEMM_OUTPUT_CS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm13, ymm0, ymm0) - CGEMM_OUTPUT_CS - - - - jmp(.CDONE) // jump to end. - - - + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx)) + vaddps(ymm4, ymm0, ymm0) + vmovups(ymm0, mem(rcx)) + + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx,32)) + vaddps(ymm5, ymm0, ymm0) + vmovups(ymm0, mem(rcx,32)) + + + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11)) + vaddps(ymm8, ymm0, ymm0) + vmovups(ymm0, mem(r11)) + + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11,32)) + vaddps(ymm9, ymm0, ymm0) + vmovups(ymm0, mem(r11,32)) + + + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12)) + vaddps(ymm12, ymm0, ymm0) + vmovups(ymm0, mem(r12)) + + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12,32)) + vaddps(ymm13, ymm0, ymm0) + vmovups(ymm0, mem(r12,32)) + + jmp(.CDONE) // jump to end. + label(.CBETAZERO) - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - jz(.CCOLSTORBZ) // jump to row storage case - - - - label(.CGENSTORBZ) - - - vmovaps(ymm4, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - vmovaps(ymm5, ymm0) - CGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - vmovaps(ymm8, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - vmovaps(ymm9, ymm0) - CGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - vmovaps(ymm12, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - vmovaps(ymm13, ymm0) - CGEMM_OUTPUT_GS - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORBZ) - - - vmovups(ymm4, mem(rcx)) - vmovups(ymm5, mem(rcx, rdx, 1)) - - vmovups(ymm8, mem(r11)) - vmovups(ymm9, mem(r11, rdx, 1)) - - vmovups(ymm12, mem(r12)) - vmovups(ymm13, mem(r12, rdx, 1)) - - - - - - + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx,32)) + + vmovups(ymm8, mem(r11)) + vmovups(ymm9, mem(r11,32)) + + vmovups(ymm12, mem(r12)) + vmovups(ymm13, mem(r12,32)) + label(.CDONE) - - - end_asm( + + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1753,41 +1207,25 @@ void bli_cgemm_haswell_asm_8x3 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( c ); } -// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. -// outputs to ymm0 -#define ZGEMM_INPUT_SCALE_GS_BETA_NZ \ - vmovupd(mem(rcx), xmm0) \ - vmovupd(mem(rcx, rsi, 1), xmm3) \ - vinsertf128(imm(1), xmm3, ymm0, ymm0) \ +#define ZGEMM_INPUT_SCALE_CS_BETA_NZ(where) \ + vmovups(where, ymm0) \ vpermilpd(imm(0x5), ymm0, ymm3) \ vmulpd(ymm1, ymm0, ymm0) \ vmulpd(ymm2, ymm3, ymm3) \ vaddsubpd(ymm3, ymm0, ymm0) - -// assumes values to output are in ymm0 -#define ZGEMM_OUTPUT_GS \ - vextractf128(imm(1), ymm0, xmm3) \ - vmovupd(xmm0, mem(rcx)) \ - vmovupd(xmm3, mem(rcx, rsi, 1)) \ - -#define ZGEMM_INPUT_SCALE_CS_BETA_NZ \ - vmovups(mem(rcx), ymm0) \ - vpermilpd(imm(0x5), ymm0, ymm3) \ - vmulpd(ymm1, ymm0, ymm0) \ - vmulpd(ymm2, ymm3, ymm3) \ - vaddsubpd(ymm3, ymm0, ymm0) - -#define ZGEMM_OUTPUT_CS \ - vmovupd(ymm0, mem(rcx)) \ void bli_zgemm_haswell_asm_4x3 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -1802,76 +1240,78 @@ void bli_zgemm_haswell_asm_4x3 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( z, 4, 3, false ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rax) // initialize loop by pre-loading vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(dcomplex) lea(mem(, rdi, 2), rdi) - + lea(mem(rcx, rdi, 1), r11) // r11 = c + 1*cs_c; lea(mem(rcx, rdi, 2), r12) // r12 = c + 2*cs_c; - + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c prefetch(0, mem(r11, 7*8)) // prefetch c + 1*cs_c prefetch(0, mem(r12, 7*8)) // prefetch c + 2*cs_c - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.ZCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.ZLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 32*16)) - + vbroadcastsd(mem(rbx, 0*8), ymm2) vbroadcastsd(mem(rbx, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 2*8), ymm2) vbroadcastsd(mem(rbx, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 4*8), ymm2) vbroadcastsd(mem(rbx, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, -2*32), ymm0) vmovapd(mem(rax, -1*32), ymm1) - + // iteration 1 vbroadcastsd(mem(rbx, 6*8), ymm2) vbroadcastsd(mem(rbx, 7*8), ymm3) @@ -1879,51 +1319,51 @@ void bli_zgemm_haswell_asm_4x3 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 8*8), ymm2) vbroadcastsd(mem(rbx, 9*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 10*8), ymm2) vbroadcastsd(mem(rbx, 11*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, 0*32), ymm0) vmovapd(mem(rax, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 38*16)) - + vbroadcastsd(mem(rbx, 12*8), ymm2) vbroadcastsd(mem(rbx, 13*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 14*8), ymm2) vbroadcastsd(mem(rbx, 15*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 16*8), ymm2) vbroadcastsd(mem(rbx, 17*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, 2*32), ymm0) vmovapd(mem(rax, 3*32), ymm1) - + // iteration 3 vbroadcastsd(mem(rbx, 18*8), ymm2) vbroadcastsd(mem(rbx, 19*8), ymm3) @@ -1931,83 +1371,83 @@ void bli_zgemm_haswell_asm_4x3 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 20*8), ymm2) vbroadcastsd(mem(rbx, 21*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 22*8), ymm2) vbroadcastsd(mem(rbx, 23*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(4*4*16), rax) // a += 4*4 (unroll x mr) add(imm(4*3*16), rbx) // b += 4*3 (unroll x nr) - + vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.ZLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.ZCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.ZLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 32*16)) - + vbroadcastsd(mem(rbx, 0*8), ymm2) vbroadcastsd(mem(rbx, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 2*8), ymm2) vbroadcastsd(mem(rbx, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 4*8), ymm2) vbroadcastsd(mem(rbx, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(1*4*16), rax) // a += 1*4 (unroll x mr) add(imm(1*3*16), rbx) // b += 1*3 (unroll x nr) - + vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.ZLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.ZPOSTACCUM) - + // permute even and odd elements // of ymm6/7, ymm10/11, ymm/14/15 vpermilpd(imm(0x5), ymm6, ymm6) @@ -2016,76 +1456,69 @@ void bli_zgemm_haswell_asm_4x3 vpermilpd(imm(0x5), ymm11, ymm11) vpermilpd(imm(0x5), ymm14, ymm14) vpermilpd(imm(0x5), ymm15, ymm15) - - + + // subtract/add even/odd elements vaddsubpd(ymm6, ymm4, ymm4) vaddsubpd(ymm7, ymm5, ymm5) - + vaddsubpd(ymm10, ymm8, ymm8) vaddsubpd(ymm11, ymm9, ymm9) - + vaddsubpd(ymm14, ymm12, ymm12) vaddsubpd(ymm15, ymm13, ymm13) - - - - + + + + mov(var(alpha), rax) // load address of alpha vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate - - + + vpermilpd(imm(0x5), ymm4, ymm3) vmulpd(ymm0, ymm4, ymm4) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm4, ymm4) - + vpermilpd(imm(0x5), ymm5, ymm3) vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm5, ymm5) - - + + vpermilpd(imm(0x5), ymm8, ymm3) vmulpd(ymm0, ymm8, ymm8) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm9, ymm3) vmulpd(ymm0, ymm9, ymm9) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm9, ymm9) - - + + vpermilpd(imm(0x5), ymm12, ymm3) vmulpd(ymm0, ymm12, ymm12) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm12, ymm12) - + vpermilpd(imm(0x5), ymm13, ymm3) vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm13, ymm13) - - - - - + + + + + mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(dcomplex) - lea(mem(, rsi, 2), rsi) - lea(mem(, rsi, 2), rdx) // rdx = 2*rs_c; - - - + + + + // now avoid loading C if beta == 0 vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. @@ -2094,171 +1527,56 @@ void bli_zgemm_haswell_asm_4x3 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(16), rsi) // set ZF if (16*rs_c) == 16. - jz(.ZCOLSTORED) // jump to row storage case - - - - label(.ZGENSTORED) - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm5, ymm0, ymm0) - ZGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm9, ymm0, ymm0) - ZGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm13, ymm0, ymm0) - ZGEMM_OUTPUT_GS - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORED) - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_CS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm5, ymm0, ymm0) - ZGEMM_OUTPUT_CS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_CS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm9, ymm0, ymm0) - ZGEMM_OUTPUT_CS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_CS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm13, ymm0, ymm0) - ZGEMM_OUTPUT_CS - - - - jmp(.ZDONE) // jump to end. - - - + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx)) + vaddpd(ymm4, ymm0, ymm0) + vmovupd(ymm0, mem(rcx)) + + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx,32)) + vaddpd(ymm5, ymm0, ymm0) + vmovupd(ymm0, mem(rcx,32)) + + + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11)) + vaddpd(ymm8, ymm0, ymm0) + vmovupd(ymm0, mem(r11)) + + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11,32)) + vaddpd(ymm9, ymm0, ymm0) + vmovupd(ymm0, mem(r11,32)) + + + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12)) + vaddpd(ymm12, ymm0, ymm0) + vmovupd(ymm0, mem(r12)) + + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12,32)) + vaddpd(ymm13, ymm0, ymm0) + vmovupd(ymm0, mem(r12,32)) + + jmp(.ZDONE) // jump to end. + label(.ZBETAZERO) - - cmp(imm(16), rsi) // set ZF if (16*rs_c) == 16. - jz(.ZCOLSTORBZ) // jump to row storage case - - - - label(.ZGENSTORBZ) - - - vmovapd(ymm4, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - vmovapd(ymm5, ymm0) - ZGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - vmovapd(ymm8, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - vmovapd(ymm9, ymm0) - ZGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - vmovapd(ymm12, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - vmovapd(ymm13, ymm0) - ZGEMM_OUTPUT_GS - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORBZ) - - - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rdx, 1)) - - vmovupd(ymm8, mem(r11)) - vmovupd(ymm9, mem(r11, rdx, 1)) - - vmovupd(ymm12, mem(r12)) - vmovupd(ymm13, mem(r12, rdx, 1)) - - - - - - + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx,32)) + + vmovupd(ymm8, mem(r11)) + vmovupd(ymm9, mem(r11,32)) + + vmovupd(ymm12, mem(r12)) + vmovupd(ymm13, mem(r12,32)) + label(.ZDONE) - - - end_asm( + + + end_asm( : // output operands (none) : // input operands [k_iter] "m" (k_iter), // 0 @@ -2273,7 +1591,7 @@ void bli_zgemm_haswell_asm_4x3 [b_next] "m" (b_next), // 9 [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2281,6 +1599,8 @@ void bli_zgemm_haswell_asm_4x3 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/knc/3/bli_dgemm_knc_asm_30x8.c b/kernels/knc/3/bli_dgemm_knc_asm_30x8.c index 880632ae07..f20e43f7cc 100644 --- a/kernels/knc/3/bli_dgemm_knc_asm_30x8.c +++ b/kernels/knc/3/bli_dgemm_knc_asm_30x8.c @@ -256,6 +256,8 @@ extern int offsets[16]; //#define LOOPMON void bli_dgemm_knc_asm_30x8 ( + dim_t m, + dim_t n, dim_t k, double* restrict alpha, double* restrict a, @@ -273,80 +275,82 @@ void bli_dgemm_knc_asm_30x8 uint64_t k64 = k; + GEMM_UKR_SETUP_CT( d, 30, 8, true ); + #ifdef MONITORS int toph, topl, both, botl, midl, midh, mid2l, mid2h; #endif #ifdef LOOPMON int tlooph, tloopl, blooph, bloopl; #endif - + __asm { #ifdef MONITORS rdtsc mov topl, eax - mov toph, edx + mov toph, edx #endif vpxord zmm0, zmm0, zmm0 vmovaps zmm1, zmm0 //clear out registers - vmovaps zmm2, zmm0 + vmovaps zmm2, zmm0 mov rsi, k64 //loop index - vmovaps zmm3, zmm0 + vmovaps zmm3, zmm0 mov r11, rs_c //load row stride - vmovaps zmm4, zmm0 + vmovaps zmm4, zmm0 sal r11, 3 //scale row stride - vmovaps zmm5, zmm0 + vmovaps zmm5, zmm0 mov r15, a //load address of a - vmovaps zmm6, zmm0 + vmovaps zmm6, zmm0 mov rbx, b //load address of b - vmovaps zmm7, zmm0 + vmovaps zmm7, zmm0 - vmovaps zmm8, zmm0 + vmovaps zmm8, zmm0 lea r10, [r11 + 2*r11 + 0] //r10 has 3 * r11 vmovaps zmm9, zmm0 - vmovaps zmm10, zmm0 - mov rdi, r11 - vmovaps zmm11, zmm0 + vmovaps zmm10, zmm0 + mov rdi, r11 + vmovaps zmm11, zmm0 sal rdi, 2 //rdi has 4*r11 - vmovaps zmm12, zmm0 + vmovaps zmm12, zmm0 mov rcx, c //load address of c for prefetching - vmovaps zmm13, zmm0 - vmovaps zmm14, zmm0 + vmovaps zmm13, zmm0 + vmovaps zmm14, zmm0 mov r8, k64 - vmovaps zmm15, zmm0 + vmovaps zmm15, zmm0 vmovaps zmm16, zmm0 vmovaps zmm17, zmm0 mov r13, L2_PREFETCH_DIST*8*8 - vmovaps zmm18, zmm0 + vmovaps zmm18, zmm0 mov r14, L2_PREFETCH_DIST*8*32 - vmovaps zmm19, zmm0 - vmovaps zmm20, zmm0 - vmovaps zmm21, zmm0 - vmovaps zmm22, zmm0 + vmovaps zmm19, zmm0 + vmovaps zmm20, zmm0 + vmovaps zmm21, zmm0 + vmovaps zmm22, zmm0 - vmovaps zmm23, zmm0 + vmovaps zmm23, zmm0 sub r8, 30 + L2_PREFETCH_DIST //Check if we have over 40 operations to do. - vmovaps zmm24, zmm0 + vmovaps zmm24, zmm0 mov r8, 30 - vmovaps zmm25, zmm0 + vmovaps zmm25, zmm0 mov r9, 8*8 //amount to increment b* by each iteration - vmovaps zmm26, zmm0 + vmovaps zmm26, zmm0 mov r12, 32*8 //amount to increment a* by each iteration - vmovaps zmm27, zmm0 - vmovaps zmm28, zmm0 - vmovaps zmm29, zmm0 + vmovaps zmm27, zmm0 + vmovaps zmm28, zmm0 + vmovaps zmm29, zmm0 #ifdef MONITORS rdtsc mov midl, eax - mov midh, edx + mov midh, edx #endif jle CONSIDER_UNDER_40 sub rsi, 30 + L2_PREFETCH_DIST - + //First 30 iterations LOOPREFECHCL2: ONE_ITER_PC_L2(rcx) @@ -357,26 +361,26 @@ void bli_dgemm_knc_asm_30x8 LOOPMAIN: ONE_ITER_MAIN_LOOP(rcx, rsi) jne LOOPMAIN - + //Penultimate 22 iterations. //Break these off from the main loop to avoid prefetching extra shit. mov r14, a_next mov r13, b_next sub r14, r15 sub r13, rbx - + mov rsi, L2_PREFETCH_DIST-10 LOOPMAIN2: ONE_ITER_MAIN_LOOP(rcx, rsi) jne LOOPMAIN2 - - + + //Last 10 iterations mov r8, 10 LOOPREFETCHCL1: ONE_ITER_PC_L1(rcx) jne LOOPREFETCHCL1 - + jmp POSTACCUM @@ -403,14 +407,8 @@ void bli_dgemm_knc_asm_30x8 mov r9, c //load address of c for update mov r12, alpha //load address of alpha - // Check if C is row stride. If not, jump to the slow scattered update - mov r14, cs_c - dec r14 - jne SCATTEREDUPDATE - mov r14, beta - vbroadcastsd zmm31, 0[r14] - + vbroadcastsd zmm31, 0[r14] vmulpd zmm0, zmm0, 0[r12]{1to8} vmulpd zmm1, zmm1, 0[r12]{1to8} @@ -467,7 +465,7 @@ void bli_dgemm_knc_asm_30x8 vmovapd [r9+2*r11+0], zmm14 vmovapd [r9+r10+0], zmm15 add r9, rdi - + vmulpd zmm16, zmm16, 0[r12]{1to8} vmulpd zmm17, zmm17, 0[r12]{1to8} vmulpd zmm18, zmm18, 0[r12]{1to8} @@ -516,47 +514,6 @@ void bli_dgemm_knc_asm_30x8 vfmadd231pd zmm29, zmm31, [r9+r11+0] vmovapd [r9+0], zmm28 vmovapd [r9+r11+0], zmm29 - - jmp END - - SCATTEREDUPDATE: - mov r10, offsetPtr - vmovapd zmm31, 0[r10] - vpbroadcastd zmm30, cs_c - mov r13, beta - vpmulld zmm30, zmm31, zmm30 - - mov ebx, 255 - UPDATE_C_ROW_SCATTERED(zmm0, 0, r9) - UPDATE_C_ROW_SCATTERED(zmm1, 1, r9) - UPDATE_C_ROW_SCATTERED(zmm2, 2, r9) - UPDATE_C_ROW_SCATTERED(zmm3, 3, r9) - UPDATE_C_ROW_SCATTERED(zmm4, 4, r9) - UPDATE_C_ROW_SCATTERED(zmm5, 5, r9) - UPDATE_C_ROW_SCATTERED(zmm6, 6, r9) - UPDATE_C_ROW_SCATTERED(zmm7, 7, r9) - UPDATE_C_ROW_SCATTERED(zmm8, 8, r9) - UPDATE_C_ROW_SCATTERED(zmm9, 9, r9) - UPDATE_C_ROW_SCATTERED(zmm10, 10, r9) - UPDATE_C_ROW_SCATTERED(zmm11, 11, r9) - UPDATE_C_ROW_SCATTERED(zmm12, 12, r9) - UPDATE_C_ROW_SCATTERED(zmm13, 13, r9) - UPDATE_C_ROW_SCATTERED(zmm14, 14, r9) - UPDATE_C_ROW_SCATTERED(zmm15, 15, r9) - UPDATE_C_ROW_SCATTERED(zmm16, 16, r9) - UPDATE_C_ROW_SCATTERED(zmm17, 17, r9) - UPDATE_C_ROW_SCATTERED(zmm18, 18, r9) - UPDATE_C_ROW_SCATTERED(zmm19, 19, r9) - UPDATE_C_ROW_SCATTERED(zmm20, 20, r9) - UPDATE_C_ROW_SCATTERED(zmm21, 21, r9) - UPDATE_C_ROW_SCATTERED(zmm22, 22, r9) - UPDATE_C_ROW_SCATTERED(zmm23, 23, r9) - UPDATE_C_ROW_SCATTERED(zmm24, 24, r9) - UPDATE_C_ROW_SCATTERED(zmm25, 25, r9) - UPDATE_C_ROW_SCATTERED(zmm26, 26, r9) - UPDATE_C_ROW_SCATTERED(zmm27, 27, r9) - UPDATE_C_ROW_SCATTERED(zmm28, 28, r9) - UPDATE_C_ROW_SCATTERED(zmm29, 29, r9) END: #ifdef MONITORS @@ -566,6 +523,8 @@ void bli_dgemm_knc_asm_30x8 #endif } + GEMM_UKR_FLUSH_CT( d ); + #ifdef LOOPMON printf("looptime = \t%d\n", bloopl - tloopl); #endif diff --git a/kernels/knc/3/bli_sgemm_knc_asm_30x16.c b/kernels/knc/3/bli_sgemm_knc_asm_30x16.c index 866cb62ec1..18a8e5e2ee 100644 --- a/kernels/knc/3/bli_sgemm_knc_asm_30x16.c +++ b/kernels/knc/3/bli_sgemm_knc_asm_30x16.c @@ -256,6 +256,8 @@ int offsets[16] __attribute__((aligned(0x1000))) = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9 //#define LOOPMON void bli_sgemm_knc_asm_30x16 ( + dim_t m, + dim_t n, dim_t k, float* restrict alpha, float* restrict a, @@ -273,80 +275,82 @@ void bli_sgemm_knc_asm_30x16 uint64_t k64 = k; + GEMM_UKR_SETUP_CT( s, 30, 16, true ); + #ifdef MONITORS int toph, topl, both, botl, midl, midh, mid2l, mid2h; #endif #ifdef LOOPMON int tlooph, tloopl, blooph, bloopl; #endif - + __asm { #ifdef MONITORS rdtsc mov topl, eax - mov toph, edx + mov toph, edx #endif vpxord zmm0, zmm0, zmm0 vmovaps zmm1, zmm0 //clear out registers - vmovaps zmm2, zmm0 + vmovaps zmm2, zmm0 mov rsi, k64 //loop index - vmovaps zmm3, zmm0 + vmovaps zmm3, zmm0 mov r11, rs_c //load row stride - vmovaps zmm4, zmm0 + vmovaps zmm4, zmm0 sal r11, 2 //scale row stride - vmovaps zmm5, zmm0 + vmovaps zmm5, zmm0 mov r15, a //load address of a - vmovaps zmm6, zmm0 + vmovaps zmm6, zmm0 mov rbx, b //load address of b - vmovaps zmm7, zmm0 + vmovaps zmm7, zmm0 - vmovaps zmm8, zmm0 + vmovaps zmm8, zmm0 lea r10, [r11 + 2*r11 + 0] //r10 has 3 * r11 vmovaps zmm9, zmm0 - vmovaps zmm10, zmm0 - mov rdi, r11 - vmovaps zmm11, zmm0 + vmovaps zmm10, zmm0 + mov rdi, r11 + vmovaps zmm11, zmm0 sal rdi, 2 //rdi has 4*r11 - vmovaps zmm12, zmm0 + vmovaps zmm12, zmm0 mov rcx, c //load address of c for prefetching - vmovaps zmm13, zmm0 - vmovaps zmm14, zmm0 + vmovaps zmm13, zmm0 + vmovaps zmm14, zmm0 mov r8, k64 - vmovaps zmm15, zmm0 + vmovaps zmm15, zmm0 vmovaps zmm16, zmm0 vmovaps zmm17, zmm0 mov r13, L2_PREFETCH_DIST*4*16 - vmovaps zmm18, zmm0 + vmovaps zmm18, zmm0 mov r14, L2_PREFETCH_DIST*4*32 - vmovaps zmm19, zmm0 - vmovaps zmm20, zmm0 - vmovaps zmm21, zmm0 - vmovaps zmm22, zmm0 + vmovaps zmm19, zmm0 + vmovaps zmm20, zmm0 + vmovaps zmm21, zmm0 + vmovaps zmm22, zmm0 - vmovaps zmm23, zmm0 + vmovaps zmm23, zmm0 sub r8, 30 + L2_PREFETCH_DIST //Check if we have over 40 operations to do. - vmovaps zmm24, zmm0 + vmovaps zmm24, zmm0 mov r8, 30 - vmovaps zmm25, zmm0 + vmovaps zmm25, zmm0 mov r9, 16*4 //amount to increment b* by each iteration - vmovaps zmm26, zmm0 + vmovaps zmm26, zmm0 mov r12, 32*4 //amount to increment a* by each iteration - vmovaps zmm27, zmm0 - vmovaps zmm28, zmm0 - vmovaps zmm29, zmm0 + vmovaps zmm27, zmm0 + vmovaps zmm28, zmm0 + vmovaps zmm29, zmm0 #ifdef MONITORS rdtsc mov midl, eax - mov midh, edx + mov midh, edx #endif jle CONSIDER_UNDER_40 sub rsi, 30 + L2_PREFETCH_DIST - + //First 30 iterations LOOPREFECHCL2: ONE_ITER_PC_L2(rcx) @@ -357,26 +361,26 @@ void bli_sgemm_knc_asm_30x16 LOOPMAIN: ONE_ITER_MAIN_LOOP(rcx, rsi) jne LOOPMAIN - + //Penultimate 22 iterations. //Break these off from the main loop to avoid prefetching extra shit. mov r14, a_next mov r13, b_next sub r14, r15 sub r13, rbx - + mov rsi, L2_PREFETCH_DIST-10 LOOPMAIN2: ONE_ITER_MAIN_LOOP(rcx, rsi) jne LOOPMAIN2 - - + + //Last 10 iterations mov r8, 10 LOOPREFETCHCL1: ONE_ITER_PC_L1(rcx) jne LOOPREFETCHCL1 - + jmp POSTACCUM @@ -384,7 +388,7 @@ void bli_sgemm_knc_asm_30x16 //Used when <= 40 iterations CONSIDER_UNDER_40: mov rsi, k64 - test rsi, rsi + test rsi, rsi je POSTACCUM LOOP_UNDER_40: ONE_ITER_MAIN_LOOP(rcx, rsi) @@ -403,13 +407,8 @@ void bli_sgemm_knc_asm_30x16 mov r9, c //load address of c for update mov r12, alpha //load address of alpha - // Check if C is row stride. If not, jump to the slow scattered update - mov r14, cs_c - dec r14 - jne SCATTEREDUPDATE - mov r14, beta - vbroadcastss zmm31, 0[r14] + vbroadcastss zmm31, 0[r14] vmulps zmm0, zmm0, 0[r12]{1to16} @@ -467,7 +466,7 @@ void bli_sgemm_knc_asm_30x16 vmovaps [r9+2*r11+0], zmm14 vmovaps [r9+r10+0], zmm15 add r9, rdi - + vmulps zmm16, zmm16, 0[r12]{1to16} vmulps zmm17, zmm17, 0[r12]{1to16} vmulps zmm18, zmm18, 0[r12]{1to16} @@ -516,48 +515,6 @@ void bli_sgemm_knc_asm_30x16 vfmadd231ps zmm29, zmm31, [r9+r11+0] vmovaps [r9+0], zmm28 vmovaps [r9+r11+0], zmm29 - - jmp END - - SCATTEREDUPDATE: - - mov r10, offsetPtr - vmovaps zmm31, 0[r10] - vpbroadcastd zmm30, cs_c - mov r13, beta - vpmulld zmm30, zmm31, zmm30 - - mov ebx, 0xFFFF - UPDATE_C_ROW_SCATTERED(zmm0, 0, r9) - UPDATE_C_ROW_SCATTERED(zmm1, 1, r9) - UPDATE_C_ROW_SCATTERED(zmm2, 2, r9) - UPDATE_C_ROW_SCATTERED(zmm3, 3, r9) - UPDATE_C_ROW_SCATTERED(zmm4, 4, r9) - UPDATE_C_ROW_SCATTERED(zmm5, 5, r9) - UPDATE_C_ROW_SCATTERED(zmm6, 6, r9) - UPDATE_C_ROW_SCATTERED(zmm7, 7, r9) - UPDATE_C_ROW_SCATTERED(zmm8, 8, r9) - UPDATE_C_ROW_SCATTERED(zmm9, 9, r9) - UPDATE_C_ROW_SCATTERED(zmm10, 10, r9) - UPDATE_C_ROW_SCATTERED(zmm11, 11, r9) - UPDATE_C_ROW_SCATTERED(zmm12, 12, r9) - UPDATE_C_ROW_SCATTERED(zmm13, 13, r9) - UPDATE_C_ROW_SCATTERED(zmm14, 14, r9) - UPDATE_C_ROW_SCATTERED(zmm15, 15, r9) - UPDATE_C_ROW_SCATTERED(zmm16, 16, r9) - UPDATE_C_ROW_SCATTERED(zmm17, 17, r9) - UPDATE_C_ROW_SCATTERED(zmm18, 18, r9) - UPDATE_C_ROW_SCATTERED(zmm19, 19, r9) - UPDATE_C_ROW_SCATTERED(zmm20, 20, r9) - UPDATE_C_ROW_SCATTERED(zmm21, 21, r9) - UPDATE_C_ROW_SCATTERED(zmm22, 22, r9) - UPDATE_C_ROW_SCATTERED(zmm23, 23, r9) - UPDATE_C_ROW_SCATTERED(zmm24, 24, r9) - UPDATE_C_ROW_SCATTERED(zmm25, 25, r9) - UPDATE_C_ROW_SCATTERED(zmm26, 26, r9) - UPDATE_C_ROW_SCATTERED(zmm27, 27, r9) - UPDATE_C_ROW_SCATTERED(zmm28, 28, r9) - UPDATE_C_ROW_SCATTERED(zmm29, 29, r9) END: #ifdef MONITORS @@ -567,6 +524,8 @@ void bli_sgemm_knc_asm_30x16 #endif } + GEMM_UKR_FLUSH_CT( s ); + #ifdef LOOPMON printf("looptime = \t%d\n", bloopl - tloopl); #endif diff --git a/kernels/knl/3/bli_dgemm_knl_asm_24x8.c b/kernels/knl/3/bli_dgemm_knl_asm_24x8.c index b794e7c059..a7f860ae02 100644 --- a/kernels/knl/3/bli_dgemm_knl_asm_24x8.c +++ b/kernels/knl/3/bli_dgemm_knl_asm_24x8.c @@ -185,6 +185,8 @@ static int32_t offsets[32] __attribute__((aligned(64))) = //#define LOOPMON void bli_dgemm_knl_asm_24x8 ( + dim_t m, + dim_t n, dim_t k_, double* restrict alpha, double* restrict a, @@ -201,10 +203,12 @@ void bli_dgemm_knl_asm_24x8 const double * a_next = bli_auxinfo_next_a( data ); const double * b_next = bli_auxinfo_next_b( data ); - const int32_t * offsetPtr = &offsets[0]; - const int64_t k = k_; - const int64_t rs_c = rs_c_; - const int64_t cs_c = cs_c_; + int32_t * offsetPtr = &offsets[0]; + int64_t k = k_; + int64_t rs_c = rs_c_; + int64_t cs_c = cs_c_; + + GEMM_UKR_SETUP_CT( d, 24, 8, true ); #ifdef MONITORS int toph, topl, both, botl, midl, midh, mid2l, mid2h; @@ -565,10 +569,7 @@ void bli_dgemm_knl_asm_24x8 // Check if C is row stride. If not, jump to the slow scattered update MOV(RAX, VAR(rs_c)) LEA(RAX, MEM(,RAX,8)) - MOV(RBX, VAR(cs_c)) LEA(RDI, MEM(RAX,RAX,2)) - CMP(RBX, IMM(1)) - JNE(SCATTEREDUPDATE) VMOVQ(RDX, XMM(1)) SAL(RDX) //shift out sign bit @@ -592,74 +593,6 @@ void bli_dgemm_knl_asm_24x8 UPDATE_C_BZ_FOUR_ROWS(24,25,26,27) UPDATE_C_BZ_FOUR_ROWS(28,29,30,31) - JMP(END) - - LABEL(SCATTEREDUPDATE) - - MOV(RDI, VAR(offsetPtr)) - VMOVAPS(ZMM(2), MEM(RDI)) - /* Note that this ignores the upper 32 bits in cs_c */ - VPBROADCASTD(ZMM(3), EBX) - VPMULLD(ZMM(2), ZMM(3), ZMM(2)) - - VMOVQ(RDX, XMM(1)) - SAL(RDX) //shift out sign bit - JZ(SCATTERBZ) - - UPDATE_C_ROW_SCATTERED( 8) - UPDATE_C_ROW_SCATTERED( 9) - UPDATE_C_ROW_SCATTERED(10) - UPDATE_C_ROW_SCATTERED(11) - UPDATE_C_ROW_SCATTERED(12) - UPDATE_C_ROW_SCATTERED(13) - UPDATE_C_ROW_SCATTERED(14) - UPDATE_C_ROW_SCATTERED(15) - UPDATE_C_ROW_SCATTERED(16) - UPDATE_C_ROW_SCATTERED(17) - UPDATE_C_ROW_SCATTERED(18) - UPDATE_C_ROW_SCATTERED(19) - UPDATE_C_ROW_SCATTERED(20) - UPDATE_C_ROW_SCATTERED(21) - UPDATE_C_ROW_SCATTERED(22) - UPDATE_C_ROW_SCATTERED(23) - UPDATE_C_ROW_SCATTERED(24) - UPDATE_C_ROW_SCATTERED(25) - UPDATE_C_ROW_SCATTERED(26) - UPDATE_C_ROW_SCATTERED(27) - UPDATE_C_ROW_SCATTERED(28) - UPDATE_C_ROW_SCATTERED(29) - UPDATE_C_ROW_SCATTERED(30) - UPDATE_C_ROW_SCATTERED(31) - - JMP(END) - - LABEL(SCATTERBZ) - - UPDATE_C_BZ_ROW_SCATTERED( 8) - UPDATE_C_BZ_ROW_SCATTERED( 9) - UPDATE_C_BZ_ROW_SCATTERED(10) - UPDATE_C_BZ_ROW_SCATTERED(11) - UPDATE_C_BZ_ROW_SCATTERED(12) - UPDATE_C_BZ_ROW_SCATTERED(13) - UPDATE_C_BZ_ROW_SCATTERED(14) - UPDATE_C_BZ_ROW_SCATTERED(15) - UPDATE_C_BZ_ROW_SCATTERED(16) - UPDATE_C_BZ_ROW_SCATTERED(17) - UPDATE_C_BZ_ROW_SCATTERED(18) - UPDATE_C_BZ_ROW_SCATTERED(19) - UPDATE_C_BZ_ROW_SCATTERED(20) - UPDATE_C_BZ_ROW_SCATTERED(21) - UPDATE_C_BZ_ROW_SCATTERED(22) - UPDATE_C_BZ_ROW_SCATTERED(23) - UPDATE_C_BZ_ROW_SCATTERED(24) - UPDATE_C_BZ_ROW_SCATTERED(25) - UPDATE_C_BZ_ROW_SCATTERED(26) - UPDATE_C_BZ_ROW_SCATTERED(27) - UPDATE_C_BZ_ROW_SCATTERED(28) - UPDATE_C_BZ_ROW_SCATTERED(29) - UPDATE_C_BZ_ROW_SCATTERED(30) - UPDATE_C_BZ_ROW_SCATTERED(31) - LABEL(END) #ifdef MONITORS @@ -701,6 +634,8 @@ void bli_dgemm_knl_asm_24x8 "zmm30", "zmm31", "memory" ) + GEMM_UKR_FLUSH_CT( d ); + #ifdef LOOPMON printf("looptime = \t%d\n", bloopl - tloopl); #endif diff --git a/kernels/knl/3/bli_sgemm_knl_asm_24x16.c b/kernels/knl/3/bli_sgemm_knl_asm_24x16.c index 6d485b5308..64feba09f1 100644 --- a/kernels/knl/3/bli_sgemm_knl_asm_24x16.c +++ b/kernels/knl/3/bli_sgemm_knl_asm_24x16.c @@ -182,6 +182,8 @@ static int32_t offsets[32] __attribute__((aligned(64))) = //#define LOOPMON void bli_sgemm_knl_asm_24x16 ( + dim_t m, + dim_t n, dim_t k_, float* restrict alpha, float* restrict a, @@ -198,10 +200,12 @@ void bli_sgemm_knl_asm_24x16 const double * a_next = bli_auxinfo_next_a( data ); const double * b_next = bli_auxinfo_next_b( data ); - const int32_t * offsetPtr = &offsets[0]; - const int64_t k = k_; - const int64_t rs_c = rs_c_; - const int64_t cs_c = cs_c_; + int32_t * offsetPtr = &offsets[0]; + int64_t k = k_; + int64_t rs_c = rs_c_; + int64_t cs_c = cs_c_; + + GEMM_UKR_SETUP_CT( s, 24, 16, true ); #ifdef MONITORS int toph, topl, both, botl, midl, midh, mid2l, mid2h; @@ -562,10 +566,7 @@ void bli_sgemm_knl_asm_24x16 // Check if C is row stride. If not, jump to the slow scattered update MOV(RAX, VAR(rs_c)) LEA(RAX, MEM(,RAX,4)) - MOV(RBX, VAR(cs_c)) LEA(RDI, MEM(RAX,RAX,2)) - CMP(RBX, IMM(1)) - JNE(SCATTEREDUPDATE) VMOVD(EDX, XMM(1)) SAL(EDX) //shift out sign bit @@ -589,74 +590,6 @@ void bli_sgemm_knl_asm_24x16 UPDATE_C_BZ_FOUR_ROWS(24,25,26,27) UPDATE_C_BZ_FOUR_ROWS(28,29,30,31) - JMP(END) - - LABEL(SCATTEREDUPDATE) - - MOV(RDI, VAR(offsetPtr)) - VMOVAPS(ZMM(2), MEM(RDI)) - /* Note that this ignores the upper 32 bits in cs_c */ - VPBROADCASTD(ZMM(3), EBX) - VPMULLD(ZMM(2), ZMM(3), ZMM(2)) - - VMOVD(EDX, XMM(1)) - SAL(EDX) //shift out sign bit - JZ(SCATTERBZ) - - UPDATE_C_ROW_SCATTERED( 8) - UPDATE_C_ROW_SCATTERED( 9) - UPDATE_C_ROW_SCATTERED(10) - UPDATE_C_ROW_SCATTERED(11) - UPDATE_C_ROW_SCATTERED(12) - UPDATE_C_ROW_SCATTERED(13) - UPDATE_C_ROW_SCATTERED(14) - UPDATE_C_ROW_SCATTERED(15) - UPDATE_C_ROW_SCATTERED(16) - UPDATE_C_ROW_SCATTERED(17) - UPDATE_C_ROW_SCATTERED(18) - UPDATE_C_ROW_SCATTERED(19) - UPDATE_C_ROW_SCATTERED(20) - UPDATE_C_ROW_SCATTERED(21) - UPDATE_C_ROW_SCATTERED(22) - UPDATE_C_ROW_SCATTERED(23) - UPDATE_C_ROW_SCATTERED(24) - UPDATE_C_ROW_SCATTERED(25) - UPDATE_C_ROW_SCATTERED(26) - UPDATE_C_ROW_SCATTERED(27) - UPDATE_C_ROW_SCATTERED(28) - UPDATE_C_ROW_SCATTERED(29) - UPDATE_C_ROW_SCATTERED(30) - UPDATE_C_ROW_SCATTERED(31) - - JMP(END) - - LABEL(SCATTERBZ) - - UPDATE_C_BZ_ROW_SCATTERED( 8) - UPDATE_C_BZ_ROW_SCATTERED( 9) - UPDATE_C_BZ_ROW_SCATTERED(10) - UPDATE_C_BZ_ROW_SCATTERED(11) - UPDATE_C_BZ_ROW_SCATTERED(12) - UPDATE_C_BZ_ROW_SCATTERED(13) - UPDATE_C_BZ_ROW_SCATTERED(14) - UPDATE_C_BZ_ROW_SCATTERED(15) - UPDATE_C_BZ_ROW_SCATTERED(16) - UPDATE_C_BZ_ROW_SCATTERED(17) - UPDATE_C_BZ_ROW_SCATTERED(18) - UPDATE_C_BZ_ROW_SCATTERED(19) - UPDATE_C_BZ_ROW_SCATTERED(20) - UPDATE_C_BZ_ROW_SCATTERED(21) - UPDATE_C_BZ_ROW_SCATTERED(22) - UPDATE_C_BZ_ROW_SCATTERED(23) - UPDATE_C_BZ_ROW_SCATTERED(24) - UPDATE_C_BZ_ROW_SCATTERED(25) - UPDATE_C_BZ_ROW_SCATTERED(26) - UPDATE_C_BZ_ROW_SCATTERED(27) - UPDATE_C_BZ_ROW_SCATTERED(28) - UPDATE_C_BZ_ROW_SCATTERED(29) - UPDATE_C_BZ_ROW_SCATTERED(30) - UPDATE_C_BZ_ROW_SCATTERED(31) - LABEL(END) #ifdef MONITORS @@ -698,6 +631,8 @@ void bli_sgemm_knl_asm_24x16 "zmm30", "zmm31", "memory" ) + GEMM_UKR_FLUSH_CT( s ); + #ifdef LOOPMON printf("looptime = \t%d\n", bloopl - tloopl); #endif diff --git a/kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c b/kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c index e52cc9e0e0..a3e39c3ac1 100644 --- a/kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c +++ b/kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c @@ -39,7 +39,9 @@ void bli_sgemm_penryn_asm_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -54,38 +56,40 @@ void bli_sgemm_penryn_asm_8x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ALIGNED( s, 8, 4, false, 16 ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r9) // load address of b_next. - + sub(imm(0-8*16), rax) // increment pointers to allow byte sub(imm(0-8*16), rbx) // offsets in the unrolled iterations. - + movaps(mem(rax, -8*16), xmm0) // initialize loop by pre-loading elements movaps(mem(rax, -7*16), xmm1) // of a and b. movaps(mem(rbx, -8*16), xmm2) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float) mov(rdi, r12) // make a copy of cs_c (in bytes) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(2, mem(r9, 0*4)) // prefetch b_next - + xorps(xmm3, xmm3) xorps(xmm4, xmm4) xorps(xmm5, xmm5) xorps(xmm6, xmm6) - + prefetch(2, mem(rcx, 6*4)) // prefetch c + 0*cs_c xorps(xmm8, xmm8) xorps(xmm9, xmm9) @@ -98,33 +102,33 @@ void bli_sgemm_penryn_asm_8x4 prefetch(2, mem(r10, rdi, 1, 6*4)) // prefetch c + 3*cs_c xorps(xmm14, xmm14) xorps(xmm15, xmm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.SLOOPKITER) // MAIN LOOP - + prefetch(0, mem(rax, (4*35+1)*8)) - + addps(xmm6, xmm10) // iteration 0 addps(xmm3, xmm14) movaps(xmm2, xmm3) pshufd(imm(0x39), xmm2, xmm7) mulps(xmm0, xmm2) mulps(xmm1, xmm3) - + addps(xmm4, xmm11) addps(xmm5, xmm15) movaps(xmm7, xmm5) pshufd(imm(0x39), xmm7, xmm6) mulps(xmm0, xmm7) mulps(xmm1, xmm5) - + addps(xmm2, xmm8) movaps(mem(rbx, -7*16), xmm2) addps(xmm3, xmm12) @@ -132,7 +136,7 @@ void bli_sgemm_penryn_asm_8x4 pshufd(imm(0x39), xmm6, xmm4) mulps(xmm0, xmm6) mulps(xmm1, xmm3) - + addps(xmm7, xmm9) addps(xmm5, xmm13) movaps(xmm4, xmm5) @@ -140,22 +144,22 @@ void bli_sgemm_penryn_asm_8x4 movaps(mem(rax, -6*16), xmm0) mulps(xmm1, xmm5) movaps(mem(rax, -5*16), xmm1) - - + + addps(xmm6, xmm10) // iteration 1 addps(xmm3, xmm14) movaps(xmm2, xmm3) pshufd(imm(0x39), xmm2, xmm7) mulps(xmm0, xmm2) mulps(xmm1, xmm3) - + addps(xmm4, xmm11) addps(xmm5, xmm15) movaps(xmm7, xmm5) pshufd(imm(0x39), xmm7, xmm6) mulps(xmm0, xmm7) mulps(xmm1, xmm5) - + addps(xmm2, xmm8) movaps(mem(rbx, -6*16), xmm2) addps(xmm3, xmm12) @@ -163,7 +167,7 @@ void bli_sgemm_penryn_asm_8x4 pshufd(imm(0x39), xmm6, xmm4) mulps(xmm0, xmm6) mulps(xmm1, xmm3) - + addps(xmm7, xmm9) addps(xmm5, xmm13) movaps(xmm4, xmm5) @@ -171,22 +175,22 @@ void bli_sgemm_penryn_asm_8x4 movaps(mem(rax, -4*16), xmm0) mulps(xmm1, xmm5) movaps(mem(rax, -3*16), xmm1) - - + + addps(xmm6, xmm10) // iteration 2 addps(xmm3, xmm14) movaps(xmm2, xmm3) pshufd(imm(0x39), xmm2, xmm7) mulps(xmm0, xmm2) mulps(xmm1, xmm3) - + addps(xmm4, xmm11) addps(xmm5, xmm15) movaps(xmm7, xmm5) pshufd(imm(0x39), xmm7, xmm6) mulps(xmm0, xmm7) mulps(xmm1, xmm5) - + addps(xmm2, xmm8) movaps(mem(rbx, -5*16), xmm2) addps(xmm3, xmm12) @@ -194,7 +198,7 @@ void bli_sgemm_penryn_asm_8x4 pshufd(imm(0x39), xmm6, xmm4) mulps(xmm0, xmm6) mulps(xmm1, xmm3) - + addps(xmm7, xmm9) addps(xmm5, xmm13) movaps(xmm4, xmm5) @@ -202,26 +206,26 @@ void bli_sgemm_penryn_asm_8x4 movaps(mem(rax, -2*16), xmm0) mulps(xmm1, xmm5) movaps(mem(rax, -1*16), xmm1) - - + + addps(xmm6, xmm10) // iteration 3 addps(xmm3, xmm14) movaps(xmm2, xmm3) pshufd(imm(0x39), xmm2, xmm7) mulps(xmm0, xmm2) mulps(xmm1, xmm3) - + sub(imm(0-4*8*4), rax) // a += 4*8 (unroll x mr) - + addps(xmm4, xmm11) addps(xmm5, xmm15) movaps(xmm7, xmm5) pshufd(imm(0x39), xmm7, xmm6) mulps(xmm0, xmm7) mulps(xmm1, xmm5) - + sub(imm(0-4*4*4), r9) // b_next += 4*4 (unroll x nr) - + addps(xmm2, xmm8) movaps(mem(rbx, -4*16), xmm2) addps(xmm3, xmm12) @@ -229,9 +233,9 @@ void bli_sgemm_penryn_asm_8x4 pshufd(imm(0x39), xmm6, xmm4) mulps(xmm0, xmm6) mulps(xmm1, xmm3) - + sub(imm(0-4*4*4), rbx) // b += 4*4 (unroll x nr) - + addps(xmm7, xmm9) addps(xmm5, xmm13) movaps(xmm4, xmm5) @@ -239,40 +243,40 @@ void bli_sgemm_penryn_asm_8x4 movaps(mem(rax, -8*16), xmm0) mulps(xmm1, xmm5) movaps(mem(rax, -7*16), xmm1) - + prefetch(2, mem(r9, 0*4)) // prefetch b_next[0] prefetch(2, mem(r9, 16*4)) // prefetch b_next[16] - - + + dec(rsi) // i -= 1; jne(.SLOOPKITER) // iterate again if i != 0. - - - + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.SLOOPKLEFT) // EDGE LOOP - + addps(xmm6, xmm10) // iteration 0 addps(xmm3, xmm14) movaps(xmm2, xmm3) pshufd(imm(0x39), xmm2, xmm7) mulps(xmm0, xmm2) mulps(xmm1, xmm3) - + addps(xmm4, xmm11) addps(xmm5, xmm15) movaps(xmm7, xmm5) pshufd(imm(0x39), xmm7, xmm6) mulps(xmm0, xmm7) mulps(xmm1, xmm5) - + addps(xmm2, xmm8) movaps(mem(rbx, -7*16), xmm2) addps(xmm3, xmm12) @@ -280,7 +284,7 @@ void bli_sgemm_penryn_asm_8x4 pshufd(imm(0x39), xmm6, xmm4) mulps(xmm0, xmm6) mulps(xmm1, xmm3) - + addps(xmm7, xmm9) addps(xmm5, xmm13) movaps(xmm4, xmm5) @@ -288,40 +292,40 @@ void bli_sgemm_penryn_asm_8x4 movaps(mem(rax, -6*16), xmm0) mulps(xmm1, xmm5) movaps(mem(rax, -5*16), xmm1) - + sub(imm(0-1*8*4), rax) // a += 8 (1 x mr) sub(imm(0-1*4*4), rbx) // b += 4 (1 x nr) - - + + dec(rsi) // i -= 1; jne(.SLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.SPOSTACCUM) - + addps(xmm6, xmm10) addps(xmm3, xmm14) addps(xmm4, xmm11) addps(xmm5, xmm15) - - + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta movss(mem(rax), xmm6) // load alpha to bottom 4 bytes of xmm6 movss(mem(rbx), xmm7) // load beta to bottom 4 bytes of xmm7 pshufd(imm(0x00), xmm6, xmm6) // populate xmm6 with four alphas pshufd(imm(0x00), xmm7, xmm7) // populate xmm7 with four betas - - + + mov(var(rs_c), rsi) // load rs_c mov(rsi, r8) // make a copy of rs_c - + lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float) lea(mem(rsi, rsi, 2), r11) // r11 = 3*(rs_c * sizeof(float)) - + lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - + // xmm8: xmm9: xmm10: xmm11: // ( ab00 ( ab01 ( ab02 ( ab03 // ab11 ab12 ab13 ab10 @@ -338,20 +342,20 @@ void bli_sgemm_penryn_asm_8x4 shufps(imm(0xd8), xmm11, xmm8) shufps(imm(0xd8), xmm10, xmm11) shufps(imm(0xd8), xmm4, xmm10) - + movaps(xmm8, xmm4) shufps(imm(0xd8), xmm10, xmm8) shufps(imm(0xd8), xmm4, xmm10) movaps(xmm9, xmm5) shufps(imm(0xd8), xmm11, xmm9) shufps(imm(0xd8), xmm5, xmm11) - + movaps(xmm13, xmm4) shufps(imm(0xd8), xmm12, xmm13) shufps(imm(0xd8), xmm15, xmm12) shufps(imm(0xd8), xmm14, xmm15) shufps(imm(0xd8), xmm4, xmm14) - + movaps(xmm12, xmm4) shufps(imm(0xd8), xmm14, xmm12) shufps(imm(0xd8), xmm4, xmm14) @@ -369,471 +373,133 @@ void bli_sgemm_penryn_asm_8x4 // ab50 ab51 ab52 ab53 // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - - - - // determine if - // c % 16 == 0, AND - // 8*cs_c % 16 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(1), r8) // set ZF if rs_c == 1. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(15), rcx) // set ZF if c & 16 is zero. - setz(bh) // bh = ( ZF == 1 ? 1 : 0 ); - test(imm(15), r12) // set ZF if (4*cs_c) & 16 is zero. - setz(al) // al = ( ZF == 1 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + // now avoid loading C if beta == 0 - + xorpd(xmm0, xmm0) // set xmm0 to zero. ucomisd(xmm0, xmm7) // check if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORED) // jump to column storage case - - - - label(.SGENSTORED) - - movlps(mem(rcx), xmm0) // load c00 ~ c30 - movhps(mem(rcx, rsi, 1), xmm0) - movlps(mem(rcx, rsi, 2), xmm1) - movhps(mem(rcx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm8) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm8, xmm0) // add the gemm result, - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - movlps(mem(rdx), xmm0) // load c40 ~ c70 - movhps(mem(rdx, rsi, 1), xmm0) - movlps(mem(rdx, rsi, 2), xmm1) - movhps(mem(rdx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm12) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm12, xmm0) // add the gemm result, - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - movlps(mem(rcx), xmm0) // load c01 ~ c31 - movhps(mem(rcx, rsi, 1), xmm0) - movlps(mem(rcx, rsi, 2), xmm1) - movhps(mem(rcx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm9) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm9, xmm0) // add the gemm result, - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - movlps(mem(rdx), xmm0) // load c41 ~ c71 - movhps(mem(rdx, rsi, 1), xmm0) - movlps(mem(rdx, rsi, 2), xmm1) - movhps(mem(rdx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm13) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm13, xmm0) // add the gemm result, - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - movlps(mem(rcx), xmm0) // load c02 ~ c32 - movhps(mem(rcx, rsi, 1), xmm0) - movlps(mem(rcx, rsi, 2), xmm1) - movhps(mem(rcx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm10) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm10, xmm0) // add the gemm result, - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - movlps(mem(rdx), xmm0) // load c42 ~ c72 - movhps(mem(rdx, rsi, 1), xmm0) - movlps(mem(rdx, rsi, 2), xmm1) - movhps(mem(rdx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm14) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm14, xmm0) // add the gemm result, - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - movlps(mem(rcx), xmm0) // load c03 ~ c33 - movhps(mem(rcx, rsi, 1), xmm0) - movlps(mem(rcx, rsi, 2), xmm1) - movhps(mem(rcx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm11) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm11, xmm0) // add the gemm result, - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - - - - movlps(mem(rdx), xmm0) // load c43 ~ c73 - movhps(mem(rdx, rsi, 1), xmm0) - movlps(mem(rdx, rsi, 2), xmm1) - movhps(mem(rdx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm15) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm15, xmm0) // add the gemm result, - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - movaps(mem(rcx), xmm0) // load c00 ~ c30, - mulps(xmm6, xmm8) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm8, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c40 ~ c70, - mulps(xmm6, xmm12) // scale by alpha, - mulps(xmm7, xmm1) // scale by beta, - addps(xmm12, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c01 ~ c31, - mulps(xmm6, xmm9) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm9, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c41 ~ c71, - mulps(xmm6, xmm13) // scale by alpha, - mulps(xmm7, xmm1) // scale by beta, - addps(xmm13, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c02 ~ c32, - mulps(xmm6, xmm10) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm10, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c42 ~ c72, - mulps(xmm6, xmm14) // scale by alpha, - mulps(xmm7, xmm1) // scale by beta, - addps(xmm14, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c03 ~ c33, - mulps(xmm6, xmm11) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm11, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - - - movaps(mem(rdx), xmm1) // load c43 ~ c73, - mulps(xmm6, xmm15) // scale by alpha, - mulps(xmm7, xmm1) // scale by beta, - addps(xmm15, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - - jmp(.SDONE) // jump to end. - - - - + + movaps(mem(rcx), xmm0) // load c00 ~ c30, + mulps(xmm6, xmm8) // scale by alpha, + mulps(xmm7, xmm0) // scale by beta, + addps(xmm8, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c40 ~ c70, + mulps(xmm6, xmm12) // scale by alpha, + mulps(xmm7, xmm1) // scale by beta, + addps(xmm12, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c01 ~ c31, + mulps(xmm6, xmm9) // scale by alpha, + mulps(xmm7, xmm0) // scale by beta, + addps(xmm9, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c41 ~ c71, + mulps(xmm6, xmm13) // scale by alpha, + mulps(xmm7, xmm1) // scale by beta, + addps(xmm13, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c02 ~ c32, + mulps(xmm6, xmm10) // scale by alpha, + mulps(xmm7, xmm0) // scale by beta, + addps(xmm10, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c42 ~ c72, + mulps(xmm6, xmm14) // scale by alpha, + mulps(xmm7, xmm1) // scale by beta, + addps(xmm14, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c03 ~ c33, + mulps(xmm6, xmm11) // scale by alpha, + mulps(xmm7, xmm0) // scale by beta, + addps(xmm11, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + + + movaps(mem(rdx), xmm1) // load c43 ~ c73, + mulps(xmm6, xmm15) // scale by alpha, + mulps(xmm7, xmm1) // scale by beta, + addps(xmm15, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + + jmp(.SDONE) // jump to end. + label(.SBETAZERO) - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORBZ) // jump to column storage case - - - - label(.SGENSTORBZ) - - mulps(xmm6, xmm8) // scale by alpha, - movaps(xmm8, xmm0) - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - mulps(xmm6, xmm12) // scale by alpha, - movaps(xmm12, xmm0) - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - mulps(xmm6, xmm9) // scale by alpha, - movaps(xmm9, xmm0) - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - mulps(xmm6, xmm13) // scale by alpha, - movaps(xmm13, xmm0) - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - mulps(xmm6, xmm10) // scale by alpha, - movaps(xmm10, xmm0) - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - mulps(xmm6, xmm14) // scale by alpha, - movaps(xmm14, xmm0) - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - mulps(xmm6, xmm11) // scale by alpha, - movaps(xmm11, xmm0) - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - - - - mulps(xmm6, xmm15) // scale by alpha, - movaps(xmm15, xmm0) - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORBZ) - - // skip loading c00 ~ c30, - mulps(xmm6, xmm8) // scale by alpha, - movaps(xmm8, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c40 ~ c70, - mulps(xmm6, xmm12) // scale by alpha, - movaps(xmm12, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c01 ~ c31, - mulps(xmm6, xmm9) // scale by alpha, - movaps(xmm9, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c41 ~ c71, - mulps(xmm6, xmm13) // scale by alpha, - movaps(xmm13, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c02 ~ c32, - mulps(xmm6, xmm10) // scale by alpha, - movaps(xmm10, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c42 ~ c72, - mulps(xmm6, xmm14) // scale by alpha, - movaps(xmm14, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c03 ~ c33, - mulps(xmm6, xmm11) // scale by alpha, - movaps(xmm11, mem(rcx)) // and store back to memory. - - // skip loading c43 ~ c73, - mulps(xmm6, xmm15) // scale by alpha, - movaps(xmm15, mem(rdx)) // and store back to memory. - - - - - - - - + + // skip loading c00 ~ c30, + mulps(xmm6, xmm8) // scale by alpha, + movaps(xmm8, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c40 ~ c70, + mulps(xmm6, xmm12) // scale by alpha, + movaps(xmm12, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c01 ~ c31, + mulps(xmm6, xmm9) // scale by alpha, + movaps(xmm9, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c41 ~ c71, + mulps(xmm6, xmm13) // scale by alpha, + movaps(xmm13, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c02 ~ c32, + mulps(xmm6, xmm10) // scale by alpha, + movaps(xmm10, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c42 ~ c72, + mulps(xmm6, xmm14) // scale by alpha, + movaps(xmm14, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c03 ~ c33, + mulps(xmm6, xmm11) // scale by alpha, + movaps(xmm11, mem(rcx)) // and store back to memory. + + // skip loading c43 ~ c73, + mulps(xmm6, xmm15) // scale by alpha, + movaps(xmm15, mem(rdx)) // and store back to memory. + label(.SDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next)/*, // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next)/*, // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "xmm0", "xmm1", "xmm2", "xmm3", @@ -842,11 +508,15 @@ void bli_sgemm_penryn_asm_8x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } void bli_dgemm_penryn_asm_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -861,39 +531,41 @@ void bli_dgemm_penryn_asm_4x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ALIGNED( d, 4, 4, false, 16 ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r9) // load address of b_next. mov(var(a_next), r11) // load address of a_next. - + sub(imm(0-8*16), rax) // increment pointers to allow byte sub(imm(0-8*16), rbx) // offsets in the unrolled iterations. - + movaps(mem(rax, -8*16), xmm0) // initialize loop by pre-loading elements movaps(mem(rax, -7*16), xmm1) // of a and b. movaps(mem(rbx, -8*16), xmm2) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) mov(rdi, r12) // make a copy of cs_c (in bytes) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(2, mem(r9, 0*8)) // prefetch b_next - + xorpd(xmm3, xmm3) xorpd(xmm4, xmm4) xorpd(xmm5, xmm5) xorpd(xmm6, xmm6) - + prefetch(2, mem(rcx, 3*8)) // prefetch c + 0*cs_c xorpd(xmm8, xmm8) xorpd(xmm9, xmm9) @@ -906,22 +578,22 @@ void bli_dgemm_penryn_asm_4x4 prefetch(2, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c xorpd(xmm14, xmm14) xorpd(xmm15, xmm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - + prefetch(0, mem(rax, (4*35+1)*8)) //prefetch(0, mem(rax, (8*97+4)*8)) - + //prefetch(0, mem(r11, 67*4*8)) // prefetch a_next[0] - + addpd(xmm3, xmm11) // iteration 0 movaps(mem(rbx, -7*16), xmm3) addpd(xmm4, xmm15) @@ -929,13 +601,13 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm2, xmm7) mulpd(xmm0, xmm2) mulpd(xmm1, xmm4) - + addpd(xmm5, xmm10) addpd(xmm6, xmm14) movaps(xmm7, xmm6) mulpd(xmm0, xmm7) mulpd(xmm1, xmm6) - + addpd(xmm2, xmm9) movaps(mem(rbx, -6*16), xmm2) addpd(xmm4, xmm13) @@ -943,7 +615,7 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm3, xmm5) mulpd(xmm0, xmm3) mulpd(xmm1, xmm4) - + addpd(xmm7, xmm8) addpd(xmm6, xmm12) movaps(xmm5, xmm6) @@ -951,9 +623,9 @@ void bli_dgemm_penryn_asm_4x4 movaps(mem(rax, -6*16), xmm0) mulpd(xmm1, xmm6) movaps(mem(rax, -5*16), xmm1) - - - + + + addpd(xmm3, xmm11) // iteration 1 movaps(mem(rbx, -5*16), xmm3) addpd(xmm4, xmm15) @@ -961,13 +633,13 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm2, xmm7) mulpd(xmm0, xmm2) mulpd(xmm1, xmm4) - + addpd(xmm5, xmm10) addpd(xmm6, xmm14) movaps(xmm7, xmm6) mulpd(xmm0, xmm7) mulpd(xmm1, xmm6) - + addpd(xmm2, xmm9) movaps(mem(rbx, -4*16), xmm2) addpd(xmm4, xmm13) @@ -975,7 +647,7 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm3, xmm5) mulpd(xmm0, xmm3) mulpd(xmm1, xmm4) - + addpd(xmm7, xmm8) addpd(xmm6, xmm12) movaps(xmm5, xmm6) @@ -983,16 +655,16 @@ void bli_dgemm_penryn_asm_4x4 movaps(mem(rax, -4*16), xmm0) mulpd(xmm1, xmm6) movaps(mem(rax, -3*16), xmm1) - - + + prefetch(0, mem(rax, (4*37+1)*8)) //prefetch(0, mem(rax, (8*97+12)*8)) - + //prefetch(0, mem(r11, 69*4*8)) // prefetch a_next[8] //sub(imm(-4*4*8), r11) // a_next += 4*4 (unroll x mr) - - - + + + addpd(xmm3, xmm11) // iteration 2 movaps(mem(rbx, -3*16), xmm3) addpd(xmm4, xmm15) @@ -1000,13 +672,13 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm2, xmm7) mulpd(xmm0, xmm2) mulpd(xmm1, xmm4) - + addpd(xmm5, xmm10) addpd(xmm6, xmm14) movaps(xmm7, xmm6) mulpd(xmm0, xmm7) mulpd(xmm1, xmm6) - + addpd(xmm2, xmm9) movaps(mem(rbx, -2*16), xmm2) addpd(xmm4, xmm13) @@ -1014,8 +686,8 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm3, xmm5) mulpd(xmm0, xmm3) mulpd(xmm1, xmm4) - - + + addpd(xmm7, xmm8) addpd(xmm6, xmm12) movaps(xmm5, xmm6) @@ -1023,9 +695,9 @@ void bli_dgemm_penryn_asm_4x4 movaps(mem(rax, -2*16), xmm0) mulpd(xmm1, xmm6) movaps(mem(rax, -1*16), xmm1) - - - + + + addpd(xmm3, xmm11) // iteration 3 movaps(mem(rbx, -1*16), xmm3) addpd(xmm4, xmm15) @@ -1033,17 +705,17 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm2, xmm7) mulpd(xmm0, xmm2) mulpd(xmm1, xmm4) - + sub(imm(0-4*4*8), rax) // a += 4*4 (unroll x mr) - + addpd(xmm5, xmm10) addpd(xmm6, xmm14) movaps(xmm7, xmm6) mulpd(xmm0, xmm7) mulpd(xmm1, xmm6) - + sub(imm(0-4*4*8), r9) // b_next += 4*4 (unroll x nr) - + addpd(xmm2, xmm9) movaps(mem(rbx, 0*16), xmm2) addpd(xmm4, xmm13) @@ -1051,9 +723,9 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm3, xmm5) mulpd(xmm0, xmm3) mulpd(xmm1, xmm4) - + sub(imm(0-4*4*8), rbx) // b += 4*4 (unroll x nr) - + addpd(xmm7, xmm8) addpd(xmm6, xmm12) movaps(xmm5, xmm6) @@ -1061,29 +733,29 @@ void bli_dgemm_penryn_asm_4x4 movaps(mem(rax, -8*16), xmm0) mulpd(xmm1, xmm6) movaps(mem(rax, -7*16), xmm1) - + prefetch(2, mem(r9, 0*8)) // prefetch b_next[0] prefetch(2, mem(r9, 8*8)) // prefetch b_next[8] - + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - + + + //prefetch(2, mem(r9, -8*8)) // prefetch b_next[-8] - - - + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + addpd(xmm3, xmm11) // iteration 0 movaps(mem(rbx, -7*16), xmm3) addpd(xmm4, xmm15) @@ -1091,13 +763,13 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm2, xmm7) mulpd(xmm0, xmm2) mulpd(xmm1, xmm4) - + addpd(xmm5, xmm10) addpd(xmm6, xmm14) movaps(xmm7, xmm6) mulpd(xmm0, xmm7) mulpd(xmm1, xmm6) - + addpd(xmm2, xmm9) movaps(mem(rbx, -6*16), xmm2) addpd(xmm4, xmm13) @@ -1105,7 +777,7 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm3, xmm5) mulpd(xmm0, xmm3) mulpd(xmm1, xmm4) - + addpd(xmm7, xmm8) addpd(xmm6, xmm12) movaps(xmm5, xmm6) @@ -1113,38 +785,38 @@ void bli_dgemm_penryn_asm_4x4 movaps(mem(rax, -6*16), xmm0) mulpd(xmm1, xmm6) movaps(mem(rax, -5*16), xmm1) - - + + sub(imm(0-4*1*8), rax) // a += 4 (1 x mr) sub(imm(0-4*1*8), rbx) // b += 4 (1 x nr) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + addpd(xmm3, xmm11) addpd(xmm4, xmm15) addpd(xmm5, xmm10) addpd(xmm6, xmm14) - - + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta movddup(mem(rax), xmm6) // load alpha and duplicate movddup(mem(rbx), xmm7) // load beta and duplicate - - + + mov(var(rs_c), rsi) // load rs_c mov(rsi, r8) // make a copy of rs_c - + lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) - + lea(mem(rcx, rsi, 2), rdx) // load address of c + 2*rs_c; - + // xmm8: xmm9: xmm10: xmm11: // ( ab01 ( ab00 ( ab03 ( ab02 // ab10 ) ab11 ) ab12 ) ab13 ) @@ -1155,15 +827,15 @@ void bli_dgemm_penryn_asm_4x4 movaps(xmm8, xmm0) movsd(xmm9, xmm8) movsd(xmm0, xmm9) - + movaps(xmm10, xmm0) movsd(xmm11, xmm10) movsd(xmm0, xmm11) - + movaps(xmm12, xmm0) movsd(xmm13, xmm12) movsd(xmm0, xmm13) - + movaps(xmm14, xmm0) movsd(xmm15, xmm14) movsd(xmm0, xmm15) @@ -1174,313 +846,133 @@ void bli_dgemm_penryn_asm_4x4 // xmm12: xmm13: xmm14: xmm15: // ( ab20 ( ab21 ( ab22 ( ab23 // ab30 ) ab31 ) ab32 ) ab33 ) - - - - // determine if - // c % 16 == 0, AND - // 8*cs_c % 16 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(1), r8) // set ZF if rs_c == 1. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(15), rcx) // set ZF if c & 16 is zero. - setz(bh) // bh = ( ZF == 1 ? 1 : 0 ); - test(imm(15), r12) // set ZF if (8*cs_c) & 16 is zero. - setz(al) // al = ( ZF == 1 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + // now avoid loading C if beta == 0 - + xorpd(xmm0, xmm0) // set xmm0 to zero. ucomisd(xmm0, xmm7) // check if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.DCOLSTORED) // jump to column storage case - - - - label(.DGENSTORED) - - movlpd(mem(rcx), xmm0) // load c00 and c10, - movhpd(mem(rcx, rsi, 1), xmm0) - mulpd(xmm6, xmm8) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm8, xmm0) // add the gemm result, - movlpd(xmm0, mem(rcx)) // and store back to memory. - movhpd(xmm0, mem(rcx, rsi, 1)) - add(rdi, rcx) - - movlpd(mem(rdx), xmm1) // load c20 and c30, - movhpd(mem(rdx, rsi, 1), xmm1) - mulpd(xmm6, xmm12) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm12, xmm1) // add the gemm result, - movlpd(xmm1, mem(rdx)) // and store back to memory. - movhpd(xmm1, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - - movlpd(mem(rcx), xmm0) // load c01 and c11, - movhpd(mem(rcx, rsi, 1), xmm0) - mulpd(xmm6, xmm9) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm9, xmm0) // add the gemm result, - movlpd(xmm0, mem(rcx)) // and store back to memory. - movhpd(xmm0, mem(rcx, rsi, 1)) - add(rdi, rcx) - - movlpd(mem(rdx), xmm1) // load c21 and c31, - movhpd(mem(rdx, rsi, 1), xmm1) - mulpd(xmm6, xmm13) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm13, xmm1) // add the gemm result, - movlpd(xmm1, mem(rdx)) // and store back to memory. - movhpd(xmm1, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - - movlpd(mem(rcx), xmm0) // load c02 and c12, - movhpd(mem(rcx, rsi, 1), xmm0) - mulpd(xmm6, xmm10) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm10, xmm0) // add the gemm result, - movlpd(xmm0, mem(rcx)) // and store back to memory. - movhpd(xmm0, mem(rcx, rsi, 1)) - add(rdi, rcx) - - movlpd(mem(rdx), xmm1) // load c22 and c32, - movhpd(mem(rdx, rsi, 1), xmm1) - mulpd(xmm6, xmm14) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm14, xmm1) // add the gemm result, - movlpd(xmm1, mem(rdx)) // and store back to memory. - movhpd(xmm1, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - - movlpd(mem(rcx), xmm0) // load c03 and c13, - movhpd(mem(rcx, rsi, 1), xmm0) - mulpd(xmm6, xmm11) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm11, xmm0) // add the gemm result, - movlpd(xmm0, mem(rcx)) // and store back to memory. - movhpd(xmm0, mem(rcx, rsi, 1)) - - - movlpd(mem(rdx), xmm1) // load c23 and c33, - movhpd(mem(rdx, rsi, 1), xmm1) - mulpd(xmm6, xmm15) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm15, xmm1) // add the gemm result, - movlpd(xmm1, mem(rdx)) // and store back to memory. - movhpd(xmm1, mem(rdx, rsi, 1)) - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORED) - - movaps(mem(rcx), xmm0) // load c00 and c10, - mulpd(xmm6, xmm8) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm8, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c20 and c30, - mulpd(xmm6, xmm12) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm12, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c01 and c11, - mulpd(xmm6, xmm9) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm9, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c21 and c31, - mulpd(xmm6, xmm13) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm13, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c02 and c12, - mulpd(xmm6, xmm10) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm10, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c22 and c32, - mulpd(xmm6, xmm14) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm14, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c03 and c13, - mulpd(xmm6, xmm11) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm11, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - - - movaps(mem(rdx), xmm1) // load c23 and c33, - mulpd(xmm6, xmm15) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm15, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - - jmp(.DDONE) // jump to end. - - - - + + movaps(mem(rcx), xmm0) // load c00 and c10, + mulpd(xmm6, xmm8) // scale by alpha, + mulpd(xmm7, xmm0) // scale by beta, + addpd(xmm8, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c20 and c30, + mulpd(xmm6, xmm12) // scale by alpha, + mulpd(xmm7, xmm1) // scale by beta, + addpd(xmm12, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c01 and c11, + mulpd(xmm6, xmm9) // scale by alpha, + mulpd(xmm7, xmm0) // scale by beta, + addpd(xmm9, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c21 and c31, + mulpd(xmm6, xmm13) // scale by alpha, + mulpd(xmm7, xmm1) // scale by beta, + addpd(xmm13, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c02 and c12, + mulpd(xmm6, xmm10) // scale by alpha, + mulpd(xmm7, xmm0) // scale by beta, + addpd(xmm10, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c22 and c32, + mulpd(xmm6, xmm14) // scale by alpha, + mulpd(xmm7, xmm1) // scale by beta, + addpd(xmm14, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c03 and c13, + mulpd(xmm6, xmm11) // scale by alpha, + mulpd(xmm7, xmm0) // scale by beta, + addpd(xmm11, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + + + movaps(mem(rdx), xmm1) // load c23 and c33, + mulpd(xmm6, xmm15) // scale by alpha, + mulpd(xmm7, xmm1) // scale by beta, + addpd(xmm15, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + + jmp(.DDONE) // jump to end. + label(.DBETAZERO) - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.DCOLSTORBZ) // jump to column storage case - - - - label(.DGENSTORBZ) - // skip loading c00 and c10, - mulpd(xmm6, xmm8) // scale by alpha, - movlpd(xmm8, mem(rcx)) // and store back to memory. - movhpd(xmm8, mem(rcx, rsi, 1)) - add(rdi, rcx) - // skip loading c20 and c30, - mulpd(xmm6, xmm12) // scale by alpha, - movlpd(xmm12, mem(rdx)) // and store back to memory. - movhpd(xmm12, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - // skip loading c01 and c11, - mulpd(xmm6, xmm9) // scale by alpha, - movlpd(xmm9, mem(rcx)) // and store back to memory. - movhpd(xmm9, mem(rcx, rsi, 1)) - add(rdi, rcx) - // skip loading c21 and c31, - mulpd(xmm6, xmm13) // scale by alpha, - movlpd(xmm13, mem(rdx)) // and store back to memory. - movhpd(xmm13, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - // skip loading c02 and c12, - mulpd(xmm6, xmm10) // scale by alpha, - movlpd(xmm10, mem(rcx)) // and store back to memory. - movhpd(xmm10, mem(rcx, rsi, 1)) - add(rdi, rcx) - // skip loading c22 and c32, - mulpd(xmm6, xmm14) // scale by alpha, - movlpd(xmm14, mem(rdx)) // and store back to memory. - movhpd(xmm14, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - // skip loading c03 and c13, - mulpd(xmm6, xmm11) // scale by alpha, - movlpd(xmm11, mem(rcx)) // and store back to memory. - movhpd(xmm11, mem(rcx, rsi, 1)) - - // skip loading c23 and c33, - mulpd(xmm6, xmm15) // scale by alpha, - movlpd(xmm15, mem(rdx)) // and store back to memory. - movhpd(xmm15, mem(rdx, rsi, 1)) - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORBZ) - - // skip loading c00 and c10, - mulpd(xmm6, xmm8) // scale by alpha, - movaps(xmm8, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c20 and c30, - mulpd(xmm6, xmm12) // scale by alpha, - movaps(xmm12, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c01 and c11, - mulpd(xmm6, xmm9) // scale by alpha, - movaps(xmm9, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c21 and c31, - mulpd(xmm6, xmm13) // scale by alpha, - movaps(xmm13, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c02 and c12, - mulpd(xmm6, xmm10) // scale by alpha, - movaps(xmm10, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c22 and c32, - mulpd(xmm6, xmm14) // scale by alpha, - movaps(xmm14, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c03 and c13, - mulpd(xmm6, xmm11) // scale by alpha, - movaps(xmm11, mem(rcx)) // and store back to memory. - - // skip loading c23 and c33, - mulpd(xmm6, xmm15) // scale by alpha, - movaps(xmm15, mem(rdx)) // and store back to memory. - - - - - - - - + + // skip loading c00 and c10, + mulpd(xmm6, xmm8) // scale by alpha, + movaps(xmm8, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c20 and c30, + mulpd(xmm6, xmm12) // scale by alpha, + movaps(xmm12, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c01 and c11, + mulpd(xmm6, xmm9) // scale by alpha, + movaps(xmm9, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c21 and c31, + mulpd(xmm6, xmm13) // scale by alpha, + movaps(xmm13, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c02 and c12, + mulpd(xmm6, xmm10) // scale by alpha, + movaps(xmm10, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c22 and c32, + mulpd(xmm6, xmm14) // scale by alpha, + movaps(xmm14, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c03 and c13, + mulpd(xmm6, xmm11) // scale by alpha, + movaps(xmm11, mem(rcx)) // and store back to memory. + + // skip loading c23 and c33, + mulpd(xmm6, xmm15) // scale by alpha, + movaps(xmm15, mem(rdx)) // and store back to memory. + label(.DDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next) // 10 : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "xmm0", "xmm1", "xmm2", "xmm3", @@ -1489,6 +981,8 @@ void bli_dgemm_penryn_asm_4x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/piledriver/3/bli_gemm_piledriver_asm_d8x3.c b/kernels/piledriver/3/bli_gemm_piledriver_asm_d8x3.c index 5963dabee6..e65ce7178a 100644 --- a/kernels/piledriver/3/bli_gemm_piledriver_asm_d8x3.c +++ b/kernels/piledriver/3/bli_gemm_piledriver_asm_d8x3.c @@ -42,7 +42,9 @@ void bli_sgemm_piledriver_asm_16x3 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -57,36 +59,38 @@ void bli_sgemm_piledriver_asm_16x3 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 8; - uint64_t k_left = k0 % 8; + uint64_t k_iter = k / 8; + uint64_t k_left = k % 8; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( s, 16, 3, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. mov(var(a_next), r14) // load address of a_next. - + prefetch(0, mem(rbx, 128)) // prefetch b prefetch(0, mem(rbx, 64+128)) // prefetch b prefetch(0, mem(rbx, 128+128)) // prefetch b - + add(imm(32*4), rax) add(imm(12*4), rbx) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float) lea(mem(rcx, rdi, 1), r10) // load address of c + 1*cs_c; lea(mem(rcx, rdi, 2), r11) // load address of c + 2*cs_c; - + vbroadcastss(mem(rbx, -12*4), xmm1) vbroadcastss(mem(rbx, -11*4), xmm2) vbroadcastss(mem(rbx, -10*4), xmm3) - + vxorps(xmm4, xmm4, xmm4) vxorps(xmm5, xmm5, xmm5) vxorps(xmm6, xmm6, xmm6) @@ -99,23 +103,23 @@ void bli_sgemm_piledriver_asm_16x3 vxorps(xmm13, xmm13, xmm13) vxorps(xmm14, xmm14, xmm14) vxorps(xmm15, xmm15, xmm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.SLOOPKITER) // MAIN LOOP - - + + je(.SCONSIDKLEFT) // if i == 0, jump to k_left code. - - + + prefetch(0, mem(rbx, 16+192)) // prefetch b - + // iteration 0 vmovaps(mem(rax, -32*4), xmm0) prefetch(0, mem(rax, 384)) @@ -136,7 +140,7 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, -8*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + // iteration 1 vmovaps(mem(rax, -16*4), xmm0) vbroadcastss(mem(rbx, -7*4), xmm3) @@ -158,7 +162,7 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, -5*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + // iteration 2 vmovaps(mem(rax, 0*4), xmm0) vbroadcastss(mem(rbx, -4*4), xmm3) @@ -180,7 +184,7 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, -2*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + // iteration 3 vmovaps(mem(rax, 16*4), xmm0) vbroadcastss(mem(rbx, -1*4), xmm3) @@ -202,10 +206,10 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, 1*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - - + + add(imm(4*16*4), rax) // a += 4*16 (unroll x mr) - + // iteration 4 vmovaps(mem(rax, -32*4), xmm0) vbroadcastss(mem(rbx, 2*4), xmm3) @@ -227,9 +231,9 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, 4*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + prefetch(0, mem(rbx, 80+192)) // prefetch b - + // iteration 5 vmovaps(mem(rax, -16*4), xmm0) vbroadcastss(mem(rbx, 5*4), xmm3) @@ -251,7 +255,7 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, 7*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + // iteration 6 vmovaps(mem(rax, 0*4), xmm0) vbroadcastss(mem(rbx, 8*4), xmm3) @@ -273,7 +277,7 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, 10*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + // iteration 7 vmovaps(mem(rax, 16*4), xmm0) vbroadcastss(mem(rbx, 11*4), xmm3) @@ -298,34 +302,34 @@ void bli_sgemm_piledriver_asm_16x3 vbroadcastss(mem(rbx, -11*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) vbroadcastss(mem(rbx, -10*4), xmm3) - - - - + + + + dec(rsi) // i -= 1; jmp(.SLOOPKITER) // jump to beginning of loop. - - - - - - + + + + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.SLOOPKLEFT) // EDGE LOOP - - + + je(.SPOSTACCUM) // if i == 0, we're done. - - + + prefetch(0, mem(rbx, 16+192)) // prefetch b - + // iteration 0 vmovaps(mem(rax, -32*4), xmm0) prefetch(0, mem(rax, 384)) @@ -347,56 +351,56 @@ void bli_sgemm_piledriver_asm_16x3 vbroadcastss(mem(rbx, -8*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) vbroadcastss(mem(rbx, -7*4), xmm3) - - + + add(imm(1*16*4), rax) // a += 4*16 (unroll x mr) add(imm(1*3*4), rbx) // a += 4*3 (unroll x nr) - - + + dec(rsi) // i -= 1; jmp(.SLOOPKLEFT) // jump to beginning of loop. - - - + + + label(.SPOSTACCUM) - - + + prefetchw0(mem(rcx, 0*8)) // prefetch c + 0*cs_c prefetchw0(mem(r10, 0*8)) // prefetch c + 1*cs_c prefetchw0(mem(r11, 0*8)) // prefetch c + 2*cs_c - - - // xmm4: xmm5: xmm6: + + + // xmm4: xmm5: xmm6: // ( ab00 ( ab01 ( ab02 - // ab10 ab11 ab12 + // ab10 ab11 ab12 // ab20 ab21 ab22 // ab30 ) ab31 ) ab32 ) - - // xmm7: xmm8: xmm9: + + // xmm7: xmm8: xmm9: // ( ab40 ( ab41 ( ab42 - // ab50 ab51 ab52 + // ab50 ab51 ab52 // ab60 ab61 ab62 // ab70 ) ab71 ) ab72 ) - + // xmm10: xmm11: xmm12: // ( ab80 ( ab01 ( ab02 - // ab90 ab11 ab12 + // ab90 ab11 ab12 // abA0 abA1 abA2 // abB0 ) abB1 ) abB2 ) - + // xmm13: xmm14: xmm15: // ( abC0 ( abC1 ( abC2 - // abD0 abD1 abD2 + // abD0 abD1 abD2 // abE0 abE1 abE2 // abF0 ) abF1 ) abF2 ) - - - + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rax), xmm0) // load alpha and duplicate vbroadcastss(mem(rbx), xmm2) // load beta and duplicate - + vmulps(xmm0, xmm4, xmm4) // scale by alpha vmulps(xmm0, xmm5, xmm5) vmulps(xmm0, xmm6, xmm6) @@ -409,32 +413,32 @@ void bli_sgemm_piledriver_asm_16x3 vmulps(xmm0, xmm13, xmm13) vmulps(xmm0, xmm14, xmm14) vmulps(xmm0, xmm15, xmm15) - - - + + + prefetch(0, mem(r14)) // prefetch a_next prefetch(0, mem(r14, 64)) // prefetch a_next - - - - + + + + mov(var(rs_c), rsi) // load rs_c lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - + lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - - - + + + // determine if // c % 32 == 0, AND // 4*cs_c % 32 == 0, AND // rs_c == 1 // ie: aligned, ldim aligned, and // column-stored - + cmp(imm(4), rsi) // set ZF if (4*rs_c) == 4. sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); test(imm(31), rcx) // set ZF if c & 32 is zero. @@ -443,465 +447,69 @@ void bli_sgemm_piledriver_asm_16x3 setz(al) // al = ( ZF == 0 ? 1 : 0 ); // and(bl,bh) followed by // and(bh,al) will reveal result - + prefetch(0, mem(r15)) // prefetch b_next prefetch(0, mem(r15, 64)) // prefetch b_next - + // now avoid loading C if beta == 0 - + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. vucomiss(xmm0, xmm2) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORED) // jump to column storage case - - - - label(.SGENSTORED) - - - vmovlps(mem(rcx), xmm0, xmm0) // load c00:c30 - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm4, xmm0, xmm0) - vmovss(xmm0, mem(rcx)) // store c00:c30 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovlps(mem(rcx), xmm0, xmm0) // load c40:c70 - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm7, xmm0, xmm0) - vmovss(xmm0, mem(rcx)) // store c40:c70 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovlps(mem(rcx), xmm0, xmm0) // load c80:cB0 - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm10, xmm0, xmm0) - vmovss(xmm0, mem(rcx)) // store c80:cB0 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovlps(mem(rcx), xmm0, xmm0) // load cC0:cF0 - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm13, xmm0, xmm0) - vmovss(xmm0, mem(rcx)) // store cC0:cF0 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovlps(mem(r10), xmm0, xmm0) // load c01:c31 - vmovhps(mem(r10, rsi, 1), xmm0, xmm0) - vmovlps(mem(r10, r12, 1), xmm1, xmm1) - vmovhps(mem(r10, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm5, xmm0, xmm0) - vmovss(xmm0, mem(r10)) // store c01:c31 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovlps(mem(r10), xmm0, xmm0) // load c41:c71 - vmovhps(mem(r10, rsi, 1), xmm0, xmm0) - vmovlps(mem(r10, r12, 1), xmm1, xmm1) - vmovhps(mem(r10, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm8, xmm0, xmm0) - vmovss(xmm0, mem(r10)) // store c41:c71 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovlps(mem(r10), xmm0, xmm0) // load c81:cB1 - vmovhps(mem(r10, rsi, 1), xmm0, xmm0) - vmovlps(mem(r10, r12, 1), xmm1, xmm1) - vmovhps(mem(r10, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm11, xmm0, xmm0) - vmovss(xmm0, mem(r10)) // store c81:cB1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovlps(mem(r10), xmm0, xmm0) // load cC1:cF1 - vmovhps(mem(r10, rsi, 1), xmm0, xmm0) - vmovlps(mem(r10, r12, 1), xmm1, xmm1) - vmovhps(mem(r10, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm14, xmm0, xmm0) - vmovss(xmm0, mem(r10)) // store cC1:cF1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovlps(mem(r11), xmm0, xmm0) // load c02:c32 - vmovhps(mem(r11, rsi, 1), xmm0, xmm0) - vmovlps(mem(r11, r12, 1), xmm1, xmm1) - vmovhps(mem(r11, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm6, xmm0, xmm0) - vmovss(xmm0, mem(r11)) // store c02:c32 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovlps(mem(r11), xmm0, xmm0) // load c42:c72 - vmovhps(mem(r11, rsi, 1), xmm0, xmm0) - vmovlps(mem(r11, r12, 1), xmm1, xmm1) - vmovhps(mem(r11, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm9, xmm0, xmm0) - vmovss(xmm0, mem(r11)) // store c42:c72 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovlps(mem(r11), xmm0, xmm0) // load c82:cB2 - vmovhps(mem(r11, rsi, 1), xmm0, xmm0) - vmovlps(mem(r11, r12, 1), xmm1, xmm1) - vmovhps(mem(r11, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm12, xmm0, xmm0) - vmovss(xmm0, mem(r11)) // store c82:cB2 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovlps(mem(r11), xmm0, xmm0) // load cC2:cF2 - vmovhps(mem(r11, rsi, 1), xmm0, xmm0) - vmovlps(mem(r11, r12, 1), xmm1, xmm1) - vmovhps(mem(r11, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm15, xmm0, xmm0) - vmovss(xmm0, mem(r11)) // store cC2:cF1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - - vfmadd231ps(mem(rcx, 0*16), xmm2, xmm4) - vfmadd231ps(mem(rcx, 1*16), xmm2, xmm7) - vfmadd231ps(mem(rcx, 2*16), xmm2, xmm10) - vfmadd231ps(mem(rcx, 3*16), xmm2, xmm13) - - vmovups(xmm4, mem(rcx, 0*16)) - vmovups(xmm7, mem(rcx, 1*16)) - vmovups(xmm10, mem(rcx, 2*16)) - vmovups(xmm13, mem(rcx, 3*16)) - - vfmadd231ps(mem(r10, 0*16), xmm2, xmm5) - vfmadd231ps(mem(r10, 1*16), xmm2, xmm8) - vfmadd231ps(mem(r10, 2*16), xmm2, xmm11) - vfmadd231ps(mem(r10, 3*16), xmm2, xmm14) - - vmovups(xmm5, mem(r10, 0*16)) - vmovups(xmm8, mem(r10, 1*16)) - vmovups(xmm11, mem(r10, 2*16)) - vmovups(xmm14, mem(r10, 3*16)) - - vfmadd231ps(mem(r11, 0*16), xmm2, xmm6) - vfmadd231ps(mem(r11, 1*16), xmm2, xmm9) - vfmadd231ps(mem(r11, 2*16), xmm2, xmm12) - vfmadd231ps(mem(r11, 3*16), xmm2, xmm15) - - vmovups(xmm6, mem(r11, 0*16)) - vmovups(xmm9, mem(r11, 1*16)) - vmovups(xmm12, mem(r11, 2*16)) - vmovups(xmm15, mem(r11, 3*16)) - - - - jmp(.SDONE) // jump to end. - - - + + vfmadd231ps(mem(rcx, 0*16), xmm2, xmm4) + vfmadd231ps(mem(rcx, 1*16), xmm2, xmm7) + vfmadd231ps(mem(rcx, 2*16), xmm2, xmm10) + vfmadd231ps(mem(rcx, 3*16), xmm2, xmm13) + + vfmadd231ps(mem(r10, 0*16), xmm2, xmm5) + vfmadd231ps(mem(r10, 1*16), xmm2, xmm8) + vfmadd231ps(mem(r10, 2*16), xmm2, xmm11) + vfmadd231ps(mem(r10, 3*16), xmm2, xmm14) + + vfmadd231ps(mem(r11, 0*16), xmm2, xmm6) + vfmadd231ps(mem(r11, 1*16), xmm2, xmm9) + vfmadd231ps(mem(r11, 2*16), xmm2, xmm12) + vfmadd231ps(mem(r11, 3*16), xmm2, xmm15) + + // fall through + label(.SBETAZERO) - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORBZ) // jump to column storage case - - - - label(.SGENSTORBZ) - - - vmovaps(xmm4, xmm0) - vmovss(xmm0, mem(rcx)) // store c00:c30 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovaps(xmm7, xmm0) - vmovss(xmm0, mem(rcx)) // store c40:c70 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovaps(xmm10, xmm0) - vmovss(xmm0, mem(rcx)) // store c80:cB0 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovaps(xmm13, xmm0) - vmovss(xmm0, mem(rcx)) // store cC0:cF0 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovaps(xmm5, xmm0) - vmovss(xmm0, mem(r10)) // store c01:c31 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovaps(xmm8, xmm0) - vmovss(xmm0, mem(r10)) // store c41:c71 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovaps(xmm11, xmm0) - vmovss(xmm0, mem(r10)) // store c81:cB1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovaps(xmm14, xmm0) - vmovss(xmm0, mem(r10)) // store cC1:cF1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovaps(xmm6, xmm0) - vmovss(xmm0, mem(r11)) // store c02:c32 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovaps(xmm9, xmm0) - vmovss(xmm0, mem(r11)) // store c42:c72 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovaps(xmm12, xmm0) - vmovss(xmm0, mem(r11)) // store c82:cB2 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovaps(xmm15, xmm0) - vmovss(xmm0, mem(r11)) // store cC2:cF1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORBZ) - - - vmovups(xmm4, mem(rcx, 0*16)) - vmovups(xmm7, mem(rcx, 1*16)) - vmovups(xmm10, mem(rcx, 2*16)) - vmovups(xmm13, mem(rcx, 3*16)) - - vmovups(xmm5, mem(r10, 0*16)) - vmovups(xmm8, mem(r10, 1*16)) - vmovups(xmm11, mem(r10, 2*16)) - vmovups(xmm14, mem(r10, 3*16)) - - vmovups(xmm6, mem(r11, 0*16)) - vmovups(xmm9, mem(r11, 1*16)) - vmovups(xmm12, mem(r11, 2*16)) - vmovups(xmm15, mem(r11, 3*16)) - - - - - - + + vmovups(xmm4, mem(rcx, 0*16)) + vmovups(xmm7, mem(rcx, 1*16)) + vmovups(xmm10, mem(rcx, 2*16)) + vmovups(xmm13, mem(rcx, 3*16)) + + vmovups(xmm5, mem(r10, 0*16)) + vmovups(xmm8, mem(r10, 1*16)) + vmovups(xmm11, mem(r10, 2*16)) + vmovups(xmm14, mem(r10, 3*16)) + + vmovups(xmm6, mem(r11, 0*16)) + vmovups(xmm9, mem(r11, 1*16)) + vmovups(xmm12, mem(r11, 2*16)) + vmovups(xmm15, mem(r11, 3*16)) + label(.SDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next) // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -909,11 +517,15 @@ void bli_sgemm_piledriver_asm_16x3 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } void bli_dgemm_piledriver_asm_8x3 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -928,36 +540,38 @@ void bli_dgemm_piledriver_asm_8x3 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 8; - uint64_t k_left = k0 % 8; + uint64_t k_iter = k / 8; + uint64_t k_left = k % 8; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( d, 8, 3, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. mov(var(a_next), r14) // load address of a_next. - + prefetch(0, mem(rbx, 128)) // prefetch b prefetch(0, mem(rbx, 64+128)) // prefetch b prefetch(0, mem(rbx, 128+128)) // prefetch b - + add(imm(16*8), rax) add(imm(12*8), rbx) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) lea(mem(rcx, rdi, 1), r10) // load address of c + 1*cs_c; lea(mem(rcx, rdi, 2), r11) // load address of c + 2*cs_c; - + vmovddup(mem(rbx, -12*8), xmm1) vmovddup(mem(rbx, -11*8), xmm2) vmovddup(mem(rbx, -10*8), xmm3) - + vxorpd(xmm4, xmm4, xmm4) vxorpd(xmm5, xmm5, xmm5) vxorpd(xmm6, xmm6, xmm6) @@ -970,24 +584,24 @@ void bli_dgemm_piledriver_asm_8x3 vxorpd(xmm13, xmm13, xmm13) vxorpd(xmm14, xmm14, xmm14) vxorpd(xmm15, xmm15, xmm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + je(.DCONSIDKLEFT) // if i == 0, jump to k_left code. - - + + prefetch(0, mem(rbx, -32+256)) // prefetch b prefetch(0, mem(rbx, 32+256)) // prefetch b - + // iteration 0 vmovaps(mem(rax, -8*16), xmm0) prefetch(0, mem(rax, 384)) // prefetch a @@ -1008,7 +622,7 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, -8*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + // iteration 1 vmovaps(mem(rax, -4*16), xmm0) prefetch(0, mem(rax, 64+384)) // prefetch a @@ -1030,7 +644,7 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, -5*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + // iteration 2 vmovaps(mem(rax, 0*16), xmm0) prefetch(0, mem(rax, 128+384)) // prefetch a @@ -1052,7 +666,7 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, -2*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + // iteration 3 vmovaps(mem(rax, 4*16), xmm0) prefetch(0, mem(rax, 192+384)) // prefetch a @@ -1075,7 +689,7 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, 1*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + // iteration 4 vmovaps(mem(rax, -8*16), xmm0) prefetch(0, mem(rax, 384)) // prefetch a @@ -1097,9 +711,9 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, 4*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + prefetch(0, mem(rbx, 96+256)) // prefetch b - + // iteration 5 vmovaps(mem(rax, -4*16), xmm0) prefetch(0, mem(rax, 64+384)) // prefetch a @@ -1121,8 +735,8 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, 7*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - - + + // iteration 6 vmovaps(mem(rax, 0*16), xmm0) prefetch(0, mem(rax, 128+384)) // prefetch a @@ -1144,7 +758,7 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, 10*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + // iteration 7 vmovaps(mem(rax, 4*16), xmm0) prefetch(0, mem(rax, 192+384)) // prefetch a @@ -1169,31 +783,31 @@ void bli_dgemm_piledriver_asm_8x3 vmovddup(mem(rbx, -11*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) vmovddup(mem(rbx, -10*8), xmm3) - - - + + + dec(rsi) // i -= 1; jmp(.DLOOPKITER) // jump to beginning of loop. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done. // else, we prepare to // enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - - + + je(.DPOSTACCUM) // if i == 0, we're done. - + // iteration 0 vmovaps(mem(rax, -8*16), xmm0) prefetch(0, mem(rax, 512)) // prefetch a @@ -1215,48 +829,48 @@ void bli_dgemm_piledriver_asm_8x3 vmovddup(mem(rbx, -8*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) vmovddup(mem(rbx, -7*8), xmm3) - - + + add(imm(1*8*8), rax) // a += 1*8 (1 x mr) add(imm(1*3*8), rbx) // b += 1*3 (1 x nr) - - + + dec(rsi) // i -= 1; jmp(.DLOOPKLEFT) // jump to beginning of loop. - - - + + + label(.DPOSTACCUM) - + prefetchw0(mem(rcx, 0*8)) // prefetch c + 0*cs_c prefetchw0(mem(r10, 0*8)) // prefetch c + 1*cs_c prefetchw0(mem(r11, 0*8)) // prefetch c + 2*cs_c - - - // xmm4: xmm5: xmm6: - // ( ab00 ( ab01 ( ab02 + + + // xmm4: xmm5: xmm6: + // ( ab00 ( ab01 ( ab02 // ab10 ) ab11 ) ab12 ) // - // xmm7: xmm8: xmm9: - // ( ab20 ( ab21 ( ab22 + // xmm7: xmm8: xmm9: + // ( ab20 ( ab21 ( ab22 // ab30 ) ab31 ) ab32 ) // - // xmm10: xmm11: xmm12: - // ( ab40 ( ab41 ( ab42 + // xmm10: xmm11: xmm12: + // ( ab40 ( ab41 ( ab42 // ab50 ) ab51 ) ab52 ) // - // xmm13: xmm14: xmm15: - // ( ab60 ( ab61 ( ab62 + // xmm13: xmm14: xmm15: + // ( ab60 ( ab61 ( ab62 // ab70 ) ab71 ) ab72 ) - - - - + + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vmovddup(mem(rax), xmm0) // load alpha and duplicate vmovddup(mem(rbx), xmm2) // load beta and duplicate - + vmulpd(xmm0, xmm4, xmm4) // scale by alpha vmulpd(xmm0, xmm5, xmm5) vmulpd(xmm0, xmm6, xmm6) @@ -1269,358 +883,89 @@ void bli_dgemm_piledriver_asm_8x3 vmulpd(xmm0, xmm13, xmm13) vmulpd(xmm0, xmm14, xmm14) vmulpd(xmm0, xmm15, xmm15) - - + + prefetch(0, mem(r14)) // prefetch a_next prefetch(0, mem(r14, 64)) // prefetch a_next - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) - - lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - - lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - - - - // determine if - // c % 32 == 0, AND - // 8*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (8*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + prefetch(0, mem(r15)) // prefetch b_next prefetch(0, mem(r15, 64)) // prefetch b_next - + // now avoid loading C if beta == 0 - + vxorpd(xmm0, xmm0, xmm0) // set xmm0 to zero. vucomisd(xmm0, xmm2) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - je(.DGENSTORED) // jump to column storage case - - - - label(.DCOLSTORED) - - // xmm4: xmm5: xmm6: - // ( ab00 ( ab01 ( ab02 - // ab10 ) ab11 ) ab12 ) - // - // xmm7: xmm8: xmm9: - // ( ab20 ( ab21 ( ab22 - // ab30 ) ab31 ) ab32 ) - // - // xmm10: xmm11: xmm12: - // ( ab40 ( ab41 ( ab42 - // ab50 ) ab51 ) ab52 ) - // - // xmm13: xmm14: xmm15: - // ( ab60 ( ab61 ( ab62 - // ab70 ) ab71 ) ab72 ) - - - vfmadd231pd(mem(rcx, 0*16), xmm2, xmm4) - vfmadd231pd(mem(rcx, 1*16), xmm2, xmm7) - vfmadd231pd(mem(rcx, 2*16), xmm2, xmm10) - vfmadd231pd(mem(rcx, 3*16), xmm2, xmm13) - - vfmadd231pd(mem(r10, 0*16), xmm2, xmm5) - vfmadd231pd(mem(r10, 1*16), xmm2, xmm8) - vfmadd231pd(mem(r10, 2*16), xmm2, xmm11) - vfmadd231pd(mem(r10, 3*16), xmm2, xmm14) - - vfmadd231pd(mem(r11, 0*16), xmm2, xmm6) - vfmadd231pd(mem(r11, 1*16), xmm2, xmm9) - vfmadd231pd(mem(r11, 2*16), xmm2, xmm12) - vfmadd231pd(mem(r11, 3*16), xmm2, xmm15) - - - vmovups(xmm4, mem(rcx, 0*16)) - vmovups(xmm7, mem(rcx, 1*16)) - vmovups(xmm10, mem(rcx, 2*16)) - vmovups(xmm13, mem(rcx, 3*16)) - - vmovups(xmm5, mem(r10, 0*16)) - vmovups(xmm8, mem(r10, 1*16)) - vmovups(xmm11, mem(r10, 2*16)) - vmovups(xmm14, mem(r10, 3*16)) - - vmovups(xmm6, mem(r11, 0*16)) - vmovups(xmm9, mem(r11, 1*16)) - vmovups(xmm12, mem(r11, 2*16)) - vmovups(xmm15, mem(r11, 3*16)) - - - - -/* - vmovupd(mem(rcx), xmm0) // load c00:c10 - vmovupd(mem(rcx, r12, 1), xmm1) // load c20:c30 - vfmadd231pd(xmm2, xmm0, xmm4) - vfmadd231pd(xmm2, xmm1, xmm7) - vmovupd(xmm4, mem(rcx)) // store c00:c10 - vmovupd(xmm7, mem(rcx, r12, 1)) // store c20:c30 - add(rdi, rcx) - - vmovupd(mem(rdx), xmm0) // load c40:c50 - vmovupd(mem(rdx, r12, 1), xmm1) // load c60:c70 - vfmadd213pd(xmm10, xmm2, xmm0) - vfmadd213pd(xmm13, xmm2, xmm1) - vmovupd(xmm0, mem(rdx)) // store c40:c50 - vmovupd(xmm1, mem(rdx, r12, 1)) // store c60:c70 - add(rdi, rdx) - - - vmovupd(mem(rcx), xmm0) // load c01:c11 - vmovupd(mem(rcx, r12, 1), xmm1) // load c21:c31 - vfmadd213pd(xmm5, xmm2, xmm0) - vfmadd213pd(xmm8, xmm2, xmm1) - vmovupd(xmm0, mem(rcx)) // store c01:c11 - vmovupd(xmm1, mem(rcx, r12, 1)) // store c21:c31 - add(rdi, rcx) - - vmovupd(mem(rdx), xmm0) // load c41:c51 - vmovupd(mem(rdx, r12, 1), xmm1) // load c61:c71 - vfmadd213pd(xmm11, xmm2, xmm0) - vfmadd213pd(xmm14, xmm2, xmm1) - vmovupd(xmm0, mem(rdx)) // store c41:c51 - vmovupd(xmm1, mem(rdx, r12, 1)) // store c61:c71 - add(rdi, rdx) - - - vmovupd(mem(rcx), xmm0) // load c02:c12 - vmovupd(mem(rcx, r12, 1), xmm1) // load c22:c32 - vfmadd213pd(xmm6, xmm2, xmm0) - vfmadd213pd(xmm9, xmm2, xmm1) - vmovupd(xmm0, mem(rcx)) // store c02:c12 - vmovupd(xmm1, mem(rcx, r12, 1)) // store c22:c32 - - vmovupd(mem(rdx), xmm0) // load c42:c52 - vmovupd(mem(rdx, r12, 1), xmm1) // load c62:c72 - vfmadd213pd(xmm12, xmm2, xmm0) - vfmadd213pd(xmm15, xmm2, xmm1) - vmovupd(xmm0, mem(rdx)) // store c42:c52 - vmovupd(xmm1, mem(rdx, r12, 1)) // store c62:c72 -*/ - - - - jmp(.DDONE) // jump to end. - - - - label(.DGENSTORED) - - - vmovlpd(mem(rcx), xmm0, xmm0) // load c00:c10 - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm4, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx)) // store c00:c10 - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c20:c30 - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm7, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx, r12, 1)) // store c20:c30 - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) - - vmovlpd(mem(rdx), xmm0, xmm0) // load c40:c50 - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm10, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx)) // store c40:c50 - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c60:c70 - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm13, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx, r12, 1)) // store c60:c70 - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) - - - vmovlpd(mem(rcx), xmm0, xmm0) // load c01:c11 - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm5, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx)) // store c01:c11 - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c21:c31 - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm8, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx, r12, 1)) // store c21:c31 - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) - - vmovlpd(mem(rdx), xmm0, xmm0) // load c41:c51 - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm11, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx)) // store c41:c51 - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c61:c71 - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm14, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx, r12, 1)) // store c61:c71 - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) - - - vmovlpd(mem(rcx), xmm0, xmm0) // load c02:c12 - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm6, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx)) // store c02:c12 - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c22:c32 - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm9, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx, r12, 1)) // store c22:c32 - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) - - vmovlpd(mem(rdx), xmm0, xmm0) // load c42:c52 - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm12, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx)) // store c42:c52 - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c62:c72 - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm15, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx, r12, 1)) // store c62:c72 - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) - - - - jmp(.DDONE) // jump to end. - - - + + // xmm4: xmm5: xmm6: + // ( ab00 ( ab01 ( ab02 + // ab10 ) ab11 ) ab12 ) + // + // xmm7: xmm8: xmm9: + // ( ab20 ( ab21 ( ab22 + // ab30 ) ab31 ) ab32 ) + // + // xmm10: xmm11: xmm12: + // ( ab40 ( ab41 ( ab42 + // ab50 ) ab51 ) ab52 ) + // + // xmm13: xmm14: xmm15: + // ( ab60 ( ab61 ( ab62 + // ab70 ) ab71 ) ab72 ) + + vfmadd231pd(mem(rcx, 0*16), xmm2, xmm4) + vfmadd231pd(mem(rcx, 1*16), xmm2, xmm7) + vfmadd231pd(mem(rcx, 2*16), xmm2, xmm10) + vfmadd231pd(mem(rcx, 3*16), xmm2, xmm13) + + vfmadd231pd(mem(r10, 0*16), xmm2, xmm5) + vfmadd231pd(mem(r10, 1*16), xmm2, xmm8) + vfmadd231pd(mem(r10, 2*16), xmm2, xmm11) + vfmadd231pd(mem(r10, 3*16), xmm2, xmm14) + + vfmadd231pd(mem(r11, 0*16), xmm2, xmm6) + vfmadd231pd(mem(r11, 1*16), xmm2, xmm9) + vfmadd231pd(mem(r11, 2*16), xmm2, xmm12) + vfmadd231pd(mem(r11, 3*16), xmm2, xmm15) + + // fall through + label(.DBETAZERO) - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.DCOLSTORBZ) // jump to column storage case - - - - label(.DGENSTORBZ) - - - vmovlpd(xmm4, mem(rcx)) - vmovhpd(xmm4, mem(rcx, rsi, 1)) - vmovlpd(xmm7, mem(rcx, r12, 1)) - vmovhpd(xmm7, mem(rcx, r13, 1)) - add(rdi, rcx) - vmovlpd(xmm10, mem(rdx)) - vmovhpd(xmm10, mem(rdx, rsi, 1)) - vmovlpd(xmm13, mem(rdx, r12, 1)) - vmovhpd(xmm13, mem(rdx, r13, 1)) - add(rdi, rdx) - - vmovlpd(xmm5, mem(rcx)) - vmovhpd(xmm5, mem(rcx, rsi, 1)) - vmovlpd(xmm8, mem(rcx, r12, 1)) - vmovhpd(xmm8, mem(rcx, r13, 1)) - add(rdi, rcx) - vmovlpd(xmm11, mem(rdx)) - vmovhpd(xmm11, mem(rdx, rsi, 1)) - vmovlpd(xmm14, mem(rdx, r12, 1)) - vmovhpd(xmm14, mem(rdx, r13, 1)) - add(rdi, rdx) - - vmovlpd(xmm6, mem(rcx)) - vmovhpd(xmm6, mem(rcx, rsi, 1)) - vmovlpd(xmm9, mem(rcx, r12, 1)) - vmovhpd(xmm9, mem(rcx, r13, 1)) - add(rdi, rcx) - vmovlpd(xmm12, mem(rdx)) - vmovhpd(xmm12, mem(rdx, rsi, 1)) - vmovlpd(xmm15, mem(rdx, r12, 1)) - vmovhpd(xmm15, mem(rdx, r13, 1)) - add(rdi, rdx) - - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORBZ) - - - vmovupd(xmm4, mem(rcx)) - vmovupd(xmm7, mem(rcx, r12, 1)) - add(rdi, rcx) - vmovupd(xmm10, mem(rdx)) - vmovupd(xmm13, mem(rdx, r12, 1)) - add(rdi, rdx) - - vmovupd(xmm5, mem(rcx)) - vmovupd(xmm8, mem(rcx, r12, 1)) - add(rdi, rcx) - vmovupd(xmm11, mem(rdx)) - vmovupd(xmm14, mem(rdx, r12, 1)) - add(rdi, rdx) - - vmovupd(xmm6, mem(rcx)) - vmovupd(xmm9, mem(rcx, r12, 1)) - add(rdi, rcx) - vmovupd(xmm12, mem(rdx)) - vmovupd(xmm15, mem(rdx, r12, 1)) - add(rdi, rdx) - - - - - + + vmovups(xmm4, mem(rcx, 0*16)) + vmovups(xmm7, mem(rcx, 1*16)) + vmovups(xmm10, mem(rcx, 2*16)) + vmovups(xmm13, mem(rcx, 3*16)) + + vmovups(xmm5, mem(r10, 0*16)) + vmovups(xmm8, mem(r10, 1*16)) + vmovups(xmm11, mem(r10, 2*16)) + vmovups(xmm14, mem(r10, 3*16)) + + vmovups(xmm6, mem(r11, 0*16)) + vmovups(xmm9, mem(r11, 1*16)) + vmovups(xmm12, mem(r11, 2*16)) + vmovups(xmm15, mem(r11, 3*16)) + label(.DDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next) // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1628,11 +973,15 @@ void bli_dgemm_piledriver_asm_8x3 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } void bli_cgemm_piledriver_asm_4x2 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -1647,28 +996,30 @@ void bli_cgemm_piledriver_asm_4x2 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 8; - uint64_t k_left = k0 % 8; + uint64_t k_iter = k / 8; + uint64_t k_left = k % 8; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( c, 4, 2, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. mov(var(a_next), r14) // load address of a_next. - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(scomplex) lea(mem(rcx, rdi, 1), r10) // load address of c + 1*cs_c; - + add(imm(32*4), rax) add(imm(16*4), rbx) - - + + vxorps(xmm8, xmm8, xmm8) vxorps(xmm9, xmm9, xmm9) vxorps(xmm10, xmm10, xmm10) @@ -1678,24 +1029,24 @@ void bli_cgemm_piledriver_asm_4x2 vxorps(xmm14, xmm14, xmm14) vxorps(xmm15, xmm15, xmm15) //vzeroall() - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.CCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.CLOOPKITER) // MAIN LOOP - - + + je(.CCONSIDKLEFT) // if i == 0, jump to k_left code. - - + + prefetch(0, mem(rbx, 256)) prefetch(0, mem(rax, 512)) - + // iteration 0 vmovaps(mem(rax, -32*4), xmm0) vbroadcastss(mem(rbx, -16*4), xmm4) @@ -1711,7 +1062,7 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, -13*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + // iteration 1 vmovaps(mem(rax, -24*4), xmm0) vbroadcastss(mem(rbx, -12*4), xmm4) @@ -1727,10 +1078,10 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, -9*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 64+256)) prefetch(0, mem(rax, 64+512)) - + // iteration 2 vmovaps(mem(rax, -16*4), xmm0) vbroadcastss(mem(rbx, -8*4), xmm4) @@ -1746,7 +1097,7 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, -5*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + // iteration 3 vmovaps(mem(rax, -8*4), xmm0) vbroadcastss(mem(rbx, -4*4), xmm4) @@ -1762,10 +1113,10 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, -1*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 128+256)) prefetch(0, mem(rax, 128+512)) - + // iteration 4 vmovaps(mem(rax, 0*4), xmm0) vbroadcastss(mem(rbx, 0*4), xmm4) @@ -1781,7 +1132,7 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, 3*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + // iteration 5 vmovaps(mem(rax, 8*4), xmm0) vbroadcastss(mem(rbx, 4*4), xmm4) @@ -1797,10 +1148,10 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, 7*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 128+256)) prefetch(0, mem(rax, 128+512)) - + // iteration 6 vmovaps(mem(rax, 16*4), xmm0) vbroadcastss(mem(rbx, 8*4), xmm4) @@ -1816,7 +1167,7 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, 11*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + // iteration 7 vmovaps(mem(rax, 24*4), xmm0) vbroadcastss(mem(rbx, 12*4), xmm4) @@ -1834,33 +1185,33 @@ void bli_cgemm_piledriver_asm_4x2 add(imm(8*2*8), rbx) // b += 8*2 (unroll x nr) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - - - + + + dec(rsi) // i -= 1; jmp(.CLOOPKITER) // jump to beginning of loop. - - - - - - + + + + + + label(.CCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.CPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.CLOOPKLEFT) // EDGE LOOP - - + + je(.CPOSTACCUM) // if i == 0, we're done. - + prefetch(0, mem(rbx, 256)) prefetch(0, mem(rax, 512)) - + // iteration 0 vmovaps(mem(rax, -32*4), xmm0) vbroadcastss(mem(rbx, -16*4), xmm4) @@ -1876,123 +1227,88 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, -13*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - - + + add(imm(1*4*8), rax) // a += 1*2 (1 x mr) add(imm(1*2*8), rbx) // b += 1*2 (1 x nr) - - + + dec(rsi) // i -= 1; jmp(.CLOOPKLEFT) // jump to beginning of loop. - - - + + + label(.CPOSTACCUM) - - + + prefetchw0(mem(rcx, 0*8)) // prefetch c + 0*cs_c prefetchw0(mem(r10, 0*8)) // prefetch c + 1*cs_c - - + + vpermilps(imm(0xb1), xmm9, xmm9) vpermilps(imm(0xb1), xmm11, xmm11) vpermilps(imm(0xb1), xmm13, xmm13) vpermilps(imm(0xb1), xmm15, xmm15) - + vaddsubps(xmm9, xmm8, xmm8) vaddsubps(xmm11, xmm10, xmm10) vaddsubps(xmm13, xmm12, xmm12) vaddsubps(xmm15, xmm14, xmm14) - - + + // xmm8: xmm10: // ( ab00 ( ab01 // ab10 ab11 // ab20 ab21 // ab30 ) ab31 ) - + // xmm12: xmm14: // ( ab40 ( ab41 // ab50 ab51 // ab60 ab61 // ab70 ) ab71 ) - - + + prefetch(0, mem(r14)) // prefetch a_next prefetch(0, mem(r14, 64)) // prefetch a_next - - + + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vbroadcastss(mem(rax), xmm0) // load alpha_r and duplicate vbroadcastss(mem(rax, 4), xmm1) // load alpha_i and duplicate - + vpermilps(imm(0xb1), xmm8, xmm9) vpermilps(imm(0xb1), xmm10, xmm11) vpermilps(imm(0xb1), xmm12, xmm13) vpermilps(imm(0xb1), xmm14, xmm15) - + vmulps(xmm8, xmm0, xmm8) vmulps(xmm10, xmm0, xmm10) vmulps(xmm12, xmm0, xmm12) vmulps(xmm14, xmm0, xmm14) - + vmulps(xmm9, xmm1, xmm9) vmulps(xmm11, xmm1, xmm11) vmulps(xmm13, xmm1, xmm13) vmulps(xmm15, xmm1, xmm15) - + vaddsubps(xmm9, xmm8, xmm8) vaddsubps(xmm11, xmm10, xmm10) vaddsubps(xmm13, xmm12, xmm12) vaddsubps(xmm15, xmm14, xmm14) - - - - + + mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rbx), xmm6) // load beta_r and duplicate vbroadcastss(mem(rbx, 4), xmm7) // load beta_i and duplicate - - - - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(scomplex) - - - lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - - - + prefetch(0, mem(r15)) // prefetch b_next prefetch(0, mem(r15, 64)) // prefetch b_next - - - - // determine if - // c % 32 == 0, AND - // 8*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (8*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + // now avoid loading C if beta == 0 - + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. vucomiss(xmm0, xmm6) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -2000,175 +1316,66 @@ void bli_cgemm_piledriver_asm_4x2 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 0, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.CCOLSTORED) // jump to column storage case - - - - label(.CGENSTORED) - - - vmovlps(mem(rcx), xmm0, xmm0) // load c00:c10 - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm2, xmm2) // load c20:c30 - vmovhps(mem(rcx, r13, 1), xmm2, xmm2) - vpermilps(imm(0xb1), xmm0, xmm1) - vpermilps(imm(0xb1), xmm2, xmm3) - - vmulps(xmm6, xmm0, xmm0) - vmulps(xmm7, xmm1, xmm1) - vaddsubps(xmm1, xmm0, xmm0) - vaddps(xmm8, xmm0, xmm0) - vmovlps(xmm0, mem(rcx)) // store c00:c10 - vmovhps(xmm0, mem(rcx, rsi, 1)) - - vmulps(xmm6, xmm2, xmm2) - vmulps(xmm7, xmm3, xmm3) - vaddsubps(xmm3, xmm2, xmm2) - vaddps(xmm12, xmm2, xmm2) - vmovlps(xmm2, mem(rcx, r12, 1)) // store c20:c30 - vmovhps(xmm2, mem(rcx, r13, 1)) - - - - vmovlps(mem(r10), xmm0, xmm0) // load c01:c11 - vmovhps(mem(r10, rsi, 1), xmm0, xmm0) - vmovlps(mem(r10, r12, 1), xmm2, xmm2) // load c21:c31 - vmovhps(mem(r10, r13, 1), xmm2, xmm2) - vpermilps(imm(0xb1), xmm0, xmm1) - vpermilps(imm(0xb1), xmm2, xmm3) - - vmulps(xmm6, xmm0, xmm0) - vmulps(xmm7, xmm1, xmm1) - vaddsubps(xmm1, xmm0, xmm0) - vaddps(xmm10, xmm0, xmm0) - vmovlps(xmm0, mem(r10)) // store c01:c11 - vmovhps(xmm0, mem(r10, rsi, 1)) - - vmulps(xmm6, xmm2, xmm2) - vmulps(xmm7, xmm3, xmm3) - vaddsubps(xmm3, xmm2, xmm2) - vaddps(xmm14, xmm2, xmm2) - vmovlps(xmm2, mem(r10, r12, 1)) // store c21:c31 - vmovhps(xmm2, mem(r10, r13, 1)) - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORED) - - - vmovups(mem(rcx), xmm0) // load c00:c10 - vmovups(mem(rcx, 16), xmm2) // load c20:c30 - vpermilps(imm(0xb1), xmm0, xmm1) - vpermilps(imm(0xb1), xmm2, xmm3) - - vmulps(xmm6, xmm0, xmm0) - vmulps(xmm7, xmm1, xmm1) - vaddsubps(xmm1, xmm0, xmm0) - vaddps(xmm8, xmm0, xmm0) - vmovups(xmm0, mem(rcx)) // store c00:c10 - - vmulps(xmm6, xmm2, xmm2) - vmulps(xmm7, xmm3, xmm3) - vaddsubps(xmm3, xmm2, xmm2) - vaddps(xmm12, xmm2, xmm2) - vmovups(xmm2, mem(rcx, 16)) // store c20:c30 - - - - vmovups(mem(r10), xmm0) // load c01:c11 - vmovups(mem(r10, 16), xmm2) // load c21:c31 - vpermilps(imm(0xb1), xmm0, xmm1) - vpermilps(imm(0xb1), xmm2, xmm3) - - vmulps(xmm6, xmm0, xmm0) - vmulps(xmm7, xmm1, xmm1) - vaddsubps(xmm1, xmm0, xmm0) - vaddps(xmm10, xmm0, xmm0) - vmovups(xmm0, mem(r10)) // store c01:c11 - - vmulps(xmm6, xmm2, xmm2) - vmulps(xmm7, xmm3, xmm3) - vaddsubps(xmm3, xmm2, xmm2) - vaddps(xmm14, xmm2, xmm2) - vmovups(xmm2, mem(r10, 16)) // store c21:c31 - - - - jmp(.CDONE) // jump to end. - - - + + vmovups(mem(rcx), xmm0) // load c00:c10 + vmovups(mem(rcx, 16), xmm2) // load c20:c30 + vpermilps(imm(0xb1), xmm0, xmm1) + vpermilps(imm(0xb1), xmm2, xmm3) + + vmulps(xmm6, xmm0, xmm0) + vmulps(xmm7, xmm1, xmm1) + vaddsubps(xmm1, xmm0, xmm0) + vaddps(xmm8, xmm0, xmm0) + + vmulps(xmm6, xmm2, xmm2) + vmulps(xmm7, xmm3, xmm3) + vaddsubps(xmm3, xmm2, xmm2) + vaddps(xmm12, xmm2, xmm2) + + vmovups(mem(r10), xmm0) // load c01:c11 + vmovups(mem(r10, 16), xmm2) // load c21:c31 + vpermilps(imm(0xb1), xmm0, xmm1) + vpermilps(imm(0xb1), xmm2, xmm3) + + vmulps(xmm6, xmm0, xmm0) + vmulps(xmm7, xmm1, xmm1) + vaddsubps(xmm1, xmm0, xmm0) + vaddps(xmm10, xmm0, xmm0) + + vmulps(xmm6, xmm2, xmm2) + vmulps(xmm7, xmm3, xmm3) + vaddsubps(xmm3, xmm2, xmm2) + vaddps(xmm14, xmm2, xmm2) + + // fall through + label(.CBETAZERO) - // check if aligned/column-stored - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.CCOLSTORBZ) // jump to column storage case - - - - label(.CGENSTORBZ) - - - vmovlps(xmm8, mem(rcx)) // store c00:c10 - vmovhps(xmm8, mem(rcx, rsi, 1)) - - vmovlps(xmm12, mem(rcx, r12, 1)) // store c20:c30 - vmovhps(xmm12, mem(rcx, r13, 1)) - - vmovlps(xmm10, mem(r10)) // store c01:c11 - vmovhps(xmm10, mem(r10, rsi, 1)) - - vmovlps(xmm14, mem(r10, r12, 1)) // store c21:c31 - vmovhps(xmm14, mem(r10, r13, 1)) - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORBZ) - - - vmovups(xmm8, mem(rcx)) // store c00:c10 - vmovups(xmm12, mem(rcx, 16)) // store c20:c30 - - vmovups(xmm10, mem(r10)) // store c01:c11 - vmovups(xmm14, mem(r10, 16)) // store c21:c31 - - - - - + + vmovups(xmm8, mem(rcx)) // store c00:c10 + vmovups(xmm12, mem(rcx, 16)) // store c20:c30 + + vmovups(xmm10, mem(r10)) // store c01:c11 + vmovups(xmm14, mem(r10, 16)) // store c21:c31 + label(.CDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next) // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2176,11 +1383,15 @@ void bli_cgemm_piledriver_asm_4x2 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( c ); } void bli_zgemm_piledriver_asm_2x2 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -2195,28 +1406,30 @@ void bli_zgemm_piledriver_asm_2x2 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 8; - uint64_t k_left = k0 % 8; + uint64_t k_iter = k / 8; + uint64_t k_left = k % 8; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( z, 2, 2, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. mov(var(a_next), r14) // load address of a_next. - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(dcomplex) lea(mem(, rdi, 2), rdi) lea(mem(rcx, rdi, 1), r10) // load address of c + 1*cs_c; - + add(imm(16*8), rax) add(imm(16*8), rbx) - + vxorpd(xmm8, xmm8, xmm8) vxorpd(xmm9, xmm9, xmm9) vxorpd(xmm10, xmm10, xmm10) @@ -2225,25 +1438,25 @@ void bli_zgemm_piledriver_asm_2x2 vxorpd(xmm13, xmm13, xmm13) vxorpd(xmm14, xmm14, xmm14) vxorpd(xmm15, xmm15, xmm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.ZCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.ZLOOPKITER) // MAIN LOOP - - + + je(.ZCONSIDKLEFT) // if i == 0, jump to k_left code. - - + + prefetch(0, mem(rbx, 256)) - + prefetch(0, mem(rax, 512)) - + // iteration 0 vmovaps(mem(rax, -16*8), xmm0) vmovddup(mem(rbx, -16*8), xmm4) @@ -2261,7 +1474,7 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, -12*8), xmm0) vmovddup(mem(rbx, -12*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + // iteration 1 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, -10*8), xmm1) @@ -2277,11 +1490,11 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, -8*8), xmm0) vmovddup(mem(rbx, -8*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 64+256)) - + prefetch(0, mem(rax, 64+512)) - + // iteration 2 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, -6*8), xmm1) @@ -2297,7 +1510,7 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, -4*8), xmm0) vmovddup(mem(rbx, -4*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + // iteration 3 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, -2*8), xmm1) @@ -2313,11 +1526,11 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, 0*8), xmm0) vmovddup(mem(rbx, 0*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 128+256)) - + prefetch(0, mem(rax, 128+512)) - + // iteration 4 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, 2*8), xmm1) @@ -2333,7 +1546,7 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, 4*8), xmm0) vmovddup(mem(rbx, 4*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + // iteration 5 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, 6*8), xmm1) @@ -2349,11 +1562,11 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, 8*8), xmm0) vmovddup(mem(rbx, 8*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 128+256)) - + prefetch(0, mem(rax, 128+512)) - + // iteration 6 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, 10*8), xmm1) @@ -2369,7 +1582,7 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, 12*8), xmm0) vmovddup(mem(rbx, 12*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + // iteration 7 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, 14*8), xmm1) @@ -2385,34 +1598,34 @@ void bli_zgemm_piledriver_asm_2x2 add(imm(8*2*16), rbx) // b += 8*2 (unroll x nr) vfmadd231pd(xmm0, xmm7, xmm11) vfmadd231pd(xmm1, xmm7, xmm15) - - - + + + dec(rsi) // i -= 1; jmp(.ZLOOPKITER) // jump to beginning of loop. - - - - - - + + + + + + label(.ZCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.ZLOOPKLEFT) // EDGE LOOP - - + + je(.ZPOSTACCUM) // if i == 0, we're done. - + prefetch(0, mem(rbx, 256)) - + prefetch(0, mem(rax, 512)) - + // iteration 0 vmovaps(mem(rax, -16*8), xmm0) vmovddup(mem(rbx, -16*8), xmm4) @@ -2428,119 +1641,86 @@ void bli_zgemm_piledriver_asm_2x2 vmovddup(mem(rbx, -13*8), xmm7) vfmadd231pd(xmm0, xmm7, xmm11) vfmadd231pd(xmm1, xmm7, xmm15) - - + + add(imm(1*2*16), rax) // a += 1*2 (1 x mr) add(imm(1*2*16), rbx) // b += 1*2 (1 x nr) - - + + dec(rsi) // i -= 1; jmp(.ZLOOPKLEFT) // jump to beginning of loop. - - - + + + label(.ZPOSTACCUM) - - + + prefetchw0(mem(rcx, 0*8)) // prefetch c + 0*cs_c prefetchw0(mem(r10, 0*8)) // prefetch c + 1*cs_c - - + + vpermilpd(imm(0x1), xmm9, xmm9) vpermilpd(imm(0x1), xmm11, xmm11) vpermilpd(imm(0x1), xmm13, xmm13) vpermilpd(imm(0x1), xmm15, xmm15) - + vaddsubpd(xmm9, xmm8, xmm8) vaddsubpd(xmm11, xmm10, xmm10) vaddsubpd(xmm13, xmm12, xmm12) vaddsubpd(xmm15, xmm14, xmm14) - - + + // xmm8: xmm10: // ( ab00 ( ab01 // ab10 ) ab11 ) - + // xmm12: xmm14: // ( ab20 ( ab21 // ab30 ) ab31 ) - - + + prefetch(0, mem(r14)) // prefetch a_next prefetch(0, mem(r14, 64)) // prefetch a_next - - + + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vmovddup(mem(rax), xmm0) // load alpha_r and duplicate vmovddup(mem(rax, 8), xmm1) // load alpha_i and duplicate - + vpermilpd(imm(0x1), xmm8, xmm9) vpermilpd(imm(0x1), xmm10, xmm11) vpermilpd(imm(0x1), xmm12, xmm13) vpermilpd(imm(0x1), xmm14, xmm15) - + vmulpd(xmm8, xmm0, xmm8) vmulpd(xmm10, xmm0, xmm10) vmulpd(xmm12, xmm0, xmm12) vmulpd(xmm14, xmm0, xmm14) - + vmulpd(xmm9, xmm1, xmm9) vmulpd(xmm11, xmm1, xmm11) vmulpd(xmm13, xmm1, xmm13) vmulpd(xmm15, xmm1, xmm15) - + vaddsubpd(xmm9, xmm8, xmm8) vaddsubpd(xmm11, xmm10, xmm10) vaddsubpd(xmm13, xmm12, xmm12) vaddsubpd(xmm15, xmm14, xmm14) - - - - + + + + mov(var(beta), rbx) // load address of beta vmovddup(mem(rbx), xmm6) // load beta_r and duplicate vmovddup(mem(rbx, 8), xmm7) // load beta_i and duplicate - - - - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(dcomplex) - lea(mem(, rsi, 2), rsi) - //lea(mem(rcx, rsi, 2), rdx) // load address of c + 2*rs_c; - - - - - + prefetch(0, mem(r15)) // prefetch b_next prefetch(0, mem(r15, 64)) // prefetch b_next - - - - // determine if - // c % 32 == 0, AND - // 16*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(16), rsi) // set ZF if (16*rs_c) == 16. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (16*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + // now avoid loading C if beta == 0 - + vxorpd(xmm0, xmm0, xmm0) // set xmm0 to zero. vucomisd(xmm0, xmm6) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -2548,161 +1728,66 @@ void bli_zgemm_piledriver_asm_2x2 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 0, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.ZCOLSTORED) // jump to column storage case - - - - label(.ZGENSTORED) - - - vmovups(mem(rcx), xmm0) // load c00 - vmovups(mem(rcx, rsi, 1), xmm2) // load c10 - vpermilpd(imm(0x1), xmm0, xmm1) - vpermilpd(imm(0x1), xmm2, xmm3) - - vmulpd(xmm6, xmm0, xmm0) - vmulpd(xmm7, xmm1, xmm1) - vaddsubpd(xmm1, xmm0, xmm0) - vaddpd(xmm8, xmm0, xmm0) - vmovups(xmm0, mem(rcx)) // store c00 - - vmulpd(xmm6, xmm2, xmm2) - vmulpd(xmm7, xmm3, xmm3) - vaddsubpd(xmm3, xmm2, xmm2) - vaddpd(xmm12, xmm2, xmm2) - vmovups(xmm2, mem(rcx, rsi, 1)) // store c10 - - - - vmovups(mem(r10), xmm0) // load c01 - vmovups(mem(r10, rsi, 1), xmm2) // load c11 - vpermilpd(imm(0x1), xmm0, xmm1) - vpermilpd(imm(0x1), xmm2, xmm3) - - vmulpd(xmm6, xmm0, xmm0) - vmulpd(xmm7, xmm1, xmm1) - vaddsubpd(xmm1, xmm0, xmm0) - vaddpd(xmm10, xmm0, xmm0) - vmovups(xmm0, mem(r10)) // store c01 - - vmulpd(xmm6, xmm2, xmm2) - vmulpd(xmm7, xmm3, xmm3) - vaddsubpd(xmm3, xmm2, xmm2) - vaddpd(xmm14, xmm2, xmm2) - vmovups(xmm2, mem(r10, rsi, 1)) // store c11 - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORED) - - - vmovups(mem(rcx), xmm0) // load c00 - vmovups(mem(rcx, 16), xmm2) // load c10 - vpermilpd(imm(0x1), xmm0, xmm1) - vpermilpd(imm(0x1), xmm2, xmm3) - - vmulpd(xmm6, xmm0, xmm0) - vmulpd(xmm7, xmm1, xmm1) - vaddsubpd(xmm1, xmm0, xmm0) - vaddpd(xmm8, xmm0, xmm0) - vmovups(xmm0, mem(rcx)) // store c00 - - vmulpd(xmm6, xmm2, xmm2) - vmulpd(xmm7, xmm3, xmm3) - vaddsubpd(xmm3, xmm2, xmm2) - vaddpd(xmm12, xmm2, xmm2) - vmovups(xmm2, mem(rcx, 16)) // store c10 - - - - vmovups(mem(r10), xmm0) // load c01 - vmovups(mem(r10, 16), xmm2) // load c11 - vpermilpd(imm(0x1), xmm0, xmm1) - vpermilpd(imm(0x1), xmm2, xmm3) - - vmulpd(xmm6, xmm0, xmm0) - vmulpd(xmm7, xmm1, xmm1) - vaddsubpd(xmm1, xmm0, xmm0) - vaddpd(xmm10, xmm0, xmm0) - vmovups(xmm0, mem(r10)) // store c01 - - vmulpd(xmm6, xmm2, xmm2) - vmulpd(xmm7, xmm3, xmm3) - vaddsubpd(xmm3, xmm2, xmm2) - vaddpd(xmm14, xmm2, xmm2) - vmovups(xmm2, mem(r10, 16)) // store c11 - - - - jmp(.ZDONE) // jump to end. - - - + + vmovups(mem(rcx), xmm0) // load c00 + vmovups(mem(rcx, 16), xmm2) // load c10 + vpermilpd(imm(0x1), xmm0, xmm1) + vpermilpd(imm(0x1), xmm2, xmm3) + + vmulpd(xmm6, xmm0, xmm0) + vmulpd(xmm7, xmm1, xmm1) + vaddsubpd(xmm1, xmm0, xmm0) + vaddpd(xmm8, xmm0, xmm0) + + vmulpd(xmm6, xmm2, xmm2) + vmulpd(xmm7, xmm3, xmm3) + vaddsubpd(xmm3, xmm2, xmm2) + vaddpd(xmm12, xmm2, xmm2) + + vmovups(mem(r10), xmm0) // load c01 + vmovups(mem(r10, 16), xmm2) // load c11 + vpermilpd(imm(0x1), xmm0, xmm1) + vpermilpd(imm(0x1), xmm2, xmm3) + + vmulpd(xmm6, xmm0, xmm0) + vmulpd(xmm7, xmm1, xmm1) + vaddsubpd(xmm1, xmm0, xmm0) + vaddpd(xmm10, xmm0, xmm0) + + vmulpd(xmm6, xmm2, xmm2) + vmulpd(xmm7, xmm3, xmm3) + vaddsubpd(xmm3, xmm2, xmm2) + vaddpd(xmm14, xmm2, xmm2) + + // fall through + label(.ZBETAZERO) - // check if aligned/column-stored - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.ZCOLSTORBZ) // jump to column storage case - - - - label(.ZGENSTORBZ) - - - vmovups(xmm8, mem(rcx)) // store c00 - vmovups(xmm12, mem(rcx, rsi, 1)) // store c10 - - vmovups(xmm10, mem(r10)) // store c01 - vmovups(xmm14, mem(r10, rsi, 1)) // store c11 - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORBZ) - - - vmovups(xmm8, mem(rcx)) // store c00 - vmovups(xmm12, mem(rcx, 16)) // store c10 - - vmovups(xmm10, mem(r10)) // store c01 - vmovups(xmm14, mem(r10, 16)) // store c11 - - - - - + + vmovups(xmm8, mem(rcx)) // store c00 + vmovups(xmm12, mem(rcx, 16)) // store c10 + + vmovups(xmm10, mem(r10)) // store c01 + vmovups(xmm14, mem(r10, 16)) // store c11 + label(.ZDONE) - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next) // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2710,6 +1795,8 @@ void bli_zgemm_piledriver_asm_2x2 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/power10/3/bli_dgemm_power10_mma.c b/kernels/power10/3/bli_dgemm_power10_mma.c index 3968249869..84e7d16d34 100644 --- a/kernels/power10/3/bli_dgemm_power10_mma.c +++ b/kernels/power10/3/bli_dgemm_power10_mma.c @@ -37,7 +37,7 @@ #define D_ASSEMBLE_VEC_PAIR \ __builtin_mma_assemble_pair (&colA_1, ca[1], ca[0]); \ - __builtin_mma_assemble_pair (&colA_2, ca[3], ca[2]); + __builtin_mma_assemble_pair (&colA_2, ca[3], ca[2]); #define D_ACCUMULATE \ __builtin_mma_xvf64gerpp (&acc0, colA_1, rb[0]); \ @@ -47,7 +47,7 @@ __builtin_mma_xvf64gerpp (&acc4, colA_2, rb[0]); \ __builtin_mma_xvf64gerpp (&acc5, colA_2, rb[1]); \ __builtin_mma_xvf64gerpp (&acc6, colA_2, rb[2]); \ - __builtin_mma_xvf64gerpp (&acc7, colA_2, rb[3]); + __builtin_mma_xvf64gerpp (&acc7, colA_2, rb[3]); #define D_INCREMENT \ A0+=8; \ @@ -57,17 +57,19 @@ LOAD_VECTORS \ D_ASSEMBLE_VEC_PAIR \ D_INCREMENT \ - D_ACCUMULATE + D_ACCUMULATE void bli_dgemm_power10_mma_8x8 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict c, inc_t rs_c0, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -76,11 +78,13 @@ void bli_dgemm_power10_mma_8x8 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. // (1 is subtracted from k0 because 1 iteration of the k loop is pulled out) - uint64_t k_iter = (k0-1) / 4; - uint64_t k_left = (k0-1) % 4; + uint64_t k_iter = (k-1) / 4; + uint64_t k_left = (k-1) % 4; uint64_t rs_c = rs_c0; + GEMM_UKR_SETUP_CT( d, 8, 8, true ); + double* restrict A0 = a; double* restrict B0 = b; double* restrict C0 = c; @@ -92,23 +96,23 @@ void bli_dgemm_power10_mma_8x8 dv4sf_t *rowC; /* 8 accumulator registers that will be used to store the result. - + Each accumulator register is mapped to 4 vector registers. Illustration: - + acc0 = [ vs0 vs1 vs3 vs4 ] - These registers are used to store the result of an outer product + These registers are used to store the result of an outer product instruction (general outer product instruction syntax: xv???ger??). */ - __vector_quad acc0, acc1, acc2, acc3, + __vector_quad acc0, acc1, acc2, acc3, acc4, acc5, acc6, acc7; - /* 2 vector pairs are necessary for a double precision outer product + /* 2 vector pairs are necessary for a double precision outer product instruction. */ - __vector_pair colA_1, + __vector_pair colA_1, colA_2; /* Prefetch C so that it stays in cache */ @@ -123,17 +127,17 @@ void bli_dgemm_power10_mma_8x8 /* Load elements into vector registers */ vec_t *ca = (vec_t *) A0; - vec_t *rb = (vec_t *) B0; + vec_t *rb = (vec_t *) B0; - /* Each accumulator represents a matrix of size + /* Each accumulator represents a matrix of size 4 x ( 16 / (datatype size in bytes) ) (vector register size = 16B) - Thus in the case of double, the accumulate registers represent a 4x2 + Thus in the case of double, the accumulate registers represent a 4x2 matrix. However, a vector register can hold at most 2 doubles. Thus, if - we performed an outer product using 2 vector register, we can only get a + we performed an outer product using 2 vector register, we can only get a 2x2 matrix. Therefore, we must create a vector register pair in order to get the desired 4x2 matrix. - + */ D_ASSEMBLE_VEC_PAIR @@ -158,7 +162,7 @@ void bli_dgemm_power10_mma_8x8 D_AB_PRODUCT D_AB_PRODUCT } - + // edge loop for (int k = 0; k 0; kk--) { + vector double va00 = vec_splats( *(double *)( pa+0 ) ); + vector double va10 = vec_splats( *(double *)( pa+d1 ) ); + vector double va20 = vec_splats( *(double *)( pa+d2 ) ); + vector double va30 = vec_splats( *(double *)( pa+d3 ) ); + vector double va40 = vec_splats( *(double *)( pa+d4 ) ); + vector double va50 = vec_splats( *(double *)( pa+d5 ) ); + vector double va60 = vec_splats( *(double *)( pa+d6 ) ); + vector double va70 = vec_splats( *(double *)( pa+d7 ) ); + pa += 8*sizeof(double); + + vector double vb00_01 = *(vector double *)( pb+0 ); + vector double vb02_03 = *(vector double *)( pb+d2 ); + pb += 4*sizeof(double); + + vc00_01 = vec_madd(va00, vb00_01, vc00_01); + vc02_03 = vec_madd(va00, vb02_03, vc02_03); + vc10_11 = vec_madd(va10, vb00_01, vc10_11); + vc12_13 = vec_madd(va10, vb02_03, vc12_13); + vc20_21 = vec_madd(va20, vb00_01, vc20_21); + vc22_23 = vec_madd(va20, vb02_03, vc22_23); + vc30_31 = vec_madd(va30, vb00_01, vc30_31); + vc32_33 = vec_madd(va30, vb02_03, vc32_33); + vc40_41 = vec_madd(va40, vb00_01, vc40_41); + vc42_43 = vec_madd(va40, vb02_03, vc42_43); + vc50_51 = vec_madd(va50, vb00_01, vc50_51); + vc52_53 = vec_madd(va50, vb02_03, vc52_53); + vc60_61 = vec_madd(va60, vb00_01, vc60_61); + vc62_63 = vec_madd(va60, vb02_03, vc62_63); + vc70_71 = vec_madd(va70, vb00_01, vc70_71); + vc72_73 = vec_madd(va70, vb02_03, vc72_73); + } + + vector double valpha = vec_splats( *alpha ); + vector double vbeta = (vector double) { *beta, *beta }; + + vector double *pc = (vector double *)c; + + vc00_01 = vec_mul(valpha, vc00_01); + vc02_03 = vec_mul(valpha, vc02_03); + pc[0] = vec_madd( pc[0], vbeta, vc00_01); + pc[1] = vec_madd( pc[1], vbeta, vc02_03); + pc += rs_c/2; + + vc10_11 = vec_mul(valpha, vc10_11); + vc12_13 = vec_mul(valpha, vc12_13); + pc[0] = vec_madd( pc[0], vbeta, vc10_11); + pc[1] = vec_madd( pc[1], vbeta, vc12_13); + pc += rs_c/2; + + vc20_21 = vec_mul(valpha, vc20_21); + vc22_23 = vec_mul(valpha, vc22_23); + pc[0] = vec_madd( pc[0], vbeta, vc20_21); + pc[1] = vec_madd( pc[1], vbeta, vc22_23); + pc += rs_c/2; + + vc30_31 = vec_mul(valpha, vc30_31); + vc32_33 = vec_mul(valpha, vc32_33); + pc[0] = vec_madd( pc[0], vbeta, vc30_31); + pc[1] = vec_madd( pc[1], vbeta, vc32_33); + pc += rs_c/2; + + vc40_41 = vec_mul(valpha, vc40_41); + vc42_43 = vec_mul(valpha, vc42_43); + pc[0] = vec_madd( pc[0], vbeta, vc40_41); + pc[1] = vec_madd( pc[1], vbeta, vc42_43); + pc += rs_c/2; + + vc50_51 = vec_mul(valpha, vc50_51); + vc52_53 = vec_mul(valpha, vc52_53); + pc[0] = vec_madd( pc[0], vbeta, vc50_51); + pc[1] = vec_madd( pc[1], vbeta, vc52_53); + pc += rs_c/2; + + vc60_61 = vec_mul(valpha, vc60_61); + vc62_63 = vec_mul(valpha, vc62_63); + pc[0] = vec_madd( pc[0], vbeta, vc60_61); + pc[1] = vec_madd( pc[1], vbeta, vc62_63); + pc += rs_c/2; + + vc70_71 = vec_mul(valpha, vc70_71); + vc72_73 = vec_mul(valpha, vc72_73); + pc[0] = vec_madd( pc[0], vbeta, vc70_71); + pc[1] = vec_madd( pc[1], vbeta, vc72_73); + pc += rs_c/2; + } + else + { + GEMM_UKR_SETUP_CT( d, 8, 4, false ); + // Optimized code for case where C columns are contiguous (column-major C) vector double vzero = vec_splats( 0.0 ); @@ -301,168 +433,8 @@ void bli_dgemm_power7_int_8x4 pc[1] = vec_madd( pc[1], vbeta, vc23_33); pc[2] = vec_madd( pc[2], vbeta, vc43_53); pc[3] = vec_madd( pc[3], vbeta, vc63_73); - } - else -#endif -#if 1 - if ( cs_c == 1 ) { - // Optimized code for case where C rows are contiguous (i.e. C is row-major) - - vector double vzero = vec_splats( 0.0 ); - - vector double vc00_01 = vzero; - vector double vc02_03 = vzero; - vector double vc10_11 = vzero; - vector double vc12_13 = vzero; - vector double vc20_21 = vzero; - vector double vc22_23 = vzero; - vector double vc30_31 = vzero; - vector double vc32_33 = vzero; - vector double vc40_41 = vzero; - vector double vc42_43 = vzero; - vector double vc50_51 = vzero; - vector double vc52_53 = vzero; - vector double vc60_61 = vzero; - vector double vc62_63 = vzero; - vector double vc70_71 = vzero; - vector double vc72_73 = vzero; - - unsigned long long pa = (unsigned long long)a; - unsigned long long pb = (unsigned long long)b; - -#if 0 - unsigned long long d1 = 1*sizeof(double); - unsigned long long d2 = 2*sizeof(double); - unsigned long long d3 = 3*sizeof(double); - unsigned long long d4 = 4*sizeof(double); - unsigned long long d6 = 6*sizeof(double); -#else - // ppc64 linux abi: r14-r31 Nonvolatile registers used for local variables - register unsigned long long d1 __asm ("r21") = 1*sizeof(double); - register unsigned long long d2 __asm ("r22") = 2*sizeof(double); - register unsigned long long d3 __asm ("r23") = 3*sizeof(double); - register unsigned long long d4 __asm ("r24") = 4*sizeof(double); - register unsigned long long d5 __asm ("r25") = 5*sizeof(double); - register unsigned long long d6 __asm ("r26") = 6*sizeof(double); - register unsigned long long d7 __asm ("r27") = 7*sizeof(double); - - __asm__ volatile (";" : "=r" (d1) : "r" (d1) ); - __asm__ volatile (";" : "=r" (d2) : "r" (d2) ); - __asm__ volatile (";" : "=r" (d3) : "r" (d3) ); - __asm__ volatile (";" : "=r" (d4) : "r" (d4) ); - __asm__ volatile (";" : "=r" (d5) : "r" (d5) ); - __asm__ volatile (";" : "=r" (d6) : "r" (d6) ); - __asm__ volatile (";" : "=r" (d7) : "r" (d7) ); -#endif - - int kk; - for (kk=k; kk > 0; kk--) { - vector double va00 = vec_splats( *(double *)( pa+0 ) ); - vector double va10 = vec_splats( *(double *)( pa+d1 ) ); - vector double va20 = vec_splats( *(double *)( pa+d2 ) ); - vector double va30 = vec_splats( *(double *)( pa+d3 ) ); - vector double va40 = vec_splats( *(double *)( pa+d4 ) ); - vector double va50 = vec_splats( *(double *)( pa+d5 ) ); - vector double va60 = vec_splats( *(double *)( pa+d6 ) ); - vector double va70 = vec_splats( *(double *)( pa+d7 ) ); - pa += 8*sizeof(double); - - vector double vb00_01 = *(vector double *)( pb+0 ); - vector double vb02_03 = *(vector double *)( pb+d2 ); - pb += 4*sizeof(double); - - vc00_01 = vec_madd(va00, vb00_01, vc00_01); - vc02_03 = vec_madd(va00, vb02_03, vc02_03); - vc10_11 = vec_madd(va10, vb00_01, vc10_11); - vc12_13 = vec_madd(va10, vb02_03, vc12_13); - vc20_21 = vec_madd(va20, vb00_01, vc20_21); - vc22_23 = vec_madd(va20, vb02_03, vc22_23); - vc30_31 = vec_madd(va30, vb00_01, vc30_31); - vc32_33 = vec_madd(va30, vb02_03, vc32_33); - vc40_41 = vec_madd(va40, vb00_01, vc40_41); - vc42_43 = vec_madd(va40, vb02_03, vc42_43); - vc50_51 = vec_madd(va50, vb00_01, vc50_51); - vc52_53 = vec_madd(va50, vb02_03, vc52_53); - vc60_61 = vec_madd(va60, vb00_01, vc60_61); - vc62_63 = vec_madd(va60, vb02_03, vc62_63); - vc70_71 = vec_madd(va70, vb00_01, vc70_71); - vc72_73 = vec_madd(va70, vb02_03, vc72_73); - } - - vector double valpha = vec_splats( *alpha ); - vector double vbeta = (vector double) { *beta, *beta }; - - vector double *pc = (vector double *)c; - - vc00_01 = vec_mul(valpha, vc00_01); - vc02_03 = vec_mul(valpha, vc02_03); - pc[0] = vec_madd( pc[0], vbeta, vc00_01); - pc[1] = vec_madd( pc[1], vbeta, vc02_03); - pc += rs_c/2; - - vc10_11 = vec_mul(valpha, vc10_11); - vc12_13 = vec_mul(valpha, vc12_13); - pc[0] = vec_madd( pc[0], vbeta, vc10_11); - pc[1] = vec_madd( pc[1], vbeta, vc12_13); - pc += rs_c/2; - - vc20_21 = vec_mul(valpha, vc20_21); - vc22_23 = vec_mul(valpha, vc22_23); - pc[0] = vec_madd( pc[0], vbeta, vc20_21); - pc[1] = vec_madd( pc[1], vbeta, vc22_23); - pc += rs_c/2; - - vc30_31 = vec_mul(valpha, vc30_31); - vc32_33 = vec_mul(valpha, vc32_33); - pc[0] = vec_madd( pc[0], vbeta, vc30_31); - pc[1] = vec_madd( pc[1], vbeta, vc32_33); - pc += rs_c/2; - - vc40_41 = vec_mul(valpha, vc40_41); - vc42_43 = vec_mul(valpha, vc42_43); - pc[0] = vec_madd( pc[0], vbeta, vc40_41); - pc[1] = vec_madd( pc[1], vbeta, vc42_43); - pc += rs_c/2; - - vc50_51 = vec_mul(valpha, vc50_51); - vc52_53 = vec_mul(valpha, vc52_53); - pc[0] = vec_madd( pc[0], vbeta, vc50_51); - pc[1] = vec_madd( pc[1], vbeta, vc52_53); - pc += rs_c/2; - - vc60_61 = vec_mul(valpha, vc60_61); - vc62_63 = vec_mul(valpha, vc62_63); - pc[0] = vec_madd( pc[0], vbeta, vc60_61); - pc[1] = vec_madd( pc[1], vbeta, vc62_63); - pc += rs_c/2; - - vc70_71 = vec_mul(valpha, vc70_71); - vc72_73 = vec_mul(valpha, vc72_73); - pc[0] = vec_madd( pc[0], vbeta, vc70_71); - pc[1] = vec_madd( pc[1], vbeta, vc72_73); - pc += rs_c/2; - } - else -#endif - { /* General case. Just do it right. */ -#if 1 || defined(UTEST) - const long MR = BLIS_DEFAULT_MR_D, NR = BLIS_DEFAULT_NR_D; - const long LDA = MR, LDB = NR; - int i, j, kk; - double c00; - - for (i=0; i < MR; i++) { - for (j=0; j < NR; j++) { - c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta; - for (kk=0; kk < k; kk++) - c00 += *alpha * (a[COLMAJ_INDEX(i,kk,LDA)] * b[ROWMAJ_INDEX(kk,j,LDB)]); - c[BLIS_INDEX(i,j,rs_c,cs_c)] = c00; - } - } -#else - //BLIS_DGEMM_UKERNEL_REF(k, alpha, a, b, beta, c, rs_c, cs_c, data); -#endif + GEMM_UKR_FLUSH_CT( d ); } } @@ -477,30 +449,26 @@ void bli_dgemm_power7_int_8x4 */ void bli_cgemm_power7_int_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, scomplex* restrict beta, - scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + scomplex* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. - uint64_t k = k0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - #if 1 || defined(UTEST) const long MR = BLIS_DEFAULT_MR_C, NR = BLIS_DEFAULT_NR_C; const long LDA = MR, LDB = NR; int i, j, kk; scomplex c00; - for (i=0; i < MR; i++) { - for (j=0; j < NR; j++) { + for (i=0; i < m; i++) { + for (j=0; j < n; j++) { scomplex tmpc, tmpa, tmpb, tmp; //c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta; tmpc = c[BLIS_INDEX(i,j,rs_c,cs_c)]; @@ -534,30 +502,26 @@ void bli_cgemm_power7_int_8x4 */ void bli_zgemm_power7_int_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, scomplex* restrict beta, - scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + scomplex* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. - uint64_t k = k0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - #if 1 || defined(UTEST) const long MR = BLIS_DEFAULT_MR_Z, NR = BLIS_DEFAULT_NR_Z; const long LDA = MR, LDB = NR; int i, j, kk; dcomplex c00; - for (i=0; i < MR; i++) { - for (j=0; j < NR; j++) { + for (i=0; i < m; i++) { + for (j=0; j < n; j++) { dcomplex tmpc, tmpa, tmpb, tmp; //c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta; tmpc = c[BLIS_INDEX(i,j,rs_c,cs_c)]; diff --git a/kernels/power7/3/test/bli_gemm_power7_int_8x4.h b/kernels/power7/3/test/bli_gemm_power7_int_8x4.h index ef1930907e..50984a67df 100644 --- a/kernels/power7/3/test/bli_gemm_power7_int_8x4.h +++ b/kernels/power7/3/test/bli_gemm_power7_int_8x4.h @@ -43,6 +43,8 @@ void bli_sgemm_opt_8x4 ( + dim_t m, + dim_t n, dim_t k, float* restrict alpha, float* restrict a, @@ -55,6 +57,8 @@ void bli_sgemm_opt_8x4 void bli_dgemm_opt_8x4 ( + dim_t m, + dim_t n, dim_t k, double* restrict alpha, double* restrict a, @@ -67,6 +71,8 @@ void bli_dgemm_opt_8x4 void bli_cgemm_opt_8x4 ( + dim_t m, + dim_t n, dim_t k, scomplex* restrict alpha, scomplex* restrict a, @@ -79,6 +85,8 @@ void bli_cgemm_opt_8x4 void bli_zgemm_opt_8x4 ( + dim_t m, + dim_t n, dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, diff --git a/kernels/power9/3/bli_gemm_power9_asm_d12x6.c b/kernels/power9/3/bli_gemm_power9_asm_d12x6.c index ec09f8e380..3e5f0d4164 100644 --- a/kernels/power9/3/bli_gemm_power9_asm_d12x6.c +++ b/kernels/power9/3/bli_gemm_power9_asm_d12x6.c @@ -37,7 +37,9 @@ void bli_dgemm_power9_asm_12x6 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -50,117 +52,91 @@ void bli_dgemm_power9_asm_12x6 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 16; - uint64_t k_left = k0 % 16; + uint64_t k_iter = k / 16; + uint64_t k_left = k % 16; - uint64_t rs_c = rs_c0; + uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( d, 12, 6, false ); + __asm__ volatile - ( - " \n\t" - "ld %%r7, %2 \n\t" // load ptr of A - "ld %%r8, %3 \n\t" // load ptr of B - "ld %%r16, %6 \n\t" // load ptr of C - " \n\t" - "ld %%r28, %4 \n\t" // load ptr for alpha - "ld %%r29, %5 \n\t" // load ptr for beta - " \n\t" - "ld %%r11, %0 \n\t" // load k_iter - "ld %%r12, %1 \n\t" // load k_left - " \n\t" - "ld %%r10, %8 \n\t" // load cs_c - "slwi %%r10, %%r10, 3 \n\t" // mul by size of elem - " \n\t" - "ld %%r9, %7 \n\t" // load rs_c - "slwi %%r9, %%r9, 3 \n\t" // mul by size of elem - " \n\t" - "ld %%r26, 0(%%r29) \n\t" // load val of beta - " \n\t" - "lxvdsx %%vs62, 0, %%r28 \n\t" // splat alpha - "lxvdsx %%vs63, 0, %%r29 \n\t" // splat beta - " \n\t" - "add %%r17, %%r16, %%r10 \n\t" // addr of col 1 of C - "add %%r18, %%r17, %%r10 \n\t" // col 2 of C - "add %%r19, %%r18, %%r10 \n\t" // col 3 of C - "add %%r20, %%r19, %%r10 \n\t" // col 4 of C - "add %%r21, %%r20, %%r10 \n\t" // col 5 of C - " \n\t" - DZERO_OUT_VREG - " \n\t" - DPRELOAD - " \n\t" - "addi %%r8, %%r8, 96 \n\t" // move to next col/row of A/B - "addi %%r7, %%r7, 96 \n\t" - " \n\t" - DPREFETCH - " \n\t" - "cmpwi %%r11, 0 \n\t" // if k_iter == 0, - "beq DCONSIDERKLEFT \n\t" // then jmp to k_left - "mtctr %%r11 \n\t" // else, do k_iter loop - " \n\t" - "DLOOPKITER: \n\t" // k_iter loop - " \n\t" - A_B_PRODUCT_16 // compute A*B - " \n\t" - "bdnz DLOOPKITER \n\t" - " \n\t" - "DCONSIDERKLEFT: \n\t" - " \n\t" - "cmpwi %%r12, 0 \n\t" // if k_left == 0, - "beq DPOSTACCUM \n\t" // then jmp to post accum - "mtctr %%r12 \n\t" // else, do k_left loop - " \n\t" - "DLOOPKLEFT: \n\t" // k_left loop - " \n\t" - A_B_PRODUCT_1 - " \n\t" - "bdnz DLOOPKLEFT \n\t" - " \n\t" - "DPOSTACCUM: \n\t" - " \n\t" - DSCALE_ALPHA - " \n\t" - "cmpdi %%r26, 0 \n\t" // if beta == 0, - "beq DBETAZERO \n\t" // then jmp to BZ - " \n\t" - "cmpwi %%r9, 8 \n\t" // if rs_c == 8 - "beq DCOLSTOREDBNZ \n\t" // then jmp to col store - " \n\t" - "DGENSTOREDBNZ: \n\t" // BNZ gen stored case - " \n\t" - DGEN_LOAD_OFS_C - " \n\t" - DGEN_SCALE_BETA - " \n\t" - "b DGENSTORED \n\t" - " \n\t" - "DCOLSTOREDBNZ: \n\t" // BNZ col stored case - " \n\t" - DCOL_SCALE_BETA - " \n\t" - "b DCOLSTORED \n\t" - " \n\t" - "DBETAZERO: \n\t" // BZ case - " \n\t" - "cmpwi %%r9, 8 \n\t" // if rs_c == 8, - "beq DCOLSTORED \n\t" // C is col stored - " \n\t" - "DGENSTORED: \n\t" // BZ gen stored case - " \n\t" - DGEN_LOAD_OFS_C - " \n\t" - DGEN_STORE - " \n\t" - "b DDONE \n\t" - " \n\t" - "DCOLSTORED: \n\t" // BZ col stored case - " \n\t" - DCOL_STORE - " \n\t" - "DDONE: \n\t" - " \n\t" - : // output operands (none) + ( + " \n\t" + "ld %%r7, %2 \n\t" // load ptr of A + "ld %%r8, %3 \n\t" // load ptr of B + "ld %%r16, %6 \n\t" // load ptr of C + " \n\t" + "ld %%r28, %4 \n\t" // load ptr for alpha + "ld %%r29, %5 \n\t" // load ptr for beta + " \n\t" + "ld %%r11, %0 \n\t" // load k_iter + "ld %%r12, %1 \n\t" // load k_left + " \n\t" + "ld %%r10, %8 \n\t" // load cs_c + "slwi %%r10, %%r10, 3 \n\t" // mul by size of elem + " \n\t" + "ld %%r9, %7 \n\t" // load rs_c + "slwi %%r9, %%r9, 3 \n\t" // mul by size of elem + " \n\t" + "ld %%r26, 0(%%r29) \n\t" // load val of beta + " \n\t" + "lxvdsx %%vs62, 0, %%r28 \n\t" // splat alpha + "lxvdsx %%vs63, 0, %%r29 \n\t" // splat beta + " \n\t" + "add %%r17, %%r16, %%r10 \n\t" // addr of col 1 of C + "add %%r18, %%r17, %%r10 \n\t" // col 2 of C + "add %%r19, %%r18, %%r10 \n\t" // col 3 of C + "add %%r20, %%r19, %%r10 \n\t" // col 4 of C + "add %%r21, %%r20, %%r10 \n\t" // col 5 of C + " \n\t" + DZERO_OUT_VREG + " \n\t" + DPRELOAD + " \n\t" + "addi %%r8, %%r8, 96 \n\t" // move to next col/row of A/B + "addi %%r7, %%r7, 96 \n\t" + " \n\t" + DPREFETCH + " \n\t" + "cmpwi %%r11, 0 \n\t" // if k_iter == 0, + "beq DCONSIDERKLEFT \n\t" // then jmp to k_left + "mtctr %%r11 \n\t" // else, do k_iter loop + " \n\t" + "DLOOPKITER: \n\t" // k_iter loop + " \n\t" + A_B_PRODUCT_16 // compute A*B + " \n\t" + "bdnz DLOOPKITER \n\t" + " \n\t" + "DCONSIDERKLEFT: \n\t" + " \n\t" + "cmpwi %%r12, 0 \n\t" // if k_left == 0, + "beq DPOSTACCUM \n\t" // then jmp to post accum + "mtctr %%r12 \n\t" // else, do k_left loop + " \n\t" + "DLOOPKLEFT: \n\t" // k_left loop + " \n\t" + A_B_PRODUCT_1 + " \n\t" + "bdnz DLOOPKLEFT \n\t" + " \n\t" + "DPOSTACCUM: \n\t" + " \n\t" + DSCALE_ALPHA + " \n\t" + "cmpdi %%r26, 0 \n\t" // if beta == 0, + "beq DBETAZERO \n\t" // then jmp to BZ + " \n\t" + DCOL_SCALE_BETA + " \n\t" + "DBETAZERO: \n\t" // BZ case + " \n\t" + DCOL_STORE + " \n\t" + "DDONE: \n\t" + " \n\t" + : // output operands (none) : // input operands "m" (k_iter), // 0 "m" (k_left), // 1 @@ -174,28 +150,30 @@ void bli_dgemm_power9_asm_12x6 "m" (b_next), // 9 "m" (a_next)*/ // 10 : // register clobber list - /* unclobberable regs: r2, r3, r4, r5, r6, r13, r14, r15, r30, r31 */ - "r0", "r7", "r8", "r9", - "r10", "r11", "r12", "r16", "r17", "r18", "r19", - "r20", "r21", "r22", "r23", "r24", "r25", "r26", "r27", "r28", "r29" + /* unclobberable regs: r2, r3, r4, r5, r6, r13, r14, r15, r30, r31 */ + "r0", "r7", "r8", "r9", + "r10", "r11", "r12", "r16", "r17", "r18", "r19", + "r20", "r21", "r22", "r23", "r24", "r25", "r26", "r27", "r28", "r29" + + #if XLC + ,"f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9" + , "f10", "f11", "f12", "f13", "f14", "f15", "f16", "f17", "f18", "f19" + , "f20" ,"f21", "f22", "f23", "f24", "f25", "f26", "f27", "f28", "f29" + , "f30" ,"f31" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9" + , "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19" + , "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29" + , "v30", "v31" + #else + , "vs0", "vs1", "vs2", "vs3", "vs4", "vs5", "vs6", "vs7", "vs8", "vs9" + , "vs10", "vs11", "vs12", "vs13", "vs14", "vs15", "vs16", "vs17", "vs18", "vs19" + , "vs20", "vs21", "vs22", "vs23", "vs24", "vs25", "vs26", "vs27", "vs28", "vs29" + , "vs30", "vs31", "vs32", "vs33", "vs34", "vs35", "vs36", "vs37", "vs38", "vs39" + , "vs40", "vs41", "vs42", "vs43", "vs44", "vs45", "vs46", "vs47", "vs48", "vs49" + , "vs50", "vs51", "vs52", "vs53" + #endif - #if XLC - ,"f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9" - , "f10", "f11", "f12", "f13", "f14", "f15", "f16", "f17", "f18", "f19" - , "f20" ,"f21", "f22", "f23", "f24", "f25", "f26", "f27", "f28", "f29" - , "f30" ,"f31" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9" - , "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19" - , "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29" - , "v30", "v31" - #else - , "vs0", "vs1", "vs2", "vs3", "vs4", "vs5", "vs6", "vs7", "vs8", "vs9" - , "vs10", "vs11", "vs12", "vs13", "vs14", "vs15", "vs16", "vs17", "vs18", "vs19" - , "vs20", "vs21", "vs22", "vs23", "vs24", "vs25", "vs26", "vs27", "vs28", "vs29" - , "vs30", "vs31", "vs32", "vs33", "vs34", "vs35", "vs36", "vs37", "vs38", "vs39" - , "vs40", "vs41", "vs42", "vs43", "vs44", "vs45", "vs46", "vs47", "vs48", "vs49" - , "vs50", "vs51", "vs52", "vs53" - #endif + ); - ); + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c b/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c index a56ef16e5e..7890ad347d 100644 --- a/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c +++ b/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c @@ -42,7 +42,9 @@ void bli_sgemm_sandybridge_asm_8x8 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -57,27 +59,29 @@ void bli_sgemm_sandybridge_asm_8x8 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( s, 8, 8, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(var(b_next), r15) // load address of b_next. - + vmovaps(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovsldup(mem(rbx, 0*32), ymm2) // elements of a and b. vpermilps(imm(0x4e), ymm2, ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float) lea(mem(rcx, rdi, 4), r10) // load address of c + 4*cs_c; - + lea(mem(rdi, rdi, 2), r14) // r14 = 3*cs_c; prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*cs_c @@ -87,7 +91,7 @@ void bli_sgemm_sandybridge_asm_8x8 prefetch(0, mem(r10, rdi, 1, 7*8)) // prefetch c + 5*cs_c prefetch(0, mem(r10, rdi, 2, 7*8)) // prefetch c + 6*cs_c prefetch(0, mem(r10, r14, 1, 7*8)) // prefetch c + 7*cs_c - + vxorps(ymm8, ymm8, ymm8) vxorps(ymm9, ymm9, ymm9) vxorps(ymm10, ymm10, ymm10) @@ -96,18 +100,18 @@ void bli_sgemm_sandybridge_asm_8x8 vxorps(ymm13, ymm13, ymm13) vxorps(ymm14, ymm14, ymm14) vxorps(ymm15, ymm15, ymm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.SLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 16*32)) vmulps(ymm0, ymm2, ymm6) @@ -117,14 +121,14 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm15, ymm6, ymm15) vaddps(ymm13, ymm7, ymm13) - + vmovaps(mem(rax, 1*32), ymm1) vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vaddps(ymm11, ymm6, ymm11) vaddps(ymm9, ymm7, ymm9) - + vmulps(ymm0, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 1*32), ymm2) @@ -132,13 +136,13 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm14, ymm6, ymm14) vaddps(ymm12, ymm7, ymm12) - + vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vaddps(ymm10, ymm6, ymm10) vaddps(ymm8, ymm7, ymm8) - + // iteration 1 vmulps(ymm1, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) @@ -147,14 +151,14 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm15, ymm6, ymm15) vaddps(ymm13, ymm7, ymm13) - + vmovaps(mem(rax, 2*32), ymm0) vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddps(ymm11, ymm6, ymm11) vaddps(ymm9, ymm7, ymm9) - + vmulps(ymm1, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 2*32), ymm2) @@ -162,14 +166,14 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm14, ymm6, ymm14) vaddps(ymm12, ymm7, ymm12) - + vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddps(ymm10, ymm6, ymm10) vaddps(ymm8, ymm7, ymm8) - - + + // iteration 2 prefetch(0, mem(rax, 18*32)) vmulps(ymm0, ymm2, ymm6) @@ -179,7 +183,7 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm15, ymm6, ymm15) vaddps(ymm13, ymm7, ymm13) - + vmovaps(mem(rax, 3*32), ymm1) add(imm(4*8*4), rax) // a += 4*8 (unroll x mr) vpermilps(imm(0x4e), ymm2, ymm3) @@ -187,7 +191,7 @@ void bli_sgemm_sandybridge_asm_8x8 vmulps(ymm0, ymm5, ymm7) vaddps(ymm11, ymm6, ymm11) vaddps(ymm9, ymm7, ymm9) - + vmulps(ymm0, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 3*32), ymm2) @@ -195,14 +199,14 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm14, ymm6, ymm14) vaddps(ymm12, ymm7, ymm12) - + vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vaddps(ymm10, ymm6, ymm10) vaddps(ymm8, ymm7, ymm8) - - + + // iteration 3 vmulps(ymm1, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) @@ -212,14 +216,14 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm15, ymm6, ymm15) vaddps(ymm13, ymm7, ymm13) - + vmovaps(mem(rax, 0*32), ymm0) vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddps(ymm11, ymm6, ymm11) vaddps(ymm9, ymm7, ymm9) - + vmulps(ymm1, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 0*32), ymm2) @@ -227,35 +231,35 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm14, ymm6, ymm14) vaddps(ymm12, ymm7, ymm12) - + vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddps(ymm10, ymm6, ymm10) vaddps(ymm8, ymm7, ymm8) - - - - + + + + dec(rsi) // i -= 1; jne(.SLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.SLOOPKLEFT) // EDGE LOOP - - + + prefetch(0, mem(rax, 16*32)) vmulps(ymm0, ymm2, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) @@ -264,7 +268,7 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm15, ymm6, ymm15) vaddps(ymm13, ymm7, ymm13) - + vmovaps(mem(rax, 1*32), ymm1) add(imm(8*1*4), rax) // a += 8 (1 x mr) vpermilps(imm(0x4e), ymm2, ymm3) @@ -272,7 +276,7 @@ void bli_sgemm_sandybridge_asm_8x8 vmulps(ymm0, ymm5, ymm7) vaddps(ymm11, ymm6, ymm11) vaddps(ymm9, ymm7, ymm9) - + vmulps(ymm0, ymm2, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 1*32), ymm2) @@ -281,122 +285,122 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm14, ymm6, ymm14) vaddps(ymm12, ymm7, ymm12) - + vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(ymm1, ymm0) vaddps(ymm10, ymm6, ymm10) vaddps(ymm8, ymm7, ymm8) - - - + + + dec(rsi) // i -= 1; jne(.SLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.SPOSTACCUM) - + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab22 ab20 ab26 ab24 // ab32 ab30 ab36 ab34 // ab44 ab46 ab40 ab42 - // ab54 ab56 ab50 ab52 + // ab54 ab56 ab50 ab52 // ab66 ab64 ab62 ab60 // ab76 ) ab74 ) ab72 ) ab70 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab23 ab21 ab27 ab25 // ab33 ab31 ab37 ab35 // ab45 ab47 ab41 ab43 - // ab55 ab57 ab51 ab53 + // ab55 ab57 ab51 ab53 // ab67 ab65 ab63 ab61 // ab77 ) ab75 ) ab73 ) ab71 ) - + vmovaps(ymm15, ymm7) vshufps(imm(0xe4), ymm13, ymm15, ymm15) vshufps(imm(0xe4), ymm7, ymm13, ymm13) - + vmovaps(ymm11, ymm7) vshufps(imm(0xe4), ymm9, ymm11, ymm11) vshufps(imm(0xe4), ymm7, ymm9, ymm9) - + vmovaps(ymm14, ymm7) vshufps(imm(0xe4), ymm12, ymm14, ymm14) vshufps(imm(0xe4), ymm7, ymm12, ymm12) - + vmovaps(ymm10, ymm7) vshufps(imm(0xe4), ymm8, ymm10, ymm10) vshufps(imm(0xe4), ymm7, ymm8, ymm8) - + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab20 ab22 ab24 ab26 // ab30 ab32 ab34 ab36 // ab44 ab46 ab40 ab42 - // ab54 ab56 ab50 ab52 + // ab54 ab56 ab50 ab52 // ab64 ab66 ab60 ab62 // ab74 ) ab76 ) ab70 ) ab72 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab21 ab23 ab25 ab27 // ab31 ab33 ab35 ab37 // ab45 ab47 ab41 ab43 - // ab55 ab57 ab51 ab53 + // ab55 ab57 ab51 ab53 // ab65 ab67 ab61 ab63 // ab75 ) ab77 ) ab71 ) ab73 ) - + vmovaps(ymm15, ymm7) vperm2f128(imm(0x30), ymm11, ymm15, ymm15) vperm2f128(imm(0x12), ymm11, ymm7, ymm11) - + vmovaps(ymm13, ymm7) vperm2f128(imm(0x30), ymm9, ymm13, ymm13) vperm2f128(imm(0x12), ymm9, ymm7, ymm9) - + vmovaps(ymm14, ymm7) vperm2f128(imm(0x30), ymm10, ymm14, ymm14) vperm2f128(imm(0x12), ymm10, ymm7, ymm10) - + vmovaps(ymm12, ymm7) vperm2f128(imm(0x30), ymm8, ymm12, ymm12) vperm2f128(imm(0x12), ymm8, ymm7, ymm8) - + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab20 ab22 ab24 ab26 // ab30 ab32 ab34 ab36 // ab40 ab42 ab44 ab46 - // ab50 ab52 ab54 ab56 + // ab50 ab52 ab54 ab56 // ab60 ab62 ab64 ab66 // ab70 ) ab72 ) ab74 ) ab76 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab21 ab23 ab25 ab27 // ab31 ab33 ab35 ab37 // ab41 ab43 ab45 ab47 - // ab51 ab53 ab55 ab57 + // ab51 ab53 ab55 ab57 // ab61 ab63 ab65 ab67 // ab71 ) ab73 ) ab75 ) ab77 ) - - - + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rax), ymm0) // load alpha and duplicate vbroadcastss(mem(rbx), ymm4) // load beta and duplicate - + vmulps(ymm0, ymm8, ymm8) // scale by alpha vmulps(ymm0, ymm9, ymm9) vmulps(ymm0, ymm10, ymm10) @@ -405,618 +409,118 @@ void bli_sgemm_sandybridge_asm_8x8 vmulps(ymm0, ymm13, ymm13) vmulps(ymm0, ymm14, ymm14) vmulps(ymm0, ymm15, ymm15) - - - - - - + + mov(var(rs_c), rsi) // load rs_c lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float) - + lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - + lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; lea(mem(r12, rsi, 1), r13) // r13 = 3*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm4) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. - jz(.SCOLSTORED) // jump to column storage case - - - - label(.SGENSTORED) - - // update c00:c70 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm15, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c01:c71 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm14, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c02:c72 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm13, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c03:c73 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm12, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c04:c74 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm11, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c05:c75 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm10, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c06:c76 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm9, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c07:c77 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm8, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - - vmovups(mem(rcx), ymm0) // load c00:c70, - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm15, ymm0, ymm0) // add the gemm result, - vmovups(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm1) // load c01:c71, - vmulps(ymm4, ymm1, ymm1) // scale by beta, - vaddps(ymm14, ymm1, ymm1) // add the gemm result, - vmovups(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm0) // load c02:c72, - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm13, ymm0, ymm0) // add the gemm result, - vmovups(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm1) // load c03:c73, - vmulps(ymm4, ymm1, ymm1) // scale by beta, - vaddps(ymm12, ymm1, ymm1) // add the gemm result, - vmovups(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm0) // load c04:c74, - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm11, ymm0, ymm0) // add the gemm result, - vmovups(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm1) // load c05:c75, - vmulps(ymm4, ymm1, ymm1) // scale by beta, - vaddps(ymm10, ymm1, ymm1) // add the gemm result, - vmovups(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm0) // load c06:c76, - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm9, ymm0, ymm0) // add the gemm result, - vmovups(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm1) // load c07:c77, - vmulps(ymm4, ymm1, ymm1) // scale by beta, - vaddps(ymm8, ymm1, ymm1) // add the gemm result, - vmovups(ymm1, mem(rcx)) // and store back to memory. - - - jmp(.SDONE) // jump to end. - - - - + + vmovups(mem(rcx), ymm0) // load c00:c70, + vmulps(ymm4, ymm0, ymm0) // scale by beta, + vaddps(ymm15, ymm0, ymm0) // add the gemm result, + vmovups(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm1) // load c01:c71, + vmulps(ymm4, ymm1, ymm1) // scale by beta, + vaddps(ymm14, ymm1, ymm1) // add the gemm result, + vmovups(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm0) // load c02:c72, + vmulps(ymm4, ymm0, ymm0) // scale by beta, + vaddps(ymm13, ymm0, ymm0) // add the gemm result, + vmovups(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm1) // load c03:c73, + vmulps(ymm4, ymm1, ymm1) // scale by beta, + vaddps(ymm12, ymm1, ymm1) // add the gemm result, + vmovups(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm0) // load c04:c74, + vmulps(ymm4, ymm0, ymm0) // scale by beta, + vaddps(ymm11, ymm0, ymm0) // add the gemm result, + vmovups(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm1) // load c05:c75, + vmulps(ymm4, ymm1, ymm1) // scale by beta, + vaddps(ymm10, ymm1, ymm1) // add the gemm result, + vmovups(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm0) // load c06:c76, + vmulps(ymm4, ymm0, ymm0) // scale by beta, + vaddps(ymm9, ymm0, ymm0) // add the gemm result, + vmovups(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm1) // load c07:c77, + vmulps(ymm4, ymm1, ymm1) // scale by beta, + vaddps(ymm8, ymm1, ymm1) // add the gemm result, + vmovups(ymm1, mem(rcx)) // and store back to memory. + + jmp(.SDONE) // jump to end. + label(.SBETAZERO) - - cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. - jz(.SCOLSTORBZ) // jump to column storage case - - - - label(.SGENSTORBZ) - - // update c00:c70 - vmovups(ymm15, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c01:c71 - vmovups(ymm14, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c02:c72 - vmovups(ymm13, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c03:c73 - vmovups(ymm12, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c04:c74 - vmovups(ymm11, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c05:c75 - vmovups(ymm10, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c06:c76 - vmovups(ymm9, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c07:c77 - vmovups(ymm8, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORBZ) - - - vmovups(ymm15, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm14, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm13, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm12, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm11, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm10, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm9, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm8, mem(rcx)) // and store back to memory. - - - - - + + vmovups(ymm15, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm14, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm13, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm12, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm11, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm10, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm9, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm8, mem(rcx)) // and store back to memory. + label(.SDONE) - + vzeroupper() - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1024,11 +528,15 @@ void bli_sgemm_sandybridge_asm_8x8 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } void bli_dgemm_sandybridge_asm_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -1043,34 +551,36 @@ void bli_dgemm_sandybridge_asm_8x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( d, 8, 4, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. //mov(var(a_next), r14) // load address of a_next. sub(imm(4*64), r15) - + vmovapd(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovapd(mem(rbx, 0*32), ymm2) // elements of a and b. vpermilpd(imm(0x5), ymm2, ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*cs_c prefetch(0, mem(r10, 3*8)) // prefetch c + 2*cs_c prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c - + vxorpd(ymm8, ymm8, ymm8) vxorpd(ymm9, ymm9, ymm9) vxorpd(ymm10, ymm10, ymm10) @@ -1079,19 +589,19 @@ void bli_dgemm_sandybridge_asm_8x4 vxorpd(ymm13, ymm13, ymm13) vxorpd(ymm14, ymm14, ymm14) vxorpd(ymm15, ymm15, ymm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - + add(imm(4*4*8), r15) // b_next += 4*4 (unroll x nr) - + // iteration 0 vmovapd(mem(rax, 1*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -1100,7 +610,7 @@ void bli_dgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm15, ymm6, ymm15) vaddpd(ymm13, ymm7, ymm13) - + prefetch(0, mem(rax, 16*32)) vmulpd(ymm1, ymm2, ymm6) vmovapd(mem(rbx, 1*32), ymm2) @@ -1108,20 +618,20 @@ void bli_dgemm_sandybridge_asm_8x4 vpermilpd(imm(0x5), ymm2, ymm3) vaddpd(ymm14, ymm6, ymm14) vaddpd(ymm12, ymm7, ymm12) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 2*32), ymm0) vaddpd(ymm11, ymm6, ymm11) vaddpd(ymm9, ymm7, ymm9) prefetch(0, mem(r15, 0*32)) // prefetch b_next[0*4] - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddpd(ymm10, ymm6, ymm10) vaddpd(ymm8, ymm7, ymm8) - - + + // iteration 1 vmovapd(mem(rax, 3*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -1130,7 +640,7 @@ void bli_dgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm15, ymm6, ymm15) vaddpd(ymm13, ymm7, ymm13) - + prefetch(0, mem(rax, 18*32)) vmulpd(ymm1, ymm2, ymm6) vmovapd(mem(rbx, 2*32), ymm2) @@ -1138,19 +648,19 @@ void bli_dgemm_sandybridge_asm_8x4 vpermilpd(imm(0x5), ymm2, ymm3) vaddpd(ymm14, ymm6, ymm14) vaddpd(ymm12, ymm7, ymm12) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 4*32), ymm0) vaddpd(ymm11, ymm6, ymm11) vaddpd(ymm9, ymm7, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddpd(ymm10, ymm6, ymm10) vaddpd(ymm8, ymm7, ymm8) - - + + // iteration 2 vmovapd(mem(rax, 5*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -1159,7 +669,7 @@ void bli_dgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm15, ymm6, ymm15) vaddpd(ymm13, ymm7, ymm13) - + prefetch(0, mem(rax, 20*32)) vmulpd(ymm1, ymm2, ymm6) vmovapd(mem(rbx, 3*32), ymm2) @@ -1168,20 +678,20 @@ void bli_dgemm_sandybridge_asm_8x4 vpermilpd(imm(0x5), ymm2, ymm3) vaddpd(ymm14, ymm6, ymm14) vaddpd(ymm12, ymm7, ymm12) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 6*32), ymm0) vaddpd(ymm11, ymm6, ymm11) vaddpd(ymm9, ymm7, ymm9) prefetch(0, mem(r15, 2*32)) // prefetch b_next[2*4] - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddpd(ymm10, ymm6, ymm10) vaddpd(ymm8, ymm7, ymm8) - - + + // iteration 3 vmovapd(mem(rax, 7*32), ymm1) add(imm(4*8*8), rax) // a += 4*8 (unroll x mr) @@ -1191,7 +701,7 @@ void bli_dgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm15, ymm6, ymm15) vaddpd(ymm13, ymm7, ymm13) - + //prefetch(0, mem(rax, 22*32)) prefetch(0, mem(rax, 14*32)) vmulpd(ymm1, ymm2, ymm6) @@ -1200,41 +710,41 @@ void bli_dgemm_sandybridge_asm_8x4 vpermilpd(imm(0x5), ymm2, ymm3) vaddpd(ymm14, ymm6, ymm14) vaddpd(ymm12, ymm7, ymm12) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 0*32), ymm0) vaddpd(ymm11, ymm6, ymm11) vaddpd(ymm9, ymm7, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddpd(ymm10, ymm6, ymm10) vaddpd(ymm8, ymm7, ymm8) - - - + + + //add(imm(4*8*8), rax) // a += 4*8 (unroll x mr) //add(imm(4*4*8), rbx) // b += 4*4 (unroll x nr) - + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + vmovapd(mem(rax, 1*32), ymm1) add(imm(8*1*8), rax) // a += 8 (1 x mr) vmulpd(ymm0, ymm2, ymm6) @@ -1243,7 +753,7 @@ void bli_dgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm15, ymm6, ymm15) vaddpd(ymm13, ymm7, ymm13) - + prefetch(0, mem(rax, 14*32)) vmulpd(ymm1, ymm2, ymm6) vmovapd(mem(rbx, 1*32), ymm2) @@ -1252,101 +762,101 @@ void bli_dgemm_sandybridge_asm_8x4 vpermilpd(imm(0x5), ymm2, ymm3) vaddpd(ymm14, ymm6, ymm14) vaddpd(ymm12, ymm7, ymm12) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 0*32), ymm0) vaddpd(ymm11, ymm6, ymm11) vaddpd(ymm9, ymm7, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddpd(ymm10, ymm6, ymm10) vaddpd(ymm8, ymm7, ymm8) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - + + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab11 ab10 ab13 ab12 + // ab11 ab10 ab13 ab12 // ab22 ab23 ab20 ab21 // ab33 ) ab32 ) ab31 ) ab30 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab51 ab50 ab53 ab52 + // ab51 ab50 ab53 ab52 // ab62 ab63 ab60 ab61 // ab73 ) ab72 ) ab71 ) ab70 ) - + vmovapd(ymm15, ymm7) vshufpd(imm(0xa), ymm15, ymm13, ymm15) vshufpd(imm(0xa), ymm13, ymm7, ymm13) - + vmovapd(ymm11, ymm7) vshufpd(imm(0xa), ymm11, ymm9, ymm11) vshufpd(imm(0xa), ymm9, ymm7, ymm9) - + vmovapd(ymm14, ymm7) vshufpd(imm(0xa), ymm14, ymm12, ymm14) vshufpd(imm(0xa), ymm12, ymm7, ymm12) - + vmovapd(ymm10, ymm7) vshufpd(imm(0xa), ymm10, ymm8, ymm10) vshufpd(imm(0xa), ymm8, ymm7, ymm8) - + // ymm15: ymm13: ymm11: ymm9: // ( ab01 ( ab00 ( ab03 ( ab02 - // ab11 ab10 ab13 ab12 + // ab11 ab10 ab13 ab12 // ab23 ab22 ab21 ab20 // ab33 ) ab32 ) ab31 ) ab30 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab41 ( ab40 ( ab43 ( ab42 - // ab51 ab50 ab53 ab52 + // ab51 ab50 ab53 ab52 // ab63 ab62 ab61 ab60 // ab73 ) ab72 ) ab71 ) ab70 ) - + vmovapd(ymm15, ymm7) vperm2f128(imm(0x30), ymm15, ymm11, ymm15) vperm2f128(imm(0x12), ymm7, ymm11, ymm11) - + vmovapd(ymm13, ymm7) vperm2f128(imm(0x30), ymm13, ymm9, ymm13) vperm2f128(imm(0x12), ymm7, ymm9, ymm9) - + vmovapd(ymm14, ymm7) vperm2f128(imm(0x30), ymm14, ymm10, ymm14) vperm2f128(imm(0x12), ymm7, ymm10, ymm10) - + vmovapd(ymm12, ymm7) vperm2f128(imm(0x30), ymm12, ymm8, ymm12) vperm2f128(imm(0x12), ymm7, ymm8, ymm8) - + // ymm9: ymm11: ymm13: ymm15: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 + // ab10 ab11 ab12 ab13 // ab20 ab21 ab22 ab23 // ab30 ) ab31 ) ab32 ) ab33 ) - + // ymm8: ymm10: ymm12: ymm14: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab50 ab51 ab52 ab53 + // ab50 ab51 ab52 ab53 // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - - + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm2) // load beta and duplicate - + vmulpd(ymm0, ymm8, ymm8) // scale by alpha vmulpd(ymm0, ymm9, ymm9) vmulpd(ymm0, ymm10, ymm10) @@ -1355,343 +865,124 @@ void bli_dgemm_sandybridge_asm_8x4 vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm0, ymm14, ymm14) vmulpd(ymm0, ymm15, ymm15) - - - - - - + + mov(var(rs_c), rsi) // load rs_c lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) - + lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - + lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; lea(mem(r12, rsi, 1), r13) // r13 = 3*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm2) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.DCOLSTORED) // jump to column storage case - - - - label(.DGENSTORED) - // update c00:c33 - - vextractf128(imm(1), ymm9, xmm1) - vmovlpd(mem(rcx), xmm0, xmm0) // load c00 and c10, - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm9, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c20 and c30, - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm11, xmm1) - vmovlpd(mem(rcx), xmm0, xmm0) // load c01 and c11, - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm11, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c21 and c31, - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm13, xmm1) - vmovlpd(mem(rcx), xmm0, xmm0) // load c02 and c12, - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm13, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c22 and c32, - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm15, xmm1) - vmovlpd(mem(rcx), xmm0, xmm0) // load c03 and c13, - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm15, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c23 and c33, - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, r13, 1)) - - // update c40:c73 - - vextractf128(imm(1), ymm8, xmm1) - vmovlpd(mem(rdx), xmm0, xmm0) // load c40 and c50, - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm8, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c60 and c70, - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm10, xmm1) - vmovlpd(mem(rdx), xmm0, xmm0) // load c41 and c51, - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm10, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c61 and c71, - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm12, xmm1) - vmovlpd(mem(rdx), xmm0, xmm0) // load c42 and c52, - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm12, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c62 and c72, - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm14, xmm1) - vmovlpd(mem(rdx), xmm0, xmm0) // load c43 and c53, - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm14, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c63 and c73, - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, r13, 1)) - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORED) - // update c00:c33 - - vmovupd(mem(rcx), ymm0) // load c00:c30, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm9, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovupd(mem(rcx), ymm0) // load c01:c31, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm11, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovupd(mem(rcx), ymm0) // load c02:c32, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm13, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovupd(mem(rcx), ymm0) // load c03:c33, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm15, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rcx)) // and store back to memory. - - // update c40:c73 - - vmovupd(mem(rdx), ymm0) // load c40:c70, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm8, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rdx)) // and store back to memory. - add(rdi, rdx) // c += cs_c; - - vmovupd(mem(rdx), ymm0) // load c41:c71, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm10, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rdx)) // and store back to memory. - add(rdi, rdx) // c += cs_c; - - vmovupd(mem(rdx), ymm0) // load c42:c72, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm12, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rdx)) // and store back to memory. - add(rdi, rdx) // c += cs_c; - - vmovupd(mem(rdx), ymm0) // load c43:c73, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm14, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rdx)) // and store back to memory. - - - jmp(.DDONE) // jump to end. - - - - + + // update c00:c33 + + vmovupd(mem(rcx), ymm0) // load c00:c30, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm9, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovupd(mem(rcx), ymm0) // load c01:c31, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm11, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovupd(mem(rcx), ymm0) // load c02:c32, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm13, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovupd(mem(rcx), ymm0) // load c03:c33, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm15, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rcx)) // and store back to memory. + + // update c40:c73 + + vmovupd(mem(rdx), ymm0) // load c40:c70, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm8, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rdx)) // and store back to memory. + add(rdi, rdx) // c += cs_c; + + vmovupd(mem(rdx), ymm0) // load c41:c71, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm10, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rdx)) // and store back to memory. + add(rdi, rdx) // c += cs_c; + + vmovupd(mem(rdx), ymm0) // load c42:c72, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm12, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rdx)) // and store back to memory. + add(rdi, rdx) // c += cs_c; + + vmovupd(mem(rdx), ymm0) // load c43:c73, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm14, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rdx)) // and store back to memory. + + jmp(.DDONE) // jump to end. + label(.DBETAZERO) - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.DCOLSTORBZ) // jump to column storage case - - - - label(.DGENSTORBZ) - // update c00:c33 - - vextractf128(imm(1), ymm9, xmm1) - vmovlpd(xmm9, mem(rcx)) // store to c00:c30 - vmovhpd(xmm9, mem(rcx, rsi, 1)) - vmovlpd(xmm1, mem(rcx, r12, 1)) - vmovhpd(xmm1, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm11, xmm1) - vmovlpd(xmm11, mem(rcx)) // store to c01:c31 - vmovhpd(xmm11, mem(rcx, rsi, 1)) - vmovlpd(xmm1, mem(rcx, r12, 1)) - vmovhpd(xmm1, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm13, xmm1) - vmovlpd(xmm13, mem(rcx)) // store to c02:c32 - vmovhpd(xmm13, mem(rcx, rsi, 1)) - vmovlpd(xmm1, mem(rcx, r12, 1)) - vmovhpd(xmm1, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm15, xmm1) - vmovlpd(xmm15, mem(rcx)) // store to c03:c33 - vmovhpd(xmm15, mem(rcx, rsi, 1)) - vmovlpd(xmm1, mem(rcx, r12, 1)) - vmovhpd(xmm1, mem(rcx, r13, 1)) - - // update c40:c73 - - vextractf128(imm(1), ymm8, xmm1) - vmovlpd(xmm8, mem(rdx)) // store to c40:c70 - vmovhpd(xmm8, mem(rdx, rsi, 1)) - vmovlpd(xmm1, mem(rdx, r12, 1)) - vmovhpd(xmm1, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm10, xmm1) - vmovlpd(xmm10, mem(rdx)) // store to c41:c71 - vmovhpd(xmm10, mem(rdx, rsi, 1)) - vmovlpd(xmm1, mem(rdx, r12, 1)) - vmovhpd(xmm1, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm12, xmm1) - vmovlpd(xmm12, mem(rdx)) // store to c42:c72 - vmovhpd(xmm12, mem(rdx, rsi, 1)) - vmovlpd(xmm1, mem(rdx, r12, 1)) - vmovhpd(xmm1, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm14, xmm1) - vmovlpd(xmm14, mem(rdx)) // store to c43:c73 - vmovhpd(xmm14, mem(rdx, rsi, 1)) - vmovlpd(xmm1, mem(rdx, r12, 1)) - vmovhpd(xmm1, mem(rdx, r13, 1)) - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORBZ) - // update c00:c33 - - vmovupd(ymm9, mem(rcx)) // store c00:c30 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm11, mem(rcx)) // store c01:c31 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm13, mem(rcx)) // store c02:c32 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm15, mem(rcx)) // store c03:c33 - - // update c40:c73 - - vmovupd(ymm8, mem(rdx)) // store c40:c70 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm10, mem(rdx)) // store c41:c71 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm12, mem(rdx)) // store c42:c72 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm14, mem(rdx)) // store c43:c73 - - - - - + + // update c00:c33 + + vmovupd(ymm9, mem(rcx)) // store c00:c30 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm11, mem(rcx)) // store c01:c31 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm13, mem(rcx)) // store c02:c32 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm15, mem(rcx)) // store c03:c33 + + // update c40:c73 + + vmovupd(ymm8, mem(rdx)) // store c40:c70 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm10, mem(rdx)) // store c41:c71 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm12, mem(rdx)) // store c42:c72 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm14, mem(rdx)) // store c43:c73 + label(.DDONE) - - vzeroupper() - + vzeroupper() - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next)/*, // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next)/*, // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1699,11 +990,15 @@ void bli_dgemm_sandybridge_asm_8x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } void bli_cgemm_sandybridge_asm_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -1718,34 +1013,36 @@ void bli_cgemm_sandybridge_asm_8x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( c, 8, 4, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. //mov(var(a_next), r14) // load address of a_next. sub(imm(4*64), r15) - + vmovaps(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovsldup(mem(rbx, 0*32), ymm2) vpermilps(imm(0x4e), ymm2, ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(scomplex) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*cs_c prefetch(0, mem(r10, 3*8)) // prefetch c + 2*cs_c prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c - + vxorps(ymm8, ymm8, ymm8) vxorps(ymm9, ymm9, ymm9) vxorps(ymm10, ymm10, ymm10) @@ -1754,19 +1051,19 @@ void bli_cgemm_sandybridge_asm_8x4 vxorps(ymm13, ymm13, ymm13) vxorps(ymm14, ymm14, ymm14) vxorps(ymm15, ymm15, ymm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.CCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.CLOOPKITER) // MAIN LOOP - + add(imm(4*4*8), r15) // b_next += 4*4 (unroll x nr) - + // iteration 0 prefetch(0, mem(rax, 8*32)) vmovaps(mem(rax, 1*32), ymm1) @@ -1776,20 +1073,20 @@ void bli_cgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm15, ymm15) vaddps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovshdup(mem(rbx, 0*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddps(ymm6, ymm14, ymm14) vaddps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vpermilps(imm(0xb1), ymm0, ymm0) vaddps(ymm6, ymm11, ymm11) vaddps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulps(ymm1, ymm5, ymm7) @@ -1797,32 +1094,32 @@ void bli_cgemm_sandybridge_asm_8x4 vaddps(ymm6, ymm10, ymm10) vaddps(ymm7, ymm8, ymm8) prefetch(0, mem(r15, 0*32)) // prefetch b_next[0*4] - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 1*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 2*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + // iteration 1 prefetch(0, mem(rax, 10*32)) vmovaps(mem(rax, 3*32), ymm1) @@ -1832,52 +1129,52 @@ void bli_cgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm15, ymm15) vaddps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovshdup(mem(rbx, 1*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddps(ymm6, ymm14, ymm14) vaddps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vpermilps(imm(0xb1), ymm0, ymm0) vaddps(ymm6, ymm11, ymm11) vaddps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulps(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm10, ymm10) vaddps(ymm7, ymm8, ymm8) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 2*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 4*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + // iteration 2 prefetch(0, mem(rax, 12*32)) vmovaps(mem(rax, 5*32), ymm1) @@ -1887,20 +1184,20 @@ void bli_cgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm15, ymm15) vaddps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovshdup(mem(rbx, 2*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddps(ymm6, ymm14, ymm14) vaddps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vpermilps(imm(0xb1), ymm0, ymm0) vaddps(ymm6, ymm11, ymm11) vaddps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulps(ymm1, ymm5, ymm7) @@ -1908,32 +1205,32 @@ void bli_cgemm_sandybridge_asm_8x4 vaddps(ymm6, ymm10, ymm10) vaddps(ymm7, ymm8, ymm8) prefetch(0, mem(r15, 2*32)) // prefetch b_next[2*4] - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 3*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 6*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + // iteration 3 prefetch(0, mem(rax, 14*32)) vmovaps(mem(rax, 7*32), ymm1) @@ -1943,74 +1240,74 @@ void bli_cgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm15, ymm15) vaddps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovshdup(mem(rbx, 3*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddps(ymm6, ymm14, ymm14) vaddps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vpermilps(imm(0xb1), ymm0, ymm0) vaddps(ymm6, ymm11, ymm11) vaddps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulps(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm10, ymm10) vaddps(ymm7, ymm8, ymm8) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 4*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 8*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + add(imm(8*4*8), rax) // a += 8*4 (unroll x mr) add(imm(4*4*8), rbx) // b += 4*4 (unroll x nr) - - + + dec(rsi) // i -= 1; jne(.CLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.CCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.CPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.CLOOPKLEFT) // EDGE LOOP - + // iteration 0 prefetch(0, mem(rax, 8*32)) vmovaps(mem(rax, 1*32), ymm1) @@ -2020,228 +1317,228 @@ void bli_cgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm15, ymm15) vaddps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovshdup(mem(rbx, 0*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddps(ymm6, ymm14, ymm14) vaddps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vpermilps(imm(0xb1), ymm0, ymm0) vaddps(ymm6, ymm11, ymm11) vaddps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulps(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm10, ymm10) vaddps(ymm7, ymm8, ymm8) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 1*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 2*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + add(imm(8*1*8), rax) // a += 8 (1 x mr) add(imm(4*1*8), rbx) // b += 4 (1 x nr) - - + + dec(rsi) // i -= 1; jne(.CLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.CPOSTACCUM) - + // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab21 ab20 ab23 ab22 - // ab31 ab30 ab33 ab32 - // ab42 ab43 ab40 ab41 - // ab52 ab53 ab50 ab51 - // ab63 ab62 ab61 ab60 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab21 ab20 ab23 ab22 + // ab31 ab30 ab33 ab32 + // ab42 ab43 ab40 ab41 + // ab52 ab53 ab50 ab51 + // ab63 ab62 ab61 ab60 // ab73 ) ab72 ) ab71 ) ab70 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba1 aba0 aba3 aba2 - // abb1 abb0 abb3 abb2 - // abc2 abc3 abc0 abc1 - // abd2 abd3 abd0 abd1 - // abe3 abe2 abe1 abe0 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba1 aba0 aba3 aba2 + // abb1 abb0 abb3 abb2 + // abc2 abc3 abc0 abc1 + // abd2 abd3 abd0 abd1 + // abe3 abe2 abe1 abe0 // abf3 abf2 abf1 abf0 ) - + vmovaps(ymm15, ymm7) vshufps(imm(0xe4), ymm13, ymm15, ymm15) vshufps(imm(0xe4), ymm7, ymm13, ymm13) - + vmovaps(ymm11, ymm7) vshufps(imm(0xe4), ymm9, ymm11, ymm11) vshufps(imm(0xe4), ymm7, ymm9, ymm9) - + vmovaps(ymm14, ymm7) vshufps(imm(0xe4), ymm12, ymm14, ymm14) vshufps(imm(0xe4), ymm7, ymm12, ymm12) - + vmovaps(ymm10, ymm7) vshufps(imm(0xe4), ymm8, ymm10, ymm10) vshufps(imm(0xe4), ymm7, ymm8, ymm8) - + // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab20 ab21 ab22 ab23 - // ab30 ab31 ab32 ab33 - // ab42 ab43 ab40 ab41 - // ab52 ab53 ab50 ab51 - // ab62 ab63 ab60 ab61 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab20 ab21 ab22 ab23 + // ab30 ab31 ab32 ab33 + // ab42 ab43 ab40 ab41 + // ab52 ab53 ab50 ab51 + // ab62 ab63 ab60 ab61 // ab72 ) ab73 ) ab70 ) ab71 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba0 aba1 aba2 aba3 - // abb0 abb1 abb2 abb3 - // abc2 abc3 abc0 abc1 - // abd2 abd3 abd0 abd1 - // abe2 abe3 abe0 abe1 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba0 aba1 aba2 aba3 + // abb0 abb1 abb2 abb3 + // abc2 abc3 abc0 abc1 + // abd2 abd3 abd0 abd1 + // abe2 abe3 abe0 abe1 // abf2 ) abf3 ) abf0 ) abf1 ) - + vmovaps(ymm15, ymm7) vperm2f128(imm(0x12), ymm15, ymm11, ymm15) vperm2f128(imm(0x30), ymm7, ymm11, ymm11) - + vmovaps(ymm13, ymm7) vperm2f128(imm(0x12), ymm13, ymm9, ymm13) vperm2f128(imm(0x30), ymm7, ymm9, ymm9) - + vmovaps(ymm14, ymm7) vperm2f128(imm(0x12), ymm14, ymm10, ymm14) vperm2f128(imm(0x30), ymm7, ymm10, ymm10) - + vmovaps(ymm12, ymm7) vperm2f128(imm(0x12), ymm12, ymm8, ymm12) vperm2f128(imm(0x30), ymm7, ymm8, ymm8) - + // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab20 ab21 ab22 ab23 - // ab30 ab31 ab32 ab33 - // ab40 ab41 ab42 ab43 - // ab50 ab51 ab52 ab53 - // ab60 ab61 ab62 ab63 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab20 ab21 ab22 ab23 + // ab30 ab31 ab32 ab33 + // ab40 ab41 ab42 ab43 + // ab50 ab51 ab52 ab53 + // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba0 aba1 aba2 aba3 - // abb0 abb1 abb2 abb3 - // abc0 abc1 abc2 abc3 - // abd0 abd1 abd2 abd3 - // abe0 abe1 abe2 abe3 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba0 aba1 aba2 aba3 + // abb0 abb1 abb2 abb3 + // abc0 abc1 abc2 abc3 + // abd0 abd1 abd2 abd3 + // abe0 abe1 abe2 abe3 // abf0 ) abf1 ) abf2 ) abf3 ) - - - - + + + + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vbroadcastss(mem(rax), ymm7) // load alpha_r and duplicate vbroadcastss(mem(rax, 4), ymm6) // load alpha_i and duplicate - + vpermilps(imm(0xb1), ymm15, ymm3) vmulps(ymm7, ymm15, ymm15) vmulps(ymm6, ymm3, ymm3) vaddsubps(ymm3, ymm15, ymm15) - + vpermilps(imm(0xb1), ymm14, ymm2) vmulps(ymm7, ymm14, ymm14) vmulps(ymm6, ymm2, ymm2) vaddsubps(ymm2, ymm14, ymm14) - + vpermilps(imm(0xb1), ymm13, ymm1) vmulps(ymm7, ymm13, ymm13) vmulps(ymm6, ymm1, ymm1) vaddsubps(ymm1, ymm13, ymm13) - + vpermilps(imm(0xb1), ymm12, ymm0) vmulps(ymm7, ymm12, ymm12) vmulps(ymm6, ymm0, ymm0) vaddsubps(ymm0, ymm12, ymm12) - + vpermilps(imm(0xb1), ymm11, ymm3) vmulps(ymm7, ymm11, ymm11) vmulps(ymm6, ymm3, ymm3) vaddsubps(ymm3, ymm11, ymm11) - + vpermilps(imm(0xb1), ymm10, ymm2) vmulps(ymm7, ymm10, ymm10) vmulps(ymm6, ymm2, ymm2) vaddsubps(ymm2, ymm10, ymm10) - + vpermilps(imm(0xb1), ymm9, ymm1) vmulps(ymm7, ymm9, ymm9) vmulps(ymm6, ymm1, ymm1) vaddsubps(ymm1, ymm9, ymm9) - + vpermilps(imm(0xb1), ymm8, ymm0) vmulps(ymm7, ymm8, ymm8) vmulps(ymm6, ymm0, ymm0) vaddsubps(ymm0, ymm8, ymm8) - - - - + + + + mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rbx), ymm7) // load beta_r and duplicate vbroadcastss(mem(rbx, 4), ymm6) // load beta_i and duplicate - - - - - - - + + + + + + + mov(var(rs_c), rsi) // load rs_c lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(scomplex) - + lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - + lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; lea(mem(r12, rsi, 1), r13) // r13 = 3*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm7) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -2249,410 +1546,144 @@ void bli_cgemm_sandybridge_asm_8x4 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 0, jump to beta == 0 case - - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.CCOLSTORED) // jump to column storage case - - - - label(.CGENSTORED) - - // update c00:c70 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c00,10) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c20,30) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c40,50) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c60,70) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c00,c10) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c20,c30) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c40,c50) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c60,c70) - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c80,90) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca0,b0) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc0,d0) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce0,f0) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c80,c90) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca0,cb0) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc0,cd0) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce0,cf0) - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c01,11) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c21,31) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c41,51) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c61,71) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c01,c11) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c21,c31) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c41,c51) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c61,c71) - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c81,91) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca1,b1) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc1,d1) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce1,f1) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c81,c91) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca1,cb1) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc1,cd1) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce1,cf1) - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c02,12) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c22,32) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c42,52) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c62,72) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c02,c12) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c22,c32) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c42,c52) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c62,c72) - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c82,92) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca2,b2) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc2,d2) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce2,f2) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c82,c92) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca2,cb2) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc2,cd2) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce2,cf2) - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c03,13) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c23,33) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c43,53) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c63,73) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c03,c13) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c23,c33) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c43,c53) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c63,c73) - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c83,93) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca3,b3) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc3,d3) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce3,f3) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c83,c93) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca3,cb3) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc3,cd3) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce3,cf3) - add(rdi, rdx) // c += cs_c; - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORED) - - // update c00:c70 - - vmovups(mem(rcx), ymm0) // load c00:c70 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rcx)) // store c00:c70 - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - - vmovups(mem(rdx), ymm0) // load c80:f0 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rdx)) // store c80:cf0 - add(rdi, rdx) // c += cs_c; - - // update c00:c70 - - vmovups(mem(rcx), ymm0) // load c01:c71 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rcx)) // store c01:c71 - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - - vmovups(mem(rdx), ymm0) // load c81:f1 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rdx)) // store c81:cf1 - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - - vmovups(mem(rcx), ymm0) // load c02:c72 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rcx)) // store c02:c72 - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - - vmovups(mem(rdx), ymm0) // load c82:f2 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rdx)) // store c82:cf2 - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - - vmovups(mem(rcx), ymm0) // load c03:c73 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rcx)) // store c03:c73 - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - - vmovups(mem(rdx), ymm0) // load c83:f3 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rdx)) // store c83:cf3 - add(rdi, rdx) // c += cs_c; - - - - jmp(.CDONE) // jump to end. - - - + + // update c00:c70 + + vmovups(mem(rcx), ymm0) // load c00:c70 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rcx)) // store c00:c70 + add(rdi, rcx) // c += cs_c; + + // update c80:cf0 + + vmovups(mem(rdx), ymm0) // load c80:f0 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rdx)) // store c80:cf0 + add(rdi, rdx) // c += cs_c; + + // update c00:c70 + + vmovups(mem(rcx), ymm0) // load c01:c71 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rcx)) // store c01:c71 + add(rdi, rcx) // c += cs_c; + + // update c81:cf1 + + vmovups(mem(rdx), ymm0) // load c81:f1 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rdx)) // store c81:cf1 + add(rdi, rdx) // c += cs_c; + + // update c02:c72 + + vmovups(mem(rcx), ymm0) // load c02:c72 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rcx)) // store c02:c72 + add(rdi, rcx) // c += cs_c; + + // update c82:cf2 + + vmovups(mem(rdx), ymm0) // load c82:f2 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rdx)) // store c82:cf2 + add(rdi, rdx) // c += cs_c; + + // update c03:c73 + + vmovups(mem(rcx), ymm0) // load c03:c73 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rcx)) // store c03:c73 + add(rdi, rcx) // c += cs_c; + + // update c83:cf3 + + vmovups(mem(rdx), ymm0) // load c83:f3 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rdx)) // store c83:cf3 + add(rdi, rdx) // c += cs_c; + + jmp(.CDONE) // jump to end. + label(.CBETAZERO) - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.CCOLSTORBZ) // jump to column storage case - - - - label(.CGENSTORBZ) - - // update c00:c70 - - vextractf128(imm(1), ymm15, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm15, mem(rcx)) // store (c00,c10) - vmovhpd(xmm15, mem(rcx, rsi, 1)) // store (c20,c30) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c40,c50) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c60,c70) - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - - vextractf128(imm(1), ymm14, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm14, mem(rdx)) // store (c80,c90) - vmovhpd(xmm14, mem(rdx, rsi, 1)) // store (ca0,cb0) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc0,cd0) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce0,cf0) - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - - vextractf128(imm(1), ymm13, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm13, mem(rcx)) // store (c01,c11) - vmovhpd(xmm13, mem(rcx, rsi, 1)) // store (c21,c31) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c41,c51) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c61,c71) - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - - vextractf128(imm(1), ymm12, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm12, mem(rdx)) // store (c81,c91) - vmovhpd(xmm12, mem(rdx, rsi, 1)) // store (ca1,cb1) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc1,cd1) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce1,cf1) - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - - vextractf128(imm(1), ymm11, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm11, mem(rcx)) // store (c02,c12) - vmovhpd(xmm11, mem(rcx, rsi, 1)) // store (c22,c32) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c42,c52) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c62,c72) - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - - vextractf128(imm(1), ymm10, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm10, mem(rdx)) // store (c82,c92) - vmovhpd(xmm10, mem(rdx, rsi, 1)) // store (ca2,cb2) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc2,cd2) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce2,cf2) - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - - vextractf128(imm(1), ymm9, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm9, mem(rcx)) // store (c03,c13) - vmovhpd(xmm9, mem(rcx, rsi, 1)) // store (c23,c33) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c43,c53) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c63,c73) - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - - vextractf128(imm(1), ymm8, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm8, mem(rdx)) // store (c83,c93) - vmovhpd(xmm8, mem(rdx, rsi, 1)) // store (ca3,cb3) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc3,cd3) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce3,cf3) - add(rdi, rdx) // c += cs_c; - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORBZ) - - - vmovups(ymm15, mem(rcx)) // store c00:c70 - add(rdi, rcx) // c += cs_c; - - vmovups(ymm14, mem(rdx)) // store c80:cf0 - add(rdi, rdx) // c += cs_c; - - vmovups(ymm13, mem(rcx)) // store c01:c71 - add(rdi, rcx) // c += cs_c; - - vmovups(ymm12, mem(rdx)) // store c81:cf1 - add(rdi, rdx) // c += cs_c; - - vmovups(ymm11, mem(rcx)) // store c02:c72 - add(rdi, rcx) // c += cs_c; - - vmovups(ymm10, mem(rdx)) // store c82:cf2 - add(rdi, rdx) // c += cs_c; - - vmovups(ymm9, mem(rcx)) // store c03:c73 - add(rdi, rcx) // c += cs_c; - - vmovups(ymm8, mem(rdx)) // store c83:cf3 - add(rdi, rdx) // c += cs_c; - - - - - + + vmovups(ymm15, mem(rcx)) // store c00:c70 + add(rdi, rcx) // c += cs_c; + + vmovups(ymm14, mem(rdx)) // store c80:cf0 + add(rdi, rdx) // c += cs_c; + + vmovups(ymm13, mem(rcx)) // store c01:c71 + add(rdi, rcx) // c += cs_c; + + vmovups(ymm12, mem(rdx)) // store c81:cf1 + add(rdi, rdx) // c += cs_c; + + vmovups(ymm11, mem(rcx)) // store c02:c72 + add(rdi, rcx) // c += cs_c; + + vmovups(ymm10, mem(rdx)) // store c82:cf2 + add(rdi, rdx) // c += cs_c; + + vmovups(ymm9, mem(rcx)) // store c03:c73 + add(rdi, rcx) // c += cs_c; + + vmovups(ymm8, mem(rdx)) // store c83:cf3 + add(rdi, rdx) // c += cs_c; + label(.CDONE) - - vzeroupper() - + vzeroupper() - - end_asm( + + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next)/*, // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next)/*, // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2660,13 +1691,17 @@ void bli_cgemm_sandybridge_asm_8x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( c ); } void bli_zgemm_sandybridge_asm_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -2681,34 +1716,36 @@ void bli_zgemm_sandybridge_asm_4x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( z, 4, 4, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(var(b_next), r15) // load address of b_next. //mov(var(a_next), r14) // load address of a_next. - + vmovapd(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovddup(mem(rbx, 0+0*32), ymm2) vmovddup(mem(rbx, 0+1*32), ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(dcomplex) lea(mem(, rdi, 2), rdi) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*cs_c prefetch(0, mem(r10, 3*8)) // prefetch c + 2*cs_c prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c - + vxorpd(ymm8, ymm8, ymm8) vxorpd(ymm9, ymm9, ymm9) vxorpd(ymm10, ymm10, ymm10) @@ -2717,18 +1754,18 @@ void bli_zgemm_sandybridge_asm_4x4 vxorpd(ymm13, ymm13, ymm13) vxorpd(ymm14, ymm14, ymm14) vxorpd(ymm15, ymm15, ymm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.ZCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.ZLOOPKITER) // MAIN LOOP - - + + // iteration 0 vmovapd(mem(rax, 1*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -2737,7 +1774,7 @@ void bli_zgemm_sandybridge_asm_4x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm15, ymm15) vaddpd(ymm7, ymm11, ymm11) - + prefetch(0, mem(rax, 16*32)) vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 8+0*32), ymm2) @@ -2745,45 +1782,45 @@ void bli_zgemm_sandybridge_asm_4x4 vmovddup(mem(rbx, 8+1*32), ymm3) vaddpd(ymm6, ymm14, ymm14) vaddpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vpermilpd(imm(0x5), ymm0, ymm0) vaddpd(ymm6, ymm13, ymm13) vaddpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulpd(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm12, ymm12) vaddpd(ymm7, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm1, ymm1) vmulpd(ymm0, ymm2, ymm6) vmulpd(ymm0, ymm3, ymm7) vaddsubpd(ymm6, ymm15, ymm15) vaddsubpd(ymm7, ymm11, ymm11) - + vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+2*32), ymm2) vmulpd(ymm1, ymm3, ymm7) vmovddup(mem(rbx, 0+3*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 2*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + // iteration 1 vmovapd(mem(rax, 3*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -2792,7 +1829,7 @@ void bli_zgemm_sandybridge_asm_4x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm15, ymm15) vaddpd(ymm7, ymm11, ymm11) - + prefetch(0, mem(rax, 18*32)) vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 8+2*32), ymm2) @@ -2800,45 +1837,45 @@ void bli_zgemm_sandybridge_asm_4x4 vmovddup(mem(rbx, 8+3*32), ymm3) vaddpd(ymm6, ymm14, ymm14) vaddpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vpermilpd(imm(0x5), ymm0, ymm0) vaddpd(ymm6, ymm13, ymm13) vaddpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulpd(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm12, ymm12) vaddpd(ymm7, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm1, ymm1) vmulpd(ymm0, ymm2, ymm6) vmulpd(ymm0, ymm3, ymm7) vaddsubpd(ymm6, ymm15, ymm15) vaddsubpd(ymm7, ymm11, ymm11) - + vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+4*32), ymm2) vmulpd(ymm1, ymm3, ymm7) vmovddup(mem(rbx, 0+5*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 4*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + // iteration 2 vmovapd(mem(rax, 5*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -2847,7 +1884,7 @@ void bli_zgemm_sandybridge_asm_4x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm15, ymm15) vaddpd(ymm7, ymm11, ymm11) - + prefetch(0, mem(rax, 20*32)) vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 8+4*32), ymm2) @@ -2855,45 +1892,45 @@ void bli_zgemm_sandybridge_asm_4x4 vmovddup(mem(rbx, 8+5*32), ymm3) vaddpd(ymm6, ymm14, ymm14) vaddpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vpermilpd(imm(0x5), ymm0, ymm0) vaddpd(ymm6, ymm13, ymm13) vaddpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulpd(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm12, ymm12) vaddpd(ymm7, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm1, ymm1) vmulpd(ymm0, ymm2, ymm6) vmulpd(ymm0, ymm3, ymm7) vaddsubpd(ymm6, ymm15, ymm15) vaddsubpd(ymm7, ymm11, ymm11) - + vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+6*32), ymm2) vmulpd(ymm1, ymm3, ymm7) vmovddup(mem(rbx, 0+7*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 6*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + // iteration 3 vmovapd(mem(rax, 7*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -2902,7 +1939,7 @@ void bli_zgemm_sandybridge_asm_4x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm15, ymm15) vaddpd(ymm7, ymm11, ymm11) - + prefetch(0, mem(rax, 22*32)) vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 8+6*32), ymm2) @@ -2910,67 +1947,67 @@ void bli_zgemm_sandybridge_asm_4x4 vmovddup(mem(rbx, 8+7*32), ymm3) vaddpd(ymm6, ymm14, ymm14) vaddpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vpermilpd(imm(0x5), ymm0, ymm0) vaddpd(ymm6, ymm13, ymm13) vaddpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulpd(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm12, ymm12) vaddpd(ymm7, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm1, ymm1) vmulpd(ymm0, ymm2, ymm6) vmulpd(ymm0, ymm3, ymm7) vaddsubpd(ymm6, ymm15, ymm15) vaddsubpd(ymm7, ymm11, ymm11) - + vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+8*32), ymm2) vmulpd(ymm1, ymm3, ymm7) vmovddup(mem(rbx, 0+9*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 8*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + add(imm(4*4*16), rbx) // b += 4*4 (unroll x nr) add(imm(4*4*16), rax) // a += 4*4 (unroll x mr) - - + + dec(rsi) // i -= 1; jne(.ZLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.ZCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.ZLOOPKLEFT) // EDGE LOOP - + // iteration 0 vmovapd(mem(rax, 1*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -2979,7 +2016,7 @@ void bli_zgemm_sandybridge_asm_4x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm15, ymm15) vaddpd(ymm7, ymm11, ymm11) - + prefetch(0, mem(rax, 16*32)) vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 8+0*32), ymm2) @@ -2987,166 +2024,166 @@ void bli_zgemm_sandybridge_asm_4x4 vmovddup(mem(rbx, 8+1*32), ymm3) vaddpd(ymm6, ymm14, ymm14) vaddpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vpermilpd(imm(0x5), ymm0, ymm0) vaddpd(ymm6, ymm13, ymm13) vaddpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulpd(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm12, ymm12) vaddpd(ymm7, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm1, ymm1) vmulpd(ymm0, ymm2, ymm6) vmulpd(ymm0, ymm3, ymm7) vaddsubpd(ymm6, ymm15, ymm15) vaddsubpd(ymm7, ymm11, ymm11) - + vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+2*32), ymm2) vmulpd(ymm1, ymm3, ymm7) vmovddup(mem(rbx, 0+3*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 2*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + add(imm(4*1*16), rax) // a += 4 (1 x mr) add(imm(4*1*16), rbx) // b += 4 (1 x nr) - - + + dec(rsi) // i -= 1; jne(.ZLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.ZPOSTACCUM) - + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 + // ab10 ab11 ab12 ab13 // ab21 ab20 ab23 ab22 // ab31 ) ab30 ) ab33 ) ab32 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab50 ab51 ab52 ab53 + // ab50 ab51 ab52 ab53 // ab61 ab60 ab63 ab62 // ab71 ) ab70 ) ab73 ) ab72 ) - - + + vmovapd(ymm15, ymm7) vperm2f128(imm(0x12), ymm15, ymm13, ymm15) vperm2f128(imm(0x30), ymm7, ymm13, ymm13) - + vmovapd(ymm11, ymm7) vperm2f128(imm(0x12), ymm11, ymm9, ymm11) vperm2f128(imm(0x30), ymm7, ymm9, ymm9) - + vmovapd(ymm14, ymm7) vperm2f128(imm(0x12), ymm14, ymm12, ymm14) vperm2f128(imm(0x30), ymm7, ymm12, ymm12) - + vmovapd(ymm10, ymm7) vperm2f128(imm(0x12), ymm10, ymm8, ymm10) vperm2f128(imm(0x30), ymm7, ymm8, ymm8) - - + + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 + // ab10 ab11 ab12 ab13 // ab20 ab21 ab22 ab23 // ab30 ) ab31 ) ab32 ) ab33 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab50 ab51 ab52 ab53 + // ab50 ab51 ab52 ab53 // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - - + + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vbroadcastsd(mem(rax), ymm7) // load alpha_r and duplicate vbroadcastsd(mem(rax, 8), ymm6) // load alpha_i and duplicate - + vpermilpd(imm(0x5), ymm15, ymm3) vmulpd(ymm7, ymm15, ymm15) vmulpd(ymm6, ymm3, ymm3) vaddsubpd(ymm3, ymm15, ymm15) - + vpermilpd(imm(0x5), ymm14, ymm2) vmulpd(ymm7, ymm14, ymm14) vmulpd(ymm6, ymm2, ymm2) vaddsubpd(ymm2, ymm14, ymm14) - + vpermilpd(imm(0x5), ymm13, ymm1) vmulpd(ymm7, ymm13, ymm13) vmulpd(ymm6, ymm1, ymm1) vaddsubpd(ymm1, ymm13, ymm13) - + vpermilpd(imm(0x5), ymm12, ymm0) vmulpd(ymm7, ymm12, ymm12) vmulpd(ymm6, ymm0, ymm0) vaddsubpd(ymm0, ymm12, ymm12) - + vpermilpd(imm(0x5), ymm11, ymm3) vmulpd(ymm7, ymm11, ymm11) vmulpd(ymm6, ymm3, ymm3) vaddsubpd(ymm3, ymm11, ymm11) - + vpermilpd(imm(0x5), ymm10, ymm2) vmulpd(ymm7, ymm10, ymm10) vmulpd(ymm6, ymm2, ymm2) vaddsubpd(ymm2, ymm10, ymm10) - + vpermilpd(imm(0x5), ymm9, ymm1) vmulpd(ymm7, ymm9, ymm9) vmulpd(ymm6, ymm1, ymm1) vaddsubpd(ymm1, ymm9, ymm9) - + vpermilpd(imm(0x5), ymm8, ymm0) vmulpd(ymm7, ymm8, ymm8) vmulpd(ymm6, ymm0, ymm0) vaddsubpd(ymm0, ymm8, ymm8) - - - - + + + + mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rbx), ymm7) // load beta_r and duplicate vbroadcastsd(mem(rbx, 8), ymm6) // load beta_i and duplicate - - - - - - - + + + + + + + mov(var(rs_c), rsi) // load rs_c lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(dcomplex) lea(mem(, rsi, 2), rsi) lea(mem(rcx, rsi, 2), rdx) // load address of c + 2*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm7) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -3154,355 +2191,142 @@ void bli_zgemm_sandybridge_asm_4x4 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 0, jump to beta == 0 case - - - cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16. - jz(.ZCOLSTORED) // jump to column storage case - - - - label(.ZGENSTORED) - // update c00:c30 - - vmovupd(mem(rcx), xmm0) // load (c00,c10) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c20,c30) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c00,c10) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c20,c30) - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vmovupd(mem(rdx), xmm0) // load (c40,c50) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c60,c70) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c40,c50) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c60,c70) - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vmovupd(mem(rcx), xmm0) // load (c01,c11) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c21,c31) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c01,c11) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c21,c31) - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vmovupd(mem(rdx), xmm0) // load (c41,c51) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c61,c71) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c41,c51) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c61,c71) - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vmovupd(mem(rcx), xmm0) // load (c02,c12) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c22,c32) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c02,c12) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c22,c32) - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vmovupd(mem(rdx), xmm0) // load (c42,c52) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c62,c72) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c42,c52) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c62,c72) - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vmovupd(mem(rcx), xmm0) // load (c03,c13) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c23,c33) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c03,c13) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c23,c33) - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vmovupd(mem(rdx), xmm0) // load (c43,c53) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c63,c73) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c43,c53) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c63,c73) - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORED) - // update c00:c30 - - vmovupd(mem(rcx), ymm0) // load c00:c30 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rcx)) // store c00:c30 - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vmovupd(mem(rdx), ymm0) // load c40:c70 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rdx)) // store c40:c70 - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vmovupd(mem(rcx), ymm0) // load c01:c31 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rcx)) // store c01:c31 - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vmovupd(mem(rdx), ymm0) // load c41:c71 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rdx)) // store c41:c71 - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vmovupd(mem(rcx), ymm0) // load c02:c32 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rcx)) // store c02:c32 - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vmovupd(mem(rdx), ymm0) // load c42:c72 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rdx)) // store c42:c72 - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vmovupd(mem(rcx), ymm0) // load c03:c33 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rcx)) // store c03:c33 - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vmovupd(mem(rdx), ymm0) // load c43:c73 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rdx)) // store c43:c73 - - - - jmp(.ZDONE) // jump to end. - - - + + // update c00:c30 + + vmovupd(mem(rcx), ymm0) // load c00:c30 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rcx)) // store c00:c30 + add(rdi, rcx) // c += cs_c; + + // update c40:c70 + + vmovupd(mem(rdx), ymm0) // load c40:c70 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rdx)) // store c40:c70 + add(rdi, rdx) // c += cs_c; + + // update c01:c31 + + vmovupd(mem(rcx), ymm0) // load c01:c31 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rcx)) // store c01:c31 + add(rdi, rcx) // c += cs_c; + + // update c41:c71 + + vmovupd(mem(rdx), ymm0) // load c41:c71 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rdx)) // store c41:c71 + add(rdi, rdx) // c += cs_c; + + // update c02:c32 + + vmovupd(mem(rcx), ymm0) // load c02:c32 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rcx)) // store c02:c32 + add(rdi, rcx) // c += cs_c; + + // update c42:c72 + + vmovupd(mem(rdx), ymm0) // load c42:c72 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rdx)) // store c42:c72 + add(rdi, rdx) // c += cs_c; + + // update c03:c33 + + vmovupd(mem(rcx), ymm0) // load c03:c33 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rcx)) // store c03:c33 + add(rdi, rcx) // c += cs_c; + + // update c43:c73 + + vmovupd(mem(rdx), ymm0) // load c43:c73 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rdx)) // store c43:c73 + + jmp(.ZDONE) // jump to end. + label(.ZBETAZERO) - - cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16. - jz(.ZCOLSTORBZ) // jump to column storage case - - - - label(.ZGENSTORBZ) - // update c00:c30 - - vextractf128(imm(1), ymm15, xmm2) - vmovupd(xmm15, mem(rcx)) // store (c00,c10) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c20,c30) - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vextractf128(imm(1), ymm14, xmm2) - vmovupd(xmm14, mem(rdx)) // store (c40,c50) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c60,c70) - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vextractf128(imm(1), ymm13, xmm2) - vmovupd(xmm13, mem(rcx)) // store (c01,c11) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c21,c31) - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vextractf128(imm(1), ymm12, xmm2) - vmovupd(xmm12, mem(rdx)) // store (c41,c51) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c61,c71) - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vextractf128(imm(1), ymm11, xmm2) - vmovupd(xmm11, mem(rcx)) // store (c02,c12) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c22,c32) - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vextractf128(imm(1), ymm10, xmm2) - vmovupd(xmm10, mem(rdx)) // store (c42,c52) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c62,c72) - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vextractf128(imm(1), ymm9, xmm2) - vmovupd(xmm9, mem(rcx)) // store (c03,c13) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c23,c33) - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vextractf128(imm(1), ymm8, xmm2) - vmovupd(xmm8, mem(rdx)) // store (c43,c53) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c63,c73) - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORBZ) - - - vmovupd(ymm15, mem(rcx)) // store c00:c30 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm14, mem(rdx)) // store c40:c70 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm13, mem(rcx)) // store c01:c31 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm12, mem(rdx)) // store c41:c71 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm11, mem(rcx)) // store c02:c32 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm10, mem(rdx)) // store c42:c72 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm9, mem(rcx)) // store c03:c33 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm8, mem(rdx)) // store c43:c73 - - - - - + + vmovupd(ymm15, mem(rcx)) // store c00:c30 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm14, mem(rdx)) // store c40:c70 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm13, mem(rcx)) // store c01:c31 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm12, mem(rdx)) // store c41:c71 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm11, mem(rcx)) // store c02:c32 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm10, mem(rdx)) // store c42:c72 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm9, mem(rcx)) // store c03:c33 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm8, mem(rdx)) // store c43:c73 + label(.ZDONE) - - vzeroupper() - + vzeroupper() - + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -3510,6 +2334,8 @@ void bli_zgemm_sandybridge_asm_4x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/sandybridge/3/bli_gemm_sandybridge_int_d8x4.c b/kernels/sandybridge/3/bli_gemm_sandybridge_int_d8x4.c index 6a1bb04f54..6bf991082b 100644 --- a/kernels/sandybridge/3/bli_gemm_sandybridge_int_d8x4.c +++ b/kernels/sandybridge/3/bli_gemm_sandybridge_int_d8x4.c @@ -32,14 +32,17 @@ */ -#include +#include +#include #include "blis.h" #if 0 void bli_sgemm_sandybridge_int_8x8 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -52,11 +55,11 @@ void bli_sgemm_sandybridge_int_8x8 } #endif - - void bli_dgemm_sandybridge_int_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -66,19 +69,22 @@ void bli_dgemm_sandybridge_int_8x4 cntx_t* restrict cntx ) { + //void* a_next = bli_auxinfo_next_a( data ); void* b_next = bli_auxinfo_next_b( data ); // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 2; - uint64_t k_left = k0 % 2; + uint64_t k_iter = k / 2; + uint64_t k_left = k % 2; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; uint64_t i; - double *c00, *c01, *c02, *c03; - double *c40, *c41, *c42, *c43; + GEMM_UKR_SETUP_CT( d, 8, 4, false ); + + double *c00, *c01, *c02, *c03; + double *c40, *c41, *c42, *c43; // Quad registers. __m256d va0_3, va4_7; @@ -87,23 +93,20 @@ void bli_dgemm_sandybridge_int_8x4 __m256d vb; __m256d vB0; - __m256d va0_3b_0, va4_7b_0; - __m256d va0_3b_1, va4_7b_1; - __m256d va0_3b_2, va4_7b_2; - __m256d va0_3b_3, va4_7b_3; - - __m256d va0_3b0, va4_7b0; - __m256d va0_3b1, va4_7b1; - __m256d va0_3b2, va4_7b2; - __m256d va0_3b3, va4_7b3; + __m256d va0_3b_0, va4_7b_0; + __m256d va0_3b_1, va4_7b_1; + __m256d va0_3b_2, va4_7b_2; + __m256d va0_3b_3, va4_7b_3; + __m256d va0_3b0, va4_7b0; + __m256d va0_3b1, va4_7b1; + __m256d va0_3b2, va4_7b2; + __m256d va0_3b3, va4_7b3; - __m256d valpha, vbeta, vtmp; + __m256d valpha, vbeta, vtmp; __m256d vc0_3_0, vc0_3_1, vc0_3_2, vc0_3_3; __m256d vc4_7_0, vc4_7_1, vc4_7_2, vc4_7_3; - __m128d aa, bb; - __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"(a) ); __asm__ volatile( "prefetcht2 0(%0) \n\t" : :"r"(b_next) ); __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"(c) ); @@ -129,19 +132,19 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b_3 = _mm256_setzero_pd(); // Load va0_3 - va0_3 = _mm256_load_pd( a ); + va0_3 = _mm256_load_pd( a ); // Load va4_7 - va4_7 = _mm256_load_pd( a + 4 ); + va4_7 = _mm256_load_pd( a + 4 ); - // Load vb (b0,b1,b2,b3) - vb0 = _mm256_load_pd( b ); + // Load vb (b0,b1,b2,b3) + vb0 = _mm256_load_pd( b ); for( i = 0; i < k_iter; ++i ) { __asm__ volatile( "prefetcht0 192(%0) \n\t" : :"r"(a) ); // Load va0_3 (Prefetch) - vA0_3 = _mm256_load_pd( a + 8 ); + vA0_3 = _mm256_load_pd( a + 8 ); // Iteration 0. vtmp = _mm256_mul_pd( va0_3, vb0 ); @@ -151,10 +154,10 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp ); // Load va4_7 (Prefetch) - vA4_7 = _mm256_load_pd( a + 12 ); + vA4_7 = _mm256_load_pd( a + 12 ); // Shuffle vb (b1,b0,b3,b2) - vb1 = _mm256_shuffle_pd( vb0, vb0, 0x5 ); + vb1 = _mm256_shuffle_pd( vb0, vb0, 0x5 ); vtmp = _mm256_mul_pd( va0_3, vb1 ); va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp ); @@ -163,10 +166,10 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp ); // Permute vb (b3,b2,b1,b0) - vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 ); + vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 ); // Load vb (b0,b1,b2,b3) (Prefetch) - vB0 = _mm256_load_pd( b + 4 ); + vB0 = _mm256_load_pd( b + 4 ); vtmp = _mm256_mul_pd( va0_3, vb2 ); va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp ); @@ -175,7 +178,7 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp ); // Shuffle vb (b3,b2,b1,b0) - vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 ); + vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 ); vtmp = _mm256_mul_pd( va0_3, vb3 ); va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp ); @@ -186,14 +189,14 @@ void bli_dgemm_sandybridge_int_8x4 // Iteration 1. __asm__ volatile( "prefetcht0 512(%0) \n\t" : :"r"(a) ); - + // Load va0_3 (Next iteration) - va0_3 = _mm256_load_pd( a + 16 ); + va0_3 = _mm256_load_pd( a + 16 ); vtmp = _mm256_mul_pd( vA0_3, vB0 ); va0_3b_0 = _mm256_add_pd( va0_3b_0, vtmp ); - vb1 = _mm256_shuffle_pd( vB0, vB0, 0x5 ); + vb1 = _mm256_shuffle_pd( vB0, vB0, 0x5 ); vtmp = _mm256_mul_pd( vA4_7, vB0 ); va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp ); @@ -202,9 +205,9 @@ void bli_dgemm_sandybridge_int_8x4 va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp ); // Load va4_7 (Next iteration) - va4_7 = _mm256_load_pd( a + 20 ); + va4_7 = _mm256_load_pd( a + 20 ); - vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 ); + vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 ); vtmp = _mm256_mul_pd( vA4_7, vb1 ); va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp ); @@ -212,13 +215,13 @@ void bli_dgemm_sandybridge_int_8x4 vtmp = _mm256_mul_pd( vA0_3, vb2 ); va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp ); - vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 ); + vb3 = _mm256_shuffle_pd( vb2, vb2, 0x5 ); vtmp = _mm256_mul_pd( vA4_7, vb2 ); va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp ); // Load vb0(Next iteration) - vb0 = _mm256_load_pd( b + 8 ); + vb0 = _mm256_load_pd( b + 8 ); vtmp = _mm256_mul_pd( vA0_3, vb3 ); va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp ); @@ -236,12 +239,12 @@ void bli_dgemm_sandybridge_int_8x4 // Iteration 0. // Load va0_3 - va0_3 = _mm256_load_pd( a ); + va0_3 = _mm256_load_pd( a ); // Load va4_7 - va4_7 = _mm256_load_pd( a + 4 ); + va4_7 = _mm256_load_pd( a + 4 ); - // Load vb (b0,b1,b2,b3) - vb = _mm256_load_pd( b ); + // Load vb (b0,b1,b2,b3) + vb = _mm256_load_pd( b ); vtmp = _mm256_mul_pd( va0_3, vb ); va0_3b_0 = _mm256_add_pd( va0_3b_0, vtmp ); @@ -250,7 +253,7 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b_0 = _mm256_add_pd( va4_7b_0, vtmp ); // Shuffle vb (b1,b0,b3,b2) - vb = _mm256_shuffle_pd( vb, vb, 0x5 ); + vb = _mm256_shuffle_pd( vb, vb, 0x5 ); vtmp = _mm256_mul_pd( va0_3, vb ); va0_3b_1 = _mm256_add_pd( va0_3b_1, vtmp ); @@ -259,7 +262,7 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b_1 = _mm256_add_pd( va4_7b_1, vtmp ); // Permute vb (b3,b2,b1,b0) - vb = _mm256_permute2f128_pd( vb, vb, 0x1 ); + vb = _mm256_permute2f128_pd( vb, vb, 0x1 ); vtmp = _mm256_mul_pd( va0_3, vb ); va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp ); @@ -268,7 +271,7 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp ); // Shuffle vb (b3,b2,b1,b0) - vb = _mm256_shuffle_pd( vb, vb, 0x5 ); + vb = _mm256_shuffle_pd( vb, vb, 0x5 ); vtmp = _mm256_mul_pd( va0_3, vb ); va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp ); @@ -309,131 +312,73 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b1 = _mm256_permute2f128_pd( vtmpa_4_7b_1, vtmpa_4_7b_3, 0x30 ); va4_7b2 = _mm256_permute2f128_pd( vtmpa_4_7b_3, vtmpa_4_7b_1, 0x30 ); - if( rs_c == 1 ) + __m128d vzero = _mm_setzero_pd( ); + + if( _mm_comieq_sd( _mm256_castpd256_pd128(vbeta), vzero ) ) { // Calculate address - c00 = ( c + 0*rs_c + 0*cs_c ); - // Load - //vc0_3_0 = _mm256_load_pd( c + 0*rs_c + 0*cs_c ); - vc0_3_0 = _mm256_load_pd( c00 ); + c00 = ( c + 0 + 0*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b0); - // Scale by beta - vc0_3_0 = _mm256_mul_pd( vbeta, vc0_3_0 ); - // Add gemm result - vc0_3_0 = _mm256_add_pd( vc0_3_0, vtmp ); // Store back to memory - _mm256_store_pd( c00, vc0_3_0 ); - + _mm256_store_pd( c00, vtmp ); + // Calculate address - c40 = ( c + 4*rs_c + 0*cs_c ); - // Load - //vc4_7_0 = _mm256_load_pd( c + 4*rs_c + 0*cs_c ); - vc4_7_0 = _mm256_load_pd( c40 ); + c40 = ( c + 4 + 0*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b0); - // Scale by beta - vc4_7_0 = _mm256_mul_pd( vbeta, vc4_7_0 ); - // Add gemm result - vc4_7_0 = _mm256_add_pd( vc4_7_0, vtmp ); // Store back to memory - _mm256_store_pd( c40, vc4_7_0 ); - + _mm256_store_pd( c40, vtmp ); + // Calculate address - c01 = ( c + 0*rs_c + 1*cs_c ); - // Load - //vc0_3_1 = _mm256_load_pd( c + 0*rs_c + 1*cs_c ); - vc0_3_1 = _mm256_load_pd( c01 ); + c01 = ( c + 0 + 1*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b1); - // Scale by beta - vc0_3_1 = _mm256_mul_pd( vbeta, vc0_3_1 ); - // Add gemm result - vc0_3_1 = _mm256_add_pd( vc0_3_1, vtmp ); // Store back to memory - _mm256_store_pd( c01, vc0_3_1 ); - + _mm256_store_pd( c01, vtmp ); + // Calculate address - c41 = ( c + 4*rs_c + 1*cs_c ); - // Load - //vc4_7_1 = _mm256_load_pd( c + 4*rs_c + 1*cs_c ); - vc4_7_1 = _mm256_load_pd( c41 ); + c41 = ( c + 4 + 1*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b1); - // Scale by beta - vc4_7_1 = _mm256_mul_pd( vbeta, vc4_7_1 ); - // Add gemm result - vc4_7_1 = _mm256_add_pd( vc4_7_1, vtmp ); // Store back to memory - _mm256_store_pd( c41, vc4_7_1 ); - + _mm256_store_pd( c41, vtmp ); + // Calculate address - c02 = ( c + 0*rs_c + 2*cs_c ); - // Load - //vc0_3_2 = _mm256_load_pd( c + 0*rs_c + 2*cs_c ); - vc0_3_2 = _mm256_load_pd( c02 ); + c02 = ( c + 0 + 2*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b2); - // Scale by beta - vc0_3_2 = _mm256_mul_pd( vbeta, vc0_3_2 ); - // Add gemm result - vc0_3_2 = _mm256_add_pd( vc0_3_2, vtmp ); // Store back to memory - _mm256_store_pd( c02, vc0_3_2 ); - + _mm256_store_pd( c02, vtmp ); + // Calculate address - c42 = ( c + 4*rs_c + 2*cs_c ); - // Load - //vc4_7_2 = _mm256_load_pd( c + 4*rs_c + 2*cs_c ); - vc4_7_2 = _mm256_load_pd( c42 ); + c42 = ( c + 4 + 2*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b2); - // Scale by beta - vc4_7_2 = _mm256_mul_pd( vbeta, vc4_7_2 ); - // Add gemm result - vc4_7_2 = _mm256_add_pd( vc4_7_2, vtmp ); // Store back to memory - _mm256_store_pd( c42, vc4_7_2 ); - + _mm256_store_pd( c42, vtmp ); + // Calculate address - c03 = ( c + 0*rs_c + 3*cs_c ); - // Load - //vc0_3_3 = _mm256_load_pd( c + 0*rs_c + 3*cs_c ); - vc0_3_3 = _mm256_load_pd( c03 ); + c03 = ( c + 0 + 3*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b3); - // Scale by beta - vc0_3_3 = _mm256_mul_pd( vbeta, vc0_3_3 ); - // Add gemm result - vc0_3_3 = _mm256_add_pd( vc0_3_3, vtmp ); // Store back to memory - _mm256_store_pd( c03, vc0_3_3 ); - + _mm256_store_pd( c03, vtmp ); + // Calculate address - c43 = ( c + 4*rs_c + 3*cs_c ); - // Load - //vc4_7_3 = _mm256_load_pd( c + 4*rs_c + 3*cs_c ); - vc4_7_3 = _mm256_load_pd( c43 ); + c43 = ( c + 4 + 3*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b3); - // Scale by beta - vc4_7_3 = _mm256_mul_pd( vbeta, vc4_7_3 ); - // Add gemm result - vc4_7_3 = _mm256_add_pd( vc4_7_3, vtmp ); // Store back to memory - _mm256_store_pd( c43, vc4_7_3 ); - + _mm256_store_pd( c43, vtmp ); } else { // Calculate address - c00 = ( c + 0*rs_c + 0*cs_c ); + c00 = ( c + 0 + 0*cs_c ); // Load - //vc0_3_0 = _mm256_load_pd( c + 0*rs_c + 0*cs_c ); - vc0_3_0 = _mm256_set_pd( *(c + 3*rs_c + 0*cs_c ), - *(c + 2*rs_c + 0*cs_c ), - *(c + 1*rs_c + 0*cs_c ), - *(c + 0*rs_c + 0*cs_c ) ); + //vc0_3_0 = _mm256_load_pd( c + 0 + 0*cs_c ); + vc0_3_0 = _mm256_load_pd( c00 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b0); // Scale by beta @@ -441,24 +386,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc0_3_0 = _mm256_add_pd( vc0_3_0, vtmp ); // Store back to memory - //_mm256_store_pd( c00, vc0_3_0 ); - - aa = _mm256_extractf128_pd( vc0_3_0, 0 ) ; - bb = _mm256_extractf128_pd( vc0_3_0, 1 ) ; - - _mm_storel_pd( c + 0*rs_c + 0*cs_c, aa ); - _mm_storeh_pd( c + 1*rs_c + 0*cs_c, aa ); - _mm_storel_pd( c + 2*rs_c + 0*cs_c, bb ); - _mm_storeh_pd( c + 3*rs_c + 0*cs_c, bb ); + _mm256_store_pd( c00, vc0_3_0 ); // Calculate address - c40 = ( c + 4*rs_c + 0*cs_c ); + c40 = ( c + 4 + 0*cs_c ); // Load - //vc4_7_0 = _mm256_load_pd( c + 4*rs_c + 0*cs_c ); - vc4_7_0 = _mm256_set_pd( *(c + 7*rs_c + 0*cs_c ), - *(c + 6*rs_c + 0*cs_c ), - *(c + 5*rs_c + 0*cs_c ), - *(c + 4*rs_c + 0*cs_c ) ); + //vc4_7_0 = _mm256_load_pd( c + 4 + 0*cs_c ); + vc4_7_0 = _mm256_load_pd( c40 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b0); // Scale by beta @@ -466,24 +400,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc4_7_0 = _mm256_add_pd( vc4_7_0, vtmp ); // Store back to memory - //_mm256_store_pd( c40, vc4_7_0 ); - - aa = _mm256_extractf128_pd( vc4_7_0, 0 ) ; - bb = _mm256_extractf128_pd( vc4_7_0, 1 ) ; - - _mm_storel_pd( c + 4*rs_c + 0*cs_c, aa ); - _mm_storeh_pd( c + 5*rs_c + 0*cs_c, aa ); - _mm_storel_pd( c + 6*rs_c + 0*cs_c, bb ); - _mm_storeh_pd( c + 7*rs_c + 0*cs_c, bb ); + _mm256_store_pd( c40, vc4_7_0 ); // Calculate address - c01 = ( c + 0*rs_c + 1*cs_c ); + c01 = ( c + 0 + 1*cs_c ); // Load - //vc0_3_1 = _mm256_load_pd( c + 0*rs_c + 1*cs_c ); - vc0_3_1 = _mm256_set_pd( *(c + 3*rs_c + 1*cs_c ), - *(c + 2*rs_c + 1*cs_c ), - *(c + 1*rs_c + 1*cs_c ), - *(c + 0*rs_c + 1*cs_c ) ); + //vc0_3_1 = _mm256_load_pd( c + 0 + 1*cs_c ); + vc0_3_1 = _mm256_load_pd( c01 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b1); // Scale by beta @@ -491,24 +414,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc0_3_1 = _mm256_add_pd( vc0_3_1, vtmp ); // Store back to memory - //_mm256_store_pd( c01, vc0_3_1 ); - - aa = _mm256_extractf128_pd( vc0_3_1, 0 ) ; - bb = _mm256_extractf128_pd( vc0_3_1, 1 ) ; - - _mm_storel_pd( c + 0*rs_c + 1*cs_c, aa ); - _mm_storeh_pd( c + 1*rs_c + 1*cs_c, aa ); - _mm_storel_pd( c + 2*rs_c + 1*cs_c, bb ); - _mm_storeh_pd( c + 3*rs_c + 1*cs_c, bb ); + _mm256_store_pd( c01, vc0_3_1 ); // Calculate address - c41 = ( c + 4*rs_c + 1*cs_c ); + c41 = ( c + 4 + 1*cs_c ); // Load - //vc4_7_1 = _mm256_load_pd( c + 4*rs_c + 1*cs_c ); - vc4_7_1 = _mm256_set_pd( *(c + 7*rs_c + 1*cs_c ), - *(c + 6*rs_c + 1*cs_c ), - *(c + 5*rs_c + 1*cs_c ), - *(c + 4*rs_c + 1*cs_c ) ); + //vc4_7_1 = _mm256_load_pd( c + 4 + 1*cs_c ); + vc4_7_1 = _mm256_load_pd( c41 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b1); // Scale by beta @@ -516,24 +428,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc4_7_1 = _mm256_add_pd( vc4_7_1, vtmp ); // Store back to memory - //_mm256_store_pd( c41, vc4_7_1 ); - - aa = _mm256_extractf128_pd( vc4_7_1, 0 ) ; - bb = _mm256_extractf128_pd( vc4_7_1, 1 ) ; - - _mm_storel_pd( c + 4*rs_c + 1*cs_c, aa ); - _mm_storeh_pd( c + 5*rs_c + 1*cs_c, aa ); - _mm_storel_pd( c + 6*rs_c + 1*cs_c, bb ); - _mm_storeh_pd( c + 7*rs_c + 1*cs_c, bb ); + _mm256_store_pd( c41, vc4_7_1 ); // Calculate address - c02 = ( c + 0*rs_c + 2*cs_c ); + c02 = ( c + 0 + 2*cs_c ); // Load - //vc0_3_2 = _mm256_load_pd( c + 0*rs_c + 2*cs_c ); - vc0_3_2 = _mm256_set_pd( *(c + 3*rs_c + 2*cs_c ), - *(c + 2*rs_c + 2*cs_c ), - *(c + 1*rs_c + 2*cs_c ), - *(c + 0*rs_c + 2*cs_c ) ); + //vc0_3_2 = _mm256_load_pd( c + 0 + 2*cs_c ); + vc0_3_2 = _mm256_load_pd( c02 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b2); // Scale by beta @@ -541,24 +442,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc0_3_2 = _mm256_add_pd( vc0_3_2, vtmp ); // Store back to memory - //_mm256_store_pd( c02, vc0_3_2 ); - - aa = _mm256_extractf128_pd( vc0_3_2, 0 ) ; - bb = _mm256_extractf128_pd( vc0_3_2, 1 ) ; - - _mm_storel_pd( c + 0*rs_c + 2*cs_c, aa ); - _mm_storeh_pd( c + 1*rs_c + 2*cs_c, aa ); - _mm_storel_pd( c + 2*rs_c + 2*cs_c, bb ); - _mm_storeh_pd( c + 3*rs_c + 2*cs_c, bb ); + _mm256_store_pd( c02, vc0_3_2 ); // Calculate address - c42 = ( c + 4*rs_c + 2*cs_c ); + c42 = ( c + 4 + 2*cs_c ); // Load - //vc4_7_2 = _mm256_load_pd( c + 4*rs_c + 2*cs_c ); - vc4_7_2 = _mm256_set_pd( *(c + 7*rs_c + 2*cs_c ), - *(c + 6*rs_c + 2*cs_c ), - *(c + 5*rs_c + 2*cs_c ), - *(c + 4*rs_c + 2*cs_c ) ); + //vc4_7_2 = _mm256_load_pd( c + 4 + 2*cs_c ); + vc4_7_2 = _mm256_load_pd( c42 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b2); // Scale by beta @@ -566,24 +456,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc4_7_2 = _mm256_add_pd( vc4_7_2, vtmp ); // Store back to memory - //_mm256_store_pd( c42, vc4_7_2 ); - - aa = _mm256_extractf128_pd( vc4_7_2, 0 ) ; - bb = _mm256_extractf128_pd( vc4_7_2, 1 ) ; - - _mm_storel_pd( c + 4*rs_c + 2*cs_c, aa ); - _mm_storeh_pd( c + 5*rs_c + 2*cs_c, aa ); - _mm_storel_pd( c + 6*rs_c + 2*cs_c, bb ); - _mm_storeh_pd( c + 7*rs_c + 2*cs_c, bb ); + _mm256_store_pd( c42, vc4_7_2 ); // Calculate address - c03 = ( c + 0*rs_c + 3*cs_c ); + c03 = ( c + 0 + 3*cs_c ); // Load - //vc0_3_3 = _mm256_load_pd( c + 0*rs_c + 3*cs_c ); - vc0_3_3 = _mm256_set_pd( *(c + 3*rs_c + 3*cs_c ), - *(c + 2*rs_c + 3*cs_c ), - *(c + 1*rs_c + 3*cs_c ), - *(c + 0*rs_c + 3*cs_c ) ); + //vc0_3_3 = _mm256_load_pd( c + 0 + 3*cs_c ); + vc0_3_3 = _mm256_load_pd( c03 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b3); // Scale by beta @@ -591,24 +470,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc0_3_3 = _mm256_add_pd( vc0_3_3, vtmp ); // Store back to memory - //_mm256_store_pd( c03, vc0_3_3 ); - - aa = _mm256_extractf128_pd( vc0_3_3, 0 ) ; - bb = _mm256_extractf128_pd( vc0_3_3, 1 ) ; - - _mm_storel_pd( c + 0*rs_c + 3*cs_c, aa ); - _mm_storeh_pd( c + 1*rs_c + 3*cs_c, aa ); - _mm_storel_pd( c + 2*rs_c + 3*cs_c, bb ); - _mm_storeh_pd( c + 3*rs_c + 3*cs_c, bb ); + _mm256_store_pd( c03, vc0_3_3 ); // Calculate address - c43 = ( c + 4*rs_c + 3*cs_c ); + c43 = ( c + 4 + 3*cs_c ); // Load - //vc4_7_3 = _mm256_load_pd( c + 4*rs_c + 3*cs_c ); - vc4_7_3 = _mm256_set_pd( *(c + 7*rs_c + 3*cs_c ), - *(c + 6*rs_c + 3*cs_c ), - *(c + 5*rs_c + 3*cs_c ), - *(c + 4*rs_c + 3*cs_c ) ); + //vc4_7_3 = _mm256_load_pd( c + 4 + 3*cs_c ); + vc4_7_3 = _mm256_load_pd( c43 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b3); // Scale by beta @@ -616,17 +484,10 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc4_7_3 = _mm256_add_pd( vc4_7_3, vtmp ); // Store back to memory - //_mm256_store_pd( c43, vc4_7_3 ); - - aa = _mm256_extractf128_pd( vc4_7_3, 0 ) ; - bb = _mm256_extractf128_pd( vc4_7_3, 1 ) ; - - _mm_storel_pd( c + 4*rs_c + 3*cs_c, aa ); - _mm_storeh_pd( c + 5*rs_c + 3*cs_c, aa ); - _mm_storel_pd( c + 6*rs_c + 3*cs_c, bb ); - _mm_storeh_pd( c + 7*rs_c + 3*cs_c, bb ); + _mm256_store_pd( c43, vc4_7_3 ); } + GEMM_UKR_FLUSH_CT( d ); } @@ -634,7 +495,9 @@ void bli_dgemm_sandybridge_int_8x4 #if 0 void bli_cgemm_sandybridge_int_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -652,7 +515,9 @@ void bli_cgemm_sandybridge_int_8x4 #if 0 void bli_zgemm_sandybridge_int_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, diff --git a/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c b/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c index 3a20cd8618..9943a170be 100644 --- a/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c +++ b/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c @@ -287,24 +287,28 @@ static int64_t offsets[16] __attribute__((aligned(64))) = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15}; -void bli_dgemm_skx_asm_16x12_l2( - dim_t k_, - double* restrict alpha, - double* restrict a, - double* restrict b, - double* restrict beta, - double* restrict c, inc_t rs_c_, inc_t cs_c_, - auxinfo_t* data, - cntx_t* restrict cntx - ) +void bli_dgemm_skx_asm_16x12_l2 + ( + dim_t m, + dim_t n, + dim_t k_, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* data, + cntx_t* restrict cntx + ) { (void)data; (void)cntx; - const int64_t* offsetPtr = &offsets[0]; - const int64_t k = k_; - const int64_t rs_c = rs_c_; - const int64_t cs_c = cs_c_; + int64_t k = k_; + int64_t rs_c = rs_c_; + int64_t cs_c = cs_c_; + + GEMM_UKR_SETUP_CT( d, 16, 12, false ); BEGIN_ASM() @@ -464,62 +468,26 @@ void bli_dgemm_skx_asm_16x12_l2( MOV(RAX, VAR(cs_c)) LEA(RAX, MEM(,RAX,8)) - MOV(RBX, VAR(rs_c)) - LEA(RBX, MEM(,RBX,8)) - - // Check if C is column stride. If not, jump to the slow scattered update - CMP(RBX, IMM(1)) - JNE(SCATTEREDUPDATE) - - VCOMISD(XMM(1), XMM(7)) - JE(COLSTORBZ) - UPDATE_C( 8, 9,10,11) - UPDATE_C(12,13,14,15) - UPDATE_C(16,17,18,19) - UPDATE_C(20,21,22,23) - UPDATE_C(24,25,26,27) - UPDATE_C(28,29,30,31) + VCOMISD(XMM(1), XMM(7)) + JE(COLSTORBZ) - JMP(END) - LABEL(COLSTORBZ) - - UPDATE_C_BZ( 8, 9,10,11) - UPDATE_C_BZ(12,13,14,15) - UPDATE_C_BZ(16,17,18,19) - UPDATE_C_BZ(20,21,22,23) - UPDATE_C_BZ(24,25,26,27) - UPDATE_C_BZ(28,29,30,31) + UPDATE_C( 8, 9,10,11) + UPDATE_C(12,13,14,15) + UPDATE_C(16,17,18,19) + UPDATE_C(20,21,22,23) + UPDATE_C(24,25,26,27) + UPDATE_C(28,29,30,31) JMP(END) - LABEL(SCATTEREDUPDATE) - - MOV(RDI, VAR(offsetPtr)) - VMOVDQA64(ZMM(2), MEM(RDI,0*64)) - VMOVDQA64(ZMM(3), MEM(RDI,1*64)) - VPBROADCASTQ(ZMM(6), RBX) - VPMULLQ(ZMM(2), ZMM(6), ZMM(2)) - VPMULLQ(ZMM(3), ZMM(6), ZMM(3)) - - VCOMISD(XMM(1), XMM(7)) - JE(SCATTERBZ) - - UPDATE_C_ROW_SCATTERED( 8, 9,10,11) - UPDATE_C_ROW_SCATTERED(12,13,14,15) - UPDATE_C_ROW_SCATTERED(16,17,18,19) - UPDATE_C_ROW_SCATTERED(20,21,22,23) - UPDATE_C_ROW_SCATTERED(24,25,26,27) - UPDATE_C_ROW_SCATTERED(28,29,30,31) - - JMP(END) - LABEL(SCATTERBZ) - - UPDATE_C_BZ_ROW_SCATTERED( 8, 9,10,11) - UPDATE_C_BZ_ROW_SCATTERED(12,13,14,15) - UPDATE_C_BZ_ROW_SCATTERED(16,17,18,19) - UPDATE_C_BZ_ROW_SCATTERED(20,21,22,23) - UPDATE_C_BZ_ROW_SCATTERED(24,25,26,27) - UPDATE_C_BZ_ROW_SCATTERED(28,29,30,31) + LABEL(COLSTORBZ) + + UPDATE_C_BZ( 8, 9,10,11) + UPDATE_C_BZ(12,13,14,15) + UPDATE_C_BZ(16,17,18,19) + UPDATE_C_BZ(20,21,22,23) + UPDATE_C_BZ(24,25,26,27) + UPDATE_C_BZ(28,29,30,31) LABEL(END) @@ -535,8 +503,7 @@ void bli_dgemm_skx_asm_16x12_l2( [beta] "m" (beta), [c] "m" (c), [rs_c] "m" (rs_c), - [cs_c] "m" (cs_c), - [offsetPtr] "m" (offsetPtr) + [cs_c] "m" (cs_c) : // register clobber list "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", @@ -545,4 +512,6 @@ void bli_dgemm_skx_asm_16x12_l2( "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/skx/3/bli_dgemm_skx_asm_16x14.c b/kernels/skx/3/bli_dgemm_skx_asm_16x14.c index 136f315323..e3bc52041d 100644 --- a/kernels/skx/3/bli_dgemm_skx_asm_16x14.c +++ b/kernels/skx/3/bli_dgemm_skx_asm_16x14.c @@ -153,24 +153,28 @@ static int64_t offsets[16] __attribute__((aligned(64))) = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15}; -void bli_dgemm_skx_asm_16x14( - dim_t k_, - double* restrict alpha, - double* restrict a, - double* restrict b, - double* restrict beta, - double* restrict c, inc_t rs_c_, inc_t cs_c_, - auxinfo_t* data, - cntx_t* restrict cntx - ) +void bli_dgemm_skx_asm_16x14 + ( + dim_t m, + dim_t n, + dim_t k_, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* data, + cntx_t* restrict cntx + ) { (void)data; (void)cntx; - const int64_t* offsetPtr = &offsets[0]; - const int64_t k = k_; - const int64_t rs_c = rs_c_*8; - const int64_t cs_c = cs_c_*8; + int64_t k = k_; + int64_t rs_c = rs_c_; + int64_t cs_c = cs_c_; + + GEMM_UKR_SETUP_CT( d, 16, 14, false ); BEGIN_ASM() @@ -220,6 +224,8 @@ void bli_dgemm_skx_asm_16x14( MOV(R12, VAR(rs_c)) MOV(R10, VAR(cs_c)) + LEA(R12, MEM(,R12,8)) + LEA(R10, MEM(,R10,8)) MOV(RDI, RSI) AND(RSI, IMM(3)) @@ -320,119 +326,41 @@ void bli_dgemm_skx_asm_16x14( MOV(RAX, R12) MOV(RBX, R10) - // Check if C is column stride. - CMP(RAX, IMM(8)) - JNE(SCATTEREDUPDATE) - - VCOMISD(XMM(1), XMM(2)) - JE(COLSTORBZ) - - UPDATE_C( 4, 5) - UPDATE_C( 6, 7) - UPDATE_C( 8, 9) - UPDATE_C(10,11) - UPDATE_C(12,13) - UPDATE_C(14,15) - UPDATE_C(16,17) - UPDATE_C(18,19) - UPDATE_C(20,21) - UPDATE_C(22,23) - UPDATE_C(24,25) - UPDATE_C(26,27) - UPDATE_C(28,29) - UPDATE_C(30,31) - - JMP(END) - LABEL(COLSTORBZ) - - UPDATE_C_BZ( 4, 5) - UPDATE_C_BZ( 6, 7) - UPDATE_C_BZ( 8, 9) - UPDATE_C_BZ(10,11) - UPDATE_C_BZ(12,13) - UPDATE_C_BZ(14,15) - UPDATE_C_BZ(16,17) - UPDATE_C_BZ(18,19) - UPDATE_C_BZ(20,21) - UPDATE_C_BZ(22,23) - UPDATE_C_BZ(24,25) - UPDATE_C_BZ(26,27) - UPDATE_C_BZ(28,29) - UPDATE_C_BZ(30,31) + VCOMISD(XMM(1), XMM(2)) + JE(COLSTORBZ) + + UPDATE_C( 4, 5) + UPDATE_C( 6, 7) + UPDATE_C( 8, 9) + UPDATE_C(10,11) + UPDATE_C(12,13) + UPDATE_C(14,15) + UPDATE_C(16,17) + UPDATE_C(18,19) + UPDATE_C(20,21) + UPDATE_C(22,23) + UPDATE_C(24,25) + UPDATE_C(26,27) + UPDATE_C(28,29) + UPDATE_C(30,31) JMP(END) - LABEL(SCATTEREDUPDATE) - - VMULPD(ZMM( 4), ZMM( 4), ZMM(0)) - VMULPD(ZMM( 5), ZMM( 5), ZMM(0)) - VMULPD(ZMM( 6), ZMM( 6), ZMM(0)) - VMULPD(ZMM( 7), ZMM( 7), ZMM(0)) - VMULPD(ZMM( 8), ZMM( 8), ZMM(0)) - VMULPD(ZMM( 9), ZMM( 9), ZMM(0)) - VMULPD(ZMM(10), ZMM(10), ZMM(0)) - VMULPD(ZMM(11), ZMM(11), ZMM(0)) - VMULPD(ZMM(12), ZMM(12), ZMM(0)) - VMULPD(ZMM(13), ZMM(13), ZMM(0)) - VMULPD(ZMM(14), ZMM(14), ZMM(0)) - VMULPD(ZMM(15), ZMM(15), ZMM(0)) - VMULPD(ZMM(16), ZMM(16), ZMM(0)) - VMULPD(ZMM(17), ZMM(17), ZMM(0)) - VMULPD(ZMM(18), ZMM(18), ZMM(0)) - VMULPD(ZMM(19), ZMM(19), ZMM(0)) - VMULPD(ZMM(20), ZMM(20), ZMM(0)) - VMULPD(ZMM(21), ZMM(21), ZMM(0)) - VMULPD(ZMM(22), ZMM(22), ZMM(0)) - VMULPD(ZMM(23), ZMM(23), ZMM(0)) - VMULPD(ZMM(24), ZMM(24), ZMM(0)) - VMULPD(ZMM(25), ZMM(25), ZMM(0)) - VMULPD(ZMM(26), ZMM(26), ZMM(0)) - VMULPD(ZMM(27), ZMM(27), ZMM(0)) - VMULPD(ZMM(28), ZMM(28), ZMM(0)) - VMULPD(ZMM(29), ZMM(29), ZMM(0)) - VMULPD(ZMM(30), ZMM(30), ZMM(0)) - VMULPD(ZMM(31), ZMM(31), ZMM(0)) - - VCOMISD(XMM(1), XMM(2)) - - MOV(RDI, VAR(offsetPtr)) - VPBROADCASTQ(ZMM(0), RAX) - VPMULLQ(ZMM(2), ZMM(0), MEM(RDI)) - VPMULLQ(ZMM(3), ZMM(0), MEM(RDI,64)) - - JE(SCATTERBZ) - - UPDATE_C_COL_SCATTERED( 4, 5) - UPDATE_C_COL_SCATTERED( 6, 7) - UPDATE_C_COL_SCATTERED( 8, 9) - UPDATE_C_COL_SCATTERED(10,11) - UPDATE_C_COL_SCATTERED(12,13) - UPDATE_C_COL_SCATTERED(14,15) - UPDATE_C_COL_SCATTERED(16,17) - UPDATE_C_COL_SCATTERED(18,19) - UPDATE_C_COL_SCATTERED(20,21) - UPDATE_C_COL_SCATTERED(22,23) - UPDATE_C_COL_SCATTERED(24,25) - UPDATE_C_COL_SCATTERED(26,27) - UPDATE_C_COL_SCATTERED(28,29) - UPDATE_C_COL_SCATTERED(30,31) - - JMP(END) - LABEL(SCATTERBZ) - - UPDATE_C_BZ_COL_SCATTERED( 4, 5) - UPDATE_C_BZ_COL_SCATTERED( 6, 7) - UPDATE_C_BZ_COL_SCATTERED( 8, 9) - UPDATE_C_BZ_COL_SCATTERED(10,11) - UPDATE_C_BZ_COL_SCATTERED(12,13) - UPDATE_C_BZ_COL_SCATTERED(14,15) - UPDATE_C_BZ_COL_SCATTERED(16,17) - UPDATE_C_BZ_COL_SCATTERED(18,19) - UPDATE_C_BZ_COL_SCATTERED(20,21) - UPDATE_C_BZ_COL_SCATTERED(22,23) - UPDATE_C_BZ_COL_SCATTERED(24,25) - UPDATE_C_BZ_COL_SCATTERED(26,27) - UPDATE_C_BZ_COL_SCATTERED(28,29) - UPDATE_C_BZ_COL_SCATTERED(30,31) + LABEL(COLSTORBZ) + + UPDATE_C_BZ( 4, 5) + UPDATE_C_BZ( 6, 7) + UPDATE_C_BZ( 8, 9) + UPDATE_C_BZ(10,11) + UPDATE_C_BZ(12,13) + UPDATE_C_BZ(14,15) + UPDATE_C_BZ(16,17) + UPDATE_C_BZ(18,19) + UPDATE_C_BZ(20,21) + UPDATE_C_BZ(22,23) + UPDATE_C_BZ(24,25) + UPDATE_C_BZ(26,27) + UPDATE_C_BZ(28,29) + UPDATE_C_BZ(30,31) LABEL(END) @@ -449,8 +377,7 @@ void bli_dgemm_skx_asm_16x14( [beta] "m" (beta), [c] "m" (c), [rs_c] "m" (rs_c), - [cs_c] "m" (cs_c), - [offsetPtr] "m" (offsetPtr) + [cs_c] "m" (cs_c) : // register clobber list "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", @@ -459,4 +386,6 @@ void bli_dgemm_skx_asm_16x14( "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c b/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c index 40af496140..8808449b65 100644 --- a/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c +++ b/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c @@ -317,24 +317,28 @@ ahead*/ static int64_t offsets[16] __attribute__((aligned(64))) = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15}; -void bli_sgemm_skx_asm_32x12_l2( - dim_t k_, - float* restrict alpha, - float* restrict a, - float* restrict b, - float* restrict beta, - float* restrict c, inc_t rs_c_, inc_t cs_c_, - auxinfo_t* data, - cntx_t* restrict cntx - ) +void bli_sgemm_skx_asm_32x12_l2 + ( + dim_t m, + dim_t n, + dim_t k_, + float* restrict alpha, + float* restrict a, + float* restrict b, + float* restrict beta, + float* restrict c, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* data, + cntx_t* restrict cntx + ) { (void)data; (void)cntx; - const int64_t* offsetPtr = &offsets[0]; - const int64_t k = k_; - const int64_t rs_c = rs_c_; - const int64_t cs_c = cs_c_; + int64_t k = k_; + int64_t rs_c = rs_c_; + int64_t cs_c = cs_c_; + + GEMM_UKR_SETUP_CT( s, 32, 12, false ); BEGIN_ASM() @@ -381,7 +385,7 @@ void bli_sgemm_skx_asm_32x12_l2( #endif #ifdef PREFETCH_B_BEFORE - /* Prefetching 3 cachlines of B (4 iterations worth of data + /* Prefetching 3 cachlines of B (4 iterations worth of data (12 (NR) x 4 (sizeof(float)) x 4 iter /64 = 3 cachelines) */ PREFETCH(0, MEM(RBX,0*64)) PREFETCH(0, MEM(RBX,1*64)) @@ -485,66 +489,26 @@ void bli_sgemm_skx_asm_32x12_l2( MOV(RAX, VAR(cs_c)) LEA(RAX, MEM(,RAX,4)) - MOV(RBX, VAR(rs_c)) - LEA(RBX, MEM(,RBX,4)) - - - // Check if C is column major (rs_c = 1). If not, jump to the slow scattered update - CMP(RBX, IMM(4)) - JNE(SCATTEREDUPDATE) - - VCOMISS(XMM(1), XMM(7)) - JE(COLSTORBZ) - UPDATE_C( 8, 9,10,11) - UPDATE_C(12,13,14,15) - UPDATE_C(16,17,18,19) - UPDATE_C(20,21,22,23) - UPDATE_C(24,25,26,27) - UPDATE_C(28,29,30,31) + VCOMISS(XMM(1), XMM(7)) + JE(COLSTORBZ) - JMP(END) - LABEL(COLSTORBZ) - - UPDATE_C_BZ( 8, 9,10,11) - UPDATE_C_BZ(12,13,14,15) - UPDATE_C_BZ(16,17,18,19) - UPDATE_C_BZ(20,21,22,23) - UPDATE_C_BZ(24,25,26,27) - UPDATE_C_BZ(28,29,30,31) + UPDATE_C( 8, 9,10,11) + UPDATE_C(12,13,14,15) + UPDATE_C(16,17,18,19) + UPDATE_C(20,21,22,23) + UPDATE_C(24,25,26,27) + UPDATE_C(28,29,30,31) JMP(END) - LABEL(SCATTEREDUPDATE) - - LEA(RDX, MEM(RCX,RBX,8)) - LEA(RDX, MEM(RDX,RBX,8)) - - MOV(RDI, VAR(offsetPtr)) - VMOVDQA64(ZMM(2), MEM(RDI,0*64)) - VMOVDQA64(ZMM(3), MEM(RDI,1*64)) - VPBROADCASTQ(ZMM(6), RBX) - VPMULLQ(ZMM(2), ZMM(6), ZMM(2)) - VPMULLQ(ZMM(3), ZMM(6), ZMM(3)) - - VCOMISS(XMM(1), XMM(7)) - JE(SCATTERBZ) - - UPDATE_C_ROW_SCATTERED( 8, 9,10,11) - UPDATE_C_ROW_SCATTERED(12,13,14,15) - UPDATE_C_ROW_SCATTERED(16,17,18,19) - UPDATE_C_ROW_SCATTERED(20,21,22,23) - UPDATE_C_ROW_SCATTERED(24,25,26,27) - UPDATE_C_ROW_SCATTERED(28,29,30,31) - - JMP(END) - LABEL(SCATTERBZ) - - UPDATE_C_BZ_ROW_SCATTERED( 8, 9,10,11) - UPDATE_C_BZ_ROW_SCATTERED(12,13,14,15) - UPDATE_C_BZ_ROW_SCATTERED(16,17,18,19) - UPDATE_C_BZ_ROW_SCATTERED(20,21,22,23) - UPDATE_C_BZ_ROW_SCATTERED(24,25,26,27) - UPDATE_C_BZ_ROW_SCATTERED(28,29,30,31) + LABEL(COLSTORBZ) + + UPDATE_C_BZ( 8, 9,10,11) + UPDATE_C_BZ(12,13,14,15) + UPDATE_C_BZ(16,17,18,19) + UPDATE_C_BZ(20,21,22,23) + UPDATE_C_BZ(24,25,26,27) + UPDATE_C_BZ(28,29,30,31) LABEL(END) @@ -560,8 +524,7 @@ void bli_sgemm_skx_asm_32x12_l2( [beta] "m" (beta), [c] "m" (c), [rs_c] "m" (rs_c), - [cs_c] "m" (cs_c), - [offsetPtr] "m" (offsetPtr) + [cs_c] "m" (cs_c) : // register clobber list "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", @@ -570,4 +533,6 @@ void bli_sgemm_skx_asm_32x12_l2( "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } diff --git a/ref_kernels/3/bb/bli_gemmbb_ref.c b/ref_kernels/3/bb/bli_gemmbb_ref.c index b45718d454..4c75c064ce 100644 --- a/ref_kernels/3/bb/bli_gemmbb_ref.c +++ b/ref_kernels/3/bb/bli_gemmbb_ref.c @@ -42,6 +42,8 @@ \ void PASTEMAC3(ch,opname,arch,suf) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -59,9 +61,6 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \ const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ \ const inc_t cs_a = packmr; \ \ diff --git a/ref_kernels/3/bb/bli_gemmtrsmbb_ref.c b/ref_kernels/3/bb/bli_gemmtrsmbb_ref.c index 681b740b52..dd4e1f153d 100644 --- a/ref_kernels/3/bb/bli_gemmtrsmbb_ref.c +++ b/ref_kernels/3/bb/bli_gemmtrsmbb_ref.c @@ -87,6 +87,8 @@ PASTEMAC(d,fprintm)( stdout, "gemmtrsm_ukr: b11", mr, 2*nr, \ /* upper: b11 = alpha * b11 - a12 * b21; */ \ gemm_ukr \ ( \ + mr, \ + nr, \ k, \ minus_one, \ a1x, \ diff --git a/ref_kernels/3/bli_gemm_ref.c b/ref_kernels/3/bli_gemm_ref.c index 931fe994b3..51ff9df4bd 100644 --- a/ref_kernels/3/bli_gemm_ref.c +++ b/ref_kernels/3/bli_gemm_ref.c @@ -44,6 +44,8 @@ \ void PASTEMAC3(ch,opname,arch,suf) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -107,8 +109,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ if ( PASTEMAC(ch,eq0)( *beta ) ) \ { \ - for ( dim_t i = 0; i < mr; ++i ) \ - for ( dim_t j = 0; j < nr; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ PASTEMAC(ch,copys) \ ( \ ab[ i*rs_ab + j*cs_ab ], \ @@ -117,8 +119,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ } \ else \ { \ - for ( dim_t i = 0; i < mr; ++i ) \ - for ( dim_t j = 0; j < nr; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ PASTEMAC(ch,xpbys) \ ( \ ab[ i*rs_ab + j*cs_ab ], \ @@ -133,8 +135,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ if ( PASTEMAC(ch,eq0)( *beta ) ) \ { \ - for ( dim_t j = 0; j < nr; ++j ) \ - for ( dim_t i = 0; i < mr; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ PASTEMAC(ch,copys) \ ( \ ab[ i*rs_ab + j*cs_ab ], \ @@ -143,8 +145,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ } \ else \ { \ - for ( dim_t j = 0; j < nr; ++j ) \ - for ( dim_t i = 0; i < mr; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ PASTEMAC(ch,xpbys) \ ( \ ab[ i*rs_ab + j*cs_ab ], \ @@ -171,6 +173,8 @@ GENTFUNC( dcomplex, z, gemm, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 4 ) \ void PASTEMAC3(ch,opname,arch,suf) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -188,9 +192,6 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \ const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ \ const inc_t cs_a = packmr; \ \ diff --git a/ref_kernels/3/bli_gemmtrsm_ref.c b/ref_kernels/3/bli_gemmtrsm_ref.c index 2b756963e4..2b260c8810 100644 --- a/ref_kernels/3/bli_gemmtrsm_ref.c +++ b/ref_kernels/3/bli_gemmtrsm_ref.c @@ -52,6 +52,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ { \ const num_t dt = PASTEMAC(ch,type); \ \ + const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ \ const inc_t rs_b = packnr; \ @@ -68,6 +70,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ /* upper: b11 = alpha * b11 - a12 * b21; */ \ gemm_ukr \ ( \ + mr, \ + nr, \ k, \ minus_one, \ a1x, \ diff --git a/ref_kernels/ind/bli_gemm1m_ref.c b/ref_kernels/ind/bli_gemm1m_ref.c index 6d2464de94..fbd15d695b 100644 --- a/ref_kernels/ind/bli_gemm1m_ref.c +++ b/ref_kernels/ind/bli_gemm1m_ref.c @@ -39,6 +39,8 @@ \ void PASTEMAC3(ch,opname,arch,suf) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -59,6 +61,9 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ +\ + const dim_t mr_r = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ + const dim_t nr_r = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ \ const dim_t k2 = 2 * k; \ \ @@ -118,6 +123,11 @@ void PASTEMAC3(ch,opname,arch,suf) \ else if ( bli_is_gen_stored( rs_c, cs_c ) ) using_ct = TRUE; \ else using_ct = FALSE; \ \ +\ + /* If we are not computing a full micro-tile, then we must write to + ct and then accumulate to c afterwards. */ \ + if ( mr != m || nr != n ) using_ct = TRUE; \ +\ \ if ( using_ct ) \ { \ @@ -149,6 +159,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ /* c = beta * c + alpha_r * a * b; */ \ rgemm_ukr \ ( \ + mr_r, \ + nr_r, \ k2, \ alpha_r, \ a_r, \ @@ -164,8 +176,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ /* Accumulate the final result in ct back to c. */ \ if ( PASTEMAC(ch,eq1)( *beta ) ) \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( j = 0; j < n; ++j ) \ + for ( i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,adds)( *(ct + i*rs_ct + j*cs_ct), \ *(c + i*rs_c + j*cs_c ) ); \ @@ -173,8 +185,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ } \ else if ( PASTEMAC(ch,eq0)( *beta ) ) \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( j = 0; j < n; ++j ) \ + for ( i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,copys)( *(ct + i*rs_ct + j*cs_ct), \ *(c + i*rs_c + j*cs_c ) ); \ @@ -182,8 +194,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ } \ else \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( j = 0; j < n; ++j ) \ + for ( i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,xpbys)( *(ct + i*rs_ct + j*cs_ct), \ *beta, \ @@ -215,6 +227,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ /* c = beta * c + alpha_r * a * b; */ \ rgemm_ukr \ ( \ + mr_r, \ + nr_r, \ k2, \ alpha_r, \ a_r, \ diff --git a/ref_kernels/ind/bli_gemmtrsm1m_ref.c b/ref_kernels/ind/bli_gemmtrsm1m_ref.c index 5cfaee9ec6..96f5a16fed 100644 --- a/ref_kernels/ind/bli_gemmtrsm1m_ref.c +++ b/ref_kernels/ind/bli_gemmtrsm1m_ref.c @@ -153,6 +153,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ upper: bt = -1.0 * a12 * b21; */ \ rgemm_ukr \ ( \ + mr_r, \ + nr_r, \ k2, \ minus_one_r, \ a1x_r, \ diff --git a/test/syrk_diagonal/complex_math.hpp b/test/syrk_diagonal/complex_math.hpp new file mode 100644 index 0000000000..9c68e730aa --- /dev/null +++ b/test/syrk_diagonal/complex_math.hpp @@ -0,0 +1,267 @@ +#include +#include +#include + +#include "blis.h" + +template +struct is_complex : std::false_type {}; + +template <> +struct is_complex : std::true_type {}; + +template <> +struct is_complex : std::true_type {}; + +template +struct is_real : std::integral_constant::value> {}; + +template struct make_complex; + +template <> struct make_complex { using type = scomplex; }; +template <> struct make_complex { using type = dcomplex; }; +template <> struct make_complex { using type = scomplex; }; +template <> struct make_complex { using type = dcomplex; }; + +template +using make_complex_t = typename make_complex::type; + +template struct make_real; + +template <> struct make_real { using type = float; }; +template <> struct make_real { using type = double; }; +template <> struct make_real { using type = float; }; +template <> struct make_real { using type = double; }; + +template +using make_real_t = typename make_real::type; + +template +struct make_complex_if : std::conditional,make_real_t> {}; + +template +using make_complex_if_t = typename make_complex_if::type; + +template +struct real_imag_part +{ + real_imag_part& operator=(T) { return *this; } + + operator T() const { return T(); } +}; + +template +std::enable_if_t::type>::value,T&> real(T& x) { return x; } + +template +std::enable_if_t::value,real_imag_part> imag(T x) { return {}; } + +inline float& real(scomplex& x) { return x.real; } + +inline float& imag(scomplex& x) { return x.imag; } + +inline double& real(dcomplex& x) { return x.real; } + +inline double& imag(dcomplex& x) { return x.imag; } + +inline const float& real(const scomplex& x) { return x.real; } + +inline const float& imag(const scomplex& x) { return x.imag; } + +inline const double& real(const dcomplex& x) { return x.real; } + +inline const double& imag(const dcomplex& x) { return x.imag; } + +template +std::enable_if_t::value,T> conj(T x) { return x; } + +template +std::enable_if_t::value,T> conj(const T& x) { return {x.real, -x.imag}; } + +template +struct convert_impl; + +template +struct convert_impl::value && is_real::value>> +{ + void operator()(T x, U& y) const { y = x; } +}; + +template +struct convert_impl::value && is_complex::value>> +{ + void operator()(T x, U& y) const { y.real = x; y.imag = 0; } +}; + +template +struct convert_impl::value && is_real::value>> +{ + void operator()(T x, U& y) const { y = x.real; } +}; + +template +struct convert_impl::value && is_complex::value>> +{ + void operator()(T x, U& y) const { y.real = x.real; y.imag = x.imag; } +}; + +template +U convert(T x) +{ + U y; + convert_impl{}(x,y); + return y; +} + +template +auto convert_prec(T x) -> make_complex_if_t::value> +{ + return convert::value>>(x); +} + +#define COMPLEX_MATH_OPS(rtype, ctype) \ +\ +inline bool operator==(rtype x, ctype y) \ +{ \ + return x == y.real && y.imag == 0; \ +} \ +\ +inline bool operator==(ctype x, rtype y) \ +{ \ + return y == x.real && x.imag == 0; \ +} \ +\ +inline bool operator==(ctype x, ctype y) \ +{ \ + return x.real == y.real && \ + x.imag == y.imag; \ + } \ + \ +inline ctype operator-(ctype x) \ +{ \ + return {-x.real, -x.imag}; \ +} \ +\ +inline ctype operator+(rtype x, ctype y) \ +{ \ + return {x+y.real, y.imag}; \ +} \ +\ +inline ctype operator+(ctype x, rtype y) \ +{ \ + return {y+x.real, x.imag}; \ +} \ +\ +inline ctype operator+(ctype x, ctype y) \ +{ \ + return {x.real+y.real, x.imag+y.imag}; \ +} \ +\ +inline ctype operator-(rtype x, ctype y) \ +{ \ + return {x-y.real, -y.imag}; \ +} \ +\ +inline ctype operator-(ctype x, rtype y) \ +{ \ + return {x.real-y, x.imag}; \ +} \ +\ +inline ctype operator-(ctype x, ctype y) \ +{ \ + return {x.real-y.real, x.imag-y.imag}; \ +} \ +\ +inline ctype operator*(rtype x, ctype y) \ +{ \ + return {x*y.real, x*y.imag}; \ +} \ +\ +inline ctype operator*(ctype x, rtype y) \ +{ \ + return {y*x.real, y*x.imag}; \ +} \ +\ +inline ctype operator*(ctype x, ctype y) \ +{ \ + return {x.real*y.real - x.imag*y.imag, \ + x.real*y.imag + x.imag*y.real}; \ +} \ +\ +inline ctype operator/(rtype x, ctype y) \ +{ \ + auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \ + auto n = std::ilogb(scale); \ + auto yrs = std::scalbn(y.real, -n); \ + auto yis = std::scalbn(y.imag, -n); \ + auto denom = y.real*yrs + y.imag*yis; \ + return {x*yrs/denom, -x*yis/denom}; \ +} \ +\ +inline ctype operator/(ctype x, rtype y) \ +{ \ + return {x.real/y, x.imag/y}; \ +} \ +\ +inline ctype operator/(ctype x, ctype y) \ +{ \ + auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \ + auto n = std::ilogb(scale); \ + auto yrs = std::scalbn(y.real, -n); \ + auto yis = std::scalbn(y.imag, -n); \ + auto denom = y.real*yrs + y.imag*yis; \ + return {(x.real*yrs + x.imag*yis)/denom, \ + (x.imag*yrs - x.real*yis)/denom}; \ +} \ +\ +inline ctype& operator+=(ctype& x, rtype y) \ +{ \ + x.real += y; \ + return x; \ +} \ +\ +inline ctype& operator+=(ctype& x, ctype y) \ +{ \ + x.real += y.real; x.imag += y.imag; \ + return x; \ +} \ +\ +inline ctype& operator-=(ctype& x, rtype y) \ +{ \ + x.real -= y; \ + return x; \ +} \ +\ +inline ctype& operator-=(ctype& x, ctype y) \ +{ \ + x.real -= y.real; x.imag -= y.imag; \ + return x; \ +} \ +\ +inline ctype& operator*=(ctype& x, rtype y) \ +{ \ + x.real *= y; x.imag *= y; \ + return x; \ +} \ +\ +inline ctype& operator*=(ctype& x, ctype y) \ +{ \ + x = x * y; \ + return x; \ +} \ +\ +inline ctype& operator/=(ctype& x, rtype y) \ +{ \ + x.real /= y; x.imag /= y; \ + return x; \ +} \ +\ +inline ctype& operator/=(ctype& x, ctype y) \ +{ \ + x = x / y; \ + return x; \ +} + +COMPLEX_MATH_OPS(float, scomplex); +COMPLEX_MATH_OPS(double, dcomplex); + diff --git a/test/syrk_diagonal/syrk_diagonal_example.c b/test/syrk_diagonal/syrk_diagonal_example.c new file mode 100644 index 0000000000..c2bfd8fa19 --- /dev/null +++ b/test/syrk_diagonal/syrk_diagonal_example.c @@ -0,0 +1,186 @@ +#include "syrk_diagonal_ref.h" + +/* + * Structure which includes all additional information beyond what is + * already stored in the obj_t structure. + * + * This structure is **read-only** during the operation! + */ +typedef struct packm_diag_params_t +{ + packm_blk_var1_params_t super; + void* d; + inc_t incd; +} packm_diag_params_t; + +/* + * Declare the pack kernel type and set up and array of + * packing kernels, one for each data type. + */ +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +void PASTEMAC(ch,op) \ + ( \ + struc_t struca, \ + diag_t diaga, \ + uplo_t uploa, \ + conj_t conja, \ + pack_t schema, \ + bool invdiag, \ + dim_t panel_dim, \ + dim_t panel_len, \ + dim_t panel_dim_max, \ + dim_t panel_len_max, \ + dim_t panel_dim_off, \ + dim_t panel_len_off, \ + void* restrict kappa, \ + void* restrict a, inc_t inca, inc_t lda, \ + void* restrict p, inc_t ldp, \ + inc_t is_p, \ + cntx_t* cntx, \ + void* params \ + ) \ +{ \ + packm_diag_params_t* params_cast = params; \ + ctype* restrict a_cast = a; \ + ctype* restrict p_cast = p; \ + ctype* restrict d_cast = params_cast->d; \ + inc_t incd = params_cast->incd; \ + ctype kappa_cast = *( ctype* )kappa; \ +\ + if ( schema != BLIS_PACKED_ROW_PANELS && \ + schema != BLIS_PACKED_COL_PANELS ) \ + bli_abort(); \ +\ + /* Apply the offset */ \ + d_cast += panel_len_off * incd; \ +\ + if ( conja ) \ + { \ + for ( dim_t j = 0; j < panel_len; j++ ) \ + { \ + ctype kappa_d; \ + PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ +\ + for (dim_t i = 0;i < panel_dim;i++) \ + PASTEMAC(ch,scal2js)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ +\ + for (dim_t i = panel_dim;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + } \ + } \ + else \ + { \ + for ( dim_t j = 0; j < panel_len; j++ ) \ + { \ + ctype kappa_d; \ + PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ +\ + for (dim_t i = 0;i < panel_dim;i++) \ + PASTEMAC(ch,scal2s)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ +\ + for (dim_t i = panel_dim;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + } \ + } \ +\ + for (dim_t j = panel_len;j < panel_len_max;j++) \ + for (dim_t i = 0;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ +} + +INSERT_GENTFUNC_BASIC0(packm_diag_ukr); + +static packm_ker_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr ); + +/* + * Modify the object A to include information about the diagonal D, + * and imbue it with special function pointers which will take care + * of the actual work of forming (D * A^T) + */ +void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a ) +{ + memset( params, 0, sizeof(*params) ); + + // Assumes D is a column vector + params->d = bli_obj_buffer_at_off( d ); + params->incd = bli_obj_row_stride( d ); + + for ( int i = BLIS_DT_LO; i <= BLIS_DT_HI; i++ ) + params->super.ukr_fn[i][i] = packm_diag_ukrs[i]; + + // Attach the parameters to the A object. + bli_obj_set_pack_params( params, a ); +} + +/* + * Implements C := alpha * A * D * A^T + beta * C + * + * where D is a diagonal matrix with elements taken from the "d" vector. + */ +void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ) +{ + obj_t ad; // this is (D * A^T) + packm_diag_params_t params; + + bli_obj_alias_to( a, &ad ); + bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T + attach_diagonal_factor( ¶ms, d, &ad ); + + // Does C := alpha * A * B + beta * C using B = (D + A^T) + bli_gemmtnat( alpha, a, &ad, beta, c, NULL, NULL ); +} + +int main( void ) +{ + obj_t a; + obj_t d; + obj_t c; + obj_t c_copy; + obj_t norm; + + dim_t m = 10; + dim_t k = 10; + + for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) + for ( int upper = 0; upper <= 1; upper++ ) + for ( int transa = 0; transa <= 1; transa++ ) + for ( int transc = 0; transc <= 1; transc++ ) + { + num_t dt = dt_; + uplo_t uplo = upper ? BLIS_UPPER : BLIS_LOWER; + + bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a ); + bli_obj_create( dt, k, 1, 1, 1, &d ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy ); + bli_obj_set_uplo( uplo , &c ); + bli_obj_set_uplo( uplo , &c_copy ); + bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); + + bli_randm( &a ); + bli_randm( &d ); + bli_randm( &c ); + bli_copym( &c, &c_copy ); + + syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c ); + syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy ); + + bli_subm( &c_copy, &c ); + bli_normfm( &c, &norm ); + + double normr, normi; + bli_getsc( &norm, &normr, &normi ); + + printf( "dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n", + dt, upper, transa, transc, normr ); + + bli_obj_free( &a ); + bli_obj_free( &d ); + bli_obj_free( &c ); + bli_obj_free( &c_copy ); + bli_obj_free( &norm ); + } +} diff --git a/test/syrk_diagonal/syrk_diagonal_example.cxx b/test/syrk_diagonal/syrk_diagonal_example.cxx new file mode 100644 index 0000000000..1c269d5c48 --- /dev/null +++ b/test/syrk_diagonal/syrk_diagonal_example.cxx @@ -0,0 +1,220 @@ +#include "syrk_diagonal_ref.h" + +/* + * Forward-declare the pack kernel type and set up and array of + * packing kernels, one for each data type. + */ +template +void packm_diag_ukr + ( + struc_t /*struca*/, + diag_t /*diaga*/, + uplo_t /*uploa*/, + conj_t conja, + pack_t schema, + bool /*invdiag*/, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + dim_t /*panel_dim_off*/, + dim_t panel_len_off, + void* restrict kappa, + void* restrict a, inc_t inca, inc_t lda, + void* restrict p, inc_t ldp, + inc_t /*is_p*/, + cntx_t* /*cntx*/, + void* params + ); + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &packm_diag_ukr; + +INSERT_GENTFUNC_BASIC0(packm_diag_ukr); + +static packm_ker_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr ); + +/* + * Structure which includes all additional information beyond what is + * already stored in the obj_t structure. + * + * This structure is **read-only** during the operation! + */ +struct packm_diag_params_t : packm_blk_var1_params_t +{ + void* d; + inc_t incd; + + packm_diag_params_t() {} + + packm_diag_params_t( void* d, inc_t incd ) + : d(d), incd(incd) + { + for ( int i = BLIS_DT_LO; i <= BLIS_DT_HI; i++ ) + ukr_fn[i][i] = packm_diag_ukrs[i]; + } +}; + +/* + * Selecting a different kernel based on the current architecture is + * currently not possible, but is something we plan to support. + */ +template +void packm_diag_ukr + ( + struc_t /*struca*/, + diag_t /*diaga*/, + uplo_t /*uploa*/, + conj_t conja, + pack_t schema, + bool /*invdiag*/, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + dim_t /*panel_dim_off*/, + dim_t panel_len_off, + void* restrict kappa, + void* restrict a, inc_t inca, inc_t lda, + void* restrict p, inc_t ldp, + inc_t /*is_p*/, + cntx_t* /*cntx*/, + void* params + ) +{ + auto params_cast = ( packm_diag_params_t* )params; + T* restrict a_cast = ( T* )a; + T* restrict p_cast = ( T* )p; + T* restrict d_cast = ( T* )params_cast->d; + auto incd = params_cast->incd; + auto kappa_cast = *( T* )kappa; + + if ( schema != BLIS_PACKED_ROW_PANELS && + schema != BLIS_PACKED_COL_PANELS ) + bli_abort(); + + /* Apply the offset */ + d_cast += panel_len_off * incd; + + if ( conja ) + { + for ( dim_t j = 0; j < panel_len; j++ ) + { + auto kappa_d = kappa_cast * d_cast[ j*incd ]; + + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i + j*ldp ] = kappa_d * conj( a_cast[ i*inca + j*lda ] ); + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); + } + } + else + { + for ( dim_t j = 0; j < panel_len; j++ ) + { + auto kappa_d = kappa_cast * d_cast[ j*incd ]; + + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i + j*ldp ] = kappa_d * a_cast[ i*inca + j*lda ]; + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); + } + } + + for (dim_t j = panel_len;j < panel_len_max;j++) + for (dim_t i = 0;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); +} + +/* + * Modify the object A to include information about the diagonal D, + * and imbue it with special function pointers which will take care + * of the actual work of forming (D * A^T) + */ +void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a ) +{ + // Assumes D is a column vector + new (params) packm_diag_params_t + ( + bli_obj_buffer_at_off( d ), + bli_obj_row_stride( d ) + ); + + // Attach the parameters to the A object. + bli_obj_set_pack_params( params, a ); +} + +/* + * Implements C := alpha * A * D * A^T + beta * C + * + * where D is a diagonal matrix with elements taken from the "d" vector. + */ +void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ) +{ + obj_t ad; // this is (D * A^T) + packm_diag_params_t params; + + bli_obj_alias_to( a, &ad ); + bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T + attach_diagonal_factor( ¶ms, d, &ad ); + + // Does C := alpha * A * B + beta * C using B = (D + A^T) + bli_gemmtnat( alpha, a, &ad, beta, c, NULL, NULL ); +} + +int main() +{ + obj_t a; + obj_t d; + obj_t c; + obj_t c_copy; + obj_t norm; + + auto m = 10; + auto k = 10; + + for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) + for ( int upper = 0; upper <= 1; upper++ ) + for ( int transa = 0; transa <= 1; transa++ ) + for ( int transc = 0; transc <= 1; transc++ ) + { + auto dt = ( num_t )dt_; + auto uplo = upper ? BLIS_UPPER : BLIS_LOWER; + + bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a ); + bli_obj_create( dt, k, 1, 1, 1, &d ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy ); + bli_obj_set_uplo( uplo , &c ); + bli_obj_set_uplo( uplo , &c_copy ); + bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); + + bli_randm( &a ); + bli_randm( &d ); + bli_randm( &c ); + bli_copym( &c, &c_copy ); + + syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c ); + syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy ); + + bli_subm( &c_copy, &c ); + bli_normfm( &c, &norm ); + + double normr, normi; + bli_getsc( &norm, &normr, &normi ); + + printf("dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n", + dt, upper, transa, transc, normr); + + bli_obj_free( &a ); + bli_obj_free( &d ); + bli_obj_free( &c ); + bli_obj_free( &c_copy ); + bli_obj_free( &norm ); + } +} diff --git a/test/syrk_diagonal/syrk_diagonal_example2.c b/test/syrk_diagonal/syrk_diagonal_example2.c new file mode 100644 index 0000000000..92371f48b0 --- /dev/null +++ b/test/syrk_diagonal/syrk_diagonal_example2.c @@ -0,0 +1,354 @@ +#include "syrk_diagonal_ref.h" + +/* + * Structure which includes all additional information beyond what is + * already stored in the obj_t structure. + * + * This structure is **read-only** during the operation! + */ +typedef struct packm_diag_params_t +{ + void* d; + inc_t incd; +} packm_diag_params_t; + +typedef void (*packm_diag_ukr_vft) + ( + bool conja, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + void* restrict kappa, + void* restrict d, inc_t incd, + void* restrict a, inc_t inca, inc_t lda, + void* restrict p, inc_t ldp + ); + +/* + * Declare the pack kernel type and set up and array of + * packing kernels, one for each data type. + */ +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +void PASTEMAC(ch,op) \ + ( \ + bool conja, \ + dim_t panel_dim, \ + dim_t panel_len, \ + dim_t panel_dim_max, \ + dim_t panel_len_max, \ + void* restrict kappa, \ + void* restrict d, inc_t incd, \ + void* restrict a, inc_t inca, inc_t lda, \ + void* restrict p, inc_t ldp \ + ) \ +{ \ + ctype* restrict a_cast = a; \ + ctype* restrict p_cast = p; \ + ctype* restrict d_cast = d; \ + ctype kappa_cast = *( ctype* )kappa; \ +\ + if ( conja ) \ + { \ + for ( dim_t j = 0; j < panel_len; j++ ) \ + { \ + ctype kappa_d; \ + PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ +\ + for (dim_t i = 0;i < panel_dim;i++) \ + PASTEMAC(ch,scal2js)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ +\ + for (dim_t i = panel_dim;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + } \ + } \ + else \ + { \ + for ( dim_t j = 0; j < panel_len; j++ ) \ + { \ + ctype kappa_d; \ + PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ +\ + for (dim_t i = 0;i < panel_dim;i++) \ + PASTEMAC(ch,scal2s)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ +\ + for (dim_t i = panel_dim;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + } \ + } \ +\ + for (dim_t j = panel_len;j < panel_len_max;j++) \ + for (dim_t i = 0;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ +} + +INSERT_GENTFUNC_BASIC0(packm_diag_ukr); + +static packm_diag_ukr_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr ); + +void packm_diag + ( + obj_t* a, + obj_t* p, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ +#if 1 + + // We begin by copying the fields of A. + bli_obj_alias_to( a, p ); + + // Get information about data types. + num_t dt = bli_obj_dt( a ); + num_t dt_tar = bli_obj_target_dt( a ); + num_t dt_scalar = bli_obj_scalar_dt( a ); + dim_t dt_size = bli_dt_size( dt ); + + if ( dt_scalar != dt || dt_tar != dt ) + bli_abort(); + + // Extract various fields from the control tree. + bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl ); + bszid_t bmult_id_n = bli_cntl_packm_params_bmid_n( cntl ); + pack_t schema = bli_cntl_packm_params_pack_schema( cntl ); + dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx ); + dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx ); + dim_t bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx ); + + if ( schema != BLIS_PACKED_ROW_PANELS && + schema != BLIS_PACKED_COL_PANELS ) + bli_abort(); + + // Store the pack schema to the object. + bli_obj_set_pack_schema( schema, p ); + + // Clear the conjugation field from the object since matrix packing + // in BLIS is deemed to take care of all conjugation necessary. + bli_obj_set_conj( BLIS_NO_CONJUGATE, p ); + + // If we are packing micropanels, mark P as dense. + bli_obj_set_uplo( BLIS_DENSE, p ); + + // Reset the view offsets to (0,0). + bli_obj_set_offs( 0, 0, p ); + + // Compute the dimensions padded by the dimension multiples. These + // dimensions will be the dimensions of the packed matrices, including + // zero-padding, and will be used by the macro- and micro-kernels. + // We compute them by starting with the effective dimensions of A (now + // in P) and aligning them to the dimension multiples (typically equal + // to register blocksizes). This does waste a little bit of space for + // level-2 operations, but that's okay with us. + dim_t m_p = bli_obj_length( p ); + dim_t n_p = bli_obj_width( p ); + dim_t m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def ); + dim_t n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def ); + + // Save the padded dimensions into the packed object. It is important + // to save these dimensions since they represent the actual dimensions + // of the zero-padded matrix. + bli_obj_set_padded_dims( m_p_pad, n_p_pad, p ); + + // The "panel stride" of a micropanel packed object is interpreted as + // the distance between the (0,0) element of panel k and the (0,0) + // element of panel k+1. We use the padded width computed above to + // allow for zero-padding (if necessary/desired) along the far end + // of each micropanel (ie: the right edge of the matrix). Zero-padding + // can also occur along the long edge of the last micropanel if the m + // dimension of the matrix is not a whole multiple of MR. + inc_t ps_p = bmult_m_pack * n_p_pad; + + /* Compute the total number of iterations we'll need. */ + dim_t n_iter = m_p_pad / bmult_m_def; + + // Store the strides and panel dimension in P. + bli_obj_set_strides( 1, bmult_m_pack, p ); + bli_obj_set_imag_stride( 1, p ); + bli_obj_set_panel_dim( bmult_m_def, p ); + bli_obj_set_panel_stride( ps_p, p ); + bli_obj_set_panel_length( bmult_m_def, p ); + bli_obj_set_panel_width( n_p, p ); + + // Compute the size of the packed buffer. + siz_t size_p = ps_p * n_iter * dt_size; + if ( size_p == 0 ) return; + + // Update the buffer address in p to point to the buffer associated + // with the mem_t entry acquired from the memory broker (now cached in + // the control tree node). + char* p_cast = (char*)bli_packm_alloc( size_p, rntm, cntl, thread ); + bli_obj_set_buffer( p_cast, p ); + +#else + + // Every thread initializes p and determines the size of memory + // block needed (which gets embedded into the otherwise "blank" mem_t + // entry in the control tree node). Return early if no packing is required. + if ( !bli_packm_init( a, p, cntx, rntm, cntl, thread ) ) + return; + + num_t dt = bli_obj_dt( a ); + dim_t dt_size = bli_dt_size( dt ); + + bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl ); + dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt, bmult_id_m, cntx ); + dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt, bmult_id_m, cntx ); + + dim_t m_p = bli_obj_length( p ); + dim_t n_p = bli_obj_width( p ); + dim_t m_p_pad = bli_obj_padded_length( p ); + dim_t n_p_pad = bli_obj_padded_width( p ); + dim_t n_iter = m_p_pad / bmult_m_def; + + char* p_cast = bli_obj_buffer( p ); + inc_t ps_p = bli_obj_panel_stride( p ); + +#endif + + char* a_cast = bli_obj_buffer_at_off( a ); + inc_t inca = bli_obj_row_stride( a ); + inc_t lda = bli_obj_col_stride( a ); + dim_t panel_len_off = bli_obj_col_off( a ); + conj_t conja = bli_obj_conj_status( a ); + + packm_diag_params_t* params = bli_obj_pack_params( a ); + char* d_cast = params->d; + inc_t incd = params->incd; + + obj_t kappa_local; + char* kappa_cast = bli_packm_scalar( &kappa_local, p ); + + packm_diag_ukr_vft packm_ker_cast = packm_diag_ukrs[ dt ]; + + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ + const dim_t nt = bli_thread_n_way( thread ); + const dim_t tid = bli_thread_work_id( thread ); + + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ + dim_t it_start, it_end, it_inc; + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); + + /* Iterate over every logical micropanel in the source matrix. */ + for ( dim_t it = 0; it < n_iter; it += 1 ) + { + dim_t panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def ); + + char* d_begin = d_cast + panel_len_off*incd*dt_size; + char* a_begin = a_cast + it* bmult_m_def*inca*dt_size; + char* p_begin = p_cast + it* ps_p*dt_size; + + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) + { + packm_ker_cast + ( + conja, + panel_dim_i, + n_p, + bmult_m_def, + n_p_pad, + kappa_cast, + d_begin, incd, + a_begin, inca, lda, + p_begin, bmult_m_pack + ); + } + } +} + +/* + * Modify the object A to include information about the diagonal D, + * and imbue it with special function pointers which will take care + * of the actual work of forming (D * A^T) + */ +void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a ) +{ + // Assumes D is a column vector + params->d = bli_obj_buffer_at_off( d ); + params->incd = bli_obj_row_stride( d ); + + // Set the custom pack function. + bli_obj_set_pack_fn( packm_diag, a ); + + // Attach the parameters to the A object. + bli_obj_set_pack_params( params, a ); +} + +/* + * Implements C := alpha * A * D * A^T + beta * C + * + * where D is a diagonal matrix with elements taken from the "d" vector. + */ +void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ) +{ + obj_t ad; // this is (D * A^T) + packm_diag_params_t params; + + bli_obj_alias_to( a, &ad ); + bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T + attach_diagonal_factor( ¶ms, d, &ad ); + + // Does C := alpha * A * B + beta * C using B = (D + A^T) + bli_gemmt( alpha, a, &ad, beta, c ); +} + +int main( void ) +{ + obj_t a; + obj_t d; + obj_t c; + obj_t c_copy; + obj_t norm; + + dim_t m = 10; + dim_t k = 10; + + for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) + for ( int upper = 0; upper <= 1; upper++ ) + for ( int transa = 0; transa <= 1; transa++ ) + for ( int transc = 0; transc <= 1; transc++ ) + { + num_t dt = dt_; + uplo_t uplo = upper ? BLIS_UPPER : BLIS_LOWER; + + bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a ); + bli_obj_create( dt, k, 1, 1, 1, &d ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy ); + bli_obj_set_uplo( uplo , &c ); + bli_obj_set_uplo( uplo , &c_copy ); + bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); + + bli_randm( &a ); + bli_randm( &d ); + bli_randm( &c ); + bli_copym( &c, &c_copy ); + + syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c ); + syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy ); + + bli_subm( &c_copy, &c ); + bli_normfm( &c, &norm ); + + double normr, normi; + bli_getsc( &norm, &normr, &normi ); + + printf( "dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n", + dt, upper, transa, transc, normr ); + + bli_obj_free( &a ); + bli_obj_free( &d ); + bli_obj_free( &c ); + bli_obj_free( &c_copy ); + bli_obj_free( &norm ); + } +} diff --git a/test/syrk_diagonal/syrk_diagonal_example2.cxx b/test/syrk_diagonal/syrk_diagonal_example2.cxx new file mode 100644 index 0000000000..8312a07ee8 --- /dev/null +++ b/test/syrk_diagonal/syrk_diagonal_example2.cxx @@ -0,0 +1,338 @@ +#include "syrk_diagonal_ref.h" + +/* + * Forward-declare the pack kernel type and set up and array of + * packing kernels, one for each data type. + */ +template +void packm_diag_ukr + ( + bool conja, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + void* restrict kappa, + void* restrict d, inc_t incd, + void* restrict a, inc_t inca, inc_t lda, + void* restrict p, inc_t ldp + ); + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &packm_diag_ukr; + +INSERT_GENTFUNC_BASIC0(packm_diag_ukr); + +using packm_diag_ukr_vft = decltype(&packm_diag_ukr); +static packm_diag_ukr_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr ); + +/* + * Structure which includes all additional information beyond what is + * already stored in the obj_t structure. + * + * This structure is **read-only** during the operation! + */ +struct packm_diag_params_t +{ + void* d; + inc_t incd; + + packm_diag_params_t() {} + + packm_diag_params_t( void* d, inc_t incd ) + : d(d), incd(incd) {} +}; + +/* + * Selecting a different kernel based on the current architecture is + * currently not possible, but is something we plan to support. + */ +template +void packm_diag_ukr + ( + bool conja, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + void* restrict kappa, + void* restrict d, inc_t incd, + void* restrict a, inc_t inca, inc_t lda, + void* restrict p, inc_t ldp + ) +{ + T* restrict a_cast = ( T* )a; + T* restrict p_cast = ( T* )p; + T* restrict d_cast = ( T* )d; + auto kappa_cast = *( T* )kappa; + + if ( conja ) + { + for ( dim_t j = 0; j < panel_len; j++ ) + { + auto kappa_d = kappa_cast * d_cast[ j*incd ]; + + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i + j*ldp ] = kappa_d * conj( a_cast[ i*inca + j*lda ] ); + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); + } + } + else + { + for ( dim_t j = 0; j < panel_len; j++ ) + { + auto kappa_d = kappa_cast * d_cast[ j*incd ]; + + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i + j*ldp ] = kappa_d * a_cast[ i*inca + j*lda ]; + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); + } + } + + for (dim_t j = panel_len;j < panel_len_max;j++) + for (dim_t i = 0;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); +} + +void packm_diag + ( + obj_t* a, + obj_t* p, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + // We begin by copying the fields of A. + bli_obj_alias_to( a, p ); + + // Get information about data types. + num_t dt = bli_obj_dt( a ); + num_t dt_tar = bli_obj_target_dt( a ); + num_t dt_scalar = bli_obj_scalar_dt( a ); + dim_t dt_size = bli_dt_size( dt ); + + if ( dt_scalar != dt || dt_tar != dt ) + bli_abort(); + + // Extract various fields from the control tree. + bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl ); + bszid_t bmult_id_n = bli_cntl_packm_params_bmid_n( cntl ); + pack_t schema = bli_cntl_packm_params_pack_schema( cntl ); + dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx ); + dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx ); + dim_t bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx ); + + if ( schema != BLIS_PACKED_ROW_PANELS && + schema != BLIS_PACKED_COL_PANELS ) + bli_abort(); + + // Store the pack schema to the object. + bli_obj_set_pack_schema( schema, p ); + + // Clear the conjugation field from the object since matrix packing + // in BLIS is deemed to take care of all conjugation necessary. + bli_obj_set_conj( BLIS_NO_CONJUGATE, p ); + + // If we are packing micropanels, mark P as dense. + bli_obj_set_uplo( BLIS_DENSE, p ); + + // Reset the view offsets to (0,0). + bli_obj_set_offs( 0, 0, p ); + + // Compute the dimensions padded by the dimension multiples. These + // dimensions will be the dimensions of the packed matrices, including + // zero-padding, and will be used by the macro- and micro-kernels. + // We compute them by starting with the effective dimensions of A (now + // in P) and aligning them to the dimension multiples (typically equal + // to register blocksizes). This does waste a little bit of space for + // level-2 operations, but that's okay with us. + dim_t m_p = bli_obj_length( p ); + dim_t n_p = bli_obj_width( p ); + dim_t m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def ); + dim_t n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def ); + + // Save the padded dimensions into the packed object. It is important + // to save these dimensions since they represent the actual dimensions + // of the zero-padded matrix. + bli_obj_set_padded_dims( m_p_pad, n_p_pad, p ); + + // The "panel stride" of a micropanel packed object is interpreted as + // the distance between the (0,0) element of panel k and the (0,0) + // element of panel k+1. We use the padded width computed above to + // allow for zero-padding (if necessary/desired) along the far end + // of each micropanel (ie: the right edge of the matrix). Zero-padding + // can also occur along the long edge of the last micropanel if the m + // dimension of the matrix is not a whole multiple of MR. + inc_t ps_p = bmult_m_pack * n_p_pad; + + /* Compute the total number of iterations we'll need. */ + dim_t n_iter = m_p_pad / bmult_m_def; + + // Store the strides and panel dimension in P. + bli_obj_set_strides( 1, bmult_m_pack, p ); + bli_obj_set_imag_stride( 1, p ); + bli_obj_set_panel_dim( bmult_m_def, p ); + bli_obj_set_panel_stride( ps_p, p ); + bli_obj_set_panel_length( bmult_m_def, p ); + bli_obj_set_panel_width( n_p, p ); + + // Compute the size of the packed buffer. + siz_t size_p = ps_p * n_iter * dt_size; + if ( size_p == 0 ) return; + + // Update the buffer address in p to point to the buffer associated + // with the mem_t entry acquired from the memory broker (now cached in + // the control tree node). + char* p_cast = (char*)bli_packm_alloc( size_p, rntm, cntl, thread ); + bli_obj_set_buffer( p_cast, p ); + + char* a_cast = (char*)bli_obj_buffer_at_off( a ); + inc_t inca = bli_obj_row_stride( a ); + inc_t lda = bli_obj_col_stride( a ); + dim_t panel_len_off = bli_obj_col_off( a ); + conj_t conja = bli_obj_conj_status( a ); + + auto params = (packm_diag_params_t*)bli_obj_pack_params( a ); + char* d_cast = (char*)params->d; + inc_t incd = params->incd; + + obj_t kappa_local; + char* kappa_cast = (char*)bli_packm_scalar( &kappa_local, p ); + + auto packm_ker_cast = packm_diag_ukrs[ dt ]; + + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ + const dim_t nt = bli_thread_n_way( thread ); + const dim_t tid = bli_thread_work_id( thread ); + + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ + dim_t it_start, it_end, it_inc; + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); + + /* Iterate over every logical micropanel in the source matrix. */ + for ( dim_t it = 0; it < n_iter; it += 1 ) + { + dim_t panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def ); + + char* d_begin = d_cast + panel_len_off*incd*dt_size; + char* a_begin = a_cast + it* bmult_m_def*inca*dt_size; + char* p_begin = p_cast + it* ps_p*dt_size; + + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) + { + packm_ker_cast( conja, + panel_dim_i, + n_p, + bmult_m_def, + n_p_pad, + kappa_cast, + d_begin, incd, + a_begin, inca, lda, + p_begin, bmult_m_pack ); + } + } +} + +/* + * Modify the object A to include information about the diagonal D, + * and imbue it with special function pointers which will take care + * of the actual work of forming (D * A^T) + */ +void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a ) +{ + // Assumes D is a column vector + new (params) packm_diag_params_t + ( + bli_obj_buffer_at_off( d ), + bli_obj_row_stride( d ) + ); + + // Set the custom pack function. + bli_obj_set_pack_fn( packm_diag, a ); + + // Attach the parameters to the A object. + bli_obj_set_pack_params( params, a ); +} + +/* + * Implements C := alpha * A * D * A^T + beta * C + * + * where D is a diagonal matrix with elements taken from the "d" vector. + */ +void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ) +{ + obj_t ad; // this is (D * A^T) + packm_diag_params_t params; + + bli_obj_alias_to( a, &ad ); + bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T + attach_diagonal_factor( ¶ms, d, &ad ); + + // Does C := alpha * A * B + beta * C using B = (D + A^T) + bli_gemmt( alpha, a, &ad, beta, c ); +} + +int main() +{ + obj_t a; + obj_t d; + obj_t c; + obj_t c_copy; + obj_t norm; + + auto m = 10; + auto k = 10; + + for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) + for ( int upper = 0; upper <= 1; upper++ ) + for ( int transa = 0; transa <= 1; transa++ ) + for ( int transc = 0; transc <= 1; transc++ ) + { + auto dt = ( num_t )dt_; + auto uplo = upper ? BLIS_UPPER : BLIS_LOWER; + + bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a ); + bli_obj_create( dt, k, 1, 1, 1, &d ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy ); + bli_obj_set_uplo( uplo , &c ); + bli_obj_set_uplo( uplo , &c_copy ); + bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); + + bli_randm( &a ); + bli_randm( &d ); + bli_randm( &c ); + bli_copym( &c, &c_copy ); + + syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c ); + syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy ); + + bli_subm( &c_copy, &c ); + bli_normfm( &c, &norm ); + + double normr, normi; + bli_getsc( &norm, &normr, &normi ); + + printf("dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n", + dt, upper, transa, transc, normr); + + bli_obj_free( &a ); + bli_obj_free( &d ); + bli_obj_free( &c ); + bli_obj_free( &c_copy ); + bli_obj_free( &norm ); + } +} diff --git a/test/syrk_diagonal/syrk_diagonal_ref.cxx b/test/syrk_diagonal/syrk_diagonal_ref.cxx new file mode 100644 index 0000000000..1d7c5d96e5 --- /dev/null +++ b/test/syrk_diagonal/syrk_diagonal_ref.cxx @@ -0,0 +1,102 @@ +#include "syrk_diagonal_ref.h" +#include "complex_math.hpp" + +typedef void (*syrk_diag_ref_vft) + ( + uplo_t uplo, + dim_t m, + dim_t k, + void* alpha, + void* a, inc_t rs_a, inc_t cs_a, + void* d, inc_t incd, + void* beta, + void* c, inc_t rs_c, inc_t cs_c + ); + +template +void syrk_diag_ref + ( + uplo_t uplo, + dim_t m, + dim_t k, + void* alpha, + void* a, inc_t rs_a, inc_t cs_a, + void* d, inc_t incd, + void* beta, + void* c, inc_t rs_c, inc_t cs_c + ) +{ + auto alpha_cast = *( T* )alpha; + auto beta_cast = *( T* )beta; + auto a_cast = ( T* )a; + auto d_cast = ( T* )d; + auto c_cast = ( T* )c; + + for ( dim_t i = 0; i < m; i++ ) + { + dim_t j_min = uplo == BLIS_UPPER ? i : 0; + dim_t j_max = uplo == BLIS_UPPER ? m : i+1; + + for ( dim_t j = j_min; j < j_max; j++ ) + { + auto ada = convert(0.0); + + for ( dim_t p = 0; p < k; p++ ) + { + ada += a_cast[ i*rs_a + p*cs_a ] * + d_cast[ p*incd ] * + a_cast[ j*rs_a + p*cs_a ]; + } + + if ( beta_cast == convert(0.0) ) + { + c_cast[ i*rs_c + j*cs_c ] = alpha_cast * ada; + } + else + { + c_cast[ i*rs_c + j*cs_c ] = alpha_cast * ada + + beta_cast * c_cast[ i*rs_c + j*cs_c ]; + } + } + } +} + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &syrk_diag_ref; + +INSERT_GENTFUNC_BASIC0(syrk_diag_ref); + +static syrk_diag_ref_vft GENARRAY( syrk_diag_ref_impl, syrk_diag_ref ); + +void syrk_diag_ref( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ) +{ + num_t dt = bli_obj_dt( a ); + + dim_t m = bli_obj_length_after_trans( a ); + dim_t k = bli_obj_width_after_trans( a ); + + inc_t rs_a = bli_obj_row_stride( a ); + inc_t cs_a = bli_obj_col_stride( a ); + inc_t rs_c = bli_obj_row_stride( c ); + inc_t cs_c = bli_obj_col_stride( c ); + inc_t incd = bli_obj_row_stride( d ); + + if ( bli_obj_has_trans( a ) ) + bli_swap_incs( &rs_a, &cs_a ); + + if ( bli_obj_has_trans( c ) ) + bli_swap_incs( &rs_c, &cs_c ); + + syrk_diag_ref_impl[ dt ] + ( + bli_obj_uplo( c ), + m, k, + bli_obj_buffer_for_1x1( dt, alpha ), + bli_obj_buffer_at_off( a ), rs_a, cs_a, + bli_obj_buffer_at_off( d ), incd, + bli_obj_buffer_for_1x1( dt, beta ), + bli_obj_buffer_at_off( c ), rs_c, cs_c + ); +} + diff --git a/test/syrk_diagonal/syrk_diagonal_ref.h b/test/syrk_diagonal/syrk_diagonal_ref.h new file mode 100644 index 0000000000..a6864caec8 --- /dev/null +++ b/test/syrk_diagonal/syrk_diagonal_ref.h @@ -0,0 +1,8 @@ +#include "blis.h" + +#ifdef __cplusplus +#include "complex_math.hpp" +extern "C" +#endif +void syrk_diag_ref( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ); + diff --git a/test/tensor_contraction/complex_math.hpp b/test/tensor_contraction/complex_math.hpp new file mode 100644 index 0000000000..9c68e730aa --- /dev/null +++ b/test/tensor_contraction/complex_math.hpp @@ -0,0 +1,267 @@ +#include +#include +#include + +#include "blis.h" + +template +struct is_complex : std::false_type {}; + +template <> +struct is_complex : std::true_type {}; + +template <> +struct is_complex : std::true_type {}; + +template +struct is_real : std::integral_constant::value> {}; + +template struct make_complex; + +template <> struct make_complex { using type = scomplex; }; +template <> struct make_complex { using type = dcomplex; }; +template <> struct make_complex { using type = scomplex; }; +template <> struct make_complex { using type = dcomplex; }; + +template +using make_complex_t = typename make_complex::type; + +template struct make_real; + +template <> struct make_real { using type = float; }; +template <> struct make_real { using type = double; }; +template <> struct make_real { using type = float; }; +template <> struct make_real { using type = double; }; + +template +using make_real_t = typename make_real::type; + +template +struct make_complex_if : std::conditional,make_real_t> {}; + +template +using make_complex_if_t = typename make_complex_if::type; + +template +struct real_imag_part +{ + real_imag_part& operator=(T) { return *this; } + + operator T() const { return T(); } +}; + +template +std::enable_if_t::type>::value,T&> real(T& x) { return x; } + +template +std::enable_if_t::value,real_imag_part> imag(T x) { return {}; } + +inline float& real(scomplex& x) { return x.real; } + +inline float& imag(scomplex& x) { return x.imag; } + +inline double& real(dcomplex& x) { return x.real; } + +inline double& imag(dcomplex& x) { return x.imag; } + +inline const float& real(const scomplex& x) { return x.real; } + +inline const float& imag(const scomplex& x) { return x.imag; } + +inline const double& real(const dcomplex& x) { return x.real; } + +inline const double& imag(const dcomplex& x) { return x.imag; } + +template +std::enable_if_t::value,T> conj(T x) { return x; } + +template +std::enable_if_t::value,T> conj(const T& x) { return {x.real, -x.imag}; } + +template +struct convert_impl; + +template +struct convert_impl::value && is_real::value>> +{ + void operator()(T x, U& y) const { y = x; } +}; + +template +struct convert_impl::value && is_complex::value>> +{ + void operator()(T x, U& y) const { y.real = x; y.imag = 0; } +}; + +template +struct convert_impl::value && is_real::value>> +{ + void operator()(T x, U& y) const { y = x.real; } +}; + +template +struct convert_impl::value && is_complex::value>> +{ + void operator()(T x, U& y) const { y.real = x.real; y.imag = x.imag; } +}; + +template +U convert(T x) +{ + U y; + convert_impl{}(x,y); + return y; +} + +template +auto convert_prec(T x) -> make_complex_if_t::value> +{ + return convert::value>>(x); +} + +#define COMPLEX_MATH_OPS(rtype, ctype) \ +\ +inline bool operator==(rtype x, ctype y) \ +{ \ + return x == y.real && y.imag == 0; \ +} \ +\ +inline bool operator==(ctype x, rtype y) \ +{ \ + return y == x.real && x.imag == 0; \ +} \ +\ +inline bool operator==(ctype x, ctype y) \ +{ \ + return x.real == y.real && \ + x.imag == y.imag; \ + } \ + \ +inline ctype operator-(ctype x) \ +{ \ + return {-x.real, -x.imag}; \ +} \ +\ +inline ctype operator+(rtype x, ctype y) \ +{ \ + return {x+y.real, y.imag}; \ +} \ +\ +inline ctype operator+(ctype x, rtype y) \ +{ \ + return {y+x.real, x.imag}; \ +} \ +\ +inline ctype operator+(ctype x, ctype y) \ +{ \ + return {x.real+y.real, x.imag+y.imag}; \ +} \ +\ +inline ctype operator-(rtype x, ctype y) \ +{ \ + return {x-y.real, -y.imag}; \ +} \ +\ +inline ctype operator-(ctype x, rtype y) \ +{ \ + return {x.real-y, x.imag}; \ +} \ +\ +inline ctype operator-(ctype x, ctype y) \ +{ \ + return {x.real-y.real, x.imag-y.imag}; \ +} \ +\ +inline ctype operator*(rtype x, ctype y) \ +{ \ + return {x*y.real, x*y.imag}; \ +} \ +\ +inline ctype operator*(ctype x, rtype y) \ +{ \ + return {y*x.real, y*x.imag}; \ +} \ +\ +inline ctype operator*(ctype x, ctype y) \ +{ \ + return {x.real*y.real - x.imag*y.imag, \ + x.real*y.imag + x.imag*y.real}; \ +} \ +\ +inline ctype operator/(rtype x, ctype y) \ +{ \ + auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \ + auto n = std::ilogb(scale); \ + auto yrs = std::scalbn(y.real, -n); \ + auto yis = std::scalbn(y.imag, -n); \ + auto denom = y.real*yrs + y.imag*yis; \ + return {x*yrs/denom, -x*yis/denom}; \ +} \ +\ +inline ctype operator/(ctype x, rtype y) \ +{ \ + return {x.real/y, x.imag/y}; \ +} \ +\ +inline ctype operator/(ctype x, ctype y) \ +{ \ + auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \ + auto n = std::ilogb(scale); \ + auto yrs = std::scalbn(y.real, -n); \ + auto yis = std::scalbn(y.imag, -n); \ + auto denom = y.real*yrs + y.imag*yis; \ + return {(x.real*yrs + x.imag*yis)/denom, \ + (x.imag*yrs - x.real*yis)/denom}; \ +} \ +\ +inline ctype& operator+=(ctype& x, rtype y) \ +{ \ + x.real += y; \ + return x; \ +} \ +\ +inline ctype& operator+=(ctype& x, ctype y) \ +{ \ + x.real += y.real; x.imag += y.imag; \ + return x; \ +} \ +\ +inline ctype& operator-=(ctype& x, rtype y) \ +{ \ + x.real -= y; \ + return x; \ +} \ +\ +inline ctype& operator-=(ctype& x, ctype y) \ +{ \ + x.real -= y.real; x.imag -= y.imag; \ + return x; \ +} \ +\ +inline ctype& operator*=(ctype& x, rtype y) \ +{ \ + x.real *= y; x.imag *= y; \ + return x; \ +} \ +\ +inline ctype& operator*=(ctype& x, ctype y) \ +{ \ + x = x * y; \ + return x; \ +} \ +\ +inline ctype& operator/=(ctype& x, rtype y) \ +{ \ + x.real /= y; x.imag /= y; \ + return x; \ +} \ +\ +inline ctype& operator/=(ctype& x, ctype y) \ +{ \ + x = x / y; \ + return x; \ +} + +COMPLEX_MATH_OPS(float, scomplex); +COMPLEX_MATH_OPS(double, dcomplex); + diff --git a/test/tensor_contraction/tcontract_example.cxx b/test/tensor_contraction/tcontract_example.cxx new file mode 100644 index 0000000000..0b935c54d4 --- /dev/null +++ b/test/tensor_contraction/tcontract_example.cxx @@ -0,0 +1,988 @@ + +#include "tcontract_ref.hpp" + +#include +#include + +static constexpr dim_t BS_K = 8; + +struct packm_tensor_params_t +{ + gint_t ndim_m, ndim_n; + const dim_t *len_m, *len_n; + const inc_t *stride_m, *stride_n; + + packm_tensor_params_t() {} + + packm_tensor_params_t( gint_t ndim_m, const dim_t* len_m, const inc_t* stride_m, + gint_t ndim_n, const dim_t* len_n, const inc_t* stride_n ) + : ndim_m(ndim_m), ndim_n(ndim_n), + len_m(len_m), len_n(len_n), + stride_m(stride_m), stride_n(stride_n) {} +}; + +using gemm_tensor_params_t = packm_tensor_params_t; + +template +void packm_ckx_nb + ( + bool conja, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + void* kappa, + void* a, inc_t inca, inc_t* bsa, inc_t* scata, + void* p, inc_t ldp + ) +{ + T* restrict a_cast = ( T* )a; + T* restrict p_cast = ( T* )p; + auto kappa_cast = *( T* )kappa; + + if ( conja ) + { + for ( auto j0 = 0; j0 < panel_len; j0 += BS_K, bsa += BS_K, scata += BS_K ) + { + auto lda = *bsa; + auto panel_len_j = std::min( panel_len-j0, BS_K ); + + if ( lda ) + { + T* restrict aj = a_cast + *scata; + + for ( auto j = 0; j < panel_len_j; j++ ) + { + for ( auto i = 0; i < panel_dim; i++ ) + p_cast[ i ] = kappa_cast * conj( aj[ i*inca + j*lda ] ); + + for ( auto i = panel_dim; i < panel_dim_max; i++ ) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + else + { + for ( auto j = 0; j < panel_len_j; j++) + { + for ( auto i = 0; i < panel_dim; i++) + p_cast[ i ] = kappa_cast * conj( a_cast[ i*inca + scata[j] ] ); + + for ( auto i = panel_dim; i < panel_dim_max; i++) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + } + } + else + { + for ( auto j0 = 0; j0 < panel_len; j0 += BS_K, bsa += BS_K, scata += BS_K ) + { + auto lda = *bsa; + auto panel_len_j = std::min( panel_len-j0, BS_K ); + + if ( lda ) + { + T* restrict aj = a_cast + *scata; + + for ( auto j = 0; j < panel_len_j; j++ ) + { + for ( auto i = 0; i < panel_dim; i++ ) + p_cast[ i ] = kappa_cast * aj[ i*inca + j*lda ]; + + for ( auto i = panel_dim; i < panel_dim_max; i++ ) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + else + { + for ( auto j = 0; j < panel_len_j; j++ ) + { + for ( auto i = 0; i < panel_dim; i++ ) + p_cast[ i ] = kappa_cast * a_cast[ i*inca + scata[j] ]; + + for ( auto i = panel_dim; i < panel_dim_max; i++ ) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + } + } + + for ( auto j = panel_len; j < panel_len_max; j++) + { + for ( auto i = 0; i < panel_dim_max; i++) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } +} + +template +void packm_ckx_ss + ( + bool conja, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + void* kappa, + void* a, inc_t* inca, inc_t* scata, + void* p, inc_t ldp + ) +{ + T* restrict a_cast = ( T* )a; + T* restrict p_cast = ( T* )p; + auto kappa_cast = *( T* )kappa; + + if ( conja ) + { + for (dim_t j = 0;j < panel_len;j++) + { + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i ] = kappa_cast * conj( a_cast[ inca[i] + scata[j] ] ); + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + else + { + for (dim_t j = 0;j < panel_len;j++) + { + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i ] = kappa_cast * a_cast[ inca[i] + scata[j] ]; + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + + for (dim_t j = panel_len;j < panel_len_max;j++) + { + for (dim_t i = 0;i < panel_dim_max;i++) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } +} + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &packm_ckx_nb; + +INSERT_GENTFUNC_BASIC0(packm_ckx_nb); + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &packm_ckx_ss; + +INSERT_GENTFUNC_BASIC0(packm_ckx_ss); + +static decltype(&packm_ckx_nb) GENARRAY( packm_ckx_nb_ukrs, packm_ckx_nb ); +static decltype(&packm_ckx_ss) GENARRAY( packm_ckx_ss_ukrs, packm_ckx_ss ); + +static void fill_scatter + ( + gint_t ndim, + const dim_t* restrict len, + const inc_t* restrict stride, + dim_t BS, + inc_t off, + dim_t size, + inc_t* restrict scat, + inc_t* restrict bs + ) +{ + if ( size == 0 ) return; + + if ( ndim == 0 ) + { + *scat = 0; + *bs = 0; + return; + } + + if ( ndim == 1 ) + { + auto l = *len; + auto s = *stride; + for ( auto i = 0; i < l; i++ ) + { + scat[i] = i*s; + bs[i] = s; + } + } + + dim_t tot_len = 1; + for ( auto i = 0; i < ndim; i++ ) + tot_len *= len[i]; + + assert(off >= 0); + assert(size >= 0); + assert(off+size <= tot_len); + + auto len0 = len[0]; + auto stride0 = stride[0]; + auto off0 = off % len0; + auto off1 = off / len0; + auto size1 = ( size + off0 + len0 - 1) / len0; + + inc_t pos1 = 0; + inc_t idx = 0; + for_each( ndim-1, len+1, off1, size1, pos1, stride+1, + [&] + { + auto pos = pos1 + off0 * stride0; + auto len_i = std::min( len0-off0, size-idx ); + for ( auto i = 0; i < len_i; i++ ) + { + scat[idx++] = pos; + pos += stride0; + } + off0 = 0; + }); + assert(idx == size); + + for ( idx = 0; idx < size; idx += BS ) + { + auto len_i = std::min( BS, size-idx ); + auto s = stride0; + + for ( auto i = idx; i < idx+len_i-1; i++) + { + if (scat[i+1]-scat[i] != s) + { + s = 0; + break; + } + } + + bs[idx] = s; + } +} + +void packm_tensor + ( + obj_t* a, + obj_t* p, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + // We begin by copying the fields of A. + bli_obj_alias_to( a, p ); + + // Get information about data types. + auto dt = bli_obj_dt( a ); + auto dt_tar = bli_obj_target_dt( a ); + auto dt_scalar = bli_obj_scalar_dt( a ); + auto dt_size = bli_dt_size( dt ); + + if ( dt_scalar != dt || dt_tar != dt ) + bli_abort(); + + // Extract various fields from the control tree. + auto bmult_id_m = bli_cntl_packm_params_bmid_m( cntl ); + auto bmult_id_n = bli_cntl_packm_params_bmid_n( cntl ); + auto schema = bli_cntl_packm_params_pack_schema( cntl ); + auto bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx ); + auto bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx ); + auto bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx ); + + if ( schema != BLIS_PACKED_ROW_PANELS && + schema != BLIS_PACKED_COL_PANELS ) + bli_abort(); + + // Store the pack schema to the object. + bli_obj_set_pack_schema( schema, p ); + + // Clear the conjugation field from the object since matrix packing + // in BLIS is deemed to take care of all conjugation necessary. + bli_obj_set_conj( BLIS_NO_CONJUGATE, p ); + + // If we are packing micropanels, mark P as dense. + bli_obj_set_uplo( BLIS_DENSE, p ); + + // Reset the view offsets to (0,0). + bli_obj_set_offs( 0, 0, p ); + + // Compute the dimensions padded by the dimension multiples. These + // dimensions will be the dimensions of the packed matrices, including + // zero-padding, and will be used by the macro- and micro-kernels. + // We compute them by starting with the effective dimensions of A (now + // in P) and aligning them to the dimension multiples (typically equal + // to register blocksizes). This does waste a little bit of space for + // level-2 operations, but that's okay with us. + auto m_p = bli_obj_length( p ); + auto n_p = bli_obj_width( p ); + auto m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def ); + auto n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def ); + + // Save the padded dimensions into the packed object. It is important + // to save these dimensions since they represent the actual dimensions + // of the zero-padded matrix. + bli_obj_set_padded_dims( m_p_pad, n_p_pad, p ); + + // The "panel stride" of a micropanel packed object is interpreted as + // the distance between the (0,0) element of panel k and the (0,0) + // element of panel k+1. We use the padded width computed above to + // allow for zero-padding (if necessary/desired) along the far end + // of each micropanel (ie: the right edge of the matrix). Zero-padding + // can also occur along the long edge of the last micropanel if the m + // dimension of the matrix is not a whole multiple of MR. + auto ps_p = bmult_m_pack * n_p_pad; + + /* Compute the total number of iterations we'll need. */ + auto n_iter = m_p_pad / bmult_m_def; + + // Store the strides and panel dimension in P. + bli_obj_set_strides( 1, bmult_m_pack, p ); + bli_obj_set_imag_stride( 1, p ); + bli_obj_set_panel_dim( bmult_m_def, p ); + bli_obj_set_panel_stride( ps_p, p ); + bli_obj_set_panel_length( bmult_m_def, p ); + bli_obj_set_panel_width( n_p, p ); + + // Compute the size of the packed buffer. + auto size_p = ps_p * n_iter * dt_size; + if ( size_p == 0 ) return; + + // Compute the size of the scatter and block-scatter vectors to the total. + // It is never necessary to add padding for alignment because: + // 1) ps_p is always even + // 2) dt_size is a power of two >= 4 + // 3) the alignment of the scatter vectors is at most 8 + auto scat_size = 2 * (m_p + n_p) * sizeof(inc_t); + + // Update the buffer address in p to point to the buffer associated + // with the mem_t entry acquired from the memory broker (now cached in + // the control tree node). + auto p_cast = (char*)bli_packm_alloc( size_p + scat_size, rntm, cntl, thread ); + bli_obj_set_buffer( p_cast, p ); + + // Get the addresses of the scatter and block-scatter vectors. These are + // placed directly after the packed matrix buffer. + auto rscat = (inc_t*)(p_cast + size_p); + auto rbs = rscat + m_p; + auto cscat = rbs + m_p; + auto cbs = cscat + n_p; + + auto a_cast = (char*)bli_obj_buffer_at_off( a ); + auto panel_dim_off = bli_obj_row_off( a ); + auto panel_len_off = bli_obj_col_off( a ); + auto conja = bli_obj_conj_status( a ); + + auto params = (packm_tensor_params_t*)bli_obj_pack_params( a ); + auto ndim_m = params->ndim_m; + auto ndim_n = params->ndim_n; + auto len_m = params->len_m; + auto len_n = params->len_n; + auto stride_m = params->stride_m; + auto stride_n = params->stride_n; + + obj_t kappa_local; + auto kappa_cast = (char*)bli_packm_scalar( &kappa_local, p ); + + auto packm_nb_ker = packm_ckx_nb_ukrs[ dt ]; + auto packm_ss_ker = packm_ckx_ss_ukrs[ dt ]; + + a_cast -= ( panel_dim_off * stride_m[0] + + panel_len_off * stride_n[0] ) * dt_size; + + /* Fill in the scatter and block-scatter vectors. This is done single-threaded for now. */ + if ( bli_thread_am_ochief( thread ) ) + { + fill_scatter + ( + ndim_m, + len_m, + stride_m, + bmult_m_def, + panel_dim_off, + m_p, + rscat, + rbs + ); + + fill_scatter + ( + ndim_n, + len_n, + stride_n, + BS_K, + panel_len_off, + n_p, + cscat, + cbs + ); + } + + /* Wait for the scatter vectors to be done. */ + bli_thread_barrier( thread ); + + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ + auto nt = bli_thread_n_way( thread ); + auto tid = bli_thread_work_id( thread ); + + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ + dim_t it_start, it_end, it_inc; + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); + + /* Iterate over every logical micropanel in the source matrix. */ + for ( auto it = 0; it < n_iter; it += 1 ) + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) + { + auto panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def ); + + auto p_begin = p_cast + it*ps_p*dt_size; + auto inca = rbs[ it*bmult_m_def ]; + + if ( inca ) + { + auto a_begin = a_cast + rscat[ it*bmult_m_def ]*dt_size; + + packm_nb_ker( conja, + panel_dim_i, + n_p, + bmult_m_def, + n_p_pad, + kappa_cast, + a_begin, inca, cbs, cscat, + p_begin, bmult_m_pack ); + } + else + { + auto a_begin = a_cast; + auto rscat_use = rscat + it*bmult_m_def; + + packm_ss_ker( conja, + panel_dim_i, + n_p, + bmult_m_def, + n_p_pad, + kappa_cast, + a_begin, rscat_use, cscat, + p_begin, bmult_m_pack ); + } + } +} + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +void PASTEMAC(ch,op) \ + ( \ + dim_t m, \ + dim_t n, \ + void* x, inc_t rs_x, inc_t cs_x, \ + void* b, \ + void* y, inc_t* rs_y, inc_t* cs_y \ + ) \ +{ \ + ctype* restrict x_cast = (ctype*)x; \ + ctype b_cast = *(ctype*)b; \ + ctype* restrict y_cast = (ctype*)y; \ +\ + if ( PASTEMAC(ch,eq0)( b_cast ) ) \ + { \ + for ( auto i = 0; i < m; i++ ) \ + for ( auto j = 0; j < n; j++ ) \ + PASTEMAC(ch,copys)( x_cast[ i*rs_x + j*cs_x ], y_cast[ rs_y[i] + cs_y[j] ] ); \ + } \ + else \ + { \ + for ( auto i = 0; i < m; i++ ) \ + for ( auto j = 0; j < n; j++ ) \ + PASTEMAC(ch,xpbys)( x_cast[ i*rs_x + j*cs_x ], b_cast, y_cast[ rs_y[i] + cs_y[j] ] ); \ + } \ +} + +INSERT_GENTFUNC_BASIC0(scatter_mxn); + +static decltype(&bli_sscatter_mxn) GENARRAY(scatter_mxn, scatter_mxn); + +void gemm_tensor + ( + obj_t* a, + obj_t* b, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + auto dt = bli_obj_dt( c ); + auto dt_size = bli_dt_size( dt ); + + auto m = bli_obj_length( c ); + auto n = bli_obj_width( c ); + auto k = bli_obj_width( a ); + + auto a_cast = (char*)bli_obj_buffer_at_off( a ); + auto pd_a = bli_obj_panel_dim( a ); + auto ps_a = bli_obj_panel_stride( a ); + + auto b_cast = (char*)bli_obj_buffer_at_off( b ); + auto pd_b = bli_obj_panel_dim( b ); + auto ps_b = bli_obj_panel_stride( b ); + + auto c_cast = (char*)bli_obj_buffer_at_off( c ); + auto rs_c0 = bli_obj_row_stride( c ); + auto cs_c0 = bli_obj_col_stride( c ); + auto off_m = bli_obj_row_off( c ); + auto off_n = bli_obj_col_off( c ); + + auto params = (gemm_tensor_params_t*)bli_obj_ker_params( c ); + auto ndim_m = params->ndim_m; + auto ndim_n = params->ndim_n; + auto len_m = params->len_m; + auto len_n = params->len_n; + auto stride_m = params->stride_m; + auto stride_n = params->stride_n; + + if ( rs_c0 != stride_m[0] || cs_c0 != stride_n[0] ) + { + std::swap( ndim_m, ndim_n ); + std::swap( len_m, len_n ); + std::swap( stride_m, stride_n ); + } + + /* If any dimension is zero, return immediately. */ + if ( bli_zero_dim3( m, n, k ) ) return; + + c_cast -= ( off_m * stride_m[0] + + off_n * stride_n[0] ) * dt_size; + + // Detach and multiply the scalars attached to A and B. + // NOTE: We know that the internal scalars of A and B are already of the + // target datatypes because the necessary typecasting would have already + // taken place during bli_packm_init(). + obj_t scalar_a; + obj_t scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + // NOTE: We know that scalar_b is of type dt due to the above code + // that casts the scalars of A and B to dt via scalar_a and scalar_b, + // and we know that the internal scalar in C is already of the type dt + // due to the casting in the implementation of bli_obj_scalar_attach(). + auto alpha_cast = (char*)bli_obj_internal_scalar_buffer( &scalar_b ); + auto beta_cast = (char*)bli_obj_internal_scalar_buffer( c ); + + /* Alias some constants to simpler names. */ + auto MR = pd_a; + auto NR = pd_b; + + /* Query the context for the micro-kernel address and cast it to its + function pointer type. */ + auto gemm_ukr = (gemm_ukr_vft)bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); + + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ + char ct[ BLIS_STACK_BUF_MAX_SIZE ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); + auto col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); + auto rs_ct = ( col_pref ? 1 : NR ); + auto cs_ct = ( col_pref ? MR : 1 ); + auto zero = (char*)bli_obj_buffer_for_const( dt, &BLIS_ZERO ); + + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ + + auto scat_size = 2 * (m + n) * sizeof(inc_t); + auto rscat_c = (inc_t*)bli_packm_alloc_ex( scat_size, BLIS_BUFFER_FOR_GEN_USE, rntm, cntl, thread ); + auto rbs_c = rscat_c + m; + auto cscat_c = rbs_c + m; + auto cbs_c = cscat_c + n; + + /* Fill in the scatter and block-scatter vectors. This is done single-threaded for now. */ + if ( bli_thread_am_ochief( thread ) ) + { + fill_scatter + ( + ndim_m, + len_m, + stride_m, + MR, + off_m, + m, + rscat_c, + rbs_c + ); + + fill_scatter + ( + ndim_n, + len_n, + stride_n, + NR, + off_n, + n, + cscat_c, + cbs_c + ); + } + + /* Wait for the scatter vectors to be done. */ + bli_thread_barrier( thread ); + + /* Compute number of primary and leftover components of the m and n + dimensions. */ + auto n_iter = n / NR; + auto n_left = n % NR; + + auto m_iter = m / MR; + auto m_left = m % MR; + + if ( n_left ) ++n_iter; + if ( m_left ) ++m_iter; + + /* Determine some increments used to step through A, B, and C. */ + auto rstep_a = ps_a * dt_size; + auto cstep_b = ps_b * dt_size; + + /* Save the virtual microkernel address and the params. */ + auxinfo_t aux; + bli_auxinfo_set_ukr( (void*)gemm_ukr, &aux ); + bli_auxinfo_set_params( params, &aux ); + + /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + loop around the microkernel. Here we query the thrinfo_t node for the + 1st (ir) loop around the microkernel. */ + auto caucus = bli_thrinfo_sub_node( thread ); + + /* Query the number of threads and thread ids for each loop. */ + auto jr_nt = bli_thread_n_way( thread ); + auto jr_tid = bli_thread_work_id( thread ); + auto ir_nt = bli_thread_n_way( caucus ); + auto ir_tid = bli_thread_work_id( caucus ); + + /* Determine the thread range and increment for the 2nd and 1st loops. + NOTE: The definition of bli_thread_range_jrir() will depend on whether + slab or round-robin partitioning was requested at configure-time. */ + dim_t jr_start, jr_end; + dim_t ir_start, ir_end; + dim_t jr_inc, ir_inc; + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); + bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); + + /* Loop over the n dimension (NR columns at a time). */ + for ( auto j = jr_start; j < jr_end; j += jr_inc ) + { + auto b1 = b_cast + j * cstep_b; + + auto n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + + /* Initialize our next panel of B to be the current panel of B. */ + auto b2 = b1; + + /* Loop over the m dimension (MR rows at a time). */ + for ( auto i = ir_start; i < ir_end; i += ir_inc ) + { + auto a1 = a_cast + i * rstep_a; + auto rscat_c1 = rscat_c + i * MR; + auto rbs_c1 = rbs_c + i * MR; + auto cscat_c1 = cscat_c + j * NR; + auto cbs_c1 = cbs_c + j * NR; + + auto m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + + /* Compute the addresses of the next panels of A and B. */ + auto a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) + { + a2 = a_cast; + b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) + b2 = b_cast; + } + + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ + bli_auxinfo_set_next_a( a2, &aux ); + bli_auxinfo_set_next_b( b2, &aux ); + + auto rs_c = *rbs_c1; + auto cs_c = *cbs_c1; + + if ( rs_c && cs_c ) + { + auto c11 = c_cast + ( *rscat_c1 + *cscat_c1 ) * dt_size; + + /* Invoke the gemm micro-kernel. */ + gemm_ukr + ( + m_cur, + n_cur, + k, + alpha_cast, + a1, + b1, + beta_cast, + c11, rs_c, cs_c, + &aux, + cntx + ); + } + else + { + /* Invoke the gemm micro-kernel. */ + gemm_ukr + ( + MR, + NR, + k, + alpha_cast, + a1, + b1, + zero, + &ct, rs_ct, cs_ct, + &aux, + cntx + ); + + /* Scatter to C. */ + scatter_mxn[ dt ] + ( + m_cur, n_cur, + &ct, rs_ct, cs_ct, + beta_cast, + c_cast, rscat_c1, cscat_c1 + ); + } + } + } +} + +static bool has_unit_stride( const std::vector& stride ) +{ + for ( auto s : stride ) + if ( s == 1 ) + return true; + return false; +} + +void tcontract( num_t dt, const std::vector& m, const std::vector& n, const std::vector& k, + const void* alpha, const void* a, std::vector rs_a, std::vector cs_a, + const void* b, std::vector rs_b, std::vector cs_b, + const void* beta, void* c, std::vector rs_c, std::vector cs_c ) +{ + if ( rs_a.size() != m.size() || + rs_b.size() != k.size() || + rs_c.size() != m.size() ) + bli_check_error_code( BLIS_INVALID_ROW_STRIDE ); + + if ( cs_a.size() != k.size() || + cs_b.size() != n.size() || + cs_c.size() != n.size() ) + bli_check_error_code( BLIS_INVALID_COL_STRIDE ); + + dim_t m_mat = 1; + dim_t n_mat = 1; + dim_t k_mat = 1; + for ( auto& i : m ) m_mat *= i; + for ( auto& i : n ) n_mat *= i; + for ( auto& i : k ) k_mat *= i; + + auto& stride_m = has_unit_stride( rs_c ) ? rs_c : rs_a; + for ( int i = 1;i < m.size(); i++ ) + for ( int j = 0;j < m.size()-i; j++ ) + if ( stride_m[j] > stride_m[j+1] ) + { + std::swap( rs_a[j], rs_a[j+1] ); + std::swap( rs_c[j], rs_c[j+1] ); + } + + auto& stride_n = has_unit_stride( cs_c ) ? cs_c : cs_b; + for ( int i = 1;i < n.size(); i++ ) + for ( int j = 0;j < n.size()-i; j++ ) + if ( stride_n[j] > stride_n[j+1] ) + { + std::swap( cs_b[j], cs_b[j+1] ); + std::swap( cs_c[j], cs_c[j+1] ); + } + + auto& stride_k = has_unit_stride( cs_a ) ? cs_a : rs_b; + for ( int i = 1;i < k.size(); i++ ) + for ( int j = 0;j < k.size()-i; j++ ) + if ( stride_k[j] > stride_k[j+1] ) + { + std::swap( cs_a[j], cs_a[j+1] ); + std::swap( rs_b[j], rs_b[j+1] ); + } + + if ( rs_a.empty() ) rs_a.push_back( 1 ); + if ( cs_a.empty() ) cs_a.push_back( 1 ); + if ( rs_b.empty() ) rs_b.push_back( 1 ); + if ( cs_b.empty() ) cs_b.push_back( 1 ); + if ( rs_c.empty() ) rs_c.push_back( 1 ); + if ( cs_c.empty() ) cs_c.push_back( 1 ); + + obj_t a_o, b_o, c_o; + bli_obj_create_with_attached_buffer( dt, m_mat, k_mat, const_cast(a), rs_a[0], cs_a[0], &a_o ); + bli_obj_create_with_attached_buffer( dt, k_mat, n_mat, const_cast(b), rs_b[0], cs_b[0], &b_o ); + bli_obj_create_with_attached_buffer( dt, m_mat, n_mat, c , rs_c[0], cs_c[0], &c_o ); + + packm_tensor_params_t params_a( m.size(), m.data(), rs_a.data(), + k.size(), k.data(), cs_a.data() ); + packm_tensor_params_t params_b( n.size(), n.data(), cs_b.data(), + k.size(), k.data(), rs_b.data() ); + gemm_tensor_params_t params_c( m.size(), m.data(), rs_c.data(), + n.size(), n.data(), cs_c.data() ); + + bli_obj_set_pack_fn( packm_tensor, &a_o ); + bli_obj_set_pack_fn( packm_tensor, &b_o ); + bli_obj_set_ker_fn( gemm_tensor, &c_o ); + bli_obj_set_pack_params( ¶ms_a, &a_o ); + bli_obj_set_pack_params( ¶ms_b, &b_o ); + bli_obj_set_ker_params( ¶ms_c, &c_o ); + + obj_t alpha_o, beta_o; + bli_obj_create_1x1_with_attached_buffer( dt, const_cast(alpha), &alpha_o ); + bli_obj_create_1x1_with_attached_buffer( dt, const_cast(beta), &beta_o ); + + rntm_t rntm; + bli_rntm_init_from_global( &rntm ); + bli_rntm_disable_l3_sup( &rntm ); + + bli_gemm_ex( &alpha_o, &a_o, &b_o, &beta_o, &c_o, NULL, &rntm ); +} + +int main() +{ + auto N = 5; + + gint_t ndim_a = 4; + gint_t ndim_b = 4; + gint_t ndim_c = 4; + + std::vector len_a(ndim_a, N); + std::vector len_b(ndim_b, N); + std::vector len_c(ndim_c, N); + + std::vector stride_a(ndim_a, 1); + std::vector stride_b(ndim_b, 1); + std::vector stride_c(ndim_c, 1); + for ( gint_t i = 1; i < ndim_a; i++ ) + stride_a[i] = stride_a[i-1] * len_a[i - 1]; + for ( gint_t i = 1; i < ndim_b; i++ ) + stride_b[i] = stride_b[i-1] * len_b[i - 1]; + for ( gint_t i = 1; i < ndim_c; i++ ) + stride_c[i] = stride_c[i-1] * len_c[i - 1]; + + std::vector dim_a(ndim_a); + std::vector dim_b(ndim_b); + std::vector dim_c(ndim_c); + std::iota(dim_a.begin(), dim_a.end(), 0); + std::iota(dim_b.begin(), dim_b.end(), 0); + std::iota(dim_c.begin(), dim_c.end(), 0); + + for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) + do + do + do + { + auto dt = ( num_t )dt_; + + auto ndim_m = (ndim_a + ndim_c - ndim_b)/2; + auto ndim_k = (ndim_a + ndim_b - ndim_c)/2; + + std::vector m(len_a.begin(), len_a.begin()+ndim_m); + std::vector n(len_b.begin()+ndim_k, len_b.end()); + std::vector k(len_b.begin(), len_b.begin()+ndim_k); + + std::vector rs_a(stride_a.begin(), stride_a.begin()+ndim_m); + std::vector cs_a(stride_a.begin()+ndim_m, stride_a.end()); + std::vector rs_b(stride_b.begin(), stride_b.begin()+ndim_k); + std::vector cs_b(stride_b.begin()+ndim_k, stride_b.end()); + std::vector rs_c(stride_c.begin(), stride_c.begin()+ndim_m); + std::vector cs_c(stride_c.begin()+ndim_m, stride_c.end()); + + dim_t m_tot = 1; + dim_t n_tot = 1; + dim_t k_tot = 1; + for ( auto i : m ) m_tot *= i; + for ( auto i : n ) n_tot *= i; + for ( auto i : k ) k_tot *= i; + + obj_t a, b, c, c_ref, norm; + + bli_obj_create( dt, m_tot*k_tot, 1, 1, 1, &a ); + bli_obj_create( dt, k_tot*n_tot, 1, 1, 1, &b ); + bli_obj_create( dt, m_tot*n_tot, 1, 1, 1, &c ); + bli_obj_create( dt, m_tot*n_tot, 1, 1, 1, &c_ref ); + bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); + + bli_randv( &a ); + bli_randv( &b ); + bli_randv( &c ); + bli_copyv( &c, &c_ref ); + + tcontract( dt, m, n, k, + bli_obj_buffer_for_const( dt, &BLIS_ONE ), + bli_obj_buffer( &a ), rs_a, cs_a, + bli_obj_buffer( &b ), rs_b, cs_b, + bli_obj_buffer_for_const( dt, &BLIS_ZERO ), + bli_obj_buffer( &c ), rs_c, cs_c ); + + tcontract_ref( dt, m, n, k, + bli_obj_buffer_for_const( dt, &BLIS_ONE ), + bli_obj_buffer( &a ), rs_a, cs_a, + bli_obj_buffer( &b ), rs_b, cs_b, + bli_obj_buffer_for_const( dt, &BLIS_ZERO ), + bli_obj_buffer( &c_ref ), rs_c, cs_c ); + + bli_subv( &c_ref, &c ); + bli_normfv( &c, &norm ); + + double normr, normi; + bli_getsc( &norm, &normr, &normi ); + + printf("dt: %d, dim_a: [%d,%d,%d,%d], dim_b: [%d,%d,%d,%d], dim_c: [%d,%d,%d,%d], norm: %g\n", + dt, dim_a[0], dim_a[1], dim_a[2], dim_a[3], + dim_b[0], dim_b[1], dim_b[2], dim_b[3], + dim_c[0], dim_c[1], dim_c[2], dim_c[3], + normr / std::sqrt( bli_obj_vector_dim( &c ) ) ); + + bli_obj_free( &a ); + bli_obj_free( &b ); + bli_obj_free( &c ); + bli_obj_free( &c_ref ); + } + while (std::next_permutation(dim_a.begin(), dim_a.end())); + while (std::next_permutation(dim_b.begin(), dim_b.end())); + while (std::next_permutation(dim_c.begin(), dim_c.end())); +} + diff --git a/test/tensor_contraction/tcontract_ref.cxx b/test/tensor_contraction/tcontract_ref.cxx new file mode 100644 index 0000000000..b4cd07f903 --- /dev/null +++ b/test/tensor_contraction/tcontract_ref.cxx @@ -0,0 +1,67 @@ +#include "tcontract_ref.hpp" + +template +void tcontract_ref( const std::vector& m, const std::vector& n, const std::vector& k, + const void* alpha, const void* a, const std::vector& rs_a, const std::vector& cs_a, + const void* b, const std::vector& rs_b, const std::vector& cs_b, + const void* beta, void* c, const std::vector& rs_c, const std::vector& cs_c ) +{ + auto alpha_cast = *( T* )alpha; + auto beta_cast = *( T* )beta; + auto a_cast = ( T* )a; + auto b_cast = ( T* )b; + auto c_cast = ( T* )c; + + for_each(m.size(), m.data(), a_cast, rs_a.data(), c_cast, rs_c.data(), + [&] + { + for_each(n.size(), n.data(), b_cast, cs_b.data(), c_cast, cs_c.data(), + [&] + { + auto ab = convert(0.0); + + for_each(k.size(), k.data(), a_cast, cs_a.data(), b_cast, rs_b.data(), + [&] + { + ab += (*a_cast) * (*b_cast); + }); + + if ( beta_cast == convert(0.0) ) + { + *c_cast = alpha_cast * ab; + } + else + { + *c_cast = alpha_cast * ab + beta_cast * (*c_cast); + } + }); + + assert(b_cast == b); + }); + + assert(a_cast == a); + assert(c_cast == c); +} + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &tcontract_ref; + +INSERT_GENTFUNC_BASIC0(tcontract_ref); + +static decltype(&tcontract_ref) GENARRAY( tcontract_ref_impl, tcontract_ref ); + +void tcontract_ref( num_t dt, const std::vector& m, const std::vector& n, const std::vector& k, + const void* alpha, const void* a, const std::vector& rs_a, const std::vector& cs_a, + const void* b, const std::vector& rs_b, const std::vector& cs_b, + const void* beta, void* c, const std::vector& rs_c, const std::vector& cs_c ) +{ + tcontract_ref_impl[ dt ] + ( + m, n, k, + alpha, a, rs_a, cs_a, + b, rs_b, cs_b, + beta, c, rs_c, cs_c + ); +} + diff --git a/test/tensor_contraction/tcontract_ref.hpp b/test/tensor_contraction/tcontract_ref.hpp new file mode 100644 index 0000000000..99d4380dce --- /dev/null +++ b/test/tensor_contraction/tcontract_ref.hpp @@ -0,0 +1,100 @@ +#include "blis.h" +#include "complex_math.hpp" + +#include +#include +#include + +inline void increment(inc_t, gint_t) {} + +template +void increment(inc_t n, gint_t i, T& off, const inc_t* s, Args&... args) +{ + off += s[i]*n; + increment(n, i, args...); +} + +template +void for_each_impl(gint_t ndim, const dim_t* n, + dim_t off, dim_t len, + Body& body, + Args&... args) +{ + std::array i = {}; + assert( ndim <= i.size() ); + + if ( off ) + { + for ( gint_t k = 0; k < ndim; k++ ) + { + i[k] = off % n[k]; + off /= n[k]; + increment(i[k], k, args...); + } + } + + for ( dim_t pos = 0; pos < len; pos++ ) + { + body(); + + for ( gint_t k = 0; k < ndim; k++ ) + { + if ( i[k] == n[k]-1 ) + { + increment(-i[k], k, args...); + i[k] = 0; + } + else + { + increment(1, k, args...); + i[k]++; + break; + } + } + } +} + +template +void for_each(gint_t ndim, const dim_t* n, + dim_t off, dim_t len, + T& a, const inc_t* s_a, + Body&& body) +{ + for_each_impl( ndim, n, off, len, body, a, s_a ); +} + +template +void for_each(gint_t ndim, const dim_t* n, + dim_t off, dim_t len, + T& a, const inc_t* s_a, + T& b, const inc_t* s_b, + Body&& body) +{ + for_each_impl( ndim, n, off, len, body, a, s_a, b, s_b ); +} + +template +void for_each(gint_t ndim, const dim_t* n, + T& a, const inc_t* s_a, + Body&& body) +{ + dim_t len = 1; + for ( gint_t i = 0;i < ndim;i++ ) len *= n[i]; + for_each_impl( ndim, n, 0, len, body, a, s_a ); +} + +template +void for_each(gint_t ndim, const dim_t* n, + T& a, const inc_t* s_a, + T& b, const inc_t* s_b, + Body&& body) +{ + dim_t len = 1; + for ( gint_t i = 0;i < ndim;i++ ) len *= n[i]; + for_each_impl( ndim, n, 0, len, body, a, s_a, b, s_b ); +} + +void tcontract_ref( num_t dt, const std::vector& m, const std::vector& n, const std::vector& k, + const void* alpha, const void* a, const std::vector& rs_a, const std::vector& cs_a, + const void* b, const std::vector& rs_b, const std::vector& cs_b, + const void* beta, void* c, const std::vector& rs_c, const std::vector& cs_c );