From a3a603d64e052ee4fc03b26090d577815044fade Mon Sep 17 00:00:00 2001 From: Devin Matthews Date: Tue, 7 Dec 2021 13:47:57 -0600 Subject: [PATCH 01/12] Import changes from `obj_t_makeover` on top of `master`. --- .../kernels/3/bli_gemm_template_noopt_mxn.c | 13 +- .../3/bli_gemmtrsm_l_template_noopt_mxn.c | 4 + .../3/bli_gemmtrsm_u_template_noopt_mxn.c | 8 +- frame/3/bli_l3_cntl.c | 4 +- frame/3/bli_l3_ft_ukr.h | 2 + frame/3/bli_l3_ukr_oapi.c | 4 + frame/3/bli_l3_ukr_prot.h | 2 + frame/3/bli_l3_ukr_tapi.c | 4 + frame/3/gemm/bli_gemm_cntl.c | 9 +- frame/3/gemm/bli_gemm_cntl.h | 6 +- frame/3/gemm/bli_gemm_front.c | 87 - frame/3/gemm/bli_gemm_ker_var2.c | 517 ++- frame/3/gemm/bli_gemm_ker_var2_md.c | 406 --- frame/3/gemm/bli_gemm_md_c2r_ref.c | 35 +- frame/3/gemm/bli_gemm_var.h | 39 +- frame/3/gemmt/bli_gemmt_l_ker_var2.c | 107 +- frame/3/gemmt/bli_gemmt_u_ker_var2.c | 107 +- frame/3/gemmt/other/bli_gemmt_l_ker_var2.c | 409 --- frame/3/gemmt/other/bli_gemmt_u_ker_var2.c | 409 --- frame/3/trmm/bli_trmm_ll_ker_var2.c | 121 +- frame/3/trmm/bli_trmm_lu_ker_var2.c | 121 +- frame/3/trmm/bli_trmm_rl_ker_var2.c | 121 +- frame/3/trmm/bli_trmm_ru_ker_var2.c | 121 +- frame/3/trsm/bli_trsm_cntl.c | 17 +- frame/3/trsm/bli_trsm_cntl.h | 9 +- frame/3/trsm/bli_trsm_ll_ker_var2.c | 52 +- frame/3/trsm/bli_trsm_lu_ker_var2.c | 52 +- frame/3/trsm/bli_trsm_rl_ker_var2.c | 54 +- frame/3/trsm/bli_trsm_ru_ker_var2.c | 54 +- frame/base/bli_auxinfo.h | 20 +- frame/include/bli_misc_macro_defs.h | 57 + frame/include/bli_type_defs.h | 7 + .../3/bli_gemm_armsve_asm_c2vx10_unindexed.c | 7 + .../3/bli_gemm_armsve_asm_d2vx10_unindexed.c | 15 +- .../3/bli_gemm_armsve_asm_s2vx10_unindexed.c | 13 +- .../3/bli_gemm_armsve_asm_z2vx10_unindexed.c | 7 + .../3/bli_gemm_armsve_asm_z2vx7_unindexed.c | 7 + .../3/bli_gemm_armsve_asm_z2vx8_unindexed.c | 7 + kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c | 56 +- kernels/armv7a/3/bli_gemm_armv7a_asm_wrap.c | 113 + kernels/armv7a/3/bli_gemm_armv7a_int_d4x4.c | 115 +- kernels/armv7a/bli_kernels_armv7a.h | 29 +- kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c | 2672 +++++++-------- kernels/bgq/3/bli_gemm_bgq_int_8x8.c | 12 + .../3/bli_gemm_bulldozer_asm_d4x6_fma4.c | 1945 ++++------- kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c | 1866 ++++------- kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c | 1826 ++++------ kernels/knc/3/bli_dgemm_knc_asm_30x8.c | 127 +- kernels/knc/3/bli_sgemm_knc_asm_30x16.c | 129 +- kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c | 1252 +++---- .../3/bli_gemm_piledriver_asm_d8x3.c | 1933 +++-------- kernels/power10/3/bli_dgemm_power10_mma.c | 43 +- kernels/power10/3/bli_i16gemm_power10_mma.c | 10 +- kernels/power10/3/bli_i16sgemm_power10_mma.c | 10 +- kernels/power10/3/bli_i4gemm_power10_mma.c | 16 +- kernels/power10/3/bli_i8gemm_power10_mma.c | 14 +- kernels/power10/3/bli_sbgemm_power10_mma.c | 18 +- kernels/power10/3/bli_sgemm_power10_mma.c | 24 +- kernels/power10/3/bli_shgemm_power10_mma.c | 18 +- kernels/power7/3/bli_gemm_power7_int_8x4.c | 368 +- .../power7/3/test/bli_gemm_power7_int_8x4.h | 8 + kernels/power9/3/bli_gemm_power9_asm_d12x6.c | 66 +- .../3/bli_gemm_sandybridge_asm_d8x4.c | 2946 +++++------------ .../3/bli_gemm_sandybridge_int_d8x4.c | 324 +- ref_kernels/3/bb/bli_gemmbb_ref.c | 5 +- ref_kernels/3/bb/bli_gemmtrsmbb_ref.c | 2 + ref_kernels/3/bli_gemm_ref.c | 23 +- ref_kernels/3/bli_gemmtrsm_ref.c | 4 + ref_kernels/ind/bli_gemm1m_ref.c | 26 +- ref_kernels/ind/bli_gemmtrsm1m_ref.c | 4 +- test/obj_t_makeover/complex_math.hpp | 267 ++ test/obj_t_makeover/syrk_diagonal_example.c | 186 ++ test/obj_t_makeover/syrk_diagonal_example.cxx | 220 ++ test/obj_t_makeover/syrk_diagonal_example.tgz | Bin 0 -> 6904 bytes test/obj_t_makeover/syrk_diagonal_example2.c | 351 ++ .../obj_t_makeover/syrk_diagonal_example2.cxx | 338 ++ test/obj_t_makeover/syrk_diagonal_ref.cxx | 102 + test/obj_t_makeover/syrk_diagonal_ref.h | 8 + test/tensor_contraction/complex_math.hpp | 267 ++ test/tensor_contraction/tcontract_example.cxx | 987 ++++++ test/tensor_contraction/tcontract_ref.cxx | 67 + test/tensor_contraction/tcontract_ref.hpp | 100 + 82 files changed, 9198 insertions(+), 12733 deletions(-) delete mode 100644 frame/3/gemm/bli_gemm_ker_var2_md.c delete mode 100644 frame/3/gemmt/other/bli_gemmt_l_ker_var2.c delete mode 100644 frame/3/gemmt/other/bli_gemmt_u_ker_var2.c create mode 100644 kernels/armv7a/3/bli_gemm_armv7a_asm_wrap.c create mode 100644 test/obj_t_makeover/complex_math.hpp create mode 100644 test/obj_t_makeover/syrk_diagonal_example.c create mode 100644 test/obj_t_makeover/syrk_diagonal_example.cxx create mode 100644 test/obj_t_makeover/syrk_diagonal_example.tgz create mode 100644 test/obj_t_makeover/syrk_diagonal_example2.c create mode 100644 test/obj_t_makeover/syrk_diagonal_example2.cxx create mode 100644 test/obj_t_makeover/syrk_diagonal_ref.cxx create mode 100644 test/obj_t_makeover/syrk_diagonal_ref.h create mode 100644 test/tensor_contraction/complex_math.hpp create mode 100644 test/tensor_contraction/tcontract_example.cxx create mode 100644 test/tensor_contraction/tcontract_ref.cxx create mode 100644 test/tensor_contraction/tcontract_ref.hpp diff --git a/config/template/kernels/3/bli_gemm_template_noopt_mxn.c b/config/template/kernels/3/bli_gemm_template_noopt_mxn.c index b7a13f3b69..06f25a0e9e 100644 --- a/config/template/kernels/3/bli_gemm_template_noopt_mxn.c +++ b/config/template/kernels/3/bli_gemm_template_noopt_mxn.c @@ -37,6 +37,8 @@ void bli_zgemm_template_noopt ( + dim_t m, + dim_t n, dim_t k, dcomplex* restrict alpha, dcomplex* restrict a1, @@ -88,8 +90,7 @@ void bli_zgemm_template_noopt dim_t l, j, i; - dcomplex ab[ bli_zmr * - bli_znr ]; + dcomplex ab[ mr * nr ]; dcomplex* abij; dcomplex ai, bj; @@ -137,16 +138,16 @@ void bli_zgemm_template_noopt if ( bli_zeq0( *beta ) ) { /* c11 := ab */ - bli_zcopys_mxn( mr, - nr, + bli_zcopys_mxn( m, + n, ab, rs_ab, cs_ab, c11, rs_c, cs_c ); } else { /* c11 := beta * c11 + ab */ - bli_zxpbys_mxn( mr, - nr, + bli_zxpbys_mxn( m, + n, ab, rs_ab, cs_ab, beta, c11, rs_c, cs_c ); diff --git a/config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c b/config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c index da0cd3110f..1582566ae6 100644 --- a/config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c +++ b/config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c @@ -74,6 +74,8 @@ void bli_zgemmtrsm_l_template_noopt */ const num_t dt = BLIS_DCOMPLEX; + const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); + const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); const inc_t rs_b = packnr; @@ -84,6 +86,8 @@ void bli_zgemmtrsm_l_template_noopt /* b11 = alpha * b11 - a10 * b01; */ bli_zgemm_template_noopt ( + mr, + nr, k, minus_one, a10, diff --git a/config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c b/config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c index 09b3af9cee..18f288f719 100644 --- a/config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c +++ b/config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c @@ -74,6 +74,8 @@ void bli_zgemmtrsm_u_template_noopt */ const num_t dt = BLIS_DCOMPLEX; + const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); + const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); const inc_t rs_b = packnr; @@ -84,10 +86,12 @@ void bli_zgemmtrsm_u_template_noopt /* b11 = alpha * b11 - a12 * b21; */ bli_zgemm_template_noopt ( + mr, + nr, k, minus_one, - a12, - b21, + a10, + b01, alpha, b11, rs_b, cs_b, data diff --git a/frame/3/bli_l3_cntl.c b/frame/3/bli_l3_cntl.c index 3cdecfbc26..c9207c226a 100644 --- a/frame/3/bli_l3_cntl.c +++ b/frame/3/bli_l3_cntl.c @@ -57,7 +57,7 @@ void bli_l3_cntl_create_if family == BLIS_GEMMT || family == BLIS_TRMM ) { - *cntl_use = bli_gemm_cntl_create( rntm, family, schema_a, schema_b ); + *cntl_use = bli_gemm_cntl_create( rntm, family, schema_a, schema_b, bli_obj_ker_fn( c ) ); } else // if ( family == BLIS_TRSM ) { @@ -66,7 +66,7 @@ void bli_l3_cntl_create_if if ( bli_obj_is_triangular( a ) ) side = BLIS_LEFT; else side = BLIS_RIGHT; - *cntl_use = bli_trsm_cntl_create( rntm, side, schema_a, schema_b ); + *cntl_use = bli_trsm_cntl_create( rntm, side, schema_a, schema_b, bli_obj_ker_fn( c ) ); } } else diff --git a/frame/3/bli_l3_ft_ukr.h b/frame/3/bli_l3_ft_ukr.h index 4249dcbd6b..561c8264fe 100644 --- a/frame/3/bli_l3_ft_ukr.h +++ b/frame/3/bli_l3_ft_ukr.h @@ -47,6 +47,8 @@ \ typedef void (*PASTECH3(ch,opname,_ukr,tsuf)) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ diff --git a/frame/3/bli_l3_ukr_oapi.c b/frame/3/bli_l3_ukr_oapi.c index 33262b0bba..805a3f8f0c 100644 --- a/frame/3/bli_l3_ukr_oapi.c +++ b/frame/3/bli_l3_ukr_oapi.c @@ -51,6 +51,8 @@ void PASTEMAC0(opname) \ \ num_t dt = bli_obj_dt( c ); \ \ + dim_t m = bli_obj_length( c ); \ + dim_t n = bli_obj_width( c ); \ dim_t k = bli_obj_width( a ); \ void* buf_a = bli_obj_buffer_at_off( a ); \ void* buf_b = bli_obj_buffer_at_off( b ); \ @@ -75,6 +77,8 @@ void PASTEMAC0(opname) \ \ f \ ( \ + m, \ + n, \ k, \ buf_alpha, \ buf_a, \ diff --git a/frame/3/bli_l3_ukr_prot.h b/frame/3/bli_l3_ukr_prot.h index ca523b1d70..f68973ff59 100644 --- a/frame/3/bli_l3_ukr_prot.h +++ b/frame/3/bli_l3_ukr_prot.h @@ -42,6 +42,8 @@ \ void PASTEMAC(ch,opname) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype_out* restrict alpha, \ ctype_in* restrict a, \ diff --git a/frame/3/bli_l3_ukr_tapi.c b/frame/3/bli_l3_ukr_tapi.c index 67e33175b7..178981976e 100644 --- a/frame/3/bli_l3_ukr_tapi.c +++ b/frame/3/bli_l3_ukr_tapi.c @@ -39,6 +39,8 @@ \ void PASTEMAC(ch,opname) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -59,6 +61,8 @@ void PASTEMAC(ch,opname) \ \ /* Invoke the typed function for the given datatype. */ \ f( \ + m, \ + n, \ k, \ alpha, \ a, \ diff --git a/frame/3/gemm/bli_gemm_cntl.c b/frame/3/gemm/bli_gemm_cntl.c index 72d78efe16..71dad78f30 100644 --- a/frame/3/gemm/bli_gemm_cntl.c +++ b/frame/3/gemm/bli_gemm_cntl.c @@ -40,10 +40,11 @@ cntl_t* bli_gemm_cntl_create rntm_t* rntm, opid_t family, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ) { - return bli_gemmbp_cntl_create( rntm, family, schema_a, schema_b ); + return bli_gemmbp_cntl_create( rntm, family, schema_a, schema_b, ker ); } // ----------------------------------------------------------------------------- @@ -53,7 +54,8 @@ cntl_t* bli_gemmbp_cntl_create rntm_t* rntm, opid_t family, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ) { void_fp macro_kernel_fp; @@ -64,6 +66,7 @@ cntl_t* bli_gemmbp_cntl_create else if ( family == BLIS_GEMMT ) macro_kernel_fp = bli_gemmt_x_ker_var2; else if ( family == BLIS_TRMM ) macro_kernel_fp = bli_trmm_xx_ker_var2; else /* should never execute */ macro_kernel_fp = NULL; + if ( ker ) macro_kernel_fp = ker; // Create two nodes for the macro-kernel. cntl_t* gemm_cntl_bu_ke = bli_gemm_cntl_create_node diff --git a/frame/3/gemm/bli_gemm_cntl.h b/frame/3/gemm/bli_gemm_cntl.h index bff91b58aa..5fa213ac41 100644 --- a/frame/3/gemm/bli_gemm_cntl.h +++ b/frame/3/gemm/bli_gemm_cntl.h @@ -38,7 +38,8 @@ cntl_t* bli_gemm_cntl_create rntm_t* rntm, opid_t family, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ); // ----------------------------------------------------------------------------- @@ -48,7 +49,8 @@ cntl_t* bli_gemmbp_cntl_create rntm_t* rntm, opid_t family, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ); #if 0 diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index a9ea21dc43..4ff45036fe 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -283,90 +283,3 @@ void bli_gemm_front #endif } -// ----------------------------------------------------------------------------- - -#if 0 - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - const bool a_is_real = bli_obj_is_real( a ); - const bool a_is_comp = bli_obj_is_complex( a ); - const bool b_is_real = bli_obj_is_real( b ); - const bool b_is_comp = bli_obj_is_complex( b ); - const bool c_is_real = bli_obj_is_real( c ); - const bool c_is_comp = bli_obj_is_complex( c ); - - const bool a_is_single = bli_obj_is_single_prec( a ); - const bool a_is_double = bli_obj_is_double_prec( a ); - const bool b_is_single = bli_obj_is_single_prec( b ); - const bool b_is_double = bli_obj_is_double_prec( b ); - const bool c_is_single = bli_obj_is_single_prec( c ); - const bool c_is_double = bli_obj_is_double_prec( c ); - - const bool comp_single = bli_obj_comp_prec( c ) == BLIS_SINGLE_PREC; - const bool comp_double = bli_obj_comp_prec( c ) == BLIS_DOUBLE_PREC; - - const bool mixeddomain = bli_obj_domain( c ) != bli_obj_domain( a ) || - bli_obj_domain( c ) != bli_obj_domain( b ); - - ( void )a_is_real; ( void )a_is_comp; - ( void )b_is_real; ( void )b_is_comp; - ( void )c_is_real; ( void )c_is_comp; - ( void )a_is_single; ( void )a_is_double; - ( void )b_is_single; ( void )b_is_double; - ( void )c_is_single; ( void )c_is_double; - ( void )comp_single; ( void )comp_double; - - if ( - //( c_is_comp && a_is_comp && b_is_real ) || - //( c_is_comp && a_is_real && b_is_comp ) || - //( c_is_real && a_is_comp && b_is_comp ) || - //( c_is_comp && a_is_real && b_is_real ) || - //( c_is_real && a_is_comp && b_is_real ) || - //( c_is_real && a_is_real && b_is_comp ) || - //FALSE - TRUE - ) - { - if ( - ( c_is_single && a_is_single && b_is_single && mixeddomain ) || - ( c_is_single && a_is_single && b_is_single && comp_single ) || - ( c_is_single && a_is_single && b_is_single && comp_double ) || - ( c_is_single && a_is_single && b_is_double ) || - ( c_is_single && a_is_double && b_is_single ) || - ( c_is_double && a_is_single && b_is_single ) || - ( c_is_single && a_is_double && b_is_double ) || - ( c_is_double && a_is_single && b_is_double ) || - ( c_is_double && a_is_double && b_is_single ) || - ( c_is_double && a_is_double && b_is_double && comp_single ) || - ( c_is_double && a_is_double && b_is_double && comp_double ) || - ( c_is_double && a_is_double && b_is_double && mixeddomain ) || - FALSE - ) - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - } - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#else -#if 0 - // If any of the storage datatypes differ, or if the execution precision - // differs from the storage precision of C, utilize the mixed datatype - // code path. - // NOTE: We could check the exec dt against the storage dt of C, but for - // now we don't support the caller setting the execution domain - // explicitly. - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#endif -#endif - diff --git a/frame/3/gemm/bli_gemm_ker_var2.c b/frame/3/gemm/bli_gemm_ker_var2.c index 0c90605524..d49c9823e2 100644 --- a/frame/3/gemm/bli_gemm_ker_var2.c +++ b/frame/3/gemm/bli_gemm_ker_var2.c @@ -35,28 +35,43 @@ #include "blis.h" -#define FUNCPTR_T gemm_fp +typedef void (*xpbys_mxn_vft) + ( + dim_t m, + dim_t n, + void* x, inc_t rs_x, inc_t cs_x, + void* b, + void* y, inc_t rs_y, inc_t cs_y + ); -typedef void (*FUNCPTR_T) - ( - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - rntm_t* rntm, - thrinfo_t* thread - ); +#undef GENTFUNC2 +#define GENTFUNC2(ctypex,ctypey,chx,chy,op) \ +void PASTEMAC2(chx,chy,op) \ + ( \ + dim_t m, \ + dim_t n, \ + void* x, inc_t rs_x, inc_t cs_x, \ + void* b, \ + void* y, inc_t rs_y, inc_t cs_y \ + ) \ +{ \ + ctypex* restrict x_cast = x; \ + ctypey* restrict b_cast = b; \ + ctypey* restrict y_cast = y; \ +\ + PASTEMAC3(chx,chy,chy,xpbys_mxn) \ + ( \ + m, n, \ + x_cast, rs_x, cs_x, \ + b_cast, \ + y_cast, rs_y, cs_y \ + ); \ +} -static FUNCPTR_T GENARRAY(ftypes,gemm_ker_var2); +INSERT_GENTFUNC2_BASIC0(xbpys_mxn_fn); +INSERT_GENTFUNC2_MIXDP0(xbpys_mxn_fn); + +static xpbys_mxn_vft GENARRAY2_ALL(xbpys_mxn, xbpys_mxn_fn); void bli_gemm_ker_var2 @@ -70,23 +85,8 @@ void bli_gemm_ker_var2 thrinfo_t* thread ) { -#ifdef BLIS_ENABLE_GEMM_MD - // By now, A and B have been packed and cast to the execution precision. - // In most cases, such as when storage precision of C differs from the - // execution precision, we utilize the mixed datatype code path. However, - // a few cases still fall within this kernel, such as mixed domain with - // equal precision (ccr, crc, rcc), hence those expressions being disabled - // in the conditional below. - if ( //( bli_obj_domain( c ) != bli_obj_domain( a ) ) || - //( bli_obj_domain( c ) != bli_obj_domain( b ) ) || - ( bli_obj_dt( c ) != bli_obj_exec_dt( c ) ) ) - { - bli_gemm_ker_var2_md( a, b, c, cntx, rntm, cntl, thread ); - return; - } -#endif - num_t dt_exec = bli_obj_exec_dt( c ); + num_t dt_c = bli_obj_dt( c ); pack_t schema_a = bli_obj_pack_schema( a ); pack_t schema_b = bli_obj_pack_schema( b ); @@ -95,50 +95,55 @@ void bli_gemm_ker_var2 dim_t n = bli_obj_width( c ); dim_t k = bli_obj_width( a ); - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); + char* a_cast = bli_obj_buffer_at_off( a ); inc_t is_a = bli_obj_imag_stride( a ); dim_t pd_a = bli_obj_panel_dim( a ); inc_t ps_a = bli_obj_panel_stride( a ); - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); + char* b_cast = bli_obj_buffer_at_off( b ); inc_t is_b = bli_obj_imag_stride( b ); dim_t pd_b = bli_obj_panel_dim( b ); inc_t ps_b = bli_obj_panel_stride( b ); - void* buf_c = bli_obj_buffer_at_off( c ); + char* c_cast = bli_obj_buffer_at_off( c ); inc_t rs_c = bli_obj_row_stride( c ); inc_t cs_c = bli_obj_col_stride( c ); - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; + /* If any dimension is zero, return immediately. */ + if ( bli_zero_dim3( m, n, k ) ) return; // Detach and multiply the scalars attached to A and B. + // NOTE: We know that the internal scalars of A and B are already of the + // target datatypes because the necessary typecasting would have already + // taken place during bli_packm_init(). + obj_t scalar_a; + obj_t scalar_b; bli_obj_scalar_detach( a, &scalar_a ); bli_obj_scalar_detach( b, &scalar_b ); bli_mulsc( &scalar_a, &scalar_b ); // Grab the addresses of the internal scalar buffers for the scalar // merged above and the scalar attached to C. - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); + // NOTE: We know that scalar_b is of type dt_exec due to the above code + // that casts the scalars of A and B to dt_exec via scalar_a and scalar_b, + // and we know that the internal scalar in C is already of the type dt_c + // due to the casting in the implementation of bli_obj_scalar_attach(). + char* alpha_cast = bli_obj_internal_scalar_buffer( &scalar_b ); + char* beta_cast = bli_obj_internal_scalar_buffer( c ); // If 1m is being employed on a column- or row-stored matrix with a // real-valued beta, we can use the real domain macro-kernel, which // eliminates a little overhead associated with the 1m virtual // micro-kernel. + // Only employ this optimization if the storage datatype of C is + // equal to the execution/computation datatype. #if 1 if ( bli_cntx_method( cntx ) == BLIS_1M ) { bli_gemm_ind_recast_1m_params ( &dt_exec, + &dt_c, schema_a, c, &m, &n, &k, @@ -151,118 +156,51 @@ void bli_gemm_ker_var2 #ifdef BLIS_ENABLE_GEMM_MD // Tweak parameters in select mixed domain cases (rcc, crc, ccr). - bli_gemm_md_ker_var2_recast - ( - &dt_exec, - bli_obj_dt( a ), - bli_obj_dt( b ), - bli_obj_dt( c ), - &m, &n, &k, - &pd_a, &ps_a, - &pd_b, &ps_b, - c, - &rs_c, &cs_c - ); + if ( bli_cntx_method( cntx ) == BLIS_NAT ) + { + bli_gemm_md_ker_var2_recast + ( + &dt_exec, + bli_obj_dt( a ), + bli_obj_dt( b ), + &dt_c, + &m, &n, &k, + &pd_a, &ps_a, + &pd_b, &ps_b, + c, + &rs_c, &cs_c + ); + } #endif - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_exec]; - - // Invoke the function. - f( schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - rntm, - thread ); -} + siz_t dt_size = bli_dt_size( dt_exec ); + siz_t dt_c_size = bli_dt_size( dt_c ); + /* Alias some constants to simpler names. */ + const dim_t MR = pd_a; + const dim_t NR = pd_b; + /*const dim_t PACKMR = cs_a;*/ + /*const dim_t PACKNR = rs_b;*/ -#undef GENTFUNC -#define GENTFUNC( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ /* Query the context for the micro-kernel address and cast it to its - function pointer type. Note that the virtual gemm ukernel is queried - instead of the native gemm ukernel. This is needed for certain - situations for the 1m method that require an extra layer of logic - to allow for handling (for example) complex values of beta. Also - note that under certain circumstances, the real-domain version of - this macrokernel will be called for 1m (NOT the complex version) - as an optimization. In these cases, the corresponding real-domain - slots within the cntx_t's virtual gemm ukernel func_t will contain - pointers to the *native* gemm ukernel, thanks to logic in the - context initialization function for the induced method (defined - in bli_cntx_ref.c). */ \ - PASTECH(ch,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ + function pointer type. */ + gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt_exec, BLIS_GEMM_UKR, cntx ); + + gemm_ker_params_t* params = bli_obj_ker_params( c ); + gemm_ukr_vft user_ukr = params ? params->ukr : NULL; + if ( user_ukr ) gemm_ukr = user_ukr; + /* Temporary C buffer for edge cases. Note that the strides of this temporary buffer are set so that they match the storage of the original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype* restrict zero = PASTEMAC(ch,0); \ - ctype* restrict a_cast = a; \ - ctype* restrict b_cast = b; \ - ctype* restrict c_cast = c; \ - ctype* restrict alpha_cast = alpha; \ - ctype* restrict beta_cast = beta; \ - ctype* restrict b1; \ - ctype* restrict c1; \ -\ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t i, j; \ - dim_t m_cur; \ - dim_t n_cur; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ + column-stored as well. */ + char ct[ BLIS_STACK_BUF_MAX_SIZE ] + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt_exec, BLIS_GEMM_UKR, cntx ); + const inc_t rs_ct = ( col_pref ? 1 : NR ); + const inc_t cs_ct = ( col_pref ? MR : 1 ); + char* zero = bli_obj_buffer_for_const( dt_exec, &BLIS_ZERO ); + /* Assumptions/assertions: rs_a == 1 @@ -275,149 +213,146 @@ void PASTEMAC(ch,varname) \ ps_b == stride to next micro-panel of B rs_c == (no assumptions) cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ + */ + /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ + dimensions. */ + dim_t n_iter = n / NR; + dim_t n_left = n % NR; + + dim_t m_iter = m / MR; + dim_t m_left = m % MR; + + if ( n_left ) ++n_iter; + if ( m_left ) ++m_iter; + + /* Determine some increments used to step through A, B, and C. */ + inc_t rstep_a = ps_a * dt_size; + + inc_t cstep_b = ps_b * dt_size; + + inc_t rstep_c = rs_c * MR * dt_c_size; + inc_t cstep_c = cs_c * NR * dt_c_size; + + auxinfo_t aux; + + /* Save the pack schemas of A and B to the auxinfo_t object. */ + bli_auxinfo_set_schema_a( schema_a, &aux ); + bli_auxinfo_set_schema_b( schema_b, &aux ); + + /* Save the imaginary stride of A and B to the auxinfo_t object. */ + bli_auxinfo_set_is_a( is_a, &aux ); + bli_auxinfo_set_is_b( is_b, &aux ); + + /* Save the virtual microkernel address and the params. */ + bli_auxinfo_set_ukr( gemm_ukr, &aux ); + bli_auxinfo_set_params( params, &aux ); + /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) loop around the microkernel. Here we query the thrinfo_t node for the - 1st (ir) loop around the microkernel. */ \ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ -\ - /* Query the number of threads and thread ids for each loop. */ \ - dim_t jr_nt = bli_thread_n_way( thread ); \ - dim_t jr_tid = bli_thread_work_id( thread ); \ - dim_t ir_nt = bli_thread_n_way( caucus ); \ - dim_t ir_tid = bli_thread_work_id( caucus ); \ -\ - dim_t jr_start, jr_end; \ - dim_t ir_start, ir_end; \ - dim_t jr_inc, ir_inc; \ -\ + 1st (ir) loop around the microkernel. */ + thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); + + /* Query the number of threads and thread ids for each loop. */ + dim_t jr_nt = bli_thread_n_way( thread ); + dim_t jr_tid = bli_thread_work_id( thread ); + dim_t ir_nt = bli_thread_n_way( caucus ); + dim_t ir_tid = bli_thread_work_id( caucus ); + + dim_t jr_start, jr_end; + dim_t ir_start, ir_end; + dim_t jr_inc, ir_inc; + /* Determine the thread range and increment for the 2nd and 1st loops. NOTE: The definition of bli_thread_range_jrir() will depend on whether - slab or round-robin partitioning was requested at configure-time. */ \ - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ - bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_start; j < jr_end; j += jr_inc ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_start; i < ir_end; i += ir_inc ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ - b2 = b_cast; \ - } \ -\ + slab or round-robin partitioning was requested at configure-time. */ + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); + bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); + + /* Loop over the n dimension (NR columns at a time). */ + for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) + { + char* b1 = b_cast + j * cstep_b; + char* c1 = c_cast + j * cstep_c; + + dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + + /* Initialize our next panel of B to be the current panel of B. */ + char* b2 = b1; + + /* Loop over the m dimension (MR rows at a time). */ + for ( dim_t i = ir_start; i < ir_end; i += ir_inc ) + { + char* a1 = a_cast + i * rstep_a; + char* c11 = c1 + i * rstep_c; + + dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + + /* Compute the addresses of the next panels of A and B. */ + char* a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) + { + a2 = a_cast; + b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) + b2 = b_cast; + } + /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the bottom edge of C and add the result from above. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ -\ + object. */ + bli_auxinfo_set_next_a( a2, &aux ); + bli_auxinfo_set_next_b( b2, &aux ); + + if ( dt_exec == dt_c ) + { + /* Invoke the gemm micro-kernel. */ + gemm_ukr + ( + m_cur, + n_cur, + k, + alpha_cast, + a1, + b1, + beta_cast, + c11, rs_c, cs_c, + &aux, + cntx + ); + } + else + { + /* Invoke the gemm micro-kernel. */ + gemm_ukr + ( + MR, + NR, + k, + alpha_cast, + a1, + b1, + zero, + &ct, rs_ct, cs_ct, + &aux, + cntx + ); + + /* Accumulate to C with type-casting. */ + xbpys_mxn[ dt_exec ][ dt_c ] + ( + m_cur, n_cur, + &ct, rs_ct, cs_ct, + beta_cast, + c11, rs_c, cs_c + ); + } + } + } + /* -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \ -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); \ -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); \ -*/ \ +PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); +PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); +PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); +*/ } -INSERT_GENTFUNC_BASIC0( gemm_ker_var2 ) - diff --git a/frame/3/gemm/bli_gemm_ker_var2_md.c b/frame/3/gemm/bli_gemm_ker_var2_md.c deleted file mode 100644 index 09c279d149..0000000000 --- a/frame/3/gemm/bli_gemm_ker_var2_md.c +++ /dev/null @@ -1,406 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#ifdef BLIS_ENABLE_GEMM_MD - -#define FUNCPTR_T gemm_fp - -typedef void (*FUNCPTR_T) - ( - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - rntm_t* rntm, - thrinfo_t* thread - ); - -static FUNCPTR_T GENARRAY2_ALL(ftypes,gemm_ker_var2_md); - - -void bli_gemm_ker_var2_md - ( - obj_t* a, - obj_t* b, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - num_t dt_exec = bli_obj_exec_dt( c ); - num_t dt_c = bli_obj_dt( c ); - - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - - dim_t m = bli_obj_length( c ); - dim_t n = bli_obj_width( c ); - dim_t k = bli_obj_width( a ); - - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); - inc_t is_a = bli_obj_imag_stride( a ); - dim_t pd_a = bli_obj_panel_dim( a ); - inc_t ps_a = bli_obj_panel_stride( a ); - - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); - inc_t is_b = bli_obj_imag_stride( b ); - dim_t pd_b = bli_obj_panel_dim( b ); - inc_t ps_b = bli_obj_panel_stride( b ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; - - // Detach and multiply the scalars attached to A and B. - // NOTE: We know that the internal scalars of A and B are already of the - // target datatypes because the necessary typecasting would have already - // taken place during bli_packm_init(). - bli_obj_scalar_detach( a, &scalar_a ); - bli_obj_scalar_detach( b, &scalar_b ); - bli_mulsc( &scalar_a, &scalar_b ); - - // Grab the addresses of the internal scalar buffers for the scalar - // merged above and the scalar attached to C. - // NOTE: We know that scalar_b is of type dt_exec due to the above code - // that casts the scalars of A and B to dt_exec via scalar_a and scalar_b, - // and we know that the internal scalar in C is already of the type dt_c - // due to the casting in the implementation of bli_obj_scalar_attach(). - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); - -#if 0 - // NOTE: Turns out that this optimization will never be employed since - // currently bli_gemm_ker_var2_md() is only called when the storage - // datatype of C differs from the execution/computation datatype, and - // this optimization would only make sense if they are equal. - - // If 1m is being employed on a column- or row-stored matrix with a - // real-valued beta, we can use the real domain macro-kernel, which - // eliminates a little overhead associated with the 1m virtual - // micro-kernel. - if ( bli_cntx_method( cntx ) == BLIS_1M ) - { - // Only employ this optimization if the storage datatype of C is - // equal to the execution/computation datatype. - if ( dt_c == dt_exec ) - { - bli_gemm_ind_recast_1m_params - ( - &dt_exec, - schema_a, - c, - &m, &n, &k, - &pd_a, &ps_a, - &pd_b, &ps_b, - &rs_c, &cs_c - ); - } - } -#endif - - // Tweak parameters in select mixed domain cases (rcc, crc, ccr). - bli_gemm_md_ker_var2_recast - ( - &dt_exec, - bli_obj_dt( a ), - bli_obj_dt( b ), - bli_obj_dt( c ), - &m, &n, &k, - &pd_a, &ps_a, - &pd_b, &ps_b, - c, - &rs_c, &cs_c - ); - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_c][dt_exec]; - - // Invoke the function. - f( schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - rntm, - thread ); -} - - -#undef GENTFUNC2 -#define GENTFUNC2( ctype_c, ctype_e, chc, che, varname ) \ -\ -void PASTEMAC2(chc,che,varname) \ - ( \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dte = PASTEMAC(che,type); \ - /*const num_t dtc = PASTEMAC(chc,type);*/ \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ - /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ \ - PASTECH(che,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dte, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype_e ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype_e ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dte, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype_e* restrict zero = PASTEMAC(che,0); \ - ctype_e* restrict a_cast = a; \ - ctype_e* restrict b_cast = b; \ - ctype_c* restrict c_cast = c; \ - ctype_e* restrict alpha_cast = alpha; \ - ctype_c* restrict beta_cast = beta; \ - ctype_e* restrict b1; \ - ctype_c* restrict c1; \ -\ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t i, j; \ - dim_t m_cur; \ - dim_t n_cur; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(che,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ - /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) - loop around the microkernel. Here we query the thrinfo_t node for the - 1st (ir) loop around the microkernel. */ \ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ -\ - /* Query the number of threads and thread ids for each loop. */ \ - dim_t jr_nt = bli_thread_n_way( thread ); \ - dim_t jr_tid = bli_thread_work_id( thread ); \ - dim_t ir_nt = bli_thread_n_way( caucus ); \ - dim_t ir_tid = bli_thread_work_id( caucus ); \ -\ - dim_t jr_start, jr_end; \ - dim_t ir_start, ir_end; \ - dim_t jr_inc, ir_inc; \ -\ - /* Determine the thread range and increment for the 2nd and 1st loops. - NOTE: The definition of bli_thread_range_jrir() will depend on whether - slab or round-robin partitioning was requested at configure-time. */ \ - bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ - bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_start; j < jr_end; j += jr_inc ) \ - { \ - ctype_e* restrict a1; \ - ctype_c* restrict c11; \ - ctype_e* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_start; i < ir_end; i += ir_inc ) \ - { \ - ctype_e* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); \ - if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); \ - if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* Always save the micropanel product to the local microtile and - then accumulate it into C via the xpbys_mxn macro. */ \ - /*if ( 1 )*/ \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the microtile of C and add the result from above. */ \ - PASTEMAC3(che,chc,chc,xpbys_mxn) \ - ( \ - m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c \ - ); \ - } \ - } \ - } \ -\ -/* -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \ -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); \ -PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: c after", m_cur, n_cur, c11, rs_c, cs_c, "%4.1f", "" ); \ -*/ \ -} - -INSERT_GENTFUNC2_BASIC0( gemm_ker_var2_md ) -INSERT_GENTFUNC2_MIXDP0( gemm_ker_var2_md ) - -#endif diff --git a/frame/3/gemm/bli_gemm_md_c2r_ref.c b/frame/3/gemm/bli_gemm_md_c2r_ref.c index 0bfb596302..13d66ae9a2 100644 --- a/frame/3/gemm/bli_gemm_md_c2r_ref.c +++ b/frame/3/gemm/bli_gemm_md_c2r_ref.c @@ -41,6 +41,8 @@ \ void PASTEMAC2(ch,opname,suf) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -61,6 +63,9 @@ void PASTEMAC2(ch,opname,suf) \ \ const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ +\ + dim_t mr_r = mr; \ + dim_t nr_r = nr; \ \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype_r ) ] \ @@ -81,6 +86,9 @@ void PASTEMAC2(ch,opname,suf) \ \ ctype_r* restrict beta_r = &PASTEMAC(ch,real)( *beta ); \ ctype_r* restrict beta_i = &PASTEMAC(ch,imag)( *beta ); \ +\ + dim_t m_use; \ + dim_t n_use; \ \ ctype_r* c_use; \ inc_t rs_c_use; \ @@ -150,13 +158,14 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ be in units of real elements. Note that we don't need to check for general storage here because that case corresponds to the scenario where we are using the ct buffer and its rs_ct/cs_ct strides. */ \ - if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) cs_c_use *= 2; \ - else rs_c_use *= 2; \ -\ + if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) { cs_c_use *= 2; mr_r *= 2; } \ + else { rs_c_use *= 2; nr_r *= 2; }\ \ /* c = beta * c + alpha_r * a * b; */ \ rgemm_ukr \ ( \ + mr_r, \ + nr_r, \ k, \ alpha_r, \ a_r, \ @@ -172,8 +181,8 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ /* Accumulate the final result in ct back to c. */ \ if ( PASTEMAC(ch,eq1)( *beta ) ) \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( j = 0; j < n; ++j ) \ + for ( i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,adds)( *(ct + i*rs_ct + j*cs_ct), \ *(c + i*rs_c + j*cs_c ) ); \ @@ -181,8 +190,8 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ } \ else if ( PASTEMAC(ch,eq0)( *beta ) ) \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( j = 0; j < n; ++j ) \ + for ( i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,copys)( *(ct + i*rs_ct + j*cs_ct), \ *(c + i*rs_c + j*cs_c ) ); \ @@ -190,8 +199,8 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ } \ else \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( j = 0; j < n; ++j ) \ + for ( i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,xpbys)( *(ct + i*rs_ct + j*cs_ct), \ *beta, \ @@ -207,17 +216,21 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ c_use = ( ctype_r* )c; \ rs_c_use = rs_c; \ cs_c_use = cs_c; \ + m_use = m; \ + n_use = n; \ \ /* Convert the strides from being in units of complex elements to be in units of real elements. Note that we don't need to check for general storage here because that case corresponds to the scenario where we are using the ct buffer and its rs_ct/cs_ct strides. */ \ - if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) cs_c_use *= 2; \ - else rs_c_use *= 2; \ + if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) { cs_c_use *= 2; m_use *= 2; } \ + else { rs_c_use *= 2; n_use *= 2; } \ \ /* c = beta * c + alpha_r * a * b; */ \ rgemm_ukr \ ( \ + m_use, \ + n_use, \ k, \ alpha_r, \ a_r, \ diff --git a/frame/3/gemm/bli_gemm_var.h b/frame/3/gemm/bli_gemm_var.h index e7befc5b46..9fb8510101 100644 --- a/frame/3/gemm/bli_gemm_var.h +++ b/frame/3/gemm/bli_gemm_var.h @@ -34,6 +34,16 @@ */ +// +// GEMM kernel parameter struct. +// + +typedef struct +{ + gemm_ukr_vft ukr; +} gemm_ker_params_t; + + // // Prototype object-based interfaces. // @@ -59,32 +69,3 @@ GENPROT( gemm_blk_var3 ) GENPROT( gemm_ker_var1 ) GENPROT( gemm_ker_var2 ) - -// -// Prototype BLAS-like interfaces with void pointer operands. -// - -#undef GENTPROT -#define GENTPROT( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ); - -INSERT_GENTPROT_BASIC0( gemm_ker_var2 ) - diff --git a/frame/3/gemmt/bli_gemmt_l_ker_var2.c b/frame/3/gemmt/bli_gemmt_l_ker_var2.c index a995e6c521..6db72fd55e 100644 --- a/frame/3/gemmt/bli_gemmt_l_ker_var2.c +++ b/frame/3/gemmt/bli_gemmt_l_ker_var2.c @@ -279,6 +279,9 @@ void PASTEMAC(ch,varname) \ /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ bli_auxinfo_set_is_a( is_a, &aux ); \ bli_auxinfo_set_is_b( is_b, &aux ); \ +\ + /* Save the desired output datatype (indicating no typecasting). */ \ + /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) loop around the microkernel. Here we query the thrinfo_t node for the @@ -381,43 +384,20 @@ void PASTEMAC(ch,varname) \ And if we're strictly above the diagonal, we do nothing and continue. */ \ { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ @@ -490,6 +470,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ + MR, \ + NR, \ k, \ alpha_cast, \ a1, \ @@ -509,43 +491,20 @@ void PASTEMAC(ch,varname) \ } \ else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ diff --git a/frame/3/gemmt/bli_gemmt_u_ker_var2.c b/frame/3/gemmt/bli_gemmt_u_ker_var2.c index 3115fc67b5..0518cc5416 100644 --- a/frame/3/gemmt/bli_gemmt_u_ker_var2.c +++ b/frame/3/gemmt/bli_gemmt_u_ker_var2.c @@ -281,6 +281,9 @@ void PASTEMAC(ch,varname) \ /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ bli_auxinfo_set_is_a( is_a, &aux ); \ bli_auxinfo_set_is_b( is_b, &aux ); \ +\ + /* Save the desired output datatype (indicating no typecasting). */ \ + /*bli_auxinfo_set_dt_on_output( dt, &aux );*/ \ \ /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) loop around the microkernel. Here we query the thrinfo_t node for the @@ -385,6 +388,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ + MR, \ + NR, \ k, \ alpha_cast, \ a1, \ @@ -404,43 +409,20 @@ void PASTEMAC(ch,varname) \ } \ else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ @@ -512,43 +494,20 @@ void PASTEMAC(ch,varname) \ And if we're strictly below the diagonal, we do nothing and continue. */ \ { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ diff --git a/frame/3/gemmt/other/bli_gemmt_l_ker_var2.c b/frame/3/gemmt/other/bli_gemmt_l_ker_var2.c deleted file mode 100644 index 0bf4b1a0fb..0000000000 --- a/frame/3/gemmt/other/bli_gemmt_l_ker_var2.c +++ /dev/null @@ -1,409 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#define FUNCPTR_T gemmt_fp - -typedef void (*FUNCPTR_T) - ( - doff_t diagoffc, - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - rntm_t* rntm, - thrinfo_t* thread - ); - -static FUNCPTR_T GENARRAY(ftypes,gemmt_l_ker_var2); - - -void bli_gemmt_l_ker_var2 - ( - obj_t* a, - obj_t* b, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - num_t dt_exec = bli_obj_exec_dt( c ); - - doff_t diagoffc = bli_obj_diag_offset( c ); - - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - - dim_t m = bli_obj_length( c ); - dim_t n = bli_obj_width( c ); - dim_t k = bli_obj_width( a ); - - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); - inc_t is_a = bli_obj_imag_stride( a ); - dim_t pd_a = bli_obj_panel_dim( a ); - inc_t ps_a = bli_obj_panel_stride( a ); - - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); - inc_t is_b = bli_obj_imag_stride( b ); - dim_t pd_b = bli_obj_panel_dim( b ); - inc_t ps_b = bli_obj_panel_stride( b ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; - - // Detach and multiply the scalars attached to A and B. - bli_obj_scalar_detach( a, &scalar_a ); - bli_obj_scalar_detach( b, &scalar_b ); - bli_mulsc( &scalar_a, &scalar_b ); - - // Grab the addresses of the internal scalar buffers for the scalar - // merged above and the scalar attached to C. - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_exec]; - - // Invoke the function. - f( diagoffc, - schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - rntm, - thread ); -} - - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - doff_t diagoffc, \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ - /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ \ - PASTECH(ch,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype* restrict zero = PASTEMAC(ch,0); \ - ctype* restrict a_cast = a; \ - ctype* restrict b_cast = b; \ - ctype* restrict c_cast = c; \ - ctype* restrict alpha_cast = alpha; \ - ctype* restrict beta_cast = beta; \ - ctype* restrict b1; \ - ctype* restrict c1; \ -\ - doff_t diagoffc_ij; \ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t m_cur; \ - dim_t n_cur; \ - dim_t i, j, ip; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Safeguard: If the current panel of C is entirely above the diagonal, - it is not stored. So we do nothing. */ \ - if ( bli_is_strictly_above_diag_n( diagoffc, m, n ) ) return; \ -\ - /* If there is a zero region above where the diagonal of C intersects - the left edge of the panel, adjust the pointer to C and A and treat - this case as if the diagonal offset were zero. */ \ - if ( diagoffc < 0 ) \ - { \ - ip = -diagoffc / MR; \ - i = ip * MR; \ - m = m - i; \ - diagoffc = -diagoffc % MR; \ - c_cast = c_cast + (i )*rs_c; \ - a_cast = a_cast + (ip )*ps_a; \ - } \ -\ - /* If there is a zero region to the right of where the diagonal - of C intersects the bottom of the panel, shrink it to prevent - "no-op" iterations from executing. */ \ - if ( diagoffc + m < n ) \ - { \ - n = diagoffc + m; \ - } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ - /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - b1 = b_cast; \ - c1 = c_cast; \ -\ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ - dim_t jr_num_threads = bli_thread_n_way( thread ); \ - dim_t jr_thread_id = bli_thread_work_id( thread ); \ - dim_t ir_num_threads = bli_thread_n_way( caucus ); \ - dim_t ir_thread_id = bli_thread_work_id( caucus ); \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_thread_id; j < n_iter; j += jr_num_threads ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Interior loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_thread_id; i < m_iter; i += ir_num_threads ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - /* Compute the diagonal offset for the submatrix at (i,j). */ \ - diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_gemmt_get_next_a_upanel( caucus, a1, rstep_a ); \ - if ( bli_is_last_iter( i, m_iter, ir_thread_id, ir_num_threads ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_gemmt_get_next_b_upanel( thread, b1, cstep_b ); \ - if ( bli_is_last_iter( j, n_iter, jr_thread_id, jr_num_threads ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* If the diagonal intersects the current MR x NR submatrix, we - compute it the temporary buffer and then add in the elements - on or below the diagonal. - Otherwise, if the submatrix is strictly below the diagonal, - we compute and store as we normally would. - And if we're strictly above the diagonal, we do nothing and - continue. */ \ - if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale C and add the result to only the stored part. */ \ - PASTEMAC(ch,xpbys_mxn_l)( diagoffc_ij, \ - m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ - { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ - } \ -} - -INSERT_GENTFUNC_BASIC0( gemmt_l_ker_var2 ) - diff --git a/frame/3/gemmt/other/bli_gemmt_u_ker_var2.c b/frame/3/gemmt/other/bli_gemmt_u_ker_var2.c deleted file mode 100644 index 1655bea555..0000000000 --- a/frame/3/gemmt/other/bli_gemmt_u_ker_var2.c +++ /dev/null @@ -1,409 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" - -#define FUNCPTR_T gemmt_fp - -typedef void (*FUNCPTR_T) - ( - doff_t diagoffc, - pack_t schema_a, - pack_t schema_b, - dim_t m, - dim_t n, - dim_t k, - void* alpha, - void* a, inc_t cs_a, inc_t is_a, - dim_t pd_a, inc_t ps_a, - void* b, inc_t rs_b, inc_t is_b, - dim_t pd_b, inc_t ps_b, - void* beta, - void* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx, - rntm_t* rntm, - thrinfo_t* thread - ); - -static FUNCPTR_T GENARRAY(ftypes,gemmt_u_ker_var2); - - -void bli_gemmt_u_ker_var2 - ( - obj_t* a, - obj_t* b, - obj_t* c, - cntx_t* cntx, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) -{ - num_t dt_exec = bli_obj_exec_dt( c ); - - doff_t diagoffc = bli_obj_diag_offset( c ); - - pack_t schema_a = bli_obj_pack_schema( a ); - pack_t schema_b = bli_obj_pack_schema( b ); - - dim_t m = bli_obj_length( c ); - dim_t n = bli_obj_width( c ); - dim_t k = bli_obj_width( a ); - - void* buf_a = bli_obj_buffer_at_off( a ); - inc_t cs_a = bli_obj_col_stride( a ); - inc_t is_a = bli_obj_imag_stride( a ); - dim_t pd_a = bli_obj_panel_dim( a ); - inc_t ps_a = bli_obj_panel_stride( a ); - - void* buf_b = bli_obj_buffer_at_off( b ); - inc_t rs_b = bli_obj_row_stride( b ); - inc_t is_b = bli_obj_imag_stride( b ); - dim_t pd_b = bli_obj_panel_dim( b ); - inc_t ps_b = bli_obj_panel_stride( b ); - - void* buf_c = bli_obj_buffer_at_off( c ); - inc_t rs_c = bli_obj_row_stride( c ); - inc_t cs_c = bli_obj_col_stride( c ); - - obj_t scalar_a; - obj_t scalar_b; - - void* buf_alpha; - void* buf_beta; - - FUNCPTR_T f; - - // Detach and multiply the scalars attached to A and B. - bli_obj_scalar_detach( a, &scalar_a ); - bli_obj_scalar_detach( b, &scalar_b ); - bli_mulsc( &scalar_a, &scalar_b ); - - // Grab the addresses of the internal scalar buffers for the scalar - // merged above and the scalar attached to C. - buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); - buf_beta = bli_obj_internal_scalar_buffer( c ); - - // Index into the type combination array to extract the correct - // function pointer. - f = ftypes[dt_exec]; - - // Invoke the function. - f( diagoffc, - schema_a, - schema_b, - m, - n, - k, - buf_alpha, - buf_a, cs_a, is_a, - pd_a, ps_a, - buf_b, rs_b, is_b, - pd_b, ps_b, - buf_beta, - buf_c, rs_c, cs_c, - cntx, - rntm, - thread ); -} - - -#undef GENTFUNC -#define GENTFUNC( ctype, ch, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - doff_t diagoffc, \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ) \ -{ \ - const num_t dt = PASTEMAC(ch,type); \ -\ - /* Alias some constants to simpler names. */ \ - const dim_t MR = pd_a; \ - const dim_t NR = pd_b; \ - /*const dim_t PACKMR = cs_a;*/ \ - /*const dim_t PACKNR = rs_b;*/ \ -\ - /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ \ - PASTECH(ch,gemm_ukr_ft) \ - gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - ctype* restrict zero = PASTEMAC(ch,0); \ - ctype* restrict a_cast = a; \ - ctype* restrict b_cast = b; \ - ctype* restrict c_cast = c; \ - ctype* restrict alpha_cast = alpha; \ - ctype* restrict beta_cast = beta; \ - ctype* restrict b1; \ - ctype* restrict c1; \ -\ - doff_t diagoffc_ij; \ - dim_t m_iter, m_left; \ - dim_t n_iter, n_left; \ - dim_t m_cur; \ - dim_t n_cur; \ - dim_t i, j, jp; \ - inc_t rstep_a; \ - inc_t cstep_b; \ - inc_t rstep_c, cstep_c; \ - auxinfo_t aux; \ -\ - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ \ -\ - /* If any dimension is zero, return immediately. */ \ - if ( bli_zero_dim3( m, n, k ) ) return; \ -\ - /* Safeguard: If the current panel of C is entirely below the diagonal, - it is not stored. So we do nothing. */ \ - if ( bli_is_strictly_below_diag_n( diagoffc, m, n ) ) return; \ -\ - /* If there is a zero region to the left of where the diagonal of C - intersects the top edge of the panel, adjust the pointer to C and B - and treat this case as if the diagonal offset were zero. */ \ - if ( diagoffc > 0 ) \ - { \ - jp = diagoffc / NR; \ - j = jp * NR; \ - n = n - j; \ - diagoffc = diagoffc % NR; \ - c_cast = c_cast + (j )*cs_c; \ - b_cast = b_cast + (jp )*ps_b; \ - } \ -\ - /* If there is a zero region below where the diagonal of C intersects - the right edge of the panel, shrink it to prevent "no-op" iterations - from executing. */ \ - if ( -diagoffc + n < m ) \ - { \ - m = -diagoffc + n; \ - } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ -\ - /* Compute number of primary and leftover components of the m and n - dimensions. */ \ - n_iter = n / NR; \ - n_left = n % NR; \ -\ - m_iter = m / MR; \ - m_left = m % MR; \ -\ - if ( n_left ) ++n_iter; \ - if ( m_left ) ++m_iter; \ -\ - /* Determine some increments used to step through A, B, and C. */ \ - rstep_a = ps_a; \ -\ - cstep_b = ps_b; \ -\ - rstep_c = rs_c * MR; \ - cstep_c = cs_c * NR; \ -\ - /* Save the pack schemas of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_schema_a( schema_a, &aux ); \ - bli_auxinfo_set_schema_b( schema_b, &aux ); \ -\ - /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ - bli_auxinfo_set_is_a( is_a, &aux ); \ - bli_auxinfo_set_is_b( is_b, &aux ); \ -\ - b1 = b_cast; \ - c1 = c_cast; \ -\ - thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ - dim_t jr_num_threads = bli_thread_n_way( thread ); \ - dim_t jr_thread_id = bli_thread_work_id( thread ); \ - dim_t ir_num_threads = bli_thread_n_way( caucus ); \ - dim_t ir_thread_id = bli_thread_work_id( caucus ); \ -\ - /* Loop over the n dimension (NR columns at a time). */ \ - for ( j = jr_thread_id; j < n_iter; j += jr_num_threads ) \ - { \ - ctype* restrict a1; \ - ctype* restrict c11; \ - ctype* restrict b2; \ -\ - b1 = b_cast + j * cstep_b; \ - c1 = c_cast + j * cstep_c; \ -\ - n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ -\ - /* Initialize our next panel of B to be the current panel of B. */ \ - b2 = b1; \ -\ - /* Interior loop over the m dimension (MR rows at a time). */ \ - for ( i = ir_thread_id; i < m_iter; i += ir_num_threads ) \ - { \ - ctype* restrict a2; \ -\ - a1 = a_cast + i * rstep_a; \ - c11 = c1 + i * rstep_c; \ -\ - /* Compute the diagonal offset for the submatrix at (i,j). */ \ - diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR; \ -\ - m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ -\ - /* Compute the addresses of the next panels of A and B. */ \ - a2 = bli_gemmt_get_next_a_upanel( caucus, a1, rstep_a ); \ - if ( bli_is_last_iter( i, m_iter, ir_thread_id, ir_num_threads ) ) \ - { \ - a2 = a_cast; \ - b2 = bli_gemmt_get_next_b_upanel( thread, b1, cstep_b ); \ - if ( bli_is_last_iter( j, n_iter, jr_thread_id, jr_num_threads ) ) \ - b2 = b_cast; \ - } \ -\ - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ \ - bli_auxinfo_set_next_a( a2, &aux ); \ - bli_auxinfo_set_next_b( b2, &aux ); \ -\ - /* If the diagonal intersects the current MR x NR submatrix, we - compute it the temporary buffer and then add in the elements - on or below the diagonal. - Otherwise, if the submatrix is strictly above the diagonal, - we compute and store as we normally would. - And if we're strictly below the diagonal, we do nothing and - continue. */ \ - if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale C and add the result to only the stored part. */ \ - PASTEMAC(ch,xpbys_mxn_u)( diagoffc_ij, \ - m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ - { \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Scale the edge of C and add the result. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - beta_cast, \ - c11, rs_c, cs_c ); \ - } \ - } \ - } \ - } \ -} - -INSERT_GENTFUNC_BASIC0( gemmt_u_ker_var2 ) - diff --git a/frame/3/trmm/bli_trmm_ll_ker_var2.c b/frame/3/trmm/bli_trmm_ll_ker_var2.c index 792281b530..faf6ca100b 100644 --- a/frame/3/trmm/bli_trmm_ll_ker_var2.c +++ b/frame/3/trmm/bli_trmm_ll_ker_var2.c @@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \ function pointer type. */ \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ ctype* restrict one = PASTEMAC(ch,1); \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ ctype* restrict c_cast = c; \ @@ -254,10 +242,6 @@ void PASTEMAC(ch,varname) \ diagoffa = 0; \ c_cast = c_cast + (i )*rs_c; \ } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ \ /* Compute number of primary and leftover components of the m and n dimensions. */ \ @@ -379,47 +363,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_a1011, \ - alpha_cast, \ - a1, \ - b1_i, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Copy edge elements of C to the temporary buffer. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - c11, rs_c, cs_c, \ - ct, rs_ct, cs_ct ); \ -\ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_a1011, \ - alpha_cast, \ - a1, \ - b1_i, \ - beta_cast, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Copy the result to the edge of C. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k_a1011, \ + alpha_cast, \ + a1, \ + b1_i, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ /*}*/ \ \ a1 += ps_a_cur; \ @@ -446,42 +403,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - one, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + one, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ /*}*/ \ \ a1 += rstep_a; \ diff --git a/frame/3/trmm/bli_trmm_lu_ker_var2.c b/frame/3/trmm/bli_trmm_lu_ker_var2.c index 69498540b7..f7a3a717ce 100644 --- a/frame/3/trmm/bli_trmm_lu_ker_var2.c +++ b/frame/3/trmm/bli_trmm_lu_ker_var2.c @@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \ function pointer type. */ \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ ctype* restrict one = PASTEMAC(ch,1); \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ ctype* restrict c_cast = c; \ @@ -261,10 +249,6 @@ void PASTEMAC(ch,varname) \ { \ m = -diagoffa + k; \ } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ \ /* Compute number of primary and leftover components of the m and n dimensions. */ \ @@ -386,47 +370,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_a1112, \ - alpha_cast, \ - a1, \ - b1_i, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Copy edge elements of C to the temporary buffer. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - c11, rs_c, cs_c, \ - ct, rs_ct, cs_ct ); \ -\ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_a1112, \ - alpha_cast, \ - a1, \ - b1_i, \ - beta_cast, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Copy the result to the edge of C. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k_a1112, \ + alpha_cast, \ + a1, \ + b1_i, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ /*}*/ \ \ a1 += ps_a_cur; \ @@ -453,42 +410,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - one, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + one, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ /*}*/ \ \ a1 += rstep_a; \ diff --git a/frame/3/trmm/bli_trmm_rl_ker_var2.c b/frame/3/trmm/bli_trmm_rl_ker_var2.c index 03e3f1e531..195f8577db 100644 --- a/frame/3/trmm/bli_trmm_rl_ker_var2.c +++ b/frame/3/trmm/bli_trmm_rl_ker_var2.c @@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \ function pointer type. */ \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ ctype* restrict one = PASTEMAC(ch,1); \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ ctype* restrict c_cast = c; \ @@ -261,10 +249,6 @@ void PASTEMAC(ch,varname) \ { \ n = diagoffb + k; \ } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ \ /* Compute number of primary and leftover components of the m and n dimensions. */ \ @@ -382,42 +366,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - one, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + one, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ @@ -501,47 +463,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_b1121, \ - alpha_cast, \ - a1_i, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Copy edge elements of C to the temporary buffer. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - c11, rs_c, cs_c, \ - ct, rs_ct, cs_ct ); \ -\ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_b1121, \ - alpha_cast, \ - a1_i, \ - b1, \ - beta_cast, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Copy the result to the edge of C. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k_b1121, \ + alpha_cast, \ + a1_i, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ \ a1 += rstep_a; \ diff --git a/frame/3/trmm/bli_trmm_ru_ker_var2.c b/frame/3/trmm/bli_trmm_ru_ker_var2.c index 5d63bd46df..47df210f71 100644 --- a/frame/3/trmm/bli_trmm_ru_ker_var2.c +++ b/frame/3/trmm/bli_trmm_ru_ker_var2.c @@ -167,20 +167,8 @@ void PASTEMAC(ch,varname) \ function pointer type. */ \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ -\ - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ - / sizeof( ctype ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ - const inc_t rs_ct = ( col_pref ? 1 : NR ); \ - const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ ctype* restrict one = PASTEMAC(ch,1); \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ ctype* restrict c_cast = c; \ @@ -262,10 +250,6 @@ void PASTEMAC(ch,varname) \ { \ k = -diagoffb + n; \ } \ -\ - /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ - PASTEMAC(ch,set0s_mxn)( MR, NR, \ - ct, rs_ct, cs_ct ); \ \ /* Compute number of primary and leftover components of the m and n dimensions. */ \ @@ -410,47 +394,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_b0111, \ - alpha_cast, \ - a1_i, \ - b1, \ - beta_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Copy edge elements of C to the temporary buffer. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - c11, rs_c, cs_c, \ - ct, rs_ct, cs_ct ); \ -\ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k_b0111, \ - alpha_cast, \ - a1_i, \ - b1, \ - beta_cast, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Copy the result to the edge of C. */ \ - PASTEMAC(ch,copys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k_b0111, \ + alpha_cast, \ + a1_i, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ \ a1 += rstep_a; \ @@ -533,42 +490,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - one, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - alpha_cast, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,adds_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + one, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ } \ } \ diff --git a/frame/3/trsm/bli_trsm_cntl.c b/frame/3/trsm/bli_trsm_cntl.c index a8196ebb93..e37e1117c8 100644 --- a/frame/3/trsm/bli_trsm_cntl.c +++ b/frame/3/trsm/bli_trsm_cntl.c @@ -40,20 +40,22 @@ cntl_t* bli_trsm_cntl_create rntm_t* rntm, side_t side, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ) { if ( bli_is_left( side ) ) - return bli_trsm_l_cntl_create( rntm, schema_a, schema_b ); + return bli_trsm_l_cntl_create( rntm, schema_a, schema_b, ker ); else - return bli_trsm_r_cntl_create( rntm, schema_a, schema_b ); + return bli_trsm_r_cntl_create( rntm, schema_a, schema_b, ker ); } cntl_t* bli_trsm_l_cntl_create ( rntm_t* rntm, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ) { void_fp macro_kernel_p; @@ -61,6 +63,7 @@ cntl_t* bli_trsm_l_cntl_create // Use the function pointer to the macrokernels that use slab // assignment of micropanels to threads in the jr and ir loops. macro_kernel_p = bli_trsm_xx_ker_var2; + if ( ker ) macro_kernel_p = ker; const opid_t family = BLIS_TRSM; @@ -200,13 +203,15 @@ cntl_t* bli_trsm_l_cntl_create cntl_t* bli_trsm_r_cntl_create ( - rntm_t* rntm, + rntm_t* rntm, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ) { // NOTE: trsm macrokernels are presently disabled for right-side execution. void_fp macro_kernel_p = bli_trsm_xx_ker_var2; + if ( ker ) macro_kernel_p = ker; const opid_t family = BLIS_TRSM; diff --git a/frame/3/trsm/bli_trsm_cntl.h b/frame/3/trsm/bli_trsm_cntl.h index 7fdb1fc4f6..86f4a29b2a 100644 --- a/frame/3/trsm/bli_trsm_cntl.h +++ b/frame/3/trsm/bli_trsm_cntl.h @@ -38,21 +38,24 @@ cntl_t* bli_trsm_cntl_create rntm_t* rntm, side_t side, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ); cntl_t* bli_trsm_l_cntl_create ( rntm_t* rntm, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ); cntl_t* bli_trsm_r_cntl_create ( rntm_t* rntm, pack_t schema_a, - pack_t schema_b + pack_t schema_b, + void_fp ker ); void bli_trsm_cntl_free diff --git a/frame/3/trsm/bli_trsm_ll_ker_var2.c b/frame/3/trsm/bli_trsm_ll_ker_var2.c index dec41301ac..cf6a0cee48 100644 --- a/frame/3/trsm/bli_trsm_ll_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ll_ker_var2.c @@ -183,7 +183,6 @@ void PASTEMAC(ch,varname) \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict minus_one = PASTEMAC(ch,m1); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ @@ -470,43 +469,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - a1, \ - b1, \ - alpha2_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - alpha2_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + minus_one, \ + a1, \ + b1, \ + alpha2_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ \ a1 += rstep_a; \ } \ diff --git a/frame/3/trsm/bli_trsm_lu_ker_var2.c b/frame/3/trsm/bli_trsm_lu_ker_var2.c index 1627a12a39..38d5ab0df1 100644 --- a/frame/3/trsm/bli_trsm_lu_ker_var2.c +++ b/frame/3/trsm/bli_trsm_lu_ker_var2.c @@ -183,7 +183,6 @@ void PASTEMAC(ch,varname) \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict minus_one = PASTEMAC(ch,m1); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ @@ -480,43 +479,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( a2, &aux ); \ bli_auxinfo_set_next_b( b2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - a1, \ - b1, \ - alpha2_cast, \ - c11, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - a1, \ - b1, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - alpha2_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + minus_one, \ + a1, \ + b1, \ + alpha2_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ \ a1 += rstep_a; \ } \ diff --git a/frame/3/trsm/bli_trsm_rl_ker_var2.c b/frame/3/trsm/bli_trsm_rl_ker_var2.c index 8cbc26b36a..3fedc9295f 100644 --- a/frame/3/trsm/bli_trsm_rl_ker_var2.c +++ b/frame/3/trsm/bli_trsm_rl_ker_var2.c @@ -188,7 +188,6 @@ void PASTEMAC(ch,varname) \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict minus_one = PASTEMAC(ch,m1); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ @@ -384,7 +383,7 @@ void PASTEMAC(ch,varname) \ \ /* Compute the addresses of the triangular block B11 and the panel B21. */ \ - b11 = b1; \ + b11 = b1; \ b21 = b1 + k_b11 * PACKNR; \ /*b21 = bli_ptr_inc_by_frac( b1, sizeof( ctype ), k_b11 * PACKNR, 1 );*/ \ \ @@ -499,43 +498,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( b2, &aux ); \ bli_auxinfo_set_next_b( a2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - b1, \ - a1, \ - alpha2_cast, \ - c11, cs_c, rs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - b1, \ - a1, \ - zero, \ - ct, cs_ct, rs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - alpha2_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + minus_one, \ + b1, \ + a1, \ + alpha2_cast, \ + c11, cs_c, rs_c, \ + &aux, \ + cntx \ + ); \ } \ \ a1 += rstep_a; \ diff --git a/frame/3/trsm/bli_trsm_ru_ker_var2.c b/frame/3/trsm/bli_trsm_ru_ker_var2.c index 97399d0ae0..72b917b2c4 100644 --- a/frame/3/trsm/bli_trsm_ru_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ru_ker_var2.c @@ -188,7 +188,6 @@ void PASTEMAC(ch,varname) \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ - ctype* restrict zero = PASTEMAC(ch,0); \ ctype* restrict minus_one = PASTEMAC(ch,m1); \ ctype* restrict a_cast = a; \ ctype* restrict b_cast = b; \ @@ -377,7 +376,7 @@ void PASTEMAC(ch,varname) \ \ /* Compute the addresses of the panel B10 and the triangular block B11. */ \ - b01 = b1; \ + b01 = b1; \ b11 = b1 + k_b01 * PACKNR; \ /*b11 = bli_ptr_inc_by_frac( b1, sizeof( ctype ), k_b01 * PACKNR, 1 );*/ \ \ @@ -492,43 +491,20 @@ void PASTEMAC(ch,varname) \ bli_auxinfo_set_next_a( b2, &aux ); \ bli_auxinfo_set_next_b( a2, &aux ); \ \ - /* Handle interior and edge cases separately. */ \ - if ( m_cur == MR && n_cur == NR ) \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - b1, \ - a1, \ - alpha2_cast, \ - c11, cs_c, rs_c, \ - &aux, \ - cntx \ - ); \ - } \ - else \ - { \ - /* Invoke the gemm micro-kernel. */ \ - gemm_ukr \ - ( \ - k, \ - minus_one, \ - b1, \ - a1, \ - zero, \ - ct, cs_ct, rs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* Add the result to the edge of C. */ \ - PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ - ct, rs_ct, cs_ct, \ - alpha2_cast, \ - c11, rs_c, cs_c ); \ - } \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + m_cur, \ + n_cur, \ + k, \ + minus_one, \ + b1, \ + a1, \ + alpha2_cast, \ + c11, cs_c, rs_c, \ + &aux, \ + cntx \ + ); \ } \ \ a1 += rstep_a; \ diff --git a/frame/base/bli_auxinfo.h b/frame/base/bli_auxinfo.h index 68b6cc7cd6..d8c6cbb13f 100644 --- a/frame/base/bli_auxinfo.h +++ b/frame/base/bli_auxinfo.h @@ -74,6 +74,15 @@ BLIS_INLINE inc_t bli_auxinfo_ps_b( auxinfo_t* ai ) return ai->ps_b; } +BLIS_INLINE void_fp bli_auxinfo_ukr( auxinfo_t* ai ) +{ + return ai->ukr; +} +BLIS_INLINE void* bli_auxinfo_params( auxinfo_t* ai ) +{ + return ai->params; +} + // auxinfo_t field modification @@ -118,5 +127,14 @@ BLIS_INLINE void bli_auxinfo_set_ps_b( inc_t ps, auxinfo_t* ai ) ai->ps_b = ps; } -#endif +BLIS_INLINE void bli_auxinfo_set_ukr( void_fp ukr, auxinfo_t* ai ) +{ + ai->ukr = ukr; +} +BLIS_INLINE void bli_auxinfo_set_params( void* params, auxinfo_t* ai ) +{ + ai->params = params; +} + +#endif diff --git a/frame/include/bli_misc_macro_defs.h b/frame/include/bli_misc_macro_defs.h index 120338beba..b166b7a171 100644 --- a/frame/include/bli_misc_macro_defs.h +++ b/frame/include/bli_misc_macro_defs.h @@ -164,5 +164,62 @@ BLIS_INLINE void bli_toggle_bool( bool* b ) #define bli_iformatspec() "%6d" +// helper macros for gemm microkernels + +#define GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major) \ +PASTEMAC(ch,ctype)* restrict _beta = beta; \ +PASTEMAC(ch,ctype)* restrict _c = c; \ +const inc_t _rs_c = rs_c; \ +const inc_t _cs_c = cs_c; \ +PASTEMAC(ch,ctype) _ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( PASTEMAC(ch,type) ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ +const inc_t _rs_ct = row_major ? nr : 1; \ +const inc_t _cs_ct = row_major ? 1 : mr; + +#define GEMM_UKR_SETUP_CT_POST(ch) \ +PASTEMAC(ch,ctype) _zero; \ +PASTEMAC(ch,set0s)( _zero ); \ +\ +if ( _use_ct ) \ +{ \ + c = _ct; \ + rs_c = _rs_ct; \ + cs_c = _cs_ct; \ + beta = &_zero; \ +} + +#define GEMM_UKR_SETUP_CT(ch,mr,nr,row_major) \ +GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ +const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \ + m != mr || n != nr; \ +GEMM_UKR_SETUP_CT_POST(ch); + +#define GEMM_UKR_SETUP_CT_AMBI(ch,mr,nr,row_major) \ +GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ +const bool _use_ct = ( cs_c != 1 && rs_c != 1 ) || \ + m != mr || n != nr; \ +GEMM_UKR_SETUP_CT_POST(ch); + +#define GEMM_UKR_SETUP_CT_ANY(ch,mr,nr,row_major) \ +GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ +const bool _use_ct = m != mr || n != nr; \ +GEMM_UKR_SETUP_CT_POST(ch); + +#define GEMM_UKR_SETUP_CT_ALIGNED(ch,mr,nr,row_major,alignment) \ +GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ +const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \ + m != mr || n != nr || \ + ( (uintptr_t)_c % alignment ) || \ + ( ( ( row_major ? _rs_c : _cs_c )*sizeof( PASTEMAC(ch,ctype) ) ) % alignment ); \ +GEMM_UKR_SETUP_CT_POST(ch); + +#define GEMM_UKR_FLUSH_CT(ch) \ +if ( _use_ct ) \ + PASTEMAC(ch,xpbys_mxn)( m, n, \ + _ct, _rs_ct, _cs_ct, \ + _beta, \ + _c, _rs_c, _cs_c); + + #endif diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index 5be0ceeb42..ba6db2af89 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -1144,6 +1144,13 @@ typedef struct inc_t ps_a; inc_t ps_b; + // The type to convert to on output. + //num_t dt_on_output; + + // (Virtual) microkernel address and additional parameters + void_fp ukr; + void* params; + } auxinfo_t; diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c index 66337e0b73..b86c901847 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c @@ -42,6 +42,8 @@ // 2vx10 microkernels. #include "armsve_asm_2vx10cmplx.h" +#include "arm_sve.h" + void bli_cgemm_armsve_asm_2vx10_unindexed ( dim_t k0, @@ -65,6 +67,9 @@ void bli_cgemm_armsve_asm_2vx10_unindexed uint64_t cs_c = cs_c0; uint64_t info = 0; + uint64_t mr = svcntw(); + GEMM_UKR_SETUP_CT( c, mr, 10, false ); + __asm__ volatile ( // " ldr x0, %[a] \n\t" // " ldr x1, %[b] \n\t" @@ -310,5 +315,7 @@ GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16) "z24","z25","z26","z27", "z28","z29","z30","z31" ); + + GEMM_UKR_FLUSH_CT( c ); } diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c index e5b78a5921..9730fb8ce3 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_d2vx10_unindexed.c @@ -42,9 +42,13 @@ // 2vx10 microkernels. #include "armsve_asm_2vx10.h" +#include "arm_sve.h" + void bli_dgemm_armsve_asm_2vx10_unindexed ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -59,11 +63,14 @@ void bli_dgemm_armsve_asm_2vx10_unindexed // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_mker = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_mker = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t mr = 2*svcntd(); + GEMM_UKR_SETUP_CT( d, mr, 10, false ); + __asm__ volatile ( " ldr x0, %[a] \n\t" " ldr x1, %[b] \n\t" @@ -324,5 +331,7 @@ GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p0,x5,x7,x8,x "z24","z25","z26","z27", "z28","z29","z30","z31" ); + + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c index 00b3f20b44..e89daa0e32 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c @@ -44,7 +44,9 @@ void bli_sgemm_armsve_asm_2vx10_unindexed ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -59,11 +61,14 @@ void bli_sgemm_armsve_asm_2vx10_unindexed // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_mker = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_mker = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + uint64_t mr = 2*svcntw(); + GEMM_UKR_SETUP_CT( s, mr, 10, false ); + __asm__ volatile ( " ldr x0, %[a] \n\t" " ldr x1, %[b] \n\t" @@ -310,5 +315,7 @@ GEMM_C_STORE_UKER_G(z10,z12,z14,z16,z18,z11,z13,z15,z17,z19,z30,p0,p0,x5,x7,x8,x "z24","z25","z26","z27", "z28","z29","z30","z31" ); + + GEMM_UKR_FLUSH_CT( s ); } diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c index 2fa37664ae..9ae666358b 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c @@ -42,6 +42,8 @@ // 2vx10 microkernels. #include "armsve_asm_2vx10cmplx.h" +#include "arm_sve.h" + void bli_zgemm_armsve_asm_2vx10_unindexed ( dim_t k0, @@ -65,6 +67,9 @@ void bli_zgemm_armsve_asm_2vx10_unindexed uint64_t cs_c = cs_c0; uint64_t info = 0; + uint64_t mr = svcntd(); + GEMM_UKR_SETUP_CT( z, mr, 10, false ); + __asm__ volatile ( // " ldr x0, %[a] \n\t" // " ldr x1, %[b] \n\t" @@ -309,5 +314,7 @@ GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z28,%2,%4,x16) "z24","z25","z26","z27", "z28","z29","z30","z31" ); + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c index 3d25719d92..fb0b596a31 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c @@ -42,6 +42,8 @@ // 2vx7 microkernels. #include "armsve_asm_2vx7cmplx.h" +#include "arm_sve.h" + void bli_zgemm_armsve_asm_2vx7_unindexed ( dim_t k0, @@ -65,6 +67,9 @@ void bli_zgemm_armsve_asm_2vx7_unindexed uint64_t cs_c = cs_c0; uint64_t info = 0; + uint64_t mr = svcntd(); + GEMM_UKR_SETUP_CT( z, mr, 7, false ); + __asm__ volatile ( // " ldr x0, %[a] \n\t" // " ldr x1, %[b] \n\t" @@ -261,6 +266,8 @@ GEMM_CCMPLX_STORE_COL7_G(z14,z15,z16,z17,z18,z19,z20,z21,z22,z23,z24,z25,z26,z27 "z24","z25","z26","z27", "z28","z29","z30","z31" ); + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c index d0eef4a8ca..13fc26c8c2 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c @@ -42,6 +42,8 @@ // 2vx8 microkernels. #include "armsve_asm_2vx8cmplx.h" +#include "arm_sve.h" + void bli_zgemm_armsve_asm_2vx8_unindexed ( dim_t k0, @@ -65,6 +67,9 @@ void bli_zgemm_armsve_asm_2vx8_unindexed uint64_t cs_c = cs_c0; uint64_t info = 0; + uint64_t mr = svcntd(); + GEMM_UKR_SETUP_CT( z, mr, 8, false ); + __asm__ volatile ( // " ldr x0, %[a] \n\t" // " ldr x1, %[b] \n\t" @@ -286,5 +291,7 @@ GEMM_CCMPLX_STORE_COL2_G(z8 ,z9 ,z10,z11,p0,z16,%2,%4,x16) "z24","z25","z26","z27", "z28","z29","z30","z31" ); + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c b/kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c index b526cd0951..9c9b691fc4 100644 --- a/kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c +++ b/kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c @@ -37,6 +37,8 @@ extern void bli_sgemm_armv7a_ker_4x4 ( + uint32_t m, + uint32_t n, uint32_t k, float* restrict alpha, float* restrict a, @@ -48,23 +50,21 @@ void bli_sgemm_armv7a_ker_4x4 void bli_sgemm_armv7a_asm_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, float* restrict beta, - float* restrict c, inc_t rs_c0, inc_t cs_c0, + float* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint32_t k = k0; - uint32_t rs_c = rs_c0; - uint32_t cs_c = cs_c0; - - bli_sgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data ); + bli_sgemm_armv7a_ker_4x4( m, n, k, alpha, a, b, beta, c, rs_c, cs_c, data ); } @@ -72,6 +72,8 @@ void bli_sgemm_armv7a_asm_4x4 extern void bli_dgemm_armv7a_ker_4x4 ( + uint32_t m, + uint32_t n, uint32_t k, double* restrict alpha, double* restrict a, @@ -83,23 +85,21 @@ void bli_dgemm_armv7a_ker_4x4 void bli_dgemm_armv7a_asm_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint32_t k = k0; - uint32_t rs_c = rs_c0; - uint32_t cs_c = cs_c0; - - bli_dgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data ); + bli_dgemm_armv7a_ker_4x4( m, n, k, alpha, a, b, beta, c, rs_c, cs_c, data ); } @@ -107,6 +107,8 @@ void bli_dgemm_armv7a_asm_4x4 extern void bli_cgemm_armv7a_ker_2x2 ( + uint32_t m, + uint32_t n, uint32_t k, scomplex* restrict alpha, scomplex* restrict a, @@ -118,23 +120,21 @@ void bli_cgemm_armv7a_ker_2x2 void bli_cgemm_armv7a_asm_2x2 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, scomplex* restrict beta, - scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + scomplex* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint32_t k = k0; - uint32_t rs_c = rs_c0; - uint32_t cs_c = cs_c0; - - bli_cgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data ); + bli_cgemm_armv7a_ker_2x2( m, n, k, alpha, a, b, beta, c, rs_c, cs_c, data ); } @@ -142,6 +142,8 @@ void bli_cgemm_armv7a_asm_2x2 extern void bli_zgemm_armv7a_ker_2x2 ( + uint32_t m, + uint32_t n, uint32_t k, dcomplex* restrict alpha, dcomplex* restrict a, @@ -153,22 +155,20 @@ void bli_zgemm_armv7a_ker_2x2 void bli_zgemm_armv7a_asm_2x2 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, dcomplex* restrict beta, - dcomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + dcomplex* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint32_t k = k0; - uint32_t rs_c = rs_c0; - uint32_t cs_c = cs_c0; - - bli_zgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data ); + bli_zgemm_armv7a_ker_2x2( m, n, k, alpha, a, b, beta, c, rs_c, cs_c, data ); } diff --git a/kernels/armv7a/3/bli_gemm_armv7a_asm_wrap.c b/kernels/armv7a/3/bli_gemm_armv7a_asm_wrap.c new file mode 100644 index 0000000000..ab2f91ced7 --- /dev/null +++ b/kernels/armv7a/3/bli_gemm_armv7a_asm_wrap.c @@ -0,0 +1,113 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "../bli_kernels_armv7a.h" + +void bli_sgemm_armv7a_asm_wrap_4x4 + ( + dim_t m, + dim_t n, + dim_t k, + float* restrict alpha, + float* restrict a, + float* restrict b, + float* restrict beta, + float* restrict c, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + GEMM_UKR_SETUP_CT( s, 4, 4, false ); + bli_sgemm_armv7a_asm_4x4(k, alpha, a, b, beta, c, rs_c, cs_c, data, cntx); + GEMM_UKR_FLUSH_CT( s ); +} + +void bli_dgemm_armv7a_asm_wrap_4x4 + ( + dim_t m, + dim_t n, + dim_t k, + double* restrict alpha, + double* restrict a, + double* restrict b, + double* restrict beta, + double* restrict c, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + GEMM_UKR_SETUP_CT( d, 4, 4, false ); + bli_dgemm_armv7a_asm_4x4(k, alpha, a, b, beta, c, rs_c, cs_c, data, cntx); + GEMM_UKR_FLUSH_CT( d ); +} + +void bli_cgemm_armv7a_asm_wrap_2x2 + ( + dim_t m, + dim_t n, + dim_t k, + scomplex* restrict alpha, + scomplex* restrict a, + scomplex* restrict b, + scomplex* restrict beta, + scomplex* restrict c, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + GEMM_UKR_SETUP_CT( c, 2, 2, false ); + bli_cgemm_armv7a_asm_2x2(k, alpha, a, b, beta, c, rs_c, cs_c, data, cntx); + GEMM_UKR_FLUSH_CT( c ); +} + +void bli_zgemm_armv7a_asm_wrap_2x2 + ( + dim_t m, + dim_t n, + dim_t k, + dcomplex* restrict alpha, + dcomplex* restrict a, + dcomplex* restrict b, + dcomplex* restrict beta, + dcomplex* restrict c, inc_t rs_c, inc_t cs_c, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + GEMM_UKR_SETUP_CT( z, 2, 2, false ); + bli_zgemm_armv7a_asm_2x2(k, alpha, a, b, beta, c, rs_c, cs_c, data, cntx); + GEMM_UKR_FLUSH_CT( z ); +} + diff --git a/kernels/armv7a/3/bli_gemm_armv7a_int_d4x4.c b/kernels/armv7a/3/bli_gemm_armv7a_int_d4x4.c index b9db587266..06f36a3463 100644 --- a/kernels/armv7a/3/bli_gemm_armv7a_int_d4x4.c +++ b/kernels/armv7a/3/bli_gemm_armv7a_int_d4x4.c @@ -37,7 +37,9 @@ void bli_sgemm_armv7a_int_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -49,12 +51,14 @@ void bli_sgemm_armv7a_int_4x4 { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint32_t k_iter = k0 / 4; - uint32_t k_left = k0 % 4; + uint32_t k_iter = k / 4; + uint32_t k_left = k % 4; uint32_t rs_c = rs_c0; uint32_t cs_c = cs_c0; uint32_t i; + GEMM_UKR_SETUP_CT( s, 4, 4, false ); + void* a_next = bli_auxinfo_next_a( data ); void* b_next = bli_auxinfo_next_b( data ); @@ -82,47 +86,17 @@ void bli_sgemm_armv7a_int_4x4 if ( *beta != 0.0F ) { - if ( rs_c == 1 ) - { - // Load column 0 - cv0 = vld1q_f32( c + 0*rs_c + 0*cs_c ); - - // Load column 1 - cv1 = vld1q_f32( c + 0*rs_c + 1*cs_c ); - - // Load column 2 - cv2 = vld1q_f32( c + 0*rs_c + 2*cs_c ); - - // Load column 3 - cv3 = vld1q_f32( c + 0*rs_c + 3*cs_c ); - } - else - { - // Load column 0 - cv0 = vld1q_lane_f32( c + 0*rs_c + 0*cs_c, cv0, 0); - cv0 = vld1q_lane_f32( c + 1*rs_c + 0*cs_c, cv0, 1); - cv0 = vld1q_lane_f32( c + 2*rs_c + 0*cs_c, cv0, 2); - cv0 = vld1q_lane_f32( c + 3*rs_c + 0*cs_c, cv0, 3); - - // Load column 1 - cv1 = vld1q_lane_f32( c + 0*rs_c + 1*cs_c, cv1, 0); - cv1 = vld1q_lane_f32( c + 1*rs_c + 1*cs_c, cv1, 1); - cv1 = vld1q_lane_f32( c + 2*rs_c + 1*cs_c, cv1, 2); - cv1 = vld1q_lane_f32( c + 3*rs_c + 1*cs_c, cv1, 3); - - // Load column 2 - cv2 = vld1q_lane_f32( c + 0*rs_c + 2*cs_c, cv2, 0); - cv2 = vld1q_lane_f32( c + 1*rs_c + 2*cs_c, cv2, 1); - cv2 = vld1q_lane_f32( c + 2*rs_c + 2*cs_c, cv2, 2); - cv2 = vld1q_lane_f32( c + 3*rs_c + 2*cs_c, cv2, 3); - - // Load column 3 - cv3 = vld1q_lane_f32( c + 0*rs_c + 3*cs_c, cv3, 0); - cv3 = vld1q_lane_f32( c + 1*rs_c + 3*cs_c, cv3, 1); - cv3 = vld1q_lane_f32( c + 2*rs_c + 3*cs_c, cv3, 2); - cv3 = vld1q_lane_f32( c + 3*rs_c + 3*cs_c, cv3, 3); - - } + // Load column 0 + cv0 = vld1q_f32( c + 0*cs_c ); + + // Load column 1 + cv1 = vld1q_f32( c + 1*cs_c ); + + // Load column 2 + cv2 = vld1q_f32( c + 2*cs_c ); + + // Load column 3 + cv3 = vld1q_f32( c + 3*cs_c ); } else { @@ -255,47 +229,22 @@ void bli_sgemm_armv7a_int_4x4 cv3 = vmlaq_f32( cv3, abv3, alphav ); } - if ( rs_c == 1 ) - { - // Store column 0 - vst1q_f32( c + 0*rs_c + 0*cs_c, cv0 ); - // Store column 1 - vst1q_f32( c + 0*rs_c + 1*cs_c, cv1 ); - // Store column 2 - vst1q_f32( c + 0*rs_c + 2*cs_c, cv2 ); - // Store column 3 - vst1q_f32( c + 0*rs_c + 3*cs_c, cv3 ); - } - else - { - // Store column 0 - vst1q_lane_f32( c + 0*rs_c + 0*cs_c, cv0, 0); - vst1q_lane_f32( c + 1*rs_c + 0*cs_c, cv0, 1); - vst1q_lane_f32( c + 2*rs_c + 0*cs_c, cv0, 2); - vst1q_lane_f32( c + 3*rs_c + 0*cs_c, cv0, 3); - - // Store column 1 - vst1q_lane_f32( c + 0*rs_c + 1*cs_c, cv1, 0); - vst1q_lane_f32( c + 1*rs_c + 1*cs_c, cv1, 1); - vst1q_lane_f32( c + 2*rs_c + 1*cs_c, cv1, 2); - vst1q_lane_f32( c + 3*rs_c + 1*cs_c, cv1, 3); - - // Store column 2 - vst1q_lane_f32( c + 0*rs_c + 2*cs_c, cv2, 0); - vst1q_lane_f32( c + 1*rs_c + 2*cs_c, cv2, 1); - vst1q_lane_f32( c + 2*rs_c + 2*cs_c, cv2, 2); - vst1q_lane_f32( c + 3*rs_c + 2*cs_c, cv2, 3); - - // Store column 3 - vst1q_lane_f32( c + 0*rs_c + 3*cs_c, cv3, 0); - vst1q_lane_f32( c + 1*rs_c + 3*cs_c, cv3, 1); - vst1q_lane_f32( c + 2*rs_c + 3*cs_c, cv3, 2); - vst1q_lane_f32( c + 3*rs_c + 3*cs_c, cv3, 3); - } + // Store column 0 + vst1q_f32( c + 0*cs_c, cv0 ); + // Store column 1 + vst1q_f32( c + 1*cs_c, cv1 ); + // Store column 2 + vst1q_f32( c + 2*cs_c, cv2 ); + // Store column 3 + vst1q_f32( c + 3*cs_c, cv3 ); + + GEMM_UKR_FLUSH_CT( s ); } void bli_dgemm_armv7a_int_4x4 ( + dim_t m, + dim_t n, dim_t k, double* restrict alpha, double* restrict a, @@ -314,6 +263,8 @@ void bli_dgemm_armv7a_int_4x4 uint32_t cs_c = cs_c0; uint32_t i; + GEMM_UKR_SETUP_CT_ANY( d, 4, 4, false ); + //void* a_next = bli_auxinfo_next_a( data ); //void* b_next = bli_auxinfo_next_b( data ); @@ -568,5 +519,7 @@ void bli_dgemm_armv7a_int_4x4 *c23 += ab23 * *alpha; *c33 += ab33 * *alpha; } + + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/armv7a/bli_kernels_armv7a.h b/kernels/armv7a/bli_kernels_armv7a.h index 7eaf16e655..9fe3b4cf42 100644 --- a/kernels/armv7a/bli_kernels_armv7a.h +++ b/kernels/armv7a/bli_kernels_armv7a.h @@ -32,10 +32,31 @@ */ -GEMM_UKR_PROT( float, s, gemm_armv7a_asm_4x4 ) -GEMM_UKR_PROT( double, d, gemm_armv7a_asm_4x4 ) -GEMM_UKR_PROT( scomplex, c, gemm_armv7a_asm_2x2 ) -GEMM_UKR_PROT( dcomplex, z, gemm_armv7a_asm_2x2 ) +void bli_sgemm_armv7a_asm_4x4(dim_t k, float *restrict alpha, + float *restrict a, float *restrict b, + float *restrict beta, float *restrict c, + inc_t rs_c, inc_t cs_c, auxinfo_t *restrict data, + cntx_t *restrict cntx); +void bli_dgemm_armv7a_asm_4x4(dim_t k, double *restrict alpha, + double *restrict a, double *restrict b, + double *restrict beta, double *restrict c, + inc_t rs_c, inc_t cs_c, auxinfo_t *restrict data, + cntx_t *restrict cntx); +void bli_cgemm_armv7a_asm_2x2(dim_t k, scomplex *restrict alpha, + scomplex *restrict a, scomplex *restrict b, + scomplex *restrict beta, scomplex *restrict c, + inc_t rs_c, inc_t cs_c, auxinfo_t *restrict data, + cntx_t *restrict cntx); +void bli_zgemm_armv7a_asm_2x2(dim_t k, dcomplex *restrict alpha, + dcomplex *restrict a, dcomplex *restrict b, + dcomplex *restrict beta, dcomplex *restrict c, + inc_t rs_c, inc_t cs_c, auxinfo_t *restrict data, + cntx_t *restrict cntx); + +GEMM_UKR_PROT( float, s, gemm_armv7a_asm_wrap_4x4 ) +GEMM_UKR_PROT( double, d, gemm_armv7a_asm_wrap_4x4 ) +GEMM_UKR_PROT( scomplex, c, gemm_armv7a_asm_wrap_2x2 ) +GEMM_UKR_PROT( dcomplex, z, gemm_armv7a_asm_wrap_2x2 ) GEMM_UKR_PROT( float, s, gemm_armv7a_int_4x4 ) GEMM_UKR_PROT( double, d, gemm_armv7a_int_4x4 ) diff --git a/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c b/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c index dfdda863b1..2b188da842 100644 --- a/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c +++ b/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c @@ -1,4 +1,4 @@ - /* + /* BLIS An object-based framework for developing high-performance BLAS-like @@ -40,20 +40,22 @@ o 4x4 Single precision micro-kernel fully functional. o Runnable on ARMv8, compiled with aarch64 GCC. o Use it together with the armv8 BLIS configuration. - o Tested on Juno board. Around 7.3 GFLOPS @ 1.1 GHz. + o Tested on Juno board. Around 7.3 GFLOPS @ 1.1 GHz. December 2014. - + * UPDATE NOVEMBER 2015 * Micro-kernel changed to 8x12 * Tested on Juno Board. Around 8.1 GFLOPS, 1 x A57 core @ 1.1 GHz. * Tested on Juno Board. Around 15.9 GFLOPS, 2 x A57 cores @ 1.1 GHz. - * Tested on Juno board. Around 3.1 GFLOPS, 1 x A53 core @ 850 MHz. + * Tested on Juno board. Around 3.1 GFLOPS, 1 x A53 core @ 850 MHz. * Tested on Juno board. Around 12 GFLOPS, 4 x A53 cores @ 850 MHz. */ void bli_sgemm_armv8a_asm_8x12 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -68,611 +70,613 @@ void bli_sgemm_armv8a_asm_8x12 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( s, 8, 12, false ); -__asm__ volatile -( -" \n\t" -" \n\t" -" ldr x0,%[aaddr] \n\t" // Load address of A. -" ldr x1,%[baddr] \n\t" // Load address of B. -" ldr x2,%[caddr] \n\t" // Load address of C. -" \n\t" -" ldr x5,%[k_iter] \n\t" // Number of unrolled iterations (k_iter). -" ldr x6,%[k_left] \n\t" // Number of remaining iterations (k_left). -" \n\t" -" ldr x10,%[cs_c] \n\t" // Load cs_c. -" lsl x10,x10,#2 \n\t" // cs_c * sizeof(float) -- AUX. -" \n\t" -" ldr x14,%[rs_c] \n\t" // Load rs_c. -" lsl x14,x14,#2 \n\t" // rs_c * sizeof(float). -" \n\t" -" add x16,x2,x10 \n\t" //Load address Column 1 of C -" add x17,x16,x10 \n\t" //Load address Column 2 of C -" add x19,x17,x10 \n\t" //Load address Column 3 of C -" add x20,x19,x10 \n\t" //Load address Column 4 of C -" add x21,x20,x10 \n\t" //Load address Column 5 of C -" add x22,x21,x10 \n\t" //Load address Column 6 of C -" add x23,x22,x10 \n\t" //Load address Column 7 of C -" add x24,x23,x10 \n\t" //Load address Column 8 of C -" add x25,x24,x10 \n\t" //Load address Column 9 of C -" add x26,x25,x10 \n\t" //Load address Column 10 of C -" add x27,x26,x10 \n\t" //Load address Column 11 of C -" \n\t" -" prfm pldl1keep,[x2] \n\t" // Prefetch c. -" prfm pldl1keep,[x16] \n\t" // Prefetch c. -" prfm pldl1keep,[x17] \n\t" // Prefetch c. -" prfm pldl1keep,[x19] \n\t" // Prefetch c. -" prfm pldl1keep,[x20] \n\t" // Prefetch c. -" prfm pldl1keep,[x21] \n\t" // Prefetch c. -" prfm pldl1keep,[x22] \n\t" // Prefetch c. -" prfm pldl1keep,[x23] \n\t" // Prefetch c. -" prfm pldl1keep,[x24] \n\t" // Prefetch c. -" prfm pldl1keep,[x25] \n\t" // Prefetch c. -" prfm pldl1keep,[x26] \n\t" // Prefetch c. -" prfm pldl1keep,[x27] \n\t" // Prefetch c. -" \n\t" -" dup v8.4s, wzr \n\t" // Vector for accummulating column 0 -" prfm PLDL1KEEP, [x1, #192] \n\t" -" dup v9.4s, wzr \n\t" // Vector for accummulating column 0 -" prfm PLDL1KEEP, [x1, #256] \n\t" -" dup v10.4s, wzr \n\t" // Vector for accummulating column 1 -" prfm PLDL1KEEP, [x1, #320] \n\t" -" dup v11.4s, wzr \n\t" // Vector for accummulating column 1 -" dup v12.4s, wzr \n\t" // Vector for accummulating column 2 -" dup v13.4s, wzr \n\t" // Vector for accummulating column 2 -" \n\t" -" dup v14.4s, wzr \n\t" // Vector for accummulating column 3 -" prfm PLDL1KEEP, [x0, #128] \n\t" -" dup v15.4s, wzr \n\t" // Vector for accummulating column 3 -" prfm PLDL1KEEP, [x0, #192] \n\t" -" dup v16.4s, wzr \n\t" // Vector for accummulating column 4 -" dup v17.4s, wzr \n\t" // Vector for accummulating column 4 -" dup v18.4s, wzr \n\t" // Vector for accummulating column 5 -" dup v19.4s, wzr \n\t" // Vector for accummulating column 5 -" \n\t" -" dup v20.4s, wzr \n\t" // Vector for accummulating column 6 -" dup v21.4s, wzr \n\t" // Vector for accummulating column 6 -" dup v22.4s, wzr \n\t" // Vector for accummulating column 7 -" dup v23.4s, wzr \n\t" // Vector for accummulating column 7 -" dup v24.4s, wzr \n\t" // Vector for accummulating column 8 -" dup v25.4s, wzr \n\t" // Vector for accummulating column 8 -" \n\t" -" dup v26.4s, wzr \n\t" // Vector for accummulating column 9 -" dup v27.4s, wzr \n\t" // Vector for accummulating column 9 -" dup v28.4s, wzr \n\t" // Vector for accummulating column 10 -" dup v29.4s, wzr \n\t" // Vector for accummulating column 10 -" dup v30.4s, wzr \n\t" // Vector for accummulating column 11 -" dup v31.4s, wzr \n\t" // Vector for accummulating column 11 -" \n\t" -" cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. -BEQ(SCONSIDERKLEFT) -" \n\t" -" ldr q0, [x0] \n\t" -" ldr q1, [x0, #16] \n\t" // Load a -" \n\t" -" ldr q2, [x1] \n\t" // Load b -" ldr q3, [x1, #16] \n\t" -" ldr q4, [x1, #32] \n\t" -" \n\t" -" add x0, x0, #32 \n\t" //update address of A -" add x1, x1, #48 \n\t" //update address of B -" \n\t" -" cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. -BEQ(SLASTITER) // (as loop is do-while-like). -" \n\t" -LABEL(SLOOPKITER) // Body of the k_iter loop. -" \n\t" -" ldr q5, [x0] \n\t" -" fmla v8.4s, v0.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s, v1.4s,v2.s[0] \n\t" // Accummulate. -" ldr q6, [x0, #16] \n\t" -" fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1] \n\t" -" \n\t" -" fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. -" prfm PLDL1KEEP, [x1, #336] \n\t" -" fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. -" prfm PLDL1KEEP, [x1, #400] \n\t" -" fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. -" prfm PLDL1KEEP, [x1, #464] \n\t" -" fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #16] \n\t" -" \n\t" -" fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #32] \n\t" -" \n\t" //End It 1 -" \n\t" -" ldr q0, [x0, #32] \n\t" -" fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. -" ldr q1, [x0, #48] \n\t" -" fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1, #48] \n\t" -" \n\t" -" fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. -" prfm PLDL1KEEP, [x0, #224] \n\t" -" fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. -" prfm PLDL1KEEP, [x0, #288] \n\t" -" fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #64] \n\t" -" \n\t" -" fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #80] \n\t" -" \n\t" //End It 2 -" \n\t" -" ldr q5, [x0, #64] \n\t" -" fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. -" ldr q6, [x0, #80] \n\t" -" fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1, #96] \n\t" -" \n\t" -" fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #112] \n\t" -" \n\t" -" fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #128] \n\t" -" \n\t" //End It 3 -" \n\t" -" ldr q0, [x0, #96] \n\t" -" fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. -" ldr q1, [x0, #112] \n\t" -" fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1, #144] \n\t" -" \n\t" -" fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #160] \n\t" -" \n\t" -" fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #176] \n\t" -" add x1, x1, #192 \n\t" -" add x0, x0, #128 \n\t" -" \n\t" //End It 4 -" sub x5,x5,1 \n\t" // i-=1. -" cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. -BNE(SLOOPKITER) -" \n\t" -LABEL(SLASTITER) // Last iteration of k_iter loop. -" \n\t" -" \n\t" -" ldr q5, [x0] \n\t" -" fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. -" ldr q6, [x0, #16] \n\t" -" fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1] \n\t" -" \n\t" -" fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #16] \n\t" -" \n\t" -" fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #32] \n\t" -" \n\t" //End It 1 -" \n\t" -" ldr q0, [x0, #32] \n\t" -" fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. -" ldr q1, [x0, #48] \n\t" -" fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1, #48] \n\t" -" \n\t" -" fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #64] \n\t" -" \n\t" -" fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #80] \n\t" -" \n\t" //End It 2 -" \n\t" -" ldr q5, [x0, #64] \n\t" -" fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. -" ldr q6, [x0, #80] \n\t" -" fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. -" ldr q2, [x1, #96] \n\t" -" \n\t" -" fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. -" ldr q3, [x1, #112] \n\t" -" \n\t" -" fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. -" ldr q4, [x1, #128] \n\t" -" \n\t" //End It 3 -" \n\t" -" fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. -" fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. -" add x1, x1, #144 \n\t" -" add x0, x0, #96 \n\t" -" \n\t" //End It 4 -" \n\t" -LABEL(SCONSIDERKLEFT) -" cmp x6,0 \n\t" // If k_left == 0, we are done. -BEQ(SPOSTACCUM) // else, we enter the k_left loop. -" \n\t" -LABEL(SLOOPKLEFT) // Body of the left iterations -" \n\t" -" ldr q0, [x0],#16 \n\t" -" ldr q1, [x0],#16 \n\t" // Load a -" \n\t" -" ldr q2, [x1],#16 \n\t" // Load b -" ldr q3, [x1],#16 \n\t" -" ldr q4, [x1],#16 \n\t" -" \n\t" -" sub x6,x6,1 \n\t" // i = i-1. -" \n\t" -" fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. -" fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. -" fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. -" fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. -" fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. -" fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. -" fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. -" fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. -" fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. -" fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. -" fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. -" fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. -" fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. -" fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. -" fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. -" \n\t" -" fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. -" fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. -" fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. -" fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. -" fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. -" fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. -" fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. -" fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. -" \n\t" -" cmp x6,0 \n\t" // Iterate again. -BNE(SLOOPKLEFT) // if i!=0. -" \n\t" -LABEL(SPOSTACCUM) -" \n\t" -" ldr x0,%[alpha] \n\t" // Alpha address. -" ldr x1,%[beta] \n\t" // Beta address. -" \n\t" -" ld1r {v6.4s},[x0] \n\t" // Load alpha. -" ld1r {v7.4s},[x1] \n\t" // Load beta -" \n\t" -" ldr x0,%[a_next] \n\t" // Pointer to next block of A. -" ldr x1,%[b_next] \n\t" // Pointer to next pointer of B. -" \n\t" + + __asm__ volatile + ( + " \n\t" + " \n\t" + " ldr x0,%[aaddr] \n\t" // Load address of A. + " ldr x1,%[baddr] \n\t" // Load address of B. + " ldr x2,%[caddr] \n\t" // Load address of C. + " \n\t" + " ldr x5,%[k_iter] \n\t" // Number of unrolled iterations (k_iter). + " ldr x6,%[k_left] \n\t" // Number of remaining iterations (k_left). + " \n\t" + " ldr x10,%[cs_c] \n\t" // Load cs_c. + " lsl x10,x10,#2 \n\t" // cs_c * sizeof(float) -- AUX. + " \n\t" + " ldr x14,%[rs_c] \n\t" // Load rs_c. + " lsl x14,x14,#2 \n\t" // rs_c * sizeof(float). + " \n\t" + " add x16,x2,x10 \n\t" //Load address Column 1 of C + " add x17,x16,x10 \n\t" //Load address Column 2 of C + " add x19,x17,x10 \n\t" //Load address Column 3 of C + " add x20,x19,x10 \n\t" //Load address Column 4 of C + " add x21,x20,x10 \n\t" //Load address Column 5 of C + " add x22,x21,x10 \n\t" //Load address Column 6 of C + " add x23,x22,x10 \n\t" //Load address Column 7 of C + " add x24,x23,x10 \n\t" //Load address Column 8 of C + " add x25,x24,x10 \n\t" //Load address Column 9 of C + " add x26,x25,x10 \n\t" //Load address Column 10 of C + " add x27,x26,x10 \n\t" //Load address Column 11 of C + " \n\t" + " prfm pldl1keep,[x2] \n\t" // Prefetch c. + " prfm pldl1keep,[x16] \n\t" // Prefetch c. + " prfm pldl1keep,[x17] \n\t" // Prefetch c. + " prfm pldl1keep,[x19] \n\t" // Prefetch c. + " prfm pldl1keep,[x20] \n\t" // Prefetch c. + " prfm pldl1keep,[x21] \n\t" // Prefetch c. + " prfm pldl1keep,[x22] \n\t" // Prefetch c. + " prfm pldl1keep,[x23] \n\t" // Prefetch c. + " prfm pldl1keep,[x24] \n\t" // Prefetch c. + " prfm pldl1keep,[x25] \n\t" // Prefetch c. + " prfm pldl1keep,[x26] \n\t" // Prefetch c. + " prfm pldl1keep,[x27] \n\t" // Prefetch c. + " \n\t" + " dup v8.4s, wzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #192] \n\t" + " dup v9.4s, wzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #256] \n\t" + " dup v10.4s, wzr \n\t" // Vector for accummulating column 1 + " prfm PLDL1KEEP, [x1, #320] \n\t" + " dup v11.4s, wzr \n\t" // Vector for accummulating column 1 + " dup v12.4s, wzr \n\t" // Vector for accummulating column 2 + " dup v13.4s, wzr \n\t" // Vector for accummulating column 2 + " \n\t" + " dup v14.4s, wzr \n\t" // Vector for accummulating column 3 + " prfm PLDL1KEEP, [x0, #128] \n\t" + " dup v15.4s, wzr \n\t" // Vector for accummulating column 3 + " prfm PLDL1KEEP, [x0, #192] \n\t" + " dup v16.4s, wzr \n\t" // Vector for accummulating column 4 + " dup v17.4s, wzr \n\t" // Vector for accummulating column 4 + " dup v18.4s, wzr \n\t" // Vector for accummulating column 5 + " dup v19.4s, wzr \n\t" // Vector for accummulating column 5 + " \n\t" + " dup v20.4s, wzr \n\t" // Vector for accummulating column 6 + " dup v21.4s, wzr \n\t" // Vector for accummulating column 6 + " dup v22.4s, wzr \n\t" // Vector for accummulating column 7 + " dup v23.4s, wzr \n\t" // Vector for accummulating column 7 + " dup v24.4s, wzr \n\t" // Vector for accummulating column 8 + " dup v25.4s, wzr \n\t" // Vector for accummulating column 8 + " \n\t" + " dup v26.4s, wzr \n\t" // Vector for accummulating column 9 + " dup v27.4s, wzr \n\t" // Vector for accummulating column 9 + " dup v28.4s, wzr \n\t" // Vector for accummulating column 10 + " dup v29.4s, wzr \n\t" // Vector for accummulating column 10 + " dup v30.4s, wzr \n\t" // Vector for accummulating column 11 + " dup v31.4s, wzr \n\t" // Vector for accummulating column 11 + " \n\t" + " cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. + BEQ(SCONSIDERKLEFT) + " \n\t" + " ldr q0, [x0] \n\t" + " ldr q1, [x0, #16] \n\t" // Load a + " \n\t" + " ldr q2, [x1] \n\t" // Load b + " ldr q3, [x1, #16] \n\t" + " ldr q4, [x1, #32] \n\t" + " \n\t" + " add x0, x0, #32 \n\t" //update address of A + " add x1, x1, #48 \n\t" //update address of B + " \n\t" + " cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. + BEQ(SLASTITER) // (as loop is do-while-like). + " \n\t" + LABEL(SLOOPKITER) // Body of the k_iter loop. + " \n\t" + " ldr q5, [x0] \n\t" + " fmla v8.4s, v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s, v1.4s,v2.s[0] \n\t" // Accummulate. + " ldr q6, [x0, #16] \n\t" + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1] \n\t" + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x1, #336] \n\t" + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x1, #400] \n\t" + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x1, #464] \n\t" + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #16] \n\t" + " \n\t" + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #32] \n\t" + " \n\t" //End It 1 + " \n\t" + " ldr q0, [x0, #32] \n\t" + " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. + " ldr q1, [x0, #48] \n\t" + " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #48] \n\t" + " \n\t" + " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x0, #224] \n\t" + " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x0, #288] \n\t" + " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #64] \n\t" + " \n\t" + " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #80] \n\t" + " \n\t" //End It 2 + " \n\t" + " ldr q5, [x0, #64] \n\t" + " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. + " ldr q6, [x0, #80] \n\t" + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #96] \n\t" + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #112] \n\t" + " \n\t" + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #128] \n\t" + " \n\t" //End It 3 + " \n\t" + " ldr q0, [x0, #96] \n\t" + " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. + " ldr q1, [x0, #112] \n\t" + " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #144] \n\t" + " \n\t" + " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #160] \n\t" + " \n\t" + " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #176] \n\t" + " add x1, x1, #192 \n\t" + " add x0, x0, #128 \n\t" + " \n\t" //End It 4 + " sub x5,x5,1 \n\t" // i-=1. + " cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. + BNE(SLOOPKITER) + " \n\t" + LABEL(SLASTITER) // Last iteration of k_iter loop. + " \n\t" + " \n\t" + " ldr q5, [x0] \n\t" + " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. + " ldr q6, [x0, #16] \n\t" + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1] \n\t" + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #16] \n\t" + " \n\t" + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #32] \n\t" + " \n\t" //End It 1 + " \n\t" + " ldr q0, [x0, #32] \n\t" + " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. + " ldr q1, [x0, #48] \n\t" + " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #48] \n\t" + " \n\t" + " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #64] \n\t" + " \n\t" + " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #80] \n\t" + " \n\t" //End It 2 + " \n\t" + " ldr q5, [x0, #64] \n\t" + " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. + " ldr q6, [x0, #80] \n\t" + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #96] \n\t" + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #112] \n\t" + " \n\t" + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #128] \n\t" + " \n\t" //End It 3 + " \n\t" + " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. + " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. + " add x1, x1, #144 \n\t" + " add x0, x0, #96 \n\t" + " \n\t" //End It 4 + " \n\t" + LABEL(SCONSIDERKLEFT) + " cmp x6,0 \n\t" // If k_left == 0, we are done. + BEQ(SPOSTACCUM) // else, we enter the k_left loop. + " \n\t" + LABEL(SLOOPKLEFT) // Body of the left iterations + " \n\t" + " ldr q0, [x0],#16 \n\t" + " ldr q1, [x0],#16 \n\t" // Load a + " \n\t" + " ldr q2, [x1],#16 \n\t" // Load b + " ldr q3, [x1],#16 \n\t" + " ldr q4, [x1],#16 \n\t" + " \n\t" + " sub x6,x6,1 \n\t" // i = i-1. + " \n\t" + " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " \n\t" + " cmp x6,0 \n\t" // Iterate again. + BNE(SLOOPKLEFT) // if i!=0. + " \n\t" + LABEL(SPOSTACCUM) + " \n\t" + " ldr x0,%[alpha] \n\t" // Alpha address. + " ldr x1,%[beta] \n\t" // Beta address. + " \n\t" + " ld1r {v6.4s},[x0] \n\t" // Load alpha. + " ld1r {v7.4s},[x1] \n\t" // Load beta + " \n\t" + " ldr x0,%[a_next] \n\t" // Pointer to next block of A. + " ldr x1,%[b_next] \n\t" // Pointer to next pointer of B. + " \n\t" " cmp x14,#4 \n\t" // If rs_c != 1 (column-major) BNE(SGENSTORED) " \n\t" -LABEL(SCOLSTORED) // C is column-major. -" \n\t" -" dup v0.4s, wzr \n\t" -" dup v1.4s, wzr \n\t" -" dup v2.4s, wzr \n\t" -" dup v3.4s, wzr \n\t" -" dup v4.4s, wzr \n\t" -" dup v5.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -BEQ(SBETAZEROCOLSTOREDS1) // Taking care of the beta==0 case. -" \n\t" -" ldr q0, [x2] \n\t" //Load column 0 of C -" ldr q1, [x2, #16] \n\t" -" ldr q2, [x16] \n\t" //Load column 1 of C -" ldr q3, [x16, #16] \n\t" -" ldr q4, [x17] \n\t" //Load column 2 of C -" ldr q5, [x17, #16] \n\t" -" \n\t" -" fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta -" fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta -" fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta -" fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta -" fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta -" fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -LABEL(SBETAZEROCOLSTOREDS1) -" \n\t" -" fmla v0.4s,v8.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v1.4s,v9.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v2.4s,v10.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v3.4s,v11.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v4.4s,v12.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v5.4s,v13.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" str q0, [x2] \n\t" //Store column 0 of C -" str q1, [x2, #16] \n\t" -" str q2, [x16] \n\t" //Store column 1 of C -" str q3, [x16, #16] \n\t" -" str q4, [x17] \n\t" //Store column 2 of C -" str q5, [x17, #16] \n\t" -" \n\t" -" dup v8.4s, wzr \n\t" -" dup v9.4s, wzr \n\t" -" dup v10.4s, wzr \n\t" -" dup v11.4s, wzr \n\t" -" dup v12.4s, wzr \n\t" -" dup v13.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -BEQ(SBETAZEROCOLSTOREDS2) // Taking care of the beta==0 case. -" \n\t" -" ldr q8, [x19] \n\t" //Load column 3 of C -" ldr q9, [x19, #16] \n\t" -" ldr q10, [x20] \n\t" //Load column 4 of C -" ldr q11, [x20, #16] \n\t" -" ldr q12, [x21] \n\t" //Load column 5 of C -" ldr q13, [x21, #16] \n\t" -" \n\t" -" fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta -" fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta -" fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta -" fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta -" fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta -" fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -LABEL(SBETAZEROCOLSTOREDS2) -" \n\t" -" fmla v8.4s, v14.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v9.4s, v15.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v10.4s,v16.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v11.4s,v17.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v12.4s,v18.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v13.4s,v19.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" str q8, [x19] \n\t" //Store column 3 of C -" str q9, [x19, #16] \n\t" -" str q10, [x20] \n\t" //Store column 4 of C -" str q11, [x20, #16] \n\t" -" str q12, [x21] \n\t" //Store column 5 of C -" str q13, [x21, #16] \n\t" -" \n\t" -" dup v0.4s, wzr \n\t" -" dup v1.4s, wzr \n\t" -" dup v2.4s, wzr \n\t" -" dup v3.4s, wzr \n\t" -" dup v4.4s, wzr \n\t" -" dup v5.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -BEQ(SBETAZEROCOLSTOREDS3) // Taking care of the beta==0 case. -" \n\t" -" ldr q0, [x22] \n\t" //Load column 6 of C -" ldr q1, [x22, #16] \n\t" -" ldr q2, [x23] \n\t" //Load column 7 of C -" ldr q3, [x23, #16] \n\t" -" ldr q4, [x24] \n\t" //Load column 8 of C -" ldr q5, [x24, #16] \n\t" -" \n\t" -" fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta -" fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta -" fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta -" fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta -" fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta -" fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -LABEL(SBETAZEROCOLSTOREDS3) -" \n\t" -" fmla v0.4s,v20.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v1.4s,v21.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v2.4s,v22.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v3.4s,v23.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v4.4s,v24.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v5.4s,v25.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" str q0, [x22] \n\t" //Store column 6 of C -" str q1, [x22, #16] \n\t" -" str q2, [x23] \n\t" //Store column 7 of C -" str q3, [x23, #16] \n\t" -" str q4, [x24] \n\t" //Store column 8 of C -" str q5, [x24, #16] \n\t" -" \n\t" -" dup v8.4s, wzr \n\t" -" dup v9.4s, wzr \n\t" -" dup v10.4s, wzr \n\t" -" dup v11.4s, wzr \n\t" -" dup v12.4s, wzr \n\t" -" dup v13.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -BEQ(SBETAZEROCOLSTOREDS4) // Taking care of the beta==0 case. -" \n\t" -" ldr q8, [x25] \n\t" //Load column 9 of C -" ldr q9, [x25, #16] \n\t" -" ldr q10, [x26] \n\t" //Load column 10 of C -" ldr q11, [x26, #16] \n\t" -" ldr q12, [x27] \n\t" //Load column 11 of C -" ldr q13, [x27, #16] \n\t" -" \n\t" -" fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta -" fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta -" fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta -" fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta -" fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta -" fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -LABEL(SBETAZEROCOLSTOREDS4) -" \n\t" -" prfm pldl2keep,[x0] \n\t" -" prfm pldl2keep,[x1] \n\t" -" \n\t" -" fmla v8.4s, v26.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v9.4s, v27.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v10.4s,v28.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v11.4s,v29.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v12.4s,v30.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v13.4s,v31.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" str q8, [x25] \n\t" //Store column 9 of C -" str q9, [x25, #16] \n\t" -" str q10, [x26] \n\t" //Store column 10 of C -" str q11, [x26, #16] \n\t" -" str q12, [x27] \n\t" //Store column 11 of C -" str q13, [x27, #16] \n\t" -" \n\t" + LABEL(SCOLSTORED) // C is column-major. + " \n\t" + " dup v0.4s, wzr \n\t" + " dup v1.4s, wzr \n\t" + " dup v2.4s, wzr \n\t" + " dup v3.4s, wzr \n\t" + " dup v4.4s, wzr \n\t" + " dup v5.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROCOLSTOREDS1) // Taking care of the beta==0 case. + " \n\t" + " ldr q0, [x2] \n\t" //Load column 0 of C + " ldr q1, [x2, #16] \n\t" + " ldr q2, [x16] \n\t" //Load column 1 of C + " ldr q3, [x16, #16] \n\t" + " ldr q4, [x17] \n\t" //Load column 2 of C + " ldr q5, [x17, #16] \n\t" + " \n\t" + " fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta + " fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta + " fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta + " fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta + " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta + " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROCOLSTOREDS1) + " \n\t" + " fmla v0.4s,v8.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v1.4s,v9.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v2.4s,v10.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v3.4s,v11.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v4.4s,v12.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v5.4s,v13.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " str q0, [x2] \n\t" //Store column 0 of C + " str q1, [x2, #16] \n\t" + " str q2, [x16] \n\t" //Store column 1 of C + " str q3, [x16, #16] \n\t" + " str q4, [x17] \n\t" //Store column 2 of C + " str q5, [x17, #16] \n\t" + " \n\t" + " dup v8.4s, wzr \n\t" + " dup v9.4s, wzr \n\t" + " dup v10.4s, wzr \n\t" + " dup v11.4s, wzr \n\t" + " dup v12.4s, wzr \n\t" + " dup v13.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROCOLSTOREDS2) // Taking care of the beta==0 case. + " \n\t" + " ldr q8, [x19] \n\t" //Load column 3 of C + " ldr q9, [x19, #16] \n\t" + " ldr q10, [x20] \n\t" //Load column 4 of C + " ldr q11, [x20, #16] \n\t" + " ldr q12, [x21] \n\t" //Load column 5 of C + " ldr q13, [x21, #16] \n\t" + " \n\t" + " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta + " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta + " fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta + " fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta + " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta + " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROCOLSTOREDS2) + " \n\t" + " fmla v8.4s, v14.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v9.4s, v15.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v10.4s,v16.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v11.4s,v17.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v12.4s,v18.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v13.4s,v19.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " str q8, [x19] \n\t" //Store column 3 of C + " str q9, [x19, #16] \n\t" + " str q10, [x20] \n\t" //Store column 4 of C + " str q11, [x20, #16] \n\t" + " str q12, [x21] \n\t" //Store column 5 of C + " str q13, [x21, #16] \n\t" + " \n\t" + " dup v0.4s, wzr \n\t" + " dup v1.4s, wzr \n\t" + " dup v2.4s, wzr \n\t" + " dup v3.4s, wzr \n\t" + " dup v4.4s, wzr \n\t" + " dup v5.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROCOLSTOREDS3) // Taking care of the beta==0 case. + " \n\t" + " ldr q0, [x22] \n\t" //Load column 6 of C + " ldr q1, [x22, #16] \n\t" + " ldr q2, [x23] \n\t" //Load column 7 of C + " ldr q3, [x23, #16] \n\t" + " ldr q4, [x24] \n\t" //Load column 8 of C + " ldr q5, [x24, #16] \n\t" + " \n\t" + " fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta + " fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta + " fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta + " fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta + " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta + " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROCOLSTOREDS3) + " \n\t" + " fmla v0.4s,v20.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v1.4s,v21.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v2.4s,v22.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v3.4s,v23.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v4.4s,v24.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v5.4s,v25.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " str q0, [x22] \n\t" //Store column 6 of C + " str q1, [x22, #16] \n\t" + " str q2, [x23] \n\t" //Store column 7 of C + " str q3, [x23, #16] \n\t" + " str q4, [x24] \n\t" //Store column 8 of C + " str q5, [x24, #16] \n\t" + " \n\t" + " dup v8.4s, wzr \n\t" + " dup v9.4s, wzr \n\t" + " dup v10.4s, wzr \n\t" + " dup v11.4s, wzr \n\t" + " dup v12.4s, wzr \n\t" + " dup v13.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROCOLSTOREDS4) // Taking care of the beta==0 case. + " \n\t" + " ldr q8, [x25] \n\t" //Load column 9 of C + " ldr q9, [x25, #16] \n\t" + " ldr q10, [x26] \n\t" //Load column 10 of C + " ldr q11, [x26, #16] \n\t" + " ldr q12, [x27] \n\t" //Load column 11 of C + " ldr q13, [x27, #16] \n\t" + " \n\t" + " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta + " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta + " fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta + " fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta + " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta + " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROCOLSTOREDS4) + " \n\t" + " prfm pldl2keep,[x0] \n\t" + " prfm pldl2keep,[x1] \n\t" + " \n\t" + " fmla v8.4s, v26.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v9.4s, v27.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v10.4s,v28.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v11.4s,v29.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v12.4s,v30.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v13.4s,v31.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " str q8, [x25] \n\t" //Store column 9 of C + " str q9, [x25, #16] \n\t" + " str q10, [x26] \n\t" //Store column 10 of C + " str q11, [x26, #16] \n\t" + " str q12, [x27] \n\t" //Store column 11 of C + " str q13, [x27, #16] \n\t" + " \n\t" " \n\t" BRANCH(SEND) // Done. " \n\t" @@ -1051,37 +1055,38 @@ LABEL(SBETAZEROGENSTOREDS4) " st1 {v13.s}[2],[x5],x14 \n\t" // Store c116 into quad and increment by rs_c. " st1 {v13.s}[3],[x5],x14 \n\t" // Store c147 into quad and increment by rs_c. " \n\t" -LABEL(SEND) // Done! -" \n\t" -:// output operands (none) -:// input operands - [aaddr] "m" (a), // 0 - [baddr] "m" (b), // 1 - [caddr] "m" (c), // 2 - [k_iter] "m" (k_iter), // 3 - [k_left] "m" (k_left), // 4 - [alpha] "m" (alpha), // 5 - [beta] "m" (beta), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [a_next] "m" (a_next), // 9 - [b_next] "m" (b_next) // 10 -:// Register clobber list - "x0", "x1", "x2", - "x5", "x6", "x10","x14", - "x16","x17","x19","x20", - "x21","x22","x23","x24", - "x25","x26","x27", - "v0", "v1", "v2", "v3", - "v4", "v5", "v6", "v7", - "v8", "v9", "v10","v11", - "v12","v13","v14","v15", - "v16","v17","v18","v19", - "v20","v21","v22","v23", - "v24","v25","v26","v27", - "v28","v29","v30","v31" -); + LABEL(SEND) // Done! + " \n\t" + :// output operands (none) + :// input operands + [aaddr] "m" (a), // 0 + [baddr] "m" (b), // 1 + [caddr] "m" (c), // 2 + [k_iter] "m" (k_iter), // 3 + [k_left] "m" (k_left), // 4 + [alpha] "m" (alpha), // 5 + [beta] "m" (beta), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [a_next] "m" (a_next), // 9 + [b_next] "m" (b_next) // 10 + :// Register clobber list + "x0", "x1", "x2", + "x5", "x6", "x10","x14", + "x16","x17","x19","x20", + "x21","x22","x23","x24", + "x25","x26","x27", + "v0", "v1", "v2", "v3", + "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11", + "v12","v13","v14","v15", + "v16","v17","v18","v19", + "v20","v21","v22","v23", + "v24","v25","v26","v27", + "v28","v29","v30","v31" + ); + GEMM_UKR_FLUSH_CT( s ); } @@ -1089,24 +1094,26 @@ LABEL(SEND) // Done! o 4x4 Double precision micro-kernel NOT fully functional yet. o Runnable on ARMv8, compiled with aarch64 GCC. o Use it together with the armv8 BLIS configuration. - o Tested on Juno board. Around 3 GFLOPS @ 1.1 GHz. + o Tested on Juno board. Around 3 GFLOPS @ 1.1 GHz. December 2014. - + * UPDATE OCTOBER 2015: Now is fully functional. * Tested on Juno board. Around 5.6 GFLOPS, 2 A57 cores @ 1.1 GHz. * Tested on Juno board. Around 4 GFLOPS, 4 A53 cores @ 850 MHz. - + * UPDATE NOVEMBER 2015 * Micro-kernel changed to 6x8 * Tested on Juno Board. Around 4 GFLOPS, 1 x A57 core @ 1.1 GHz. * Tested on Juno Board. Around 7.6 GFLOPS, 2 x A57 cores @ 1.1 GHz. - * Tested on Juno board. Around 1.5 GFLOPS, 1 x A53 core @ 850 MHz. + * Tested on Juno board. Around 1.5 GFLOPS, 1 x A53 core @ 850 MHz. * Tested on Juno board. Around 5.5 GFLOPS, 4 x A53 cores @ 850 MHz. */ void bli_dgemm_armv8a_asm_6x8 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -1121,672 +1128,674 @@ void bli_dgemm_armv8a_asm_6x8 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; -__asm__ volatile -( -" \n\t" -" ldr x0,%[aaddr] \n\t" // Load address of A -" ldr x1,%[baddr] \n\t" // Load address of B -" ldr x2,%[caddr] \n\t" // Load address of C -" \n\t" -" ldr x5,%[k_iter] \n\t" // Init guard (k_iter) -" ldr x6,%[k_left] \n\t" // Init guard (k_iter) -" \n\t" -" ldr x10,%[cs_c] \n\t" // Load cs_c -" lsl x10,x10,#3 \n\t" // cs_c * sizeof(double) -" \n\t" -" ldr x14,%[rs_c] \n\t" // Load rs_c. -" lsl x14,x14,#3 \n\t" // rs_c * sizeof(double). -" \n\t" -" add x20,x2,x10 \n\t" //Load address Column 1 of C -" add x21,x20,x10 \n\t" //Load address Column 2 of C -" add x22,x21,x10 \n\t" //Load address Column 3 of C -" add x23,x22,x10 \n\t" //Load address Column 4 of C -" add x24,x23,x10 \n\t" //Load address Column 5 of C -" add x25,x24,x10 \n\t" //Load address Column 6 of C -" add x26,x25,x10 \n\t" //Load address Column 7 of C -" \n\t" -" prfm pldl1keep,[x2] \n\t" // Prefetch c. -" prfm pldl1keep,[x20] \n\t" // Prefetch c. -" prfm pldl1keep,[x21] \n\t" // Prefetch c. -" prfm pldl1keep,[x22] \n\t" // Prefetch c. -" prfm pldl1keep,[x23] \n\t" // Prefetch c. -" prfm pldl1keep,[x24] \n\t" // Prefetch c. -" prfm pldl1keep,[x25] \n\t" // Prefetch c. -" prfm pldl1keep,[x26] \n\t" // Prefetch c. -" \n\t" -" dup v8.2d, xzr \n\t" // Vector for accummulating column 0 -" prfm PLDL1KEEP, [x1, #256] \n\t" -" dup v9.2d, xzr \n\t" // Vector for accummulating column 0 -" prfm PLDL1KEEP, [x1, #320] \n\t" -" dup v10.2d, xzr \n\t" // Vector for accummulating column 0 -" prfm PLDL1KEEP, [x1, #384] \n\t" -" dup v11.2d, xzr \n\t" // Vector for accummulating column 1 -" prfm PLDL1KEEP, [x1, #448] \n\t" -" dup v12.2d, xzr \n\t" // Vector for accummulating column 1 -" dup v13.2d, xzr \n\t" // Vector for accummulating column 1 -" \n\t" -" dup v14.2d, xzr \n\t" // Vector for accummulating column 2 -" prfm PLDL1KEEP, [x0, #192] \n\t" -" dup v15.2d, xzr \n\t" // Vector for accummulating column 2 -" prfm PLDL1KEEP, [x0, #256] \n\t" -" dup v16.2d, xzr \n\t" // Vector for accummulating column 2 -" prfm PLDL1KEEP, [x0, #320] \n\t" -" dup v17.2d, xzr \n\t" // Vector for accummulating column 3 -" dup v18.2d, xzr \n\t" // Vector for accummulating column 3 -" dup v19.2d, xzr \n\t" // Vector for accummulating column 3 -" \n\t" -" dup v20.2d, xzr \n\t" // Vector for accummulating column 4 -" dup v21.2d, xzr \n\t" // Vector for accummulating column 4 -" dup v22.2d, xzr \n\t" // Vector for accummulating column 4 -" dup v23.2d, xzr \n\t" // Vector for accummulating column 5 -" dup v24.2d, xzr \n\t" // Vector for accummulating column 5 -" dup v25.2d, xzr \n\t" // Vector for accummulating column 5 -" \n\t" -" dup v26.2d, xzr \n\t" // Vector for accummulating column 6 -" dup v27.2d, xzr \n\t" // Vector for accummulating column 6 -" dup v28.2d, xzr \n\t" // Vector for accummulating column 6 -" dup v29.2d, xzr \n\t" // Vector for accummulating column 7 -" dup v30.2d, xzr \n\t" // Vector for accummulating column 7 -" dup v31.2d, xzr \n\t" // Vector for accummulating column 7 -" \n\t" -" \n\t" -" cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. -BEQ(DCONSIDERKLEFT) -" \n\t" -" ldr q0, [x0] \n\t" // Load a -" ldr q1, [x0, #16] \n\t" -" ldr q2, [x0, #32] \n\t" -" \n\t" -" ldr q3, [x1] \n\t" // Load b -" ldr q4, [x1, #16] \n\t" -" ldr q5, [x1, #32] \n\t" -" ldr q6, [x1, #48] \n\t" -" \n\t" -" add x0, x0, #48 \n\t" //update address of A -" add x1, x1, #64 \n\t" //update address of B -" \n\t" -" cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. -BEQ(DLASTITER) // (as loop is do-while-like). -" \n\t" -LABEL(DLOOP) // Body -" \n\t" -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x1, #448] \n\t" //512-64=448 -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x1, #512] \n\t" -" fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x1, #576] \n\t" -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate -" ldr q3, [x1] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate -" ldr q7, [x0, #32] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate -" ldr q4, [x1, #16] \n\t" -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #32] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #16] \n\t" -" \n\t" -" fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #48] \n\t" -" \n\t" // End it 1 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x1, #640] \n\t" -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x0, #336] \n\t" -" fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x0, #400] \n\t" -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate -" ldr q3, [x1, #64] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate -" ldr q2, [x0, #80] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate -" ldr q4, [x1, #80] \n\t" -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #96] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0, #48] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #64] \n\t" -" \n\t" -" fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #112] \n\t" -" \n\t" //End it 2 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" prfm PLDL1KEEP, [x0, #464] \n\t" -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate -" ldr q3, [x1, #128] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate -" ldr q7, [x0, #128] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate -" ldr q4, [x1, #144] \n\t" -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #160] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0, #96] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #112] \n\t" -" \n\t" -" fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #176] \n\t" -" \n\t" // End it 3 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate -" ldr q3, [x1, #192] \n\t" -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate -" ldr q2, [x0, #176] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate -" ldr q4, [x1, #208] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #224] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0, #144] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #160] \n\t" -" \n\t" -" fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #240] \n\t" -" \n\t" //End it 4 -" add x0, x0, #192 \n\t" -" add x1, x1, #256 \n\t" -" \n\t" -" sub x5,x5,1 \n\t" // i-=1 -" cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. -BNE(DLOOP) -" \n\t" -LABEL(DLASTITER) -" \n\t" -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate -" ldr q3, [x1] \n\t" -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate -" ldr q7, [x0, #32] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate -" ldr q4, [x1, #16] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #32] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #16] \n\t" -" \n\t" -" fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #48] \n\t" -" \n\t" // End it 1 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate -" ldr q3, [x1, #64] \n\t" -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate -" ldr q2, [x0, #80] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate -" ldr q4, [x1, #80] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #96] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0, #48] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #64] \n\t" -" \n\t" -" fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #112] \n\t" -" \n\t" //End it 2 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate -" ldr q3, [x1, #128] \n\t" -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate -" ldr q7, [x0, #128] \n\t" -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate -" ldr q4, [x1, #144] \n\t" -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate -" ldr q5, [x1, #160] \n\t" -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" ldr q0, [x0, #96] \n\t" -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" ldr q1, [x0, #112] \n\t" -" \n\t" -" fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate -" ldr q6, [x1, #176] \n\t" -" \n\t" // End it 3 -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" add x1, x1, #192 \n\t" -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" \n\t" -" fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate -" \n\t" //End it 4 -" add x0, x0, #144 \n\t" -" \n\t" -LABEL(DCONSIDERKLEFT) -" cmp x6,0 \n\t" // If k_left == 0, we are done. -BEQ(DPOSTACCUM) // else, we enter the k_left loop. -" \n\t" -LABEL(DLOOPKLEFT) -" \n\t" -" ldr q0, [x0],#16 \n\t" -" ldr q1, [x0],#16 \n\t" // Load a -" ldr q2, [x0],#16 \n\t" -" \n\t" -" ldr q3, [x1],#16 \n\t" // Load b -" ldr q4, [x1],#16 \n\t" -" ldr q5, [x1],#16 \n\t" -" ldr q6, [x1],#16 \n\t" -" \n\t" -" sub x6,x6,1 \n\t" -" \n\t" -" fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate -" fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate -" fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate -" \n\t" -" fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate -" fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate -" fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate -" \n\t" -" fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate -" fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate -" fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate -" \n\t" -" fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate -" fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate -" fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate -" \n\t" -" fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate -" fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate -" fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate -" \n\t" -" fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate -" fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate -" fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate -" \n\t" -" fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate -" fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate -" \n\t" -" fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate -" fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate -" \n\t" -" fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate -" fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate -" \n\t" -" cmp x6,0 \n\t" // Iterate again. -BNE(DLOOPKLEFT) // if i!=0. -" \n\t" -LABEL(DPOSTACCUM) -" \n\t" -" ldr x0,%[alpha] \n\t" // Alpha address -" ldr x1,%[beta] \n\t" // Beta address -" \n\t" -" ld1r {v6.2d},[x0] \n\t" // Load alpha. -" ld1r {v7.2d},[x1] \n\t" // Load beta -" \n\t" -" ldr x0,%[a_next] \n\t" // Next A address for later use. -" ldr x1,%[b_next] \n\t" // Next B address for later use. -" \n\t" + GEMM_UKR_SETUP_CT( d, 6, 8, false ); + + __asm__ volatile + ( + " \n\t" + " ldr x0,%[aaddr] \n\t" // Load address of A + " ldr x1,%[baddr] \n\t" // Load address of B + " ldr x2,%[caddr] \n\t" // Load address of C + " \n\t" + " ldr x5,%[k_iter] \n\t" // Init guard (k_iter) + " ldr x6,%[k_left] \n\t" // Init guard (k_iter) + " \n\t" + " ldr x10,%[cs_c] \n\t" // Load cs_c + " lsl x10,x10,#3 \n\t" // cs_c * sizeof(double) + " \n\t" + " ldr x14,%[rs_c] \n\t" // Load rs_c. + " lsl x14,x14,#3 \n\t" // rs_c * sizeof(double). + " \n\t" + " add x20,x2,x10 \n\t" //Load address Column 1 of C + " add x21,x20,x10 \n\t" //Load address Column 2 of C + " add x22,x21,x10 \n\t" //Load address Column 3 of C + " add x23,x22,x10 \n\t" //Load address Column 4 of C + " add x24,x23,x10 \n\t" //Load address Column 5 of C + " add x25,x24,x10 \n\t" //Load address Column 6 of C + " add x26,x25,x10 \n\t" //Load address Column 7 of C + " \n\t" + " prfm pldl1keep,[x2] \n\t" // Prefetch c. + " prfm pldl1keep,[x20] \n\t" // Prefetch c. + " prfm pldl1keep,[x21] \n\t" // Prefetch c. + " prfm pldl1keep,[x22] \n\t" // Prefetch c. + " prfm pldl1keep,[x23] \n\t" // Prefetch c. + " prfm pldl1keep,[x24] \n\t" // Prefetch c. + " prfm pldl1keep,[x25] \n\t" // Prefetch c. + " prfm pldl1keep,[x26] \n\t" // Prefetch c. + " \n\t" + " dup v8.2d, xzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #256] \n\t" + " dup v9.2d, xzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #320] \n\t" + " dup v10.2d, xzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #384] \n\t" + " dup v11.2d, xzr \n\t" // Vector for accummulating column 1 + " prfm PLDL1KEEP, [x1, #448] \n\t" + " dup v12.2d, xzr \n\t" // Vector for accummulating column 1 + " dup v13.2d, xzr \n\t" // Vector for accummulating column 1 + " \n\t" + " dup v14.2d, xzr \n\t" // Vector for accummulating column 2 + " prfm PLDL1KEEP, [x0, #192] \n\t" + " dup v15.2d, xzr \n\t" // Vector for accummulating column 2 + " prfm PLDL1KEEP, [x0, #256] \n\t" + " dup v16.2d, xzr \n\t" // Vector for accummulating column 2 + " prfm PLDL1KEEP, [x0, #320] \n\t" + " dup v17.2d, xzr \n\t" // Vector for accummulating column 3 + " dup v18.2d, xzr \n\t" // Vector for accummulating column 3 + " dup v19.2d, xzr \n\t" // Vector for accummulating column 3 + " \n\t" + " dup v20.2d, xzr \n\t" // Vector for accummulating column 4 + " dup v21.2d, xzr \n\t" // Vector for accummulating column 4 + " dup v22.2d, xzr \n\t" // Vector for accummulating column 4 + " dup v23.2d, xzr \n\t" // Vector for accummulating column 5 + " dup v24.2d, xzr \n\t" // Vector for accummulating column 5 + " dup v25.2d, xzr \n\t" // Vector for accummulating column 5 + " \n\t" + " dup v26.2d, xzr \n\t" // Vector for accummulating column 6 + " dup v27.2d, xzr \n\t" // Vector for accummulating column 6 + " dup v28.2d, xzr \n\t" // Vector for accummulating column 6 + " dup v29.2d, xzr \n\t" // Vector for accummulating column 7 + " dup v30.2d, xzr \n\t" // Vector for accummulating column 7 + " dup v31.2d, xzr \n\t" // Vector for accummulating column 7 + " \n\t" + " \n\t" + " cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. + BEQ(DCONSIDERKLEFT) + " \n\t" + " ldr q0, [x0] \n\t" // Load a + " ldr q1, [x0, #16] \n\t" + " ldr q2, [x0, #32] \n\t" + " \n\t" + " ldr q3, [x1] \n\t" // Load b + " ldr q4, [x1, #16] \n\t" + " ldr q5, [x1, #32] \n\t" + " ldr q6, [x1, #48] \n\t" + " \n\t" + " add x0, x0, #48 \n\t" //update address of A + " add x1, x1, #64 \n\t" //update address of B + " \n\t" + " cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. + BEQ(DLASTITER) // (as loop is do-while-like). + " \n\t" + LABEL(DLOOP) // Body + " \n\t" + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x1, #448] \n\t" //512-64=448 + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x1, #512] \n\t" + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x1, #576] \n\t" + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " ldr q3, [x1] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " ldr q7, [x0, #32] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " ldr q4, [x1, #16] \n\t" + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #32] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #16] \n\t" + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #48] \n\t" + " \n\t" // End it 1 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x1, #640] \n\t" + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x0, #336] \n\t" + " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x0, #400] \n\t" + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate + " ldr q3, [x1, #64] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate + " ldr q2, [x0, #80] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate + " ldr q4, [x1, #80] \n\t" + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #96] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #48] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #64] \n\t" + " \n\t" + " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #112] \n\t" + " \n\t" //End it 2 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x0, #464] \n\t" + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " ldr q3, [x1, #128] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " ldr q7, [x0, #128] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " ldr q4, [x1, #144] \n\t" + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #160] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #96] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #112] \n\t" + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #176] \n\t" + " \n\t" // End it 3 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate + " ldr q3, [x1, #192] \n\t" + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate + " ldr q2, [x0, #176] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate + " ldr q4, [x1, #208] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #224] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #144] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #160] \n\t" + " \n\t" + " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #240] \n\t" + " \n\t" //End it 4 + " add x0, x0, #192 \n\t" + " add x1, x1, #256 \n\t" + " \n\t" + " sub x5,x5,1 \n\t" // i-=1 + " cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. + BNE(DLOOP) + " \n\t" + LABEL(DLASTITER) + " \n\t" + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " ldr q3, [x1] \n\t" + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " ldr q7, [x0, #32] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " ldr q4, [x1, #16] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #32] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #16] \n\t" + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #48] \n\t" + " \n\t" // End it 1 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate + " ldr q3, [x1, #64] \n\t" + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate + " ldr q2, [x0, #80] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate + " ldr q4, [x1, #80] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #96] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #48] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #64] \n\t" + " \n\t" + " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #112] \n\t" + " \n\t" //End it 2 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " ldr q3, [x1, #128] \n\t" + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " ldr q7, [x0, #128] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " ldr q4, [x1, #144] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #160] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #96] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #112] \n\t" + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #176] \n\t" + " \n\t" // End it 3 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " add x1, x1, #192 \n\t" + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate + " \n\t" //End it 4 + " add x0, x0, #144 \n\t" + " \n\t" + LABEL(DCONSIDERKLEFT) + " cmp x6,0 \n\t" // If k_left == 0, we are done. + BEQ(DPOSTACCUM) // else, we enter the k_left loop. + " \n\t" + LABEL(DLOOPKLEFT) + " \n\t" + " ldr q0, [x0],#16 \n\t" + " ldr q1, [x0],#16 \n\t" // Load a + " ldr q2, [x0],#16 \n\t" + " \n\t" + " ldr q3, [x1],#16 \n\t" // Load b + " ldr q4, [x1],#16 \n\t" + " ldr q5, [x1],#16 \n\t" + " ldr q6, [x1],#16 \n\t" + " \n\t" + " sub x6,x6,1 \n\t" + " \n\t" + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " cmp x6,0 \n\t" // Iterate again. + BNE(DLOOPKLEFT) // if i!=0. + " \n\t" + LABEL(DPOSTACCUM) + " \n\t" + " ldr x0,%[alpha] \n\t" // Alpha address + " ldr x1,%[beta] \n\t" // Beta address + " \n\t" + " ld1r {v6.2d},[x0] \n\t" // Load alpha. + " ld1r {v7.2d},[x1] \n\t" // Load beta + " \n\t" + " ldr x0,%[a_next] \n\t" // Next A address for later use. + " ldr x1,%[b_next] \n\t" // Next B address for later use. + " \n\t" " cmp x14,#8 \n\t" // If rs_c != 1 (column-major) BNE(DGENSTORED) " \n\t" -LABEL(DCOLSTORED) // C is column-major. -" \n\t" -" dup v0.2d, xzr \n\t" -" dup v1.2d, xzr \n\t" -" dup v2.2d, xzr \n\t" -" dup v3.2d, xzr \n\t" -" dup v4.2d, xzr \n\t" -" dup v5.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -BEQ(DBETAZEROCOLSTOREDS1) // Taking care of the beta==0 case. -" \n\t" -" ldr q0, [x2] \n\t" //Load column 0 of C -" ldr q1, [x2, #16] \n\t" -" ldr q2, [x2, #32] \n\t" -" \n\t" -" ldr q3, [x20] \n\t" //Load column 1 of C -" ldr q4, [x20, #16] \n\t" -" ldr q5, [x20, #32] \n\t" -" \n\t" -" fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta -" fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta -" fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta -" fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta -" fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta -" fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -LABEL(DBETAZEROCOLSTOREDS1) -" \n\t" -" fmla v0.2d,v8.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v1.2d,v9.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v2.2d,v10.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v3.2d,v11.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v4.2d,v12.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v5.2d,v13.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" str q0, [x2] \n\t" //Store column 0 of C -" str q1, [x2, #16] \n\t" -" str q2, [x2, #32] \n\t" -" \n\t" -" str q3, [x20] \n\t" //Store column 1 of C -" str q4, [x20, #16] \n\t" -" str q5, [x20, #32] \n\t" -" \n\t" -" dup v8.2d, xzr \n\t" -" dup v9.2d, xzr \n\t" -" dup v10.2d, xzr \n\t" -" dup v11.2d, xzr \n\t" -" dup v12.2d, xzr \n\t" -" dup v13.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -BEQ(DBETAZEROCOLSTOREDS2) // Taking care of the beta==0 case. -" \n\t" -" ldr q8, [x21] \n\t" //Load column 2 of C -" ldr q9, [x21, #16] \n\t" -" ldr q10, [x21, #32] \n\t" -" \n\t" -" ldr q11, [x22] \n\t" //Load column 3 of C -" ldr q12, [x22, #16] \n\t" -" ldr q13, [x22, #32] \n\t" -" \n\t" -" fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta -" fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta -" fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta -" fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta -" fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta -" fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -LABEL(DBETAZEROCOLSTOREDS2) -" \n\t" -" fmla v8.2d, v14.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v9.2d, v15.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v10.2d,v16.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v11.2d,v17.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v12.2d,v18.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v13.2d,v19.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" str q8, [x21] \n\t" //Store column 2 of C -" str q9, [x21, #16] \n\t" -" str q10, [x21, #32] \n\t" -" \n\t" -" str q11, [x22] \n\t" //Store column 3 of C -" str q12, [x22, #16] \n\t" -" str q13, [x22, #32] \n\t" -" \n\t" -" dup v0.2d, xzr \n\t" -" dup v1.2d, xzr \n\t" -" dup v2.2d, xzr \n\t" -" dup v3.2d, xzr \n\t" -" dup v4.2d, xzr \n\t" -" dup v5.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -BEQ(DBETAZEROCOLSTOREDS3) // Taking care of the beta==0 case. -" \n\t" -" ldr q0, [x23] \n\t" //Load column 4 of C -" ldr q1, [x23, #16] \n\t" -" ldr q2, [x23, #32] \n\t" -" \n\t" -" ldr q3, [x24] \n\t" //Load column 5 of C -" ldr q4, [x24, #16] \n\t" -" ldr q5, [x24, #32] \n\t" -" \n\t" -" fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta -" fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta -" fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta -" fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta -" fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta -" fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -LABEL(DBETAZEROCOLSTOREDS3) -" \n\t" -" fmla v0.2d,v20.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v1.2d,v21.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v2.2d,v22.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v3.2d,v23.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v4.2d,v24.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v5.2d,v25.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" str q0, [x23] \n\t" //Store column 4 of C -" str q1, [x23, #16] \n\t" -" str q2, [x23, #32] \n\t" -" \n\t" -" str q3, [x24] \n\t" //Store column 5 of C -" str q4, [x24, #16] \n\t" -" str q5, [x24, #32] \n\t" -" \n\t" -" dup v8.2d, xzr \n\t" -" dup v9.2d, xzr \n\t" -" dup v10.2d, xzr \n\t" -" dup v11.2d, xzr \n\t" -" dup v12.2d, xzr \n\t" -" dup v13.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -BEQ(DBETAZEROCOLSTOREDS4) // Taking care of the beta==0 case. -" \n\t" -" ldr q8, [x25] \n\t" //Load column 6 of C -" ldr q9, [x25, #16] \n\t" -" ldr q10, [x25, #32] \n\t" -" \n\t" -" ldr q11, [x26] \n\t" //Load column 7 of C -" ldr q12, [x26, #16] \n\t" -" ldr q13, [x26, #32] \n\t" -" \n\t" -" fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta -" fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta -" fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta -" fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta -" fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta -" fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -LABEL(DBETAZEROCOLSTOREDS4) -" \n\t" -" prfm pldl2keep,[x0] \n\t" -" prfm pldl2keep,[x1] \n\t" -" \n\t" -" fmla v8.2d, v26.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v9.2d, v27.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v10.2d,v28.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v11.2d,v29.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v12.2d,v30.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v13.2d,v31.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" str q8, [x25] \n\t" //Store column 6 of C -" str q9, [x25, #16] \n\t" -" str q10, [x25, #32] \n\t" -" \n\t" -" str q11, [x26] \n\t" //Store column 7 of C -" str q12, [x26, #16] \n\t" -" str q13, [x26, #32] \n\t" -" \n\t" + LABEL(DCOLSTORED) // C is column-major. + " \n\t" + " dup v0.2d, xzr \n\t" + " dup v1.2d, xzr \n\t" + " dup v2.2d, xzr \n\t" + " dup v3.2d, xzr \n\t" + " dup v4.2d, xzr \n\t" + " dup v5.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROCOLSTOREDS1) // Taking care of the beta==0 case. + " \n\t" + " ldr q0, [x2] \n\t" //Load column 0 of C + " ldr q1, [x2, #16] \n\t" + " ldr q2, [x2, #32] \n\t" + " \n\t" + " ldr q3, [x20] \n\t" //Load column 1 of C + " ldr q4, [x20, #16] \n\t" + " ldr q5, [x20, #32] \n\t" + " \n\t" + " fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta + " fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta + " fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta + " fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta + " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta + " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROCOLSTOREDS1) + " \n\t" + " fmla v0.2d,v8.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v1.2d,v9.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v2.2d,v10.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v3.2d,v11.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v4.2d,v12.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v5.2d,v13.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " str q0, [x2] \n\t" //Store column 0 of C + " str q1, [x2, #16] \n\t" + " str q2, [x2, #32] \n\t" + " \n\t" + " str q3, [x20] \n\t" //Store column 1 of C + " str q4, [x20, #16] \n\t" + " str q5, [x20, #32] \n\t" + " \n\t" + " dup v8.2d, xzr \n\t" + " dup v9.2d, xzr \n\t" + " dup v10.2d, xzr \n\t" + " dup v11.2d, xzr \n\t" + " dup v12.2d, xzr \n\t" + " dup v13.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROCOLSTOREDS2) // Taking care of the beta==0 case. + " \n\t" + " ldr q8, [x21] \n\t" //Load column 2 of C + " ldr q9, [x21, #16] \n\t" + " ldr q10, [x21, #32] \n\t" + " \n\t" + " ldr q11, [x22] \n\t" //Load column 3 of C + " ldr q12, [x22, #16] \n\t" + " ldr q13, [x22, #32] \n\t" + " \n\t" + " fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta + " fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta + " fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta + " fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta + " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta + " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROCOLSTOREDS2) + " \n\t" + " fmla v8.2d, v14.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v9.2d, v15.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v10.2d,v16.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v11.2d,v17.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v12.2d,v18.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v13.2d,v19.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " str q8, [x21] \n\t" //Store column 2 of C + " str q9, [x21, #16] \n\t" + " str q10, [x21, #32] \n\t" + " \n\t" + " str q11, [x22] \n\t" //Store column 3 of C + " str q12, [x22, #16] \n\t" + " str q13, [x22, #32] \n\t" + " \n\t" + " dup v0.2d, xzr \n\t" + " dup v1.2d, xzr \n\t" + " dup v2.2d, xzr \n\t" + " dup v3.2d, xzr \n\t" + " dup v4.2d, xzr \n\t" + " dup v5.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROCOLSTOREDS3) // Taking care of the beta==0 case. + " \n\t" + " ldr q0, [x23] \n\t" //Load column 4 of C + " ldr q1, [x23, #16] \n\t" + " ldr q2, [x23, #32] \n\t" + " \n\t" + " ldr q3, [x24] \n\t" //Load column 5 of C + " ldr q4, [x24, #16] \n\t" + " ldr q5, [x24, #32] \n\t" + " \n\t" + " fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta + " fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta + " fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta + " fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta + " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta + " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROCOLSTOREDS3) + " \n\t" + " fmla v0.2d,v20.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v1.2d,v21.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v2.2d,v22.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v3.2d,v23.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v4.2d,v24.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v5.2d,v25.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " str q0, [x23] \n\t" //Store column 4 of C + " str q1, [x23, #16] \n\t" + " str q2, [x23, #32] \n\t" + " \n\t" + " str q3, [x24] \n\t" //Store column 5 of C + " str q4, [x24, #16] \n\t" + " str q5, [x24, #32] \n\t" + " \n\t" + " dup v8.2d, xzr \n\t" + " dup v9.2d, xzr \n\t" + " dup v10.2d, xzr \n\t" + " dup v11.2d, xzr \n\t" + " dup v12.2d, xzr \n\t" + " dup v13.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROCOLSTOREDS4) // Taking care of the beta==0 case. + " \n\t" + " ldr q8, [x25] \n\t" //Load column 6 of C + " ldr q9, [x25, #16] \n\t" + " ldr q10, [x25, #32] \n\t" + " \n\t" + " ldr q11, [x26] \n\t" //Load column 7 of C + " ldr q12, [x26, #16] \n\t" + " ldr q13, [x26, #32] \n\t" + " \n\t" + " fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta + " fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta + " fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta + " fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta + " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta + " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROCOLSTOREDS4) + " \n\t" + " prfm pldl2keep,[x0] \n\t" + " prfm pldl2keep,[x1] \n\t" + " \n\t" + " fmla v8.2d, v26.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v9.2d, v27.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v10.2d,v28.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v11.2d,v29.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v12.2d,v30.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v13.2d,v31.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " str q8, [x25] \n\t" //Store column 6 of C + " str q9, [x25, #16] \n\t" + " str q10, [x25, #32] \n\t" + " \n\t" + " str q11, [x26] \n\t" //Store column 7 of C + " str q12, [x26, #16] \n\t" + " str q13, [x26, #32] \n\t" + " \n\t" BRANCH(DEND) " \n\t" LABEL(DGENSTORED) // C is general-stride stored. @@ -2042,45 +2051,46 @@ LABEL(DBETAZEROGENSTOREDS4) " st1 {v13.d}[0],[x27],x14 \n\t" // Store c74 into quad and increment by rs_c. " st1 {v13.d}[1],[x27],x14 \n\t" // Store c75 into quad and increment by rs_c. " \n\t" -LABEL(DEND) // Done! -" \n\t" -:// output operands (none) -:// input operands - [aaddr] "m" (a), // 0 - [baddr] "m" (b), // 1 - [caddr] "m" (c), // 2 - [k_iter] "m" (k_iter), // 3 - [k_left] "m" (k_left), // 4 - [alpha] "m" (alpha), // 5 - [beta] "m" (beta), // 6 - [rs_c] "m" (rs_c), // 6 - [cs_c] "m" (cs_c), // 7 - [a_next] "m" (a_next), // 8 - [b_next] "m" (b_next) // 9 -:// Register clobber list - "x0","x1","x2", - "x5","x6","x10", - "x14","x16","x17", - "x20","x21","x22","x23","x24","x25","x26","x27", - "v0","v1","v2", - "v3","v4","v5", - "v6","v7","v8", - "v9","v10","v11", - "v12","v13","v14", - "v15","v16","v17","v18","v19", - "v20","v21","v22","v23", - "v24","v25","v26","v27", - "v28","v29","v30","v31" -); - - + LABEL(DEND) // Done! + " \n\t" + :// output operands (none) + :// input operands + [aaddr] "m" (a), // 0 + [baddr] "m" (b), // 1 + [caddr] "m" (c), // 2 + [k_iter] "m" (k_iter), // 3 + [k_left] "m" (k_left), // 4 + [alpha] "m" (alpha), // 5 + [beta] "m" (beta), // 6 + [rs_c] "m" (rs_c), // 6 + [cs_c] "m" (cs_c), // 7 + [a_next] "m" (a_next), // 8 + [b_next] "m" (b_next) // 9 + :// Register clobber list + "x0","x1","x2", + "x5","x6","x10", + "x14","x16","x17", + "x20","x21","x22","x23","x24","x25","x26","x27", + "v0","v1","v2", + "v3","v4","v5", + "v6","v7","v8", + "v9","v10","v11", + "v12","v13","v14", + "v15","v16","v17","v18","v19", + "v20","v21","v22","v23", + "v24","v25","v26","v27", + "v28","v29","v30","v31" + ); + GEMM_UKR_FLUSH_CT( d ); } #if 0 void bli_cgemm_armv8a_opt_4x4 ( + dim_t m, + dim_t n, dim_t k, scomplex* restrict alpha, scomplex* restrict a, @@ -2095,6 +2105,8 @@ void bli_cgemm_armv8a_opt_4x4 void bli_zgemm_armv8a_opt_4x4 ( + dim_t m, + dim_t n, dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, diff --git a/kernels/bgq/3/bli_gemm_bgq_int_8x8.c b/kernels/bgq/3/bli_gemm_bgq_int_8x8.c index 1612e69b0d..15e3e072f3 100644 --- a/kernels/bgq/3/bli_gemm_bgq_int_8x8.c +++ b/kernels/bgq/3/bli_gemm_bgq_int_8x8.c @@ -56,6 +56,8 @@ void bli_dgemm_bgq_int_8x8 ( + dim_t m, + dim_t n, dim_t k, double* restrict alpha, double* restrict a, @@ -66,6 +68,8 @@ void bli_dgemm_bgq_int_8x8 cntx_t* restrict cntx ) { + GEMM_UKR_SETUP_CT_ANY( d, 8, 8, false ); + //Registers for storing C. //4 4x4 subblocks of C, c00, c01, c10, c11 //4 registers per subblock: a, b, c, d @@ -201,6 +205,8 @@ void bli_dgemm_bgq_int_8x8 UPDATE( AB, c, 0 ); AB = vec_perm( c11d, c11d, pattern ); UPDATE( AB, c, 4 ); + + GEMM_UKR_FLUSH_CT( d ); } void printvec(vector4double v) @@ -214,6 +220,8 @@ void printvec(vector4double v) void bli_zgemm_bgq_int_4x4 ( + dim_t m, + dim_t n, dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, @@ -224,6 +232,8 @@ void bli_zgemm_bgq_int_4x4 cntx_t* restrict cntx ) { + GEMM_UKR_SETUP_CT_ANY( z, 4, 4, false ); + double* a_d = ( double* )a; double* b_d = ( double* )b; double* c_d = ( double* )c; @@ -368,4 +378,6 @@ void bli_zgemm_bgq_int_4x4 c_d += 2*cs_c; ZUPDATE( c03a, c03b, c_d, 0 ); ZUPDATE( c13a, c13b, c_d, 4 ); + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c b/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c index 403aaaaeef..7eba05a138 100644 --- a/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c +++ b/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c @@ -90,7 +90,9 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -102,25 +104,27 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ALIGNED( s, 8, 8, false, 32 ); + begin_asm() - + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. - + vmovaps(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovsldup(mem(rbx, 0*32), ymm2) // elements of a and b. vpermilps(imm(0x4e), ymm2, ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float) lea(mem(rcx, rdi, 4), r10) // load address of c + 4*cs_c; - + lea(mem(rdi, rdi, 2), r14) // r14 = 3*cs_c; prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*cs_c @@ -130,7 +134,7 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 prefetch(0, mem(r10, rdi, 1, 7*8)) // prefetch c + 5*cs_c prefetch(0, mem(r10, rdi, 2, 7*8)) // prefetch c + 6*cs_c prefetch(0, mem(r10, r14, 1, 7*8)) // prefetch c + 7*cs_c - + vxorps(ymm8, ymm8, ymm8) vxorps(ymm9, ymm9, ymm9) vxorps(ymm10, ymm10, ymm10) @@ -139,15 +143,15 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 vxorps(ymm13, ymm13, ymm13) vxorps(ymm14, ymm14, ymm14) vxorps(ymm15, ymm15, ymm15) - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - + label(.SLOOPKITER) // MAIN LOOP - + // iteration 0 prefetch(0, mem(rax, 16*32)) vfmaddps(ymm15, ymm0, ymm2, ymm15) @@ -155,44 +159,44 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 vmovshdup(mem(rbx, 0*32), ymm2) vfmaddps(ymm13, ymm0, ymm3, ymm13) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vmovaps(mem(rax, 1*32), ymm1) vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm11, ymm0, ymm4, ymm11) vfmaddps(ymm9, ymm0, ymm5, ymm9) - + vfmaddps(ymm14, ymm0, ymm2, ymm14) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 1*32), ymm2) vfmaddps(ymm12, ymm0, ymm3, ymm12) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm10, ymm0, ymm4, ymm10) vfmaddps(ymm8, ymm0, ymm5, ymm8) - + // iteration 1 vfmaddps(ymm15, ymm1, ymm2, ymm15) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovshdup(mem(rbx, 1*32), ymm2) vfmaddps(ymm13, ymm1, ymm3, ymm13) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vmovaps(mem(rax, 2*32), ymm0) vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm11, ymm1, ymm4, ymm11) vfmaddps(ymm9, ymm1, ymm5, ymm9) - + vfmaddps(ymm14, ymm1, ymm2, ymm14) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 2*32), ymm2) vfmaddps(ymm12, ymm1, ymm3, ymm12) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm10, ymm1, ymm4, ymm10) vfmaddps(ymm8, ymm1, ymm5, ymm8) - + // iteration 2 prefetch(0, mem(rax, 18*32)) vfmaddps(ymm15, ymm0, ymm2, ymm15) @@ -200,23 +204,23 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 vmovshdup(mem(rbx, 2*32), ymm2) vfmaddps(ymm13, ymm0, ymm3, ymm13) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vmovaps(mem(rax, 3*32), ymm1) add(imm(4*8*4), rax) // a += 4*8 (unroll x mr) vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm11, ymm0, ymm4, ymm11) vfmaddps(ymm9, ymm0, ymm5, ymm9) - + vfmaddps(ymm14, ymm0, ymm2, ymm14) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 3*32), ymm2) vfmaddps(ymm12, ymm0, ymm3, ymm12) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm10, ymm0, ymm4, ymm10) vfmaddps(ymm8, ymm0, ymm5, ymm8) - + // iteration 3 vfmaddps(ymm15, ymm1, ymm2, ymm15) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) @@ -224,134 +228,134 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 add(imm(4*8*4), rbx) // b += 4*8 (unroll x nr) vfmaddps(ymm13, ymm1, ymm3, ymm13) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vmovaps(mem(rax, 0*32), ymm0) vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm11, ymm1, ymm4, ymm11) vfmaddps(ymm9, ymm1, ymm5, ymm9) - + vfmaddps(ymm14, ymm1, ymm2, ymm14) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 0*32), ymm2) vfmaddps(ymm12, ymm1, ymm3, ymm12) vperm2f128(imm(0x03), ymm3, ymm3, ymm5) - + vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm10, ymm1, ymm4, ymm10) vfmaddps(ymm8, ymm1, ymm5, ymm8) - - - + + + dec(rsi) // i -= 1; jne(.SLOOPKITER) // iterate again if i != 0. - - - - + + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - + label(.SLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 16*32)) vfmaddps(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmovshdup(mem(rbx, 0*32), ymm2) vfmaddps(ymm13, ymm0, ymm3, ymm13) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + vmovaps(mem(rax, 1*32), ymm1) add(imm(8*1*4), rax) // a += 8 (1 x mr) vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm11, ymm0, ymm4, ymm11) vfmaddps(ymm9, ymm0, ymm5, ymm9) - + vfmaddps(ymm14, ymm0, ymm2, ymm14) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 1*32), ymm2) add(imm(8*1*4), rbx) // b += 8 (1 x nr) vfmaddps(ymm12, ymm0, ymm3, ymm12) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + vpermilps(imm(0x4e), ymm2, ymm3) vfmaddps(ymm10, ymm0, ymm4, ymm10) vfmaddps(ymm8, ymm0, ymm5, ymm8) vmovaps(ymm1, ymm0) - - + + dec(rsi) // i -= 1; jne(.SLOOPKLEFT) // iterate again if i != 0. - - + + label(.SPOSTACCUM) // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab22 ab20 ab26 ab24 // ab32 ab30 ab36 ab34 // ab44 ab46 ab40 ab42 - // ab54 ab56 ab50 ab52 + // ab54 ab56 ab50 ab52 // ab66 ab64 ab62 ab60 // ab76 ) ab74 ) ab72 ) ab70 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab23 ab21 ab27 ab25 // ab33 ab31 ab37 ab35 // ab45 ab47 ab41 ab43 - // ab55 ab57 ab51 ab53 + // ab55 ab57 ab51 ab53 // ab67 ab65 ab63 ab61 // ab77 ) ab75 ) ab73 ) ab71 ) GROUP_YMM_BY_4 // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab20 ab22 ab24 ab26 // ab30 ab32 ab34 ab36 // ab44 ab46 ab40 ab42 - // ab54 ab56 ab50 ab52 + // ab54 ab56 ab50 ab52 // ab64 ab66 ab60 ab62 // ab74 ) ab76 ) ab70 ) ab72 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab21 ab23 ab25 ab27 // ab31 ab33 ab35 ab37 // ab45 ab47 ab41 ab43 - // ab55 ab57 ab51 ab53 + // ab55 ab57 ab51 ab53 // ab65 ab67 ab61 ab63 // ab75 ) ab77 ) ab71 ) ab73 ) // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab20 ab22 ab24 ab26 // ab30 ab32 ab34 ab36 // ab40 ab42 ab44 ab46 - // ab50 ab52 ab54 ab56 + // ab50 ab52 ab54 ab56 // ab60 ab62 ab64 ab66 // ab70 ) ab72 ) ab74 ) ab76 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab21 ab23 ab25 ab27 // ab31 ab33 ab35 ab37 // ab41 ab43 ab45 ab47 - // ab51 ab53 ab55 ab57 + // ab51 ab53 ab55 ab57 // ab61 ab63 ab65 ab67 // ab71 ) ab73 ) ab75 ) ab77 ) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rax), ymm0) // load alpha and duplicate vbroadcastss(mem(rbx), ymm4) // load beta and duplicate - + vmulps(ymm0, ymm8, ymm8) // scale by alpha vmulps(ymm0, ymm9, ymm9) vmulps(ymm0, ymm10, ymm10) @@ -360,384 +364,98 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 vmulps(ymm0, ymm13, ymm13) vmulps(ymm0, ymm14, ymm14) vmulps(ymm0, ymm15, ymm15) - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float) - - lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - - lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; - lea(mem(r12, rsi, 1), r13) // r13 = 3*rs_c; - - - // determine if - // c % 32 == 0, AND - // 4*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(4), rsi) // set ZF if (4*rs_c) == 4. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (4*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + + // now avoid loading C if beta == 0 - + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm4) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORED) // jump to column storage case - - - label(.SGENSTORED) - // update c00:c70 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vfmaddps(ymm15, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm14, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm14, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c02:c72 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm13, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm13, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c03:c73 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm12, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm12, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c04:c74 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm11, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm11, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c05:c75 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm10, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm10, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c06:c76 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm9, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm9, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c07:c77 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm8, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm8, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - - STORE_SS - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - - vmovaps(mem(rcx), ymm0) // load c00:c70, -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm15, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm15, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - vmovaps(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm1) // load c01:c71, -// vmulps(ymm4, ymm1, ymm1) // scale by beta, -// vaddps(ymm14, ymm1, ymm1) // add the gemm result, - vfmaddps(ymm14, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, - vmovaps(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm0) // load c02:c72, -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm13, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm13, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - vmovaps(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm1) // load c03:c73, -// vmulps(ymm4, ymm1, ymm1) // scale by beta, -// vaddps(ymm12, ymm1, ymm1) // add the gemm result, - vfmaddps(ymm12, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, - vmovaps(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm0) // load c04:c74, -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm11, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm11, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - vmovaps(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm1) // load c05:c75, -// vmulps(ymm4, ymm1, ymm1) // scale by beta, -// vaddps(ymm10, ymm1, ymm1) // add the gemm result, - vfmaddps(ymm10, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, - vmovaps(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm0) // load c06:c76, -// vmulps(ymm4, ymm0, ymm0) // scale by beta, -// vaddps(ymm9, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm9, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - vmovaps(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm1) // load c07:c77, -// vmulps(ymm4, ymm1, ymm1) // scale by beta, -// vaddps(ymm8, ymm1, ymm1) // add the gemm result, - vfmaddps(ymm8, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, - vmovaps(ymm1, mem(rcx)) // and store back to memory. - - jmp(.SDONE) // jump to end. - - + + vmovaps(mem(rcx), ymm0) // load c00:c70, + // vmulps(ymm4, ymm0, ymm0) // scale by beta, + // vaddps(ymm15, ymm0, ymm0) // add the gemm result, + vfmaddps(ymm15, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, + vmovaps(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm1) // load c01:c71, + // vmulps(ymm4, ymm1, ymm1) // scale by beta, + // vaddps(ymm14, ymm1, ymm1) // add the gemm result, + vfmaddps(ymm14, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, + vmovaps(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm0) // load c02:c72, + // vmulps(ymm4, ymm0, ymm0) // scale by beta, + // vaddps(ymm13, ymm0, ymm0) // add the gemm result, + vfmaddps(ymm13, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, + vmovaps(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm1) // load c03:c73, + // vmulps(ymm4, ymm1, ymm1) // scale by beta, + // vaddps(ymm12, ymm1, ymm1) // add the gemm result, + vfmaddps(ymm12, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, + vmovaps(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm0) // load c04:c74, + // vmulps(ymm4, ymm0, ymm0) // scale by beta, + // vaddps(ymm11, ymm0, ymm0) // add the gemm result, + vfmaddps(ymm11, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, + vmovaps(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm1) // load c05:c75, + // vmulps(ymm4, ymm1, ymm1) // scale by beta, + // vaddps(ymm10, ymm1, ymm1) // add the gemm result, + vfmaddps(ymm10, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, + vmovaps(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm0) // load c06:c76, + // vmulps(ymm4, ymm0, ymm0) // scale by beta, + // vaddps(ymm9, ymm0, ymm0) // add the gemm result, + vfmaddps(ymm9, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, + vmovaps(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm1) // load c07:c77, + // vmulps(ymm4, ymm1, ymm1) // scale by beta, + // vaddps(ymm8, ymm1, ymm1) // add the gemm result, + vfmaddps(ymm8, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, + vmovaps(ymm1, mem(rcx)) // and store back to memory. + + jmp(.SDONE) // jump to end. + label(.SBETAZERO) - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORBZ) // jump to column storage case - - - label(.SGENSTORBZ) - // update c00:c70 - vmovapd(ymm15, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - vmovapd(ymm14, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - vmovapd(ymm13, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - vmovapd(ymm12, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c04:c74 - vmovapd(ymm11, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c05:c75 - vmovapd(ymm10, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c06:c76 - vmovapd(ymm9, ymm0) - STORE_SS - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - // update c07:c77 - vmovapd(ymm8, ymm0) - STORE_SS - - jmp(.SDONE) // jump to end. - - - label(.SCOLSTORBZ) - - vmovaps(ymm15, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm14, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm13, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm12, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm11, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm10, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm9, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm8, mem(rcx)) // and store back to memory. - + + vmovaps(ymm15, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm14, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm13, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm12, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm11, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm10, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm9, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm8, mem(rcx)) // and store back to memory. + label(.SDONE) - + end_asm( : // output operands (none) @@ -754,7 +472,7 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 [b_next] "m" (b_next), // 9 [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -762,6 +480,8 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } #undef KERNEL4x6_1 @@ -862,7 +582,9 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 void bli_dgemm_bulldozer_asm_4x6_fma4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -874,33 +596,35 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 12; - uint64_t k_left = k0 % 12; + uint64_t k_iter = k / 12; + uint64_t k_left = k % 12; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ANY( d, 4, 6, false ); + begin_asm() - - + + vzeroall() mov(var(b), rbx) // load address of b. mov(var(a), rax) // load address of a. prefetch(0, mem(rax, 64)) - - + + vmovaps(mem(rbx, 0*8), xmm1) vmovaps(mem(rbx, 2*8), xmm2) vmovaps(mem(rbx, 4*8), xmm3) add(imm(12*8), rbx) add(imm(8*8), rax) - + mov(var(k_iter), rsi) // i = k_iter; notice var(k_iter) not $0 test(rsi, rsi) je(.CONSIDERKLEFT) - + ALIGN32 label(.LOOPKITER) // MAIN LOOP - + KERNEL4x6_1(xx) KERNEL4x6_2(xx) KERNEL4x6_3(xx) @@ -913,27 +637,27 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 KERNEL4x6_2(xx) KERNEL4x6_3(xx) KERNEL4x6_4(xx) - + dec(rsi) jne(.LOOPKITER) - + label(.CONSIDERKLEFT) - + mov(var(k_left), rsi) - test(rsi, rsi) + test(rsi, rsi) label(.LOOPKLEFT) je(.POSTACCUM) - + KERNEL4x6_1(xx) add(imm(6*8), rbx) add(imm(4*8), rax) - + dec(rsi) jmp(.LOOPKLEFT) // iterate again if i != 0. - + label(.POSTACCUM) - - + + mov(var(rs_c), rsi) // load cs_c mov(var(cs_c), rdi) // load rs_c vmovddup(mem(var(alpha)), xmm2) //load alpha @@ -942,32 +666,32 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 sal(imm(3), rsi) // cs_c *= sizeof(double) sal(imm(3), rdi) // rs_c *= sizeof(double) lea(mem(rcx, rdi, 2), rdx) - - vmovlpd(mem(rcx), xmm0, xmm0) - vmovlpd(mem(rdx), xmm1, xmm1) + + vmovlpd(mem(rcx), xmm0, xmm0) + vmovlpd(mem(rdx), xmm1, xmm1) vmovhpd(mem(rcx, rdi, 1), xmm0, xmm0) vmovhpd(mem(rdx, rdi, 1), xmm1, xmm1) lea(mem(rdx, rdi, 2), r8) vmulpd(xmm2, xmm4, xmm4) // scale by alpha, vmulpd(xmm2, xmm5, xmm5) // scale by alpha, vfmaddpd(xmm4, xmm0, xmm3, xmm4) // scale by beta, and add the gemm result - vmovlpd(mem(r8), xmm0, xmm0) + vmovlpd(mem(r8), xmm0, xmm0) vfmaddpd(xmm5, xmm1, xmm3, xmm5) // scale by beta, and add the gemm result vmovhpd(mem(r8, rdi, 1), xmm0, xmm0) vmovlpd(xmm4, mem(rcx)) // and store back to memory. vmovlpd(xmm5, mem(rdx)) // and store back to memory. vmovhpd(xmm4, mem(rcx, rdi, 1)) - add(rsi, rcx) + add(rsi, rcx) vmovhpd(xmm5, mem(rdx, rdi, 1)) - add(rsi, rdx) - + add(rsi, rdx) + vmulpd(xmm2, xmm6, xmm6) // scale by alpha, vfmaddpd(xmm6, xmm0, xmm3, xmm6) // scale by beta, and add the gemm result vmovlpd(xmm6, mem(r8)) // and store back to memory. vmovhpd(xmm6, mem(r8, rdi, 1)) - add(rsi, r8) - - + add(rsi, r8) + + vmovlpd(mem(rcx), xmm0, xmm0) vmovlpd(mem(rdx), xmm1, xmm1) vmovlpd(mem(r8), xmm4, xmm4) @@ -984,13 +708,13 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 vmovlpd(xmm8, mem(rdx)) // and store back to memory. vmovlpd(xmm9, mem(r8)) // and store back to memory. vmovhpd(xmm7, mem(rcx, rdi, 1)) - add(rsi, rcx) + add(rsi, rcx) vmovhpd(xmm8, mem(rdx, rdi, 1)) - add(rsi, rdx) + add(rsi, rdx) vmovhpd(xmm9, mem(r8, rdi, 1)) - add(rsi, r8) - - + add(rsi, r8) + + vmovlpd(mem(rcx), xmm0, xmm0) vmovlpd(mem(rdx), xmm1, xmm1) vmovlpd(mem(r8), xmm4, xmm4) @@ -1007,13 +731,13 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 vmovlpd(xmm11, mem(rdx)) // and store back to memory. vmovlpd(xmm12, mem(r8)) // and store back to memory. vmovhpd(xmm10, mem(rcx, rdi, 1)) - add(rsi, rcx) + add(rsi, rcx) vmovhpd(xmm11, mem(rdx, rdi, 1)) - add(rsi, rdx) + add(rsi, rdx) vmovhpd(xmm12, mem(r8, rdi, 1)) - add(rsi, r8) - - + add(rsi, r8) + + vmovlpd(mem(rcx), xmm0, xmm0) vmovlpd(mem(rdx), xmm1, xmm1) vmovlpd(mem(r8), xmm4, xmm4) @@ -1031,7 +755,7 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 vmovlpd(xmm15, mem(r8)) // and store back to memory. vmovhpd(xmm13, mem(rcx, rdi, 1)) vmovhpd(xmm14, mem(rdx, rdi, 1)) - vmovhpd(xmm15, mem(r8, rdi, 1)) + vmovhpd(xmm15, mem(r8, rdi, 1)) end_asm( : // output operands (none) @@ -1055,6 +779,8 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } //The parameter "i" is the iteration number, i.e. the B values to read #define MADD_TO_YMM(i) \ @@ -1076,7 +802,9 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 void bli_cgemm_bulldozer_asm_8x4_fma4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -1091,33 +819,35 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ALIGNED( c, 8, 4, false, 32 ); + begin_asm() - + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. //mov(var(a_next), r14) // load address of a_next. sub(imm(4*64), r15) - + vmovaps(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovsldup(mem(rbx, 0*32), ymm2) vpermilps(imm(0x4e), ymm2, ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(scomplex) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*cs_c prefetch(0, mem(r10, 3*8)) // prefetch c + 2*cs_c prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c - + vxorps(ymm8, ymm8, ymm8) vxorps(ymm9, ymm9, ymm9) vxorps(ymm10, ymm10, ymm10) @@ -1126,343 +856,312 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 vxorps(ymm13, ymm13, ymm13) vxorps(ymm14, ymm14, ymm14) vxorps(ymm15, ymm15, ymm15) - + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.CCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - + label(.CLOOPKITER) // MAIN LOOP - + add(imm(4*4*8), r15) // b_next += 4*4 (unroll x nr) - + // iteration 0 prefetch(0, mem(rax, 8*32)) vmovaps(mem(rax, 1*32), ymm1) MADD_TO_YMM(0) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vaddsubps(ymm6, ymm15, ymm15) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 1*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 2*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + // iteration 1 prefetch(0, mem(rax, 10*32)) vmovaps(mem(rax, 3*32), ymm1) MADD_TO_YMM(1) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 2*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 4*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - + // iteration 2 prefetch(0, mem(rax, 12*32)) vmovaps(mem(rax, 5*32), ymm1) MADD_TO_YMM(2) prefetch(0, mem(r15, 2*32)) // prefetch b_next[2*4] - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 3*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 6*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + // iteration 3 prefetch(0, mem(rax, 14*32)) vmovaps(mem(rax, 7*32), ymm1) MADD_TO_YMM(3) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 4*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 8*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + add(imm(8*4*8), rax) // a += 8*4 (unroll x mr) add(imm(4*4*8), rbx) // b += 4*4 (unroll x nr) - - + + dec(rsi) // i -= 1; jne(.CLOOPKITER) // iterate again if i != 0. - - - + + + label(.CCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.CPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.CLOOPKLEFT) // EDGE LOOP - + // iteration 0 prefetch(0, mem(rax, 8*32)) vmovaps(mem(rax, 1*32), ymm1) MADD_TO_YMM(0) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 1*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 2*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + add(imm(8*1*8), rax) // a += 8 (1 x mr) add(imm(4*1*8), rbx) // b += 4 (1 x nr) - - + + dec(rsi) // i -= 1; jne(.CLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.CPOSTACCUM) - + // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab21 ab20 ab23 ab22 - // ab31 ab30 ab33 ab32 - // ab42 ab43 ab40 ab41 - // ab52 ab53 ab50 ab51 - // ab63 ab62 ab61 ab60 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab21 ab20 ab23 ab22 + // ab31 ab30 ab33 ab32 + // ab42 ab43 ab40 ab41 + // ab52 ab53 ab50 ab51 + // ab63 ab62 ab61 ab60 // ab73 ) ab72 ) ab71 ) ab70 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba1 aba0 aba3 aba2 - // abb1 abb0 abb3 abb2 - // abc2 abc3 abc0 abc1 - // abd2 abd3 abd0 abd1 - // abe3 abe2 abe1 abe0 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba1 aba0 aba3 aba2 + // abb1 abb0 abb3 abb2 + // abc2 abc3 abc0 abc1 + // abd2 abd3 abd0 abd1 + // abe3 abe2 abe1 abe0 // abf3 abf2 abf1 abf0 ) GROUP_YMM_BY_4 // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab20 ab21 ab22 ab23 - // ab30 ab31 ab32 ab33 - // ab42 ab43 ab40 ab41 - // ab52 ab53 ab50 ab51 - // ab62 ab63 ab60 ab61 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab20 ab21 ab22 ab23 + // ab30 ab31 ab32 ab33 + // ab42 ab43 ab40 ab41 + // ab52 ab53 ab50 ab51 + // ab62 ab63 ab60 ab61 // ab72 ) ab73 ) ab70 ) ab71 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba0 aba1 aba2 aba3 - // abb0 abb1 abb2 abb3 - // abc2 abc3 abc0 abc1 - // abd2 abd3 abd0 abd1 - // abe2 abe3 abe0 abe1 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba0 aba1 aba2 aba3 + // abb0 abb1 abb2 abb3 + // abc2 abc3 abc0 abc1 + // abd2 abd3 abd0 abd1 + // abe2 abe3 abe0 abe1 // abf2 ) abf3 ) abf0 ) abf1 ) - + // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab20 ab21 ab22 ab23 - // ab30 ab31 ab32 ab33 - // ab40 ab41 ab42 ab43 - // ab50 ab51 ab52 ab53 - // ab60 ab61 ab62 ab63 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab20 ab21 ab22 ab23 + // ab30 ab31 ab32 ab33 + // ab40 ab41 ab42 ab43 + // ab50 ab51 ab52 ab53 + // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba0 aba1 aba2 aba3 - // abb0 abb1 abb2 abb3 - // abc0 abc1 abc2 abc3 - // abd0 abd1 abd2 abd3 - // abe0 abe1 abe2 abe3 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba0 aba1 aba2 aba3 + // abb0 abb1 abb2 abb3 + // abc0 abc1 abc2 abc3 + // abd0 abd1 abd2 abd3 + // abe0 abe1 abe2 abe3 // abf0 ) abf1 ) abf2 ) abf3 ) - + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vbroadcastss(mem(rax), ymm7) // load alpha_r and duplicate vbroadcastss(mem(rax, 4), ymm6) // load alpha_i and duplicate - + vpermilps(imm(0xb1), ymm15, ymm3) vmulps(ymm7, ymm15, ymm15) vmulps(ymm6, ymm3, ymm3) vaddsubps(ymm3, ymm15, ymm15) - + vpermilps(imm(0xb1), ymm14, ymm2) vmulps(ymm7, ymm14, ymm14) vmulps(ymm6, ymm2, ymm2) vaddsubps(ymm2, ymm14, ymm14) - + vpermilps(imm(0xb1), ymm13, ymm1) vmulps(ymm7, ymm13, ymm13) vmulps(ymm6, ymm1, ymm1) vaddsubps(ymm1, ymm13, ymm13) - + vpermilps(imm(0xb1), ymm12, ymm0) vmulps(ymm7, ymm12, ymm12) vmulps(ymm6, ymm0, ymm0) vaddsubps(ymm0, ymm12, ymm12) - + vpermilps(imm(0xb1), ymm11, ymm3) vmulps(ymm7, ymm11, ymm11) vmulps(ymm6, ymm3, ymm3) vaddsubps(ymm3, ymm11, ymm11) - + vpermilps(imm(0xb1), ymm10, ymm2) vmulps(ymm7, ymm10, ymm10) vmulps(ymm6, ymm2, ymm2) vaddsubps(ymm2, ymm10, ymm10) - + vpermilps(imm(0xb1), ymm9, ymm1) vmulps(ymm7, ymm9, ymm9) vmulps(ymm6, ymm1, ymm1) vaddsubps(ymm1, ymm9, ymm9) - + vpermilps(imm(0xb1), ymm8, ymm0) vmulps(ymm7, ymm8, ymm8) vmulps(ymm6, ymm0, ymm0) vaddsubps(ymm0, ymm8, ymm8) - - - - + + + + mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rbx), ymm7) // load beta_r and duplicate vbroadcastss(mem(rbx, 4), ymm6) // load beta_i and duplicate - - - - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(scomplex) - - lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - - lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; - lea(mem(r12, rsi, 1), r13) // r13 = 3*rs_c; - - - - // determine if - // c % 32 == 0, AND - // 8*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (8*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + + // now avoid loading C if beta == 0 - + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm7) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -1470,371 +1169,109 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 0, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.CCOLSTORED) // jump to column storage case - - - - label(.CGENSTORED) - - // update c00:c70 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c00,10) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c20,30) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c40,50) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c60,70) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c00,c10) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c20,c30) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c40,c50) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c60,c70) - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c80,90) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca0,b0) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc0,d0) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce0,f0) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c80,c90) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca0,cb0) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc0,cd0) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce0,cf0) - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c01,11) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c21,31) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c41,51) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c61,71) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c01,c11) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c21,c31) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c41,c51) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c61,c71) - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c81,91) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca1,b1) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc1,d1) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce1,f1) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c81,c91) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca1,cb1) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc1,cd1) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce1,cf1) - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c02,12) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c22,32) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c42,52) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c62,72) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c02,c12) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c22,c32) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c42,c52) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c62,c72) - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c82,92) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca2,b2) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc2,d2) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce2,f2) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c82,c92) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca2,cb2) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc2,cd2) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce2,cf2) - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c03,13) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c23,33) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c43,53) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c63,73) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c03,c13) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c23,c33) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c43,c53) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c63,c73) - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c83,93) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca3,b3) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc3,d3) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce3,f3) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c83,c93) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca3,cb3) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc3,cd3) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce3,cf3) - add(rdi, rdx) // c += cs_c; - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORED) - - // update c00:c70 - - vmovaps(mem(rcx), ymm0) // load c00:c70 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx)) // store c00:c70 - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - - vmovaps(mem(rdx), ymm0) // load c80:f0 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rdx)) // store c80:cf0 - add(rdi, rdx) // c += cs_c; - - // update c00:c70 - - vmovaps(mem(rcx), ymm0) // load c01:c71 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx)) // store c01:c71 - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - - vmovaps(mem(rdx), ymm0) // load c81:f1 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rdx)) // store c81:cf1 - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - vmovaps(mem(rcx), ymm0) // load c02:c72 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx)) // store c02:c72 - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - vmovaps(mem(rdx), ymm0) // load c82:f2 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rdx)) // store c82:cf2 - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - vmovaps(mem(rcx), ymm0) // load c03:c73 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx)) // store c03:c73 - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - vmovaps(mem(rdx), ymm0) // load c83:f3 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rdx)) // store c83:cf3 - add(rdi, rdx) // c += cs_c; - - jmp(.CDONE) // jump to end. - - + + // update c00:c70 + + vmovaps(mem(rcx), ymm0) // load c00:c70 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx)) // store c00:c70 + + // update c80:cf0 + + vmovaps(mem(rcx,32), ymm0) // load c80:f0 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx,32)) // store c80:cf0 + add(rdi, rcx) // c += cs_c; + + // update c00:c70 + + vmovaps(mem(rcx), ymm0) // load c01:c71 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx)) // store c01:c71 + + // update c81:cf1 + + vmovaps(mem(rcx,32), ymm0) // load c81:f1 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx,32)) // store c81:cf1 + add(rdi, rcx) // c += cs_c; + + // update c02:c72 + vmovaps(mem(rcx), ymm0) // load c02:c72 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx)) // store c02:c72 + + // update c82:cf2 + vmovaps(mem(rcx,32), ymm0) // load c82:f2 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx,32)) // store c82:cf2 + add(rdi, rcx) // c += cs_c; + + // update c03:c73 + vmovaps(mem(rcx), ymm0) // load c03:c73 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx)) // store c03:c73 + + // update c83:cf3 + vmovaps(mem(rcx,32), ymm0) // load c83:f3 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx,32)) // store c83:cf3 + //add(rdi, rcx) // c += cs_c; + + jmp(.CDONE) // jump to end. + label(.CBETAZERO) - // check if aligned/column-stored - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.CCOLSTORBZ) // jump to column storage case - - - label(.CGENSTORBZ) - // update c00:c70 - vextractf128(imm(1), ymm15, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm15, mem(rcx)) // store (c00,c10) - vmovhpd(xmm15, mem(rcx, rsi, 1)) // store (c20,c30) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c40,c50) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c60,c70) - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - vextractf128(imm(1), ymm14, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm14, mem(rdx)) // store (c80,c90) - vmovhpd(xmm14, mem(rdx, rsi, 1)) // store (ca0,cb0) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc0,cd0) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce0,cf0) - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - vextractf128(imm(1), ymm13, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm13, mem(rcx)) // store (c01,c11) - vmovhpd(xmm13, mem(rcx, rsi, 1)) // store (c21,c31) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c41,c51) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c61,c71) - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - vextractf128(imm(1), ymm12, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm12, mem(rdx)) // store (c81,c91) - vmovhpd(xmm12, mem(rdx, rsi, 1)) // store (ca1,cb1) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc1,cd1) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce1,cf1) - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - vextractf128(imm(1), ymm11, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm11, mem(rcx)) // store (c02,c12) - vmovhpd(xmm11, mem(rcx, rsi, 1)) // store (c22,c32) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c42,c52) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c62,c72) - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - vextractf128(imm(1), ymm10, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm10, mem(rdx)) // store (c82,c92) - vmovhpd(xmm10, mem(rdx, rsi, 1)) // store (ca2,cb2) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc2,cd2) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce2,cf2) - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - vextractf128(imm(1), ymm9, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm9, mem(rcx)) // store (c03,c13) - vmovhpd(xmm9, mem(rcx, rsi, 1)) // store (c23,c33) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c43,c53) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c63,c73) - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - vextractf128(imm(1), ymm8, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm8, mem(rdx)) // store (c83,c93) - vmovhpd(xmm8, mem(rdx, rsi, 1)) // store (ca3,cb3) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc3,cd3) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce3,cf3) - add(rdi, rdx) // c += cs_c; - - - jmp(.CDONE) // jump to end. - - - label(.CCOLSTORBZ) - - vmovaps(ymm15, mem(rcx)) // store c00:c70 - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm14, mem(rdx)) // store c80:cf0 - add(rdi, rdx) // c += cs_c; - - vmovaps(ymm13, mem(rcx)) // store c01:c71 - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm12, mem(rdx)) // store c81:cf1 - add(rdi, rdx) // c += cs_c; - - vmovaps(ymm11, mem(rcx)) // store c02:c72 - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm10, mem(rdx)) // store c82:cf2 - add(rdi, rdx) // c += cs_c; - - vmovaps(ymm9, mem(rcx)) // store c03:c73 - add(rdi, rcx) // c += cs_c; - - vmovaps(ymm8, mem(rdx)) // store c83:cf3 - add(rdi, rdx) // c += cs_c; - - - + + vmovaps(ymm15, mem(rcx)) // store c00:c70 + vmovaps(ymm14, mem(rcx,32)) // store c80:cf0 + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm13, mem(rcx)) // store c01:c71 + vmovaps(ymm12, mem(rcx,32)) // store c81:cf1 + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm11, mem(rcx)) // store c02:c72 + vmovaps(ymm10, mem(rcx,32)) // store c82:cf2 + add(rdi, rcx) // c += cs_c; + + vmovaps(ymm9, mem(rcx)) // store c03:c73 + vmovaps(ymm8, mem(rcx,32)) // store c83:cf3 + add(rdi, rcx) // c += cs_c; + label(.CDONE) - + end_asm( : // output operands (none) @@ -1851,7 +1288,7 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 [b_next] "m" (b_next)/*, // 9 [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", @@ -1859,6 +1296,8 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) + + GEMM_UKR_FLUSH_CT( c ); } #define MADDSUBPD_TO_YMM \ @@ -1883,11 +1322,13 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 vmulpd(ymm7, ymm(i), ymm(i))\ vmulpd(ymm6, ymm(j), ymm(j))\ vaddsubpd(ymm(j), ymm(i), ymm(i))\ - + void bli_zgemm_bulldozer_asm_4x4_fma4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -1902,34 +1343,36 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ALIGNED( z, 4, 4, false, 32 ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(var(b_next), r15) // load address of b_next. //mov(var(a_next), r14) // load address of a_next. - + vmovapd(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovddup(mem(rbx, 0+0*32), ymm2) vmovddup(mem(rbx, 0+1*32), ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(dcomplex) lea(mem(, rdi, 2), rdi) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*cs_c prefetch(0, mem(r10, 3*8)) // prefetch c + 2*cs_c prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c - + vxorpd(ymm8, ymm8, ymm8) vxorpd(ymm9, ymm9, ymm9) vxorpd(ymm10, ymm10, ymm10) @@ -1938,28 +1381,28 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vxorpd(ymm13, ymm13, ymm13) vxorpd(ymm14, ymm14, ymm14) vxorpd(ymm15, ymm15, ymm15) - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.ZCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - + label(.ZLOOPKITER) // MAIN LOOP - + // iteration 0 vmovapd(mem(rax, 1*32), ymm1) vfmaddpd(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vfmaddpd(ymm11, ymm0, ymm3, ymm11) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + prefetch(0, mem(rax, 16*32)) vfmaddpd(ymm14, ymm1, ymm2, ymm14) vmovddup(mem(rbx, 8+0*32), ymm2) vfmaddpd(ymm10, ymm1, ymm3, ymm10) vmovddup(mem(rbx, 8+1*32), ymm3) - + MADDSUBPD_TO_YMM vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+2*32), ymm2) @@ -1967,31 +1410,31 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vmovddup(mem(rbx, 0+3*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 2*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - + // iteration 1 vmovapd(mem(rax, 3*32), ymm1) vfmaddpd(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vfmaddpd(ymm11, ymm0, ymm3, ymm11) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + prefetch(0, mem(rax, 18*32)) vfmaddpd(ymm14, ymm1, ymm2, ymm14) vmovddup(mem(rbx, 8+2*32), ymm2) vfmaddpd(ymm10, ymm1, ymm3, ymm10) vmovddup(mem(rbx, 8+3*32), ymm3) - + MADDSUBPD_TO_YMM vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+4*32), ymm2) @@ -1999,31 +1442,31 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vmovddup(mem(rbx, 0+5*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 4*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - + // iteration 2 vmovapd(mem(rax, 5*32), ymm1) vfmaddpd(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vfmaddpd(ymm11, ymm0, ymm3, ymm11) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + prefetch(0, mem(rax, 20*32)) vfmaddpd(ymm14, ymm1, ymm2, ymm14) vmovddup(mem(rbx, 8+4*32), ymm2) vfmaddpd(ymm10, ymm1, ymm3, ymm10) vmovddup(mem(rbx, 8+5*32), ymm3) - + MADDSUBPD_TO_YMM vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+6*32), ymm2) @@ -2031,31 +1474,31 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vmovddup(mem(rbx, 0+7*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 6*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - + // iteration 3 vmovapd(mem(rax, 7*32), ymm1) vfmaddpd(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vfmaddpd(ymm11, ymm0, ymm3, ymm11) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + prefetch(0, mem(rax, 22*32)) vfmaddpd(ymm14, ymm1, ymm2, ymm14) vmovddup(mem(rbx, 8+6*32), ymm2) vfmaddpd(ymm10, ymm1, ymm3, ymm10) vmovddup(mem(rbx, 8+7*32), ymm3) - + MADDSUBPD_TO_YMM vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+8*32), ymm2) @@ -2063,48 +1506,48 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vmovddup(mem(rbx, 0+9*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 8*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - + add(imm(4*4*16), rbx) // b += 4*4 (unroll x nr) add(imm(4*4*16), rax) // a += 4*4 (unroll x mr) - + dec(rsi) // i -= 1; jne(.ZLOOPKITER) // iterate again if i != 0. - - + + label(.ZCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.ZLOOPKLEFT) // EDGE LOOP - + // iteration 0 vmovapd(mem(rax, 1*32), ymm1) vfmaddpd(ymm15, ymm0, ymm2, ymm15) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vfmaddpd(ymm11, ymm0, ymm3, ymm11) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) - + prefetch(0, mem(rax, 16*32)) vfmaddpd(ymm14, ymm1, ymm2, ymm14) vmovddup(mem(rbx, 8+0*32), ymm2) vfmaddpd(ymm10, ymm1, ymm3, ymm10) vmovddup(mem(rbx, 8+1*32), ymm3) - + MADDSUBPD_TO_YMM vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+2*32), ymm2) @@ -2112,75 +1555,75 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 vmovddup(mem(rbx, 0+3*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 2*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + add(imm(4*1*16), rax) // a += 4 (1 x mr) add(imm(4*1*16), rbx) // b += 4 (1 x nr) - + dec(rsi) // i -= 1; jne(.ZLOOPKLEFT) // iterate again if i != 0. - - + + label(.ZPOSTACCUM) // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 + // ab10 ab11 ab12 ab13 // ab21 ab20 ab23 ab22 // ab31 ) ab30 ) ab33 ) ab32 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab50 ab51 ab52 ab53 + // ab50 ab51 ab52 ab53 // ab61 ab60 ab63 ab62 // ab71 ) ab70 ) ab73 ) ab72 ) - + vmovapd(ymm15, ymm7) vperm2f128(imm(0x12), ymm15, ymm13, ymm15) vperm2f128(imm(0x30), ymm7, ymm13, ymm13) - + vmovapd(ymm11, ymm7) vperm2f128(imm(0x12), ymm11, ymm9, ymm11) vperm2f128(imm(0x30), ymm7, ymm9, ymm9) - + vmovapd(ymm14, ymm7) vperm2f128(imm(0x12), ymm14, ymm12, ymm14) vperm2f128(imm(0x30), ymm7, ymm12, ymm12) - + vmovapd(ymm10, ymm7) vperm2f128(imm(0x12), ymm10, ymm8, ymm10) vperm2f128(imm(0x30), ymm7, ymm8, ymm8) - - + + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 + // ab10 ab11 ab12 ab13 // ab20 ab21 ab22 ab23 // ab30 ) ab31 ) ab32 ) ab33 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab50 ab51 ab52 ab53 + // ab50 ab51 ab52 ab53 // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - - + + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vbroadcastsd(mem(rax), ymm7) // load alpha_r and duplicate vbroadcastsd(mem(rax, 8), ymm6) // load alpha_i and duplicate - + Z_ALPHA(15, 3) Z_ALPHA(14, 2) Z_ALPHA(13, 1) @@ -2190,38 +1633,14 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 Z_ALPHA(10, 2) Z_ALPHA(9, 1) Z_ALPHA(8, 0) - + mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rbx), ymm7) // load beta_r and duplicate vbroadcastsd(mem(rbx, 8), ymm6) // load beta_i and duplicate - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(dcomplex) - lea(mem(, rsi, 2), rsi) - lea(mem(rcx, rsi, 2), rdx) // load address of c + 2*rs_c; - - - - // determine if - // c % 32 == 0, AND - // 16*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(16), rsi) // set ZF if (16*rs_c) == 16. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (16*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm7) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -2229,285 +1648,89 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 0, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.ZCOLSTORED) // jump to column storage case - - - - label(.ZGENSTORED) - // update c00:c30 - - vmovupd(mem(rcx), xmm0) // load (c00,c10) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c20,c30) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c00,c10) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c20,c30) - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vmovupd(mem(rdx), xmm0) // load (c40,c50) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c60,c70) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c40,c50) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c60,c70) - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vmovupd(mem(rcx), xmm0) // load (c01,c11) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c21,c31) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c01,c11) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c21,c31) - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vmovupd(mem(rdx), xmm0) // load (c41,c51) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c61,c71) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c41,c51) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c61,c71) - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vmovupd(mem(rcx), xmm0) // load (c02,c12) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c22,c32) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c02,c12) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c22,c32) - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vmovupd(mem(rdx), xmm0) // load (c42,c52) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c62,c72) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c42,c52) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c62,c72) - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vmovupd(mem(rcx), xmm0) // load (c03,c13) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c23,c33) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c03,c13) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c23,c33) - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vmovupd(mem(rdx), xmm0) // load (c43,c53) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c63,c73) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c43,c53) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c63,c73) - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORED) - // update c00:c30 - - vmovapd(mem(rcx), ymm0) // load c00:c30 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx)) // store c00:c30 - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vmovapd(mem(rdx), ymm0) // load c40:c70 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rdx)) // store c40:c70 - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vmovapd(mem(rcx), ymm0) // load c01:c31 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx)) // store c01:c31 - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vmovapd(mem(rdx), ymm0) // load c41:c71 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rdx)) // store c41:c71 - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vmovapd(mem(rcx), ymm0) // load c02:c32 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx)) // store c02:c32 - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vmovapd(mem(rdx), ymm0) // load c42:c72 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rdx)) // store c42:c72 - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vmovapd(mem(rcx), ymm0) // load c03:c33 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx)) // store c03:c33 - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vmovapd(mem(rdx), ymm0) // load c43:c73 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rdx)) // store c43:c73 - - - - jmp(.ZDONE) // jump to end. - - - + + // update c00:c30 + + vmovapd(mem(rcx), ymm0) // load c00:c30 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx)) // store c00:c30 + + // update c40:c70 + + vmovapd(mem(rcx,32), ymm0) // load c40:c70 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx,32)) // store c40:c70 + add(rdi, rcx) // c += cs_c; + + // update c01:c31 + + vmovapd(mem(rcx), ymm0) // load c01:c31 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx)) // store c01:c31 + + // update c41:c71 + + vmovapd(mem(rcx,32), ymm0) // load c41:c71 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx,32)) // store c41:c71 + add(rdi, rcx) // c += cs_c; + + // update c02:c32 + + vmovapd(mem(rcx), ymm0) // load c02:c32 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx)) // store c02:c32 + + // update c42:c72 + + vmovapd(mem(rcx,32), ymm0) // load c42:c72 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx,32)) // store c42:c72 + add(rdi, rcx) // c += cs_c; + + // update c03:c33 + + vmovapd(mem(rcx), ymm0) // load c03:c33 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx)) // store c03:c33 + + // update c43:c73 + + vmovapd(mem(rcx,32), ymm0) // load c43:c73 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx,32)) // store c43:c73 + add(rdi, rcx) // c += cs_c; + + jmp(.ZDONE) // jump to end. + label(.ZBETAZERO) - // check if aligned/column-stored - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.ZCOLSTORBZ) // jump to column storage case - - - - label(.ZGENSTORBZ) - // update c00:c30 - - vextractf128(imm(1), ymm15, xmm2) - vmovupd(xmm15, mem(rcx)) // store (c00,c10) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c20,c30) - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vextractf128(imm(1), ymm14, xmm2) - vmovupd(xmm14, mem(rdx)) // store (c40,c50) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c60,c70) - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vextractf128(imm(1), ymm13, xmm2) - vmovupd(xmm13, mem(rcx)) // store (c01,c11) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c21,c31) - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vextractf128(imm(1), ymm12, xmm2) - vmovupd(xmm12, mem(rdx)) // store (c41,c51) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c61,c71) - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vextractf128(imm(1), ymm11, xmm2) - vmovupd(xmm11, mem(rcx)) // store (c02,c12) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c22,c32) - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vextractf128(imm(1), ymm10, xmm2) - vmovupd(xmm10, mem(rdx)) // store (c42,c52) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c62,c72) - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vextractf128(imm(1), ymm9, xmm2) - vmovupd(xmm9, mem(rcx)) // store (c03,c13) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c23,c33) - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vextractf128(imm(1), ymm8, xmm2) - vmovupd(xmm8, mem(rdx)) // store (c43,c53) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c63,c73) - - - jmp(.ZDONE) // jump to end. - - - label(.ZCOLSTORBZ) - - - vmovapd(ymm15, mem(rcx)) // store c00:c30 - add(rdi, rcx) // c += cs_c; - - vmovapd(ymm14, mem(rdx)) // store c40:c70 - add(rdi, rdx) // c += cs_c; - - vmovapd(ymm13, mem(rcx)) // store c01:c31 - add(rdi, rcx) // c += cs_c; - - vmovapd(ymm12, mem(rdx)) // store c41:c71 - add(rdi, rdx) // c += cs_c; - - vmovapd(ymm11, mem(rcx)) // store c02:c32 - add(rdi, rcx) // c += cs_c; - - vmovapd(ymm10, mem(rdx)) // store c42:c72 - add(rdi, rdx) // c += cs_c; - - vmovapd(ymm9, mem(rcx)) // store c03:c33 - add(rdi, rcx) // c += cs_c; - - vmovapd(ymm8, mem(rdx)) // store c43:c73 - - + + vmovapd(ymm15, mem(rcx)) // store c00:c30 + vmovapd(ymm14, mem(rcx,32)) // store c40:c70 + add(rdi, rcx) // c += cs_c; + + vmovapd(ymm13, mem(rcx)) // store c01:c31 + vmovapd(ymm12, mem(rcx,32)) // store c41:c71 + add(rdi, rcx) // c += cs_c; + + vmovapd(ymm11, mem(rcx)) // store c02:c32 + vmovapd(ymm10, mem(rcx,32)) // store c42:c72 + add(rdi, rcx) // c += cs_c; + + vmovapd(ymm9, mem(rcx)) // store c03:c33 + vmovapd(ymm8, mem(rcx,32)) // store c43:c73 + //add(rdi, rcx) // c += cs_c; + label(.ZDONE) - + end_asm( : // output operands (none) @@ -2524,7 +1747,7 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 [b_next] "m" (b_next), // 9 [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", @@ -2532,5 +1755,7 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 "ymm12", "ymm13", "ymm14", "ymm15", "memory" ) + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c index 7907bd9018..897f6fd778 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c @@ -79,7 +79,9 @@ void bli_sgemm_haswell_asm_6x16 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -94,11 +96,13 @@ void bli_sgemm_haswell_asm_6x16 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_AMBI( s, 6, 16, true ); + begin_asm() vzeroall() // zero all xmm/ymm registers. @@ -116,18 +120,47 @@ void bli_sgemm_haswell_asm_6x16 mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) - - lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; - lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c - prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c - prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c - prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c - prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c - - - + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) + + cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. + jz(.SCOLPREFETCH) // jump to column prefetch case + + lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; + lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.SPREFETCHDONE) + + label(.SCOLPREFETCH) + + lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; + lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 7*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 7*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 7*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 7*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, r13, 1, 7*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 4, 7*8)) // prefetch c + 7*cs_c + lea(mem(rcx, rsi, 8), r14) // r14 = c + 8*cs_c; + lea(mem(r14, r13, 1), rdx) // rdx = c + 11*cs_c; + prefetch(0, mem(r14, 7*8)) // prefetch c + 8*cs_c + prefetch(0, mem(r14, rsi, 1, 7*8)) // prefetch c + 9*cs_c + prefetch(0, mem(r14, rsi, 2, 7*8)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 1, 7*8)) // prefetch c + 12*cs_c + prefetch(0, mem(rdx, rsi, 2, 7*8)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, r13, 1, 7*8)) // prefetch c + 14*cs_c + prefetch(0, mem(rdx, rsi, 4, 7*8)) // prefetch c + 15*cs_c + + label(.SPREFETCHDONE) mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. @@ -344,527 +377,324 @@ void bli_sgemm_haswell_asm_6x16 vucomiss(xmm0, xmm3) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. - jz(.SROWSTORED) // jump to row storage case - - - cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. - jz(.SCOLSTORED) // jump to column storage case - - - - label(.SGENSTORED) - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm4, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm6, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm8, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm10, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm12, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm14, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += rs_c; - - - mov(rdx, rcx) // rcx = c + 8*cs_c - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm5, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm7, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm9, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm11, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm13, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm15, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += rs_c; - - - - jmp(.SDONE) // jump to end. - - - - label(.SROWSTORED) - - - vfmadd231ps(mem(rcx), ymm3, ymm4) - vmovups(ymm4, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm5) - vmovups(ymm5, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm6) - vmovups(ymm6, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm7) - vmovups(ymm7, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm8) - vmovups(ymm8, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm9) - vmovups(ymm9, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm10) - vmovups(ymm10, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm11) - vmovups(ymm11, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm12) - vmovups(ymm12, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm13) - vmovups(ymm13, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm14) - vmovups(ymm14, mem(rcx)) - //add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm15) - vmovups(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - - vbroadcastss(mem(rbx), ymm3) - - vunpcklps(ymm6, ymm4, ymm0) - vunpcklps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm0) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) - vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) - vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - - vunpckhps(ymm6, ymm4, ymm0) - vunpckhps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) - vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) - vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) - vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) - vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - - lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - - vunpcklps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(r14), xmm1, xmm1) - vmovhpd(mem(r14, rsi, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) - vmovlpd(mem(r14, rsi, 4), xmm1, xmm1) - vmovhpd(mem(r14, r15, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - - vunpckhps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) - vmovhpd(mem(r14, r13, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) - vmovlpd(mem(r14, r13, 2), xmm1, xmm1) - vmovhpd(mem(r14, r10, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - - lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - - vunpcklps(ymm7, ymm5, ymm0) - vunpcklps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm0) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) - vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) - vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - - vunpckhps(ymm7, ymm5, ymm0) - vunpckhps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) - vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) - vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) - vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) - vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - - //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - - vunpcklps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(r14), xmm1, xmm1) - vmovhpd(mem(r14, rsi, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) - vmovlpd(mem(r14, rsi, 4), xmm1, xmm1) - vmovhpd(mem(r14, r15, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - - vunpckhps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) - vmovhpd(mem(r14, r13, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) - vmovlpd(mem(r14, r13, 2), xmm1, xmm1) - vmovhpd(mem(r14, r10, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - - //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - - jmp(.SDONE) // jump to end. - - + cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm5) + vmovups(ymm5, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm7) + vmovups(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm9) + vmovups(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm11) + vmovups(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm12) + vmovups(ymm12, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm13) + vmovups(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm14) + vmovups(ymm14, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm15) + vmovups(ymm15, mem(rcx,32)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(r14), xmm1, xmm1) + vmovhpd(mem(r14, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(r14, rsi, 4), xmm1, xmm1) + vmovhpd(mem(r14, r15, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) + vmovhpd(mem(r14, r13, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(r14, r13, 2), xmm1, xmm1) + vmovhpd(mem(r14, r10, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) + + lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c + + + + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(r14), xmm1, xmm1) + vmovhpd(mem(r14, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(r14, rsi, 4), xmm1, xmm1) + vmovhpd(mem(r14, r15, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) + vmovhpd(mem(r14, r13, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(r14, r13, 2), xmm1, xmm1) + vmovhpd(mem(r14, r10, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) + + //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c + + jmp(.SDONE) // jump to end. label(.SBETAZERO) - cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. - jz(.SROWSTORBZ) // jump to row storage case - - cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. - jz(.SCOLSTORBZ) // jump to column storage case - - - - label(.SGENSTORBZ) - - - vmovaps(ymm4, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm6, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm8, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm10, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm12, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm14, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += rs_c; - - - mov(rdx, rcx) // rcx = c + 8*cs_c - - - vmovaps(ymm5, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovaps(ymm7, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - + cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case - vmovaps(ymm9, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx,32)) + add(rdi, rcx) + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx,32)) + add(rdi, rcx) - vmovaps(ymm11, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx,32)) + add(rdi, rcx) - vmovaps(ymm13, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; + vmovups(ymm10, mem(rcx)) + vmovups(ymm11, mem(rcx,32)) + add(rdi, rcx) - vmovaps(ymm15, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += rs_c; + vmovups(ymm12, mem(rcx)) + vmovups(ymm13, mem(rcx,32)) + add(rdi, rcx) - jmp(.SDONE) // jump to end. + vmovups(ymm14, mem(rcx)) + vmovups(ymm15, mem(rcx,32)) + //add(rdi, rcx) + jmp(.SDONE) // jump to end. + label(.SCOLSTORBZ) - label(.SROWSTORBZ) + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - vmovups(ymm4, mem(rcx)) - add(rdi, rcx) - vmovups(ymm5, mem(rdx)) - add(rdi, rdx) + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - vmovups(ymm6, mem(rcx)) - add(rdi, rcx) - vmovups(ymm7, mem(rdx)) - add(rdi, rdx) + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) - vmovups(ymm8, mem(rcx)) - add(rdi, rcx) - vmovups(ymm9, mem(rdx)) - add(rdi, rdx) + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - vmovups(ymm10, mem(rcx)) - add(rdi, rcx) - vmovups(ymm11, mem(rdx)) - add(rdi, rdx) + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - vmovups(ymm12, mem(rcx)) - add(rdi, rcx) - vmovups(ymm13, mem(rdx)) - add(rdi, rdx) + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) + lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - vmovups(ymm14, mem(rcx)) - //add(rdi, rcx) - vmovups(ymm15, mem(rdx)) - //add(rdi, rdx) + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) - jmp(.SDONE) // jump to end. + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - label(.SCOLSTORBZ) + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - vunpcklps(ymm6, ymm4, ymm0) - vunpcklps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - - vunpckhps(ymm6, ymm4, ymm0) - vunpckhps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - - lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - - vunpcklps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) - vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - - vunpckhps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) - vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - - lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - - vunpcklps(ymm7, ymm5, ymm0) - vunpcklps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - - vunpckhps(ymm7, ymm5, ymm0) - vunpckhps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - - //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - - vunpcklps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) - vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - - vunpckhps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) - vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - - //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) + //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_ label(.SDONE) @@ -896,6 +726,8 @@ void bli_sgemm_haswell_asm_6x16 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } @@ -927,7 +759,9 @@ void bli_sgemm_haswell_asm_6x16 void bli_dgemm_haswell_asm_6x8 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -942,11 +776,13 @@ void bli_dgemm_haswell_asm_6x8 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_AMBI( d, 6, 8, true ); + begin_asm() vzeroall() // zero all xmm/ymm registers. @@ -964,17 +800,37 @@ void bli_dgemm_haswell_asm_6x8 mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) - lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; - lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c - prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c - prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c - prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c - prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. + jz(.SCOLPREFETCH) // jump to column prefetch case + + lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; + lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.SPREFETCHDONE) + + label(.SCOLPREFETCH) + lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; + lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 7*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 7*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 7*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 7*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, r13, 1, 7*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 4, 7*8)) // prefetch c + 7*cs_c + label(.SPREFETCHDONE) mov(var(k_iter), rsi) // i = k_iter; @@ -1194,422 +1050,226 @@ void bli_dgemm_haswell_asm_6x8 vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx,32)) + add(rdi, rcx) - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.DROWSTORED) // jump to row storage case - - - cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. - jz(.DCOLSTORED) // jump to column storage case - - - - label(.DGENSTORED) - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm4, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm6, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm8, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm10, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm12, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm14, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - - - mov(rdx, rcx) // rcx = c + 4*cs_c - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm5, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm7, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm9, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm11, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm13, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm15, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - - - - jmp(.DDONE) // jump to end. - - - - label(.DROWSTORED) - - - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm5) - vmovupd(ymm5, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm7) - vmovupd(ymm7, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm8) - vmovupd(ymm8, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm9) - vmovupd(ymm9, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm10) - vmovupd(ymm10, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm11) - vmovupd(ymm11, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm12) - vmovupd(ymm12, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm13) - vmovupd(ymm13, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm14) - vmovupd(ymm14, mem(rcx)) - //add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm15) - vmovupd(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORED) - - - vunpcklpd(ymm6, ymm4, ymm0) - vunpckhpd(ymm6, ymm4, ymm1) - vunpcklpd(ymm10, ymm8, ymm2) - vunpckhpd(ymm10, ymm8, ymm3) - vinsertf128(imm(0x1), xmm2, ymm0, ymm4) - vinsertf128(imm(0x1), xmm3, ymm1, ymm6) - vperm2f128(imm(0x31), ymm2, ymm0, ymm8) - vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - - vbroadcastsd(mem(rbx), ymm3) - - vfmadd231pd(mem(rcx), ymm3, ymm4) - vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) - vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) - vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm10) - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm6, mem(rcx, rsi, 1)) - vmovupd(ymm8, mem(rcx, rsi, 2)) - vmovupd(ymm10, mem(rcx, r13, 1)) - - lea(mem(rcx, rsi, 4), rcx) - - vunpcklpd(ymm14, ymm12, ymm0) - vunpckhpd(ymm14, ymm12, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vfmadd231pd(mem(r14), xmm3, xmm0) - vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) - vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) - vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) - - lea(mem(r14, rsi, 4), r14) - - - vunpcklpd(ymm7, ymm5, ymm0) - vunpckhpd(ymm7, ymm5, ymm1) - vunpcklpd(ymm11, ymm9, ymm2) - vunpckhpd(ymm11, ymm9, ymm3) - vinsertf128(imm(0x1), xmm2, ymm0, ymm5) - vinsertf128(imm(0x1), xmm3, ymm1, ymm7) - vperm2f128(imm(0x31), ymm2, ymm0, ymm9) - vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - - vbroadcastsd(mem(rbx), ymm3) - - vfmadd231pd(mem(rcx), ymm3, ymm5) - vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) - vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) - vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm11) - vmovupd(ymm5, mem(rcx)) - vmovupd(ymm7, mem(rcx, rsi, 1)) - vmovupd(ymm9, mem(rcx, rsi, 2)) - vmovupd(ymm11, mem(rcx, r13, 1)) - - //lea(mem(rcx, rsi, 4), rcx) - - vunpcklpd(ymm15, ymm13, ymm0) - vunpckhpd(ymm15, ymm13, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vfmadd231pd(mem(r14), xmm3, xmm0) - vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) - vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) - vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) - - //lea(mem(r14, rsi, 4), r14) - - - - jmp(.DDONE) // jump to end. + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx,32)) + //add(rdi, rcx) + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORED) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, r13, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(r14), xmm3, xmm0) + vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) + vmovupd(xmm0, mem(r14)) + vmovupd(xmm1, mem(r14, rsi, 1)) + vmovupd(xmm2, mem(r14, rsi, 2)) + vmovupd(xmm4, mem(r14, r13, 1)) + + lea(mem(r14, rsi, 4), r14) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, r13, 1)) + + //lea(mem(rcx, rsi, 4), rcx) + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(r14), xmm3, xmm0) + vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) + vmovupd(xmm0, mem(r14)) + vmovupd(xmm1, mem(r14, rsi, 1)) + vmovupd(xmm2, mem(r14, rsi, 2)) + vmovupd(xmm4, mem(r14, r13, 1)) + + //lea(mem(r14, rsi, 4), r14) + + jmp(.DDONE) // jump to end. label(.DBETAZERO) - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.DROWSTORBZ) // jump to row storage case - - cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. - jz(.DCOLSTORBZ) // jump to column storage case - - - - label(.DGENSTORBZ) - - - vmovapd(ymm4, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm6, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm8, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm10, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm12, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - - - vmovapd(ymm14, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - - - mov(rdx, rcx) // rcx = c + 4*cs_c - - - vmovapd(ymm5, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case - vmovapd(ymm7, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx,32)) + add(rdi, rcx) + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx,32)) + add(rdi, rcx) - vmovapd(ymm9, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx,32)) + add(rdi, rcx) - vmovapd(ymm11, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx,32)) + add(rdi, rcx) - vmovapd(ymm13, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += rs_c; + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx,32)) + add(rdi, rcx) - vmovapd(ymm15, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ + vmovupd(ymm14, mem(rcx)) + vmovupd(ymm15, mem(rcx,32)) + //add(rdi, rcx) + jmp(.DDONE) // jump to end. - jmp(.DDONE) // jump to end. + label(.DCOLSTORBZ) + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, r13, 1)) - label(.DROWSTORBZ) + lea(mem(rcx, rsi, 4), rcx) + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) - vmovupd(ymm4, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm5, mem(rdx)) - add(rdi, rdx) + vmovupd(xmm0, mem(r14)) + vmovupd(xmm1, mem(r14, rsi, 1)) + vmovupd(xmm2, mem(r14, rsi, 2)) + vmovupd(xmm4, mem(r14, r13, 1)) - vmovupd(ymm6, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm7, mem(rdx)) - add(rdi, rdx) + lea(mem(r14, rsi, 4), r14) - vmovupd(ymm8, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm9, mem(rdx)) - add(rdi, rdx) + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, r13, 1)) - vmovupd(ymm10, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm11, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm12, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm13, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm14, mem(rcx)) - //add(rdi, rcx) - vmovupd(ymm15, mem(rdx)) - //add(rdi, rdx) - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORBZ) - - - vunpcklpd(ymm6, ymm4, ymm0) - vunpckhpd(ymm6, ymm4, ymm1) - vunpcklpd(ymm10, ymm8, ymm2) - vunpckhpd(ymm10, ymm8, ymm3) - vinsertf128(imm(0x1), xmm2, ymm0, ymm4) - vinsertf128(imm(0x1), xmm3, ymm1, ymm6) - vperm2f128(imm(0x31), ymm2, ymm0, ymm8) - vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm6, mem(rcx, rsi, 1)) - vmovupd(ymm8, mem(rcx, rsi, 2)) - vmovupd(ymm10, mem(rcx, r13, 1)) - - lea(mem(rcx, rsi, 4), rcx) - - vunpcklpd(ymm14, ymm12, ymm0) - vunpckhpd(ymm14, ymm12, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) - - lea(mem(r14, rsi, 4), r14) - - - vunpcklpd(ymm7, ymm5, ymm0) - vunpckhpd(ymm7, ymm5, ymm1) - vunpcklpd(ymm11, ymm9, ymm2) - vunpckhpd(ymm11, ymm9, ymm3) - vinsertf128(imm(0x1), xmm2, ymm0, ymm5) - vinsertf128(imm(0x1), xmm3, ymm1, ymm7) - vperm2f128(imm(0x31), ymm2, ymm0, ymm9) - vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - - vmovupd(ymm5, mem(rcx)) - vmovupd(ymm7, mem(rcx, rsi, 1)) - vmovupd(ymm9, mem(rcx, rsi, 2)) - vmovupd(ymm11, mem(rcx, r13, 1)) - - //lea(mem(rcx, rsi, 4), rcx) - - vunpcklpd(ymm15, ymm13, ymm0) - vunpckhpd(ymm15, ymm13, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) - - //lea(mem(r14, rsi, 4), r14) + //lea(mem(rcx, rsi, 4), rcx) + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + vmovupd(xmm0, mem(r14)) + vmovupd(xmm1, mem(r14, rsi, 1)) + vmovupd(xmm2, mem(r14, rsi, 2)) + vmovupd(xmm4, mem(r14, r13, 1)) + //lea(mem(r14, rsi, 4), r14) label(.DDONE) @@ -1641,45 +1301,26 @@ void bli_dgemm_haswell_asm_6x8 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } -// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. -// outputs to ymm0 -#define CGEMM_INPUT_SCALE_GS_BETA_NZ \ - vmovlpd(mem(rcx), xmm0, xmm0) \ - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) \ - vmovlpd(mem(rcx, rsi, 2), xmm3, xmm3) \ - vmovhpd(mem(rcx, r13, 1), xmm3, xmm3) \ - vinsertf128(imm(1), xmm3, ymm0, ymm0) \ - vpermilps(imm(0xb1), ymm0, ymm3) \ - vmulps(ymm1, ymm0, ymm0) \ - vmulps(ymm2, ymm3, ymm3) \ - vaddsubps(ymm3, ymm0, ymm0) -// assumes values to output are in ymm0 -#define CGEMM_OUTPUT_GS \ - vextractf128(imm(1), ymm0, xmm3) \ - vmovlpd(xmm0, mem(rcx)) \ - vmovhpd(xmm0, mem(rcx, rsi, 1)) \ - vmovlpd(xmm3, mem(rcx, rsi, 2)) \ - vmovhpd(xmm3, mem(rcx, r13, 1)) - -#define CGEMM_INPUT_SCALE_RS_BETA_NZ \ - vmovups(mem(rcx), ymm0) \ +#define CGEMM_INPUT_SCALE_RS_BETA_NZ(where) \ + vmovups(where, ymm0) \ vpermilps(imm(0xb1), ymm0, ymm3) \ vmulps(ymm1, ymm0, ymm0) \ vmulps(ymm2, ymm3, ymm3) \ vaddsubps(ymm3, ymm0, ymm0) -#define CGEMM_OUTPUT_RS \ - vmovups(ymm0, mem(rcx)) \ - void bli_cgemm_haswell_asm_3x8 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -1694,11 +1335,13 @@ void bli_cgemm_haswell_asm_3x8 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( c, 3, 8, true ); + begin_asm() vzeroall() // zero all xmm/ymm registers. @@ -1969,15 +1612,6 @@ void bli_cgemm_haswell_asm_3x8 vbroadcastss(mem(rbx, 4), ymm2) // load beta_i and duplicate - - - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(scomplex) - lea(mem(, rsi, 4), rdx) // rdx = 4*cs_c; - lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; - - - // now avoid loading C if beta == 0 vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm1) // set ZF if beta_r == 0. @@ -1987,162 +1621,49 @@ void bli_cgemm_haswell_asm_3x8 and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 1, jump to beta == 0 case - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.CROWSTORED) // jump to row storage case - - - - label(.CGENSTORED) - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm4, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*cs_c; - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm5, ymm0, ymm0) - CGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*rs_c + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx)) + vaddps(ymm4, ymm0, ymm0) + vmovups(ymm0, mem(rcx)) + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx,32)) + vaddps(ymm5, ymm0, ymm0) + vmovups(ymm0, mem(rcx,32)) - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm8, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*cs_c; - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm9, ymm0, ymm0) - CGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*rs_c + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11)) + vaddps(ymm8, ymm0, ymm0) + vmovups(ymm0, mem(r11)) + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11,32)) + vaddps(ymm9, ymm0, ymm0) + vmovups(ymm0, mem(r11,32)) - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm12, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*cs_c; - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm13, ymm0, ymm0) - CGEMM_OUTPUT_GS + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12)) + vaddps(ymm12, ymm0, ymm0) + vmovups(ymm0, mem(r12)) + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12,32)) + vaddps(ymm13, ymm0, ymm0) + vmovups(ymm0, mem(r12,32)) - jmp(.CDONE) // jump to end. - - - - label(.CROWSTORED) - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm4, ymm0, ymm0) - CGEMM_OUTPUT_RS - add(rdx, rcx) // c += 4*cs_c; - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm5, ymm0, ymm0) - CGEMM_OUTPUT_RS - mov(r11, rcx) // rcx = c + 1*rs_c - - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm8, ymm0, ymm0) - CGEMM_OUTPUT_RS - add(rdx, rcx) // c += 4*cs_c; - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm9, ymm0, ymm0) - CGEMM_OUTPUT_RS - mov(r12, rcx) // rcx = c + 2*rs_c - - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm12, ymm0, ymm0) - CGEMM_OUTPUT_RS - add(rdx, rcx) // c += 4*cs_c; - - - CGEMM_INPUT_SCALE_RS_BETA_NZ - vaddps(ymm13, ymm0, ymm0) - CGEMM_OUTPUT_RS - - - - jmp(.CDONE) // jump to end. - - + jmp(.CDONE) // jump to end. label(.CBETAZERO) - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.CROWSTORBZ) // jump to row storage case - - - - label(.CGENSTORBZ) - - - vmovaps(ymm4, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovaps(ymm5, ymm0) - CGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*rs_c - - - - vmovaps(ymm8, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovaps(ymm9, ymm0) - CGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*rs_c - - - - vmovaps(ymm12, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovaps(ymm13, ymm0) - CGEMM_OUTPUT_GS - - - - jmp(.CDONE) // jump to end. - - - - label(.CROWSTORBZ) - - - vmovups(ymm4, mem(rcx)) - vmovups(ymm5, mem(rcx, rdx, 1)) - - vmovups(ymm8, mem(r11)) - vmovups(ymm9, mem(r11, rdx, 1)) - - vmovups(ymm12, mem(r12)) - vmovups(ymm13, mem(r12, rdx, 1)) - + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx,32)) + vmovups(ymm8, mem(r11)) + vmovups(ymm9, mem(r11,32)) + vmovups(ymm12, mem(r12)) + vmovups(ymm13, mem(r12,32)) label(.CDONE) @@ -2174,41 +1695,25 @@ void bli_cgemm_haswell_asm_3x8 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( c ); } -// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. -// outputs to ymm0 -#define ZGEMM_INPUT_SCALE_GS_BETA_NZ \ - vmovupd(mem(rcx), xmm0) \ - vmovupd(mem(rcx, rsi, 1), xmm3) \ - vinsertf128(imm(1), xmm3, ymm0, ymm0) \ +#define ZGEMM_INPUT_SCALE_RS_BETA_NZ(where) \ + vmovupd(where, ymm0) \ vpermilpd(imm(0x5), ymm0, ymm3) \ vmulpd(ymm1, ymm0, ymm0) \ vmulpd(ymm2, ymm3, ymm3) \ vaddsubpd(ymm3, ymm0, ymm0) -// assumes values to output are in ymm0 -#define ZGEMM_OUTPUT_GS \ - vextractf128(imm(1), ymm0, xmm3) \ - vmovupd(xmm0, mem(rcx)) \ - vmovupd(xmm3, mem(rcx, rsi, 1)) \ - -#define ZGEMM_INPUT_SCALE_RS_BETA_NZ \ - vmovupd(mem(rcx), ymm0) \ - vpermilpd(imm(0x5), ymm0, ymm3) \ - vmulpd(ymm1, ymm0, ymm0) \ - vmulpd(ymm2, ymm3, ymm3) \ - vaddsubpd(ymm3, ymm0, ymm0) - -#define ZGEMM_OUTPUT_RS \ - vmovupd(ymm0, mem(rcx)) \ - void bli_zgemm_haswell_asm_3x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -2223,11 +1728,13 @@ void bli_zgemm_haswell_asm_3x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( z, 3, 4, true ); + begin_asm() vzeroall() // zero all xmm/ymm registers. @@ -2501,14 +2008,6 @@ void bli_zgemm_haswell_asm_3x4 - - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(dcomplex) - lea(mem(, rsi, 2), rsi) - lea(mem(, rsi, 2), rdx) // rdx = 2*cs_c; - - - // now avoid loading C if beta == 0 vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. @@ -2518,162 +2017,49 @@ void bli_zgemm_haswell_asm_3x4 and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx)) + vaddpd(ymm4, ymm0, ymm0) + vmovupd(ymm0, mem(rcx)) - cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16. - jz(.ZROWSTORED) // jump to row storage case - - - - label(.ZGENSTORED) - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm5, ymm0, ymm0) - ZGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*rs_c - - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx,32)) + vaddpd(ymm5, ymm0, ymm0) + vmovupd(ymm0, mem(rcx,32)) - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm9, ymm0, ymm0) - ZGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*rs_c + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11)) + vaddpd(ymm8, ymm0, ymm0) + vmovupd(ymm0, mem(r11)) - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11,32)) + vaddpd(ymm9, ymm0, ymm0) + vmovupd(ymm0, mem(r11,32)) - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm13, ymm0, ymm0) - ZGEMM_OUTPUT_GS + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12)) + vaddpd(ymm12, ymm0, ymm0) + vmovupd(ymm0, mem(r12)) - jmp(.ZDONE) // jump to end. - - - - label(.ZROWSTORED) - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_RS - add(rdx, rcx) // c += 2*cs_c; - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm5, ymm0, ymm0) - ZGEMM_OUTPUT_RS - mov(r11, rcx) // rcx = c + 1*rs_c - - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_RS - add(rdx, rcx) // c += 2*cs_c; - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm9, ymm0, ymm0) - ZGEMM_OUTPUT_RS - mov(r12, rcx) // rcx = c + 2*rs_c - - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_RS - add(rdx, rcx) // c += 2*cs_c; - - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm13, ymm0, ymm0) - ZGEMM_OUTPUT_RS - - - - jmp(.ZDONE) // jump to end. + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12,32)) + vaddpd(ymm13, ymm0, ymm0) + vmovupd(ymm0, mem(r12,32)) + jmp(.ZDONE) // jump to end. label(.ZBETAZERO) - cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16. - jz(.ZROWSTORBZ) // jump to row storage case - - - - label(.ZGENSTORBZ) - - - vmovapd(ymm4, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovapd(ymm5, ymm0) - ZGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*rs_c - - - - vmovapd(ymm8, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovapd(ymm9, ymm0) - ZGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*rs_c - - - - vmovapd(ymm12, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*cs_c; - - - vmovapd(ymm13, ymm0) - ZGEMM_OUTPUT_GS - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZROWSTORBZ) - - - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rdx, 1)) - - vmovupd(ymm8, mem(r11)) - vmovupd(ymm9, mem(r11, rdx, 1)) - - vmovupd(ymm12, mem(r12)) - vmovupd(ymm13, mem(r12, rdx, 1)) - + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx,32)) + vmovupd(ymm8, mem(r11)) + vmovupd(ymm9, mem(r11,32)) + vmovupd(ymm12, mem(r12)) + vmovupd(ymm13, mem(r12,32)) label(.ZDONE) @@ -2705,6 +2091,8 @@ void bli_zgemm_haswell_asm_3x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c b/kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c index b074da965c..27dc99f52a 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c @@ -78,7 +78,9 @@ void bli_sgemm_haswell_asm_16x6 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -93,29 +95,31 @@ void bli_sgemm_haswell_asm_16x6 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( s, 16, 6, true ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rax) // initialize loop by pre-loading vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float) - + lea(mem(rdi, rdi, 2), r13) // r13 = 3*cs_c; lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c; prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c @@ -124,46 +128,46 @@ void bli_sgemm_haswell_asm_16x6 prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*cs_c prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*cs_c - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.SLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 128*4)) - + vbroadcastss(mem(rbx, 0*4), ymm2) vbroadcastss(mem(rbx, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 2*4), ymm2) vbroadcastss(mem(rbx, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 4*4), ymm2) vbroadcastss(mem(rbx, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, -2*32), ymm0) vmovaps(mem(rax, -1*32), ymm1) - + // iteration 1 vbroadcastss(mem(rbx, 6*4), ymm2) vbroadcastss(mem(rbx, 7*4), ymm3) @@ -171,51 +175,51 @@ void bli_sgemm_haswell_asm_16x6 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 8*4), ymm2) vbroadcastss(mem(rbx, 9*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 10*4), ymm2) vbroadcastss(mem(rbx, 11*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, 0*32), ymm0) vmovaps(mem(rax, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 152*4)) - + vbroadcastss(mem(rbx, 12*4), ymm2) vbroadcastss(mem(rbx, 13*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 14*4), ymm2) vbroadcastss(mem(rbx, 15*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 16*4), ymm2) vbroadcastss(mem(rbx, 17*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, 2*32), ymm0) vmovaps(mem(rax, 3*32), ymm1) - + // iteration 3 vbroadcastss(mem(rbx, 18*4), ymm2) vbroadcastss(mem(rbx, 19*4), ymm3) @@ -223,91 +227,91 @@ void bli_sgemm_haswell_asm_16x6 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 20*4), ymm2) vbroadcastss(mem(rbx, 21*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 22*4), ymm2) vbroadcastss(mem(rbx, 23*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(4*16*4), rax) // a += 4*16 (unroll x mr) add(imm(4*6*4), rbx) // b += 4*6 (unroll x nr) - + vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.SLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.SLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 128*4)) - + vbroadcastss(mem(rbx, 0*4), ymm2) vbroadcastss(mem(rbx, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 2*4), ymm2) vbroadcastss(mem(rbx, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 4*4), ymm2) vbroadcastss(mem(rbx, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(1*16*4), rax) // a += 1*16 (unroll x mr) add(imm(1*6*4), rbx) // b += 1*6 (unroll x nr) - + vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.SLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.SPOSTACCUM) - - - - + + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rax), ymm0) // load alpha and duplicate vbroadcastss(mem(rbx), ymm3) // load beta and duplicate - + vmulps(ymm0, ymm4, ymm4) // scale by alpha vmulps(ymm0, ymm5, ymm5) vmulps(ymm0, ymm6, ymm6) @@ -320,298 +324,90 @@ void bli_sgemm_haswell_asm_16x6 vmulps(ymm0, ymm13, ymm13) vmulps(ymm0, ymm14, ymm14) vmulps(ymm0, ymm15, ymm15) - - - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float) - - lea(mem(rcx, rsi, 8), rdx) // load address of c + 8*rs_c; - - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - lea(mem(rsi, rsi, 4), r15) // r15 = 5*rs_c; - lea(mem(r13, rsi, 4), r10) // r10 = 7*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm3) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(4), rsi) // set ZF if (4*rs_c) == 4. - jz(.SCOLSTORED) // jump to column storage case - - - - label(.SGENSTORED) - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm4, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm6, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm8, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm10, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm12, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm14, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - mov(rdx, rcx) // rcx = c + 8*rs_c - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm5, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm7, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm9, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm11, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm13, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - SGEMM_INPUT_GS_BETA_NZ - vfmadd213ps(ymm15, ymm3, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - - vfmadd231ps(mem(rcx), ymm3, ymm4) - vmovups(ymm4, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm5) - vmovups(ymm5, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm6) - vmovups(ymm6, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm7) - vmovups(ymm7, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm8) - vmovups(ymm8, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm9) - vmovups(ymm9, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm10) - vmovups(ymm10, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm11) - vmovups(ymm11, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm12) - vmovups(ymm12, mem(rcx)) - add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm13) - vmovups(ymm13, mem(rdx)) - add(rdi, rdx) - - - vfmadd231ps(mem(rcx), ymm3, ymm14) - vmovups(ymm14, mem(rcx)) - //add(rdi, rcx) - vfmadd231ps(mem(rdx), ymm3, ymm15) - vmovups(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm5) + vmovups(ymm5, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm7) + vmovups(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm9) + vmovups(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm11) + vmovups(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm12) + vmovups(ymm12, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm13) + vmovups(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm14) + vmovups(ymm14, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm15) + vmovups(ymm15, mem(rcx,32)) + //add(rdi, rcx) + jmp(.SDONE) // jump to end. - - - + label(.SBETAZERO) - - cmp(imm(4), rsi) // set ZF if (4*rs_c) == 4. - jz(.SCOLSTORBZ) // jump to column storage case - - - - label(.SGENSTORBZ) - - - vmovaps(ymm4, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm6, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm8, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm10, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm12, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm14, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - mov(rdx, rcx) // rcx = c + 8*rs_c - - - vmovaps(ymm5, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm7, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm9, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm11, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm13, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovaps(ymm15, ymm0) - SGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORBZ) - - - vmovups(ymm4, mem(rcx)) - add(rdi, rcx) - vmovups(ymm5, mem(rdx)) - add(rdi, rdx) - - vmovups(ymm6, mem(rcx)) - add(rdi, rcx) - vmovups(ymm7, mem(rdx)) - add(rdi, rdx) - - - vmovups(ymm8, mem(rcx)) - add(rdi, rcx) - vmovups(ymm9, mem(rdx)) - add(rdi, rdx) - - - vmovups(ymm10, mem(rcx)) - add(rdi, rcx) - vmovups(ymm11, mem(rdx)) - add(rdi, rdx) - - - vmovups(ymm12, mem(rcx)) - add(rdi, rcx) - vmovups(ymm13, mem(rdx)) - add(rdi, rdx) - - - vmovups(ymm14, mem(rcx)) - //add(rdi, rcx) - vmovups(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - - - - + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx,32)) + add(rdi, rcx) + + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vmovups(ymm10, mem(rcx)) + vmovups(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vmovups(ymm12, mem(rcx)) + vmovups(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vmovups(ymm14, mem(rcx)) + vmovups(ymm15, mem(rcx,32)) + //add(rdi, rcx) + label(.SDONE) - - + + end_asm( : // output operands (none) @@ -628,7 +424,7 @@ void bli_sgemm_haswell_asm_16x6 [b_next] "m" (b_next), // 9 [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -636,6 +432,8 @@ void bli_sgemm_haswell_asm_16x6 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } #define DGEMM_INPUT_GS_BETA_NZ \ @@ -664,7 +462,9 @@ void bli_sgemm_haswell_asm_16x6 void bli_dgemm_haswell_asm_8x6 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -679,29 +479,31 @@ void bli_dgemm_haswell_asm_8x6 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( d, 8, 6, false ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rax) // initialize loop by pre-loading vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) - + lea(mem(rdi, rdi, 2), r13) // r13 = 3*cs_c; lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c; prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c @@ -710,46 +512,46 @@ void bli_dgemm_haswell_asm_8x6 prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*cs_c prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*cs_c - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 64*8)) - + vbroadcastsd(mem(rbx, 0*8), ymm2) vbroadcastsd(mem(rbx, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 2*8), ymm2) vbroadcastsd(mem(rbx, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 4*8), ymm2) vbroadcastsd(mem(rbx, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, -2*32), ymm0) vmovapd(mem(rax, -1*32), ymm1) - + // iteration 1 vbroadcastsd(mem(rbx, 6*8), ymm2) vbroadcastsd(mem(rbx, 7*8), ymm3) @@ -757,51 +559,51 @@ void bli_dgemm_haswell_asm_8x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 8*8), ymm2) vbroadcastsd(mem(rbx, 9*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 10*8), ymm2) vbroadcastsd(mem(rbx, 11*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, 0*32), ymm0) vmovapd(mem(rax, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 76*8)) - + vbroadcastsd(mem(rbx, 12*8), ymm2) vbroadcastsd(mem(rbx, 13*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 14*8), ymm2) vbroadcastsd(mem(rbx, 15*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 16*8), ymm2) vbroadcastsd(mem(rbx, 17*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, 2*32), ymm0) vmovapd(mem(rax, 3*32), ymm1) - + // iteration 3 vbroadcastsd(mem(rbx, 18*8), ymm2) vbroadcastsd(mem(rbx, 19*8), ymm3) @@ -809,91 +611,91 @@ void bli_dgemm_haswell_asm_8x6 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 20*8), ymm2) vbroadcastsd(mem(rbx, 21*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 22*8), ymm2) vbroadcastsd(mem(rbx, 23*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(4*8*8), rax) // a += 4*8 (unroll x mr) add(imm(4*6*8), rbx) // b += 4*6 (unroll x nr) - + vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 64*8)) - + vbroadcastsd(mem(rbx, 0*8), ymm2) vbroadcastsd(mem(rbx, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 2*8), ymm2) vbroadcastsd(mem(rbx, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 4*8), ymm2) vbroadcastsd(mem(rbx, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(1*8*8), rax) // a += 1*8 (unroll x mr) add(imm(1*6*8), rbx) // b += 1*6 (unroll x nr) - + vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - - - + + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) @@ -906,297 +708,90 @@ void bli_dgemm_haswell_asm_8x6 vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm0, ymm14, ymm14) vmulpd(ymm0, ymm15, ymm15) - - - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) - - lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - //lea(mem(rsi, rsi, 4), r15) // r15 = 5*rs_c; - //lea(mem(r13, rsi, 4), r10) // r10 = 7*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - jz(.DCOLSTORED) // jump to column storage case - - - - label(.DGENSTORED) - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm4, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm6, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm8, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm10, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm12, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm14, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - mov(rdx, rcx) // rcx = c + 4*rs_c - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm5, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm7, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm9, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm11, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm13, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - DGEMM_INPUT_GS_BETA_NZ - vfmadd213pd(ymm15, ymm3, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORED) - - - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm5) - vmovupd(ymm5, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm7) - vmovupd(ymm7, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm8) - vmovupd(ymm8, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm9) - vmovupd(ymm9, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm10) - vmovupd(ymm10, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm11) - vmovupd(ymm11, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm12) - vmovupd(ymm12, mem(rcx)) - add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm13) - vmovupd(ymm13, mem(rdx)) - add(rdi, rdx) - - - vfmadd231pd(mem(rcx), ymm3, ymm14) - vmovupd(ymm14, mem(rcx)) - //add(rdi, rcx) - vfmadd231pd(mem(rdx), ymm3, ymm15) - vmovupd(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - jmp(.DDONE) // jump to end. - - - + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx,32)) + //add(rdi, rcx) + + jmp(.DDONE) // jump to end. + label(.DBETAZERO) - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - jz(.DCOLSTORBZ) // jump to column storage case - - - - label(.DGENSTORBZ) - - - vmovapd(ymm4, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm6, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm8, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm10, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm12, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm14, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - mov(rdx, rcx) // rcx = c + 4*rs_c - - - vmovapd(ymm5, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm7, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm9, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm11, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm13, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - add(rdi, rcx) // c += cs_c; - - - vmovapd(ymm15, ymm0) - DGEMM_OUTPUT_GS_BETA_NZ - //add(rdi, rcx) // c += cs_c; - - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORBZ) - - - vmovupd(ymm4, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm5, mem(rdx)) - add(rdi, rdx) - - vmovupd(ymm6, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm7, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm8, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm9, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm10, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm11, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm12, mem(rcx)) - add(rdi, rcx) - vmovupd(ymm13, mem(rdx)) - add(rdi, rdx) - - - vmovupd(ymm14, mem(rcx)) - //add(rdi, rcx) - vmovupd(ymm15, mem(rdx)) - //add(rdi, rdx) - - - - - - - + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx,32)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx)) + vmovupd(ymm15, mem(rcx,32)) + //add(rdi, rcx) + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -1213,7 +808,7 @@ void bli_dgemm_haswell_asm_8x6 [b_next] "m" (b_next), // 9 [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1221,45 +816,25 @@ void bli_dgemm_haswell_asm_8x6 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } -// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. -// outputs to ymm0 -#define CGEMM_INPUT_SCALE_GS_BETA_NZ \ - vmovlpd(mem(rcx), xmm0, xmm0) \ - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) \ - vmovlpd(mem(rcx, rsi, 2), xmm3, xmm3) \ - vmovhpd(mem(rcx, r13, 1), xmm3, xmm3) \ - vinsertf128(imm(1), xmm3, ymm0, ymm0) \ +#define CGEMM_INPUT_SCALE_CS_BETA_NZ(where) \ + vmovups(where, ymm0) \ vpermilps(imm(0xb1), ymm0, ymm3) \ vmulps(ymm1, ymm0, ymm0) \ vmulps(ymm2, ymm3, ymm3) \ vaddsubps(ymm3, ymm0, ymm0) -// assumes values to output are in ymm0 -#define CGEMM_OUTPUT_GS \ - vextractf128(imm(1), ymm0, xmm3) \ - vmovlpd(xmm0, mem(rcx)) \ - vmovhpd(xmm0, mem(rcx, rsi, 1)) \ - vmovlpd(xmm3, mem(rcx, rsi, 2)) \ - vmovhpd(xmm3, mem(rcx, r13, 1)) - -#define CGEMM_INPUT_SCALE_CS_BETA_NZ \ - vmovups(mem(rcx), ymm0) \ - vpermilps(imm(0xb1), ymm0, ymm3) \ - vmulps(ymm1, ymm0, ymm0) \ - vmulps(ymm2, ymm3, ymm3) \ - vaddsubps(ymm3, ymm0, ymm0) - -#define CGEMM_OUTPUT_CS \ - vmovups(ymm0, mem(rcx)) \ - void bli_cgemm_haswell_asm_8x3 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -1274,75 +849,77 @@ void bli_cgemm_haswell_asm_8x3 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( c, 8, 3, false ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rax) // initialize loop by pre-loading vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(scomplex) - + lea(mem(rcx, rdi, 1), r11) // r11 = c + 1*cs_c; lea(mem(rcx, rdi, 2), r12) // r12 = c + 2*cs_c; - + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c prefetch(0, mem(r11, 7*8)) // prefetch c + 1*cs_c prefetch(0, mem(r12, 7*8)) // prefetch c + 2*cs_c - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.CCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.CLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 32*8)) - + vbroadcastss(mem(rbx, 0*4), ymm2) vbroadcastss(mem(rbx, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 2*4), ymm2) vbroadcastss(mem(rbx, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 4*4), ymm2) vbroadcastss(mem(rbx, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, -2*32), ymm0) vmovaps(mem(rax, -1*32), ymm1) - + // iteration 1 vbroadcastss(mem(rbx, 6*4), ymm2) vbroadcastss(mem(rbx, 7*4), ymm3) @@ -1350,51 +927,51 @@ void bli_cgemm_haswell_asm_8x3 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 8*4), ymm2) vbroadcastss(mem(rbx, 9*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 10*4), ymm2) vbroadcastss(mem(rbx, 11*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, 0*32), ymm0) vmovaps(mem(rax, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 38*8)) - + vbroadcastss(mem(rbx, 12*4), ymm2) vbroadcastss(mem(rbx, 13*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 14*4), ymm2) vbroadcastss(mem(rbx, 15*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 16*4), ymm2) vbroadcastss(mem(rbx, 17*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + vmovaps(mem(rax, 2*32), ymm0) vmovaps(mem(rax, 3*32), ymm1) - + // iteration 3 vbroadcastss(mem(rbx, 18*4), ymm2) vbroadcastss(mem(rbx, 19*4), ymm3) @@ -1402,84 +979,84 @@ void bli_cgemm_haswell_asm_8x3 vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 20*4), ymm2) vbroadcastss(mem(rbx, 21*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 22*4), ymm2) vbroadcastss(mem(rbx, 23*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(4*8*8), rax) // a += 4*8 (unroll x mr) add(imm(4*3*8), rbx) // b += 4*3 (unroll x nr) - + vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.CLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.CCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.CPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.CLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 32*8)) - + vbroadcastss(mem(rbx, 0*4), ymm2) vbroadcastss(mem(rbx, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) vfmadd231ps(ymm1, ymm2, ymm5) vfmadd231ps(ymm0, ymm3, ymm6) vfmadd231ps(ymm1, ymm3, ymm7) - + vbroadcastss(mem(rbx, 2*4), ymm2) vbroadcastss(mem(rbx, 3*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm8) vfmadd231ps(ymm1, ymm2, ymm9) vfmadd231ps(ymm0, ymm3, ymm10) vfmadd231ps(ymm1, ymm3, ymm11) - + vbroadcastss(mem(rbx, 4*4), ymm2) vbroadcastss(mem(rbx, 5*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm12) vfmadd231ps(ymm1, ymm2, ymm13) vfmadd231ps(ymm0, ymm3, ymm14) vfmadd231ps(ymm1, ymm3, ymm15) - + add(imm(1*8*8), rax) // a += 1*8 (unroll x mr) add(imm(1*3*8), rbx) // b += 1*3 (unroll x nr) - + vmovaps(mem(rax, -4*32), ymm0) vmovaps(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.CLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.CPOSTACCUM) - - + + // permute even and odd elements // of ymm6/7, ymm10/11, ymm/14/15 vpermilps(imm(0xb1), ymm6, ymm6) @@ -1488,76 +1065,68 @@ void bli_cgemm_haswell_asm_8x3 vpermilps(imm(0xb1), ymm11, ymm11) vpermilps(imm(0xb1), ymm14, ymm14) vpermilps(imm(0xb1), ymm15, ymm15) - - + + // subtract/add even/odd elements vaddsubps(ymm6, ymm4, ymm4) vaddsubps(ymm7, ymm5, ymm5) - + vaddsubps(ymm10, ymm8, ymm8) vaddsubps(ymm11, ymm9, ymm9) - + vaddsubps(ymm14, ymm12, ymm12) vaddsubps(ymm15, ymm13, ymm13) - - - - + + + + mov(var(alpha), rax) // load address of alpha vbroadcastss(mem(rax), ymm0) // load alpha_r and duplicate vbroadcastss(mem(rax, 4), ymm1) // load alpha_i and duplicate - - + + vpermilps(imm(0xb1), ymm4, ymm3) vmulps(ymm0, ymm4, ymm4) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm4, ymm4) - + vpermilps(imm(0xb1), ymm5, ymm3) vmulps(ymm0, ymm5, ymm5) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm5, ymm5) - - + + vpermilps(imm(0xb1), ymm8, ymm3) vmulps(ymm0, ymm8, ymm8) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm8, ymm8) - + vpermilps(imm(0xb1), ymm9, ymm3) vmulps(ymm0, ymm9, ymm9) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm9, ymm9) - - + + vpermilps(imm(0xb1), ymm12, ymm3) vmulps(ymm0, ymm12, ymm12) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm12, ymm12) - + vpermilps(imm(0xb1), ymm13, ymm3) vmulps(ymm0, ymm13, ymm13) vmulps(ymm1, ymm3, ymm3) vaddsubps(ymm3, ymm13, ymm13) - - - - - + + + + + mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rbx), ymm1) // load beta_r and duplicate vbroadcastss(mem(rbx, 4), ymm2) // load beta_i and duplicate - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(scomplex) - lea(mem(, rsi, 4), rdx) // rdx = 4*rs_c; - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - - - + + + // now avoid loading C if beta == 0 vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm1) // set ZF if beta_r == 0. @@ -1566,169 +1135,54 @@ void bli_cgemm_haswell_asm_8x3 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.CCOLSTORED) // jump to row storage case - - - - label(.CGENSTORED) - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm4, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm5, ymm0, ymm0) - CGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm8, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm9, ymm0, ymm0) - CGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm12, ymm0, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_GS_BETA_NZ - vaddps(ymm13, ymm0, ymm0) - CGEMM_OUTPUT_GS - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORED) - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm4, ymm0, ymm0) - CGEMM_OUTPUT_CS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm5, ymm0, ymm0) - CGEMM_OUTPUT_CS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm8, ymm0, ymm0) - CGEMM_OUTPUT_CS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm9, ymm0, ymm0) - CGEMM_OUTPUT_CS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm12, ymm0, ymm0) - CGEMM_OUTPUT_CS - add(rdx, rcx) // c += 4*rs_c; - - - CGEMM_INPUT_SCALE_CS_BETA_NZ - vaddps(ymm13, ymm0, ymm0) - CGEMM_OUTPUT_CS - - - - jmp(.CDONE) // jump to end. - - - + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx)) + vaddps(ymm4, ymm0, ymm0) + vmovups(ymm0, mem(rcx)) + + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx,32)) + vaddps(ymm5, ymm0, ymm0) + vmovups(ymm0, mem(rcx,32)) + + + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11)) + vaddps(ymm8, ymm0, ymm0) + vmovups(ymm0, mem(r11)) + + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11,32)) + vaddps(ymm9, ymm0, ymm0) + vmovups(ymm0, mem(r11,32)) + + + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12)) + vaddps(ymm12, ymm0, ymm0) + vmovups(ymm0, mem(r12)) + + + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12,32)) + vaddps(ymm13, ymm0, ymm0) + vmovups(ymm0, mem(r12,32)) + + jmp(.CDONE) // jump to end. + label(.CBETAZERO) - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - jz(.CCOLSTORBZ) // jump to row storage case - - - - label(.CGENSTORBZ) - - - vmovaps(ymm4, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - vmovaps(ymm5, ymm0) - CGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - vmovaps(ymm8, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - vmovaps(ymm9, ymm0) - CGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - vmovaps(ymm12, ymm0) - CGEMM_OUTPUT_GS - add(rdx, rcx) // c += 4*rs_c; - - - vmovaps(ymm13, ymm0) - CGEMM_OUTPUT_GS - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORBZ) - - - vmovups(ymm4, mem(rcx)) - vmovups(ymm5, mem(rcx, rdx, 1)) - - vmovups(ymm8, mem(r11)) - vmovups(ymm9, mem(r11, rdx, 1)) - - vmovups(ymm12, mem(r12)) - vmovups(ymm13, mem(r12, rdx, 1)) - - - - - - + + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx,32)) + + vmovups(ymm8, mem(r11)) + vmovups(ymm9, mem(r11,32)) + + vmovups(ymm12, mem(r12)) + vmovups(ymm13, mem(r12,32)) + label(.CDONE) - - + + end_asm( : // output operands (none) @@ -1745,7 +1199,7 @@ void bli_cgemm_haswell_asm_8x3 [b_next] "m" (b_next), // 9 [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1753,41 +1207,25 @@ void bli_cgemm_haswell_asm_8x3 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( c ); } -// assumes beta.r, beta.i have been broadcast into ymm1, ymm2. -// outputs to ymm0 -#define ZGEMM_INPUT_SCALE_GS_BETA_NZ \ - vmovupd(mem(rcx), xmm0) \ - vmovupd(mem(rcx, rsi, 1), xmm3) \ - vinsertf128(imm(1), xmm3, ymm0, ymm0) \ - vpermilpd(imm(0x5), ymm0, ymm3) \ - vmulpd(ymm1, ymm0, ymm0) \ - vmulpd(ymm2, ymm3, ymm3) \ - vaddsubpd(ymm3, ymm0, ymm0) - -// assumes values to output are in ymm0 -#define ZGEMM_OUTPUT_GS \ - vextractf128(imm(1), ymm0, xmm3) \ - vmovupd(xmm0, mem(rcx)) \ - vmovupd(xmm3, mem(rcx, rsi, 1)) \ - -#define ZGEMM_INPUT_SCALE_CS_BETA_NZ \ - vmovups(mem(rcx), ymm0) \ +#define ZGEMM_INPUT_SCALE_CS_BETA_NZ(where) \ + vmovups(where, ymm0) \ vpermilpd(imm(0x5), ymm0, ymm3) \ vmulpd(ymm1, ymm0, ymm0) \ vmulpd(ymm2, ymm3, ymm3) \ vaddsubpd(ymm3, ymm0, ymm0) - -#define ZGEMM_OUTPUT_CS \ - vmovupd(ymm0, mem(rcx)) \ void bli_zgemm_haswell_asm_4x3 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -1802,76 +1240,78 @@ void bli_zgemm_haswell_asm_4x3 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( z, 4, 3, false ); + begin_asm() - + vzeroall() // zero all xmm/ymm registers. - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(%9, r15) // load address of b_next. - + add(imm(32*4), rax) // initialize loop by pre-loading vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(dcomplex) lea(mem(, rdi, 2), rdi) - + lea(mem(rcx, rdi, 1), r11) // r11 = c + 1*cs_c; lea(mem(rcx, rdi, 2), r12) // r12 = c + 2*cs_c; - + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c prefetch(0, mem(r11, 7*8)) // prefetch c + 1*cs_c prefetch(0, mem(r12, 7*8)) // prefetch c + 2*cs_c - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.ZCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.ZLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 32*16)) - + vbroadcastsd(mem(rbx, 0*8), ymm2) vbroadcastsd(mem(rbx, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 2*8), ymm2) vbroadcastsd(mem(rbx, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 4*8), ymm2) vbroadcastsd(mem(rbx, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, -2*32), ymm0) vmovapd(mem(rax, -1*32), ymm1) - + // iteration 1 vbroadcastsd(mem(rbx, 6*8), ymm2) vbroadcastsd(mem(rbx, 7*8), ymm3) @@ -1879,51 +1319,51 @@ void bli_zgemm_haswell_asm_4x3 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 8*8), ymm2) vbroadcastsd(mem(rbx, 9*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 10*8), ymm2) vbroadcastsd(mem(rbx, 11*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, 0*32), ymm0) vmovapd(mem(rax, 1*32), ymm1) - + // iteration 2 prefetch(0, mem(rax, 38*16)) - + vbroadcastsd(mem(rbx, 12*8), ymm2) vbroadcastsd(mem(rbx, 13*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 14*8), ymm2) vbroadcastsd(mem(rbx, 15*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 16*8), ymm2) vbroadcastsd(mem(rbx, 17*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rax, 2*32), ymm0) vmovapd(mem(rax, 3*32), ymm1) - + // iteration 3 vbroadcastsd(mem(rbx, 18*8), ymm2) vbroadcastsd(mem(rbx, 19*8), ymm3) @@ -1931,83 +1371,83 @@ void bli_zgemm_haswell_asm_4x3 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 20*8), ymm2) vbroadcastsd(mem(rbx, 21*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 22*8), ymm2) vbroadcastsd(mem(rbx, 23*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(4*4*16), rax) // a += 4*4 (unroll x mr) add(imm(4*3*16), rbx) // b += 4*3 (unroll x nr) - + vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.ZLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.ZCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.ZLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 32*16)) - + vbroadcastsd(mem(rbx, 0*8), ymm2) vbroadcastsd(mem(rbx, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rbx, 2*8), ymm2) vbroadcastsd(mem(rbx, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rbx, 4*8), ymm2) vbroadcastsd(mem(rbx, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(1*4*16), rax) // a += 1*4 (unroll x mr) add(imm(1*3*16), rbx) // b += 1*3 (unroll x nr) - + vmovapd(mem(rax, -4*32), ymm0) vmovapd(mem(rax, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.ZLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.ZPOSTACCUM) - + // permute even and odd elements // of ymm6/7, ymm10/11, ymm/14/15 vpermilpd(imm(0x5), ymm6, ymm6) @@ -2016,76 +1456,69 @@ void bli_zgemm_haswell_asm_4x3 vpermilpd(imm(0x5), ymm11, ymm11) vpermilpd(imm(0x5), ymm14, ymm14) vpermilpd(imm(0x5), ymm15, ymm15) - - + + // subtract/add even/odd elements vaddsubpd(ymm6, ymm4, ymm4) vaddsubpd(ymm7, ymm5, ymm5) - + vaddsubpd(ymm10, ymm8, ymm8) vaddsubpd(ymm11, ymm9, ymm9) - + vaddsubpd(ymm14, ymm12, ymm12) vaddsubpd(ymm15, ymm13, ymm13) - - - - + + + + mov(var(alpha), rax) // load address of alpha vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate - - + + vpermilpd(imm(0x5), ymm4, ymm3) vmulpd(ymm0, ymm4, ymm4) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm4, ymm4) - + vpermilpd(imm(0x5), ymm5, ymm3) vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm5, ymm5) - - + + vpermilpd(imm(0x5), ymm8, ymm3) vmulpd(ymm0, ymm8, ymm8) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm9, ymm3) vmulpd(ymm0, ymm9, ymm9) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm9, ymm9) - - + + vpermilpd(imm(0x5), ymm12, ymm3) vmulpd(ymm0, ymm12, ymm12) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm12, ymm12) - + vpermilpd(imm(0x5), ymm13, ymm3) vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm1, ymm3, ymm3) vaddsubpd(ymm3, ymm13, ymm13) - - - - - + + + + + mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(dcomplex) - lea(mem(, rsi, 2), rsi) - lea(mem(, rsi, 2), rdx) // rdx = 2*rs_c; - - - + + + + // now avoid loading C if beta == 0 vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. @@ -2094,169 +1527,54 @@ void bli_zgemm_haswell_asm_4x3 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(16), rsi) // set ZF if (16*rs_c) == 16. - jz(.ZCOLSTORED) // jump to row storage case - - - - label(.ZGENSTORED) - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm5, ymm0, ymm0) - ZGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm9, ymm0, ymm0) - ZGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_GS_BETA_NZ - vaddpd(ymm13, ymm0, ymm0) - ZGEMM_OUTPUT_GS - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORED) - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_CS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm5, ymm0, ymm0) - ZGEMM_OUTPUT_CS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_CS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm9, ymm0, ymm0) - ZGEMM_OUTPUT_CS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_CS - add(rdx, rcx) // c += 2*rs_c; - - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm13, ymm0, ymm0) - ZGEMM_OUTPUT_CS - - - - jmp(.ZDONE) // jump to end. - - - + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx)) + vaddpd(ymm4, ymm0, ymm0) + vmovupd(ymm0, mem(rcx)) + + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx,32)) + vaddpd(ymm5, ymm0, ymm0) + vmovupd(ymm0, mem(rcx,32)) + + + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11)) + vaddpd(ymm8, ymm0, ymm0) + vmovupd(ymm0, mem(r11)) + + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11,32)) + vaddpd(ymm9, ymm0, ymm0) + vmovupd(ymm0, mem(r11,32)) + + + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12)) + vaddpd(ymm12, ymm0, ymm0) + vmovupd(ymm0, mem(r12)) + + + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12,32)) + vaddpd(ymm13, ymm0, ymm0) + vmovupd(ymm0, mem(r12,32)) + + jmp(.ZDONE) // jump to end. + label(.ZBETAZERO) - - cmp(imm(16), rsi) // set ZF if (16*rs_c) == 16. - jz(.ZCOLSTORBZ) // jump to row storage case - - - - label(.ZGENSTORBZ) - - - vmovapd(ymm4, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - vmovapd(ymm5, ymm0) - ZGEMM_OUTPUT_GS - mov(r11, rcx) // rcx = c + 1*cs_c - - - - vmovapd(ymm8, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - vmovapd(ymm9, ymm0) - ZGEMM_OUTPUT_GS - mov(r12, rcx) // rcx = c + 2*cs_c - - - - vmovapd(ymm12, ymm0) - ZGEMM_OUTPUT_GS - add(rdx, rcx) // c += 2*rs_c; - - - vmovapd(ymm13, ymm0) - ZGEMM_OUTPUT_GS - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORBZ) - - - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rdx, 1)) - - vmovupd(ymm8, mem(r11)) - vmovupd(ymm9, mem(r11, rdx, 1)) - - vmovupd(ymm12, mem(r12)) - vmovupd(ymm13, mem(r12, rdx, 1)) - - - - - - + + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx,32)) + + vmovupd(ymm8, mem(r11)) + vmovupd(ymm9, mem(r11,32)) + + vmovupd(ymm12, mem(r12)) + vmovupd(ymm13, mem(r12,32)) + label(.ZDONE) - - + + end_asm( : // output operands (none) @@ -2273,7 +1591,7 @@ void bli_zgemm_haswell_asm_4x3 [b_next] "m" (b_next), // 9 [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2281,6 +1599,8 @@ void bli_zgemm_haswell_asm_4x3 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/knc/3/bli_dgemm_knc_asm_30x8.c b/kernels/knc/3/bli_dgemm_knc_asm_30x8.c index 880632ae07..f20e43f7cc 100644 --- a/kernels/knc/3/bli_dgemm_knc_asm_30x8.c +++ b/kernels/knc/3/bli_dgemm_knc_asm_30x8.c @@ -256,6 +256,8 @@ extern int offsets[16]; //#define LOOPMON void bli_dgemm_knc_asm_30x8 ( + dim_t m, + dim_t n, dim_t k, double* restrict alpha, double* restrict a, @@ -273,80 +275,82 @@ void bli_dgemm_knc_asm_30x8 uint64_t k64 = k; + GEMM_UKR_SETUP_CT( d, 30, 8, true ); + #ifdef MONITORS int toph, topl, both, botl, midl, midh, mid2l, mid2h; #endif #ifdef LOOPMON int tlooph, tloopl, blooph, bloopl; #endif - + __asm { #ifdef MONITORS rdtsc mov topl, eax - mov toph, edx + mov toph, edx #endif vpxord zmm0, zmm0, zmm0 vmovaps zmm1, zmm0 //clear out registers - vmovaps zmm2, zmm0 + vmovaps zmm2, zmm0 mov rsi, k64 //loop index - vmovaps zmm3, zmm0 + vmovaps zmm3, zmm0 mov r11, rs_c //load row stride - vmovaps zmm4, zmm0 + vmovaps zmm4, zmm0 sal r11, 3 //scale row stride - vmovaps zmm5, zmm0 + vmovaps zmm5, zmm0 mov r15, a //load address of a - vmovaps zmm6, zmm0 + vmovaps zmm6, zmm0 mov rbx, b //load address of b - vmovaps zmm7, zmm0 + vmovaps zmm7, zmm0 - vmovaps zmm8, zmm0 + vmovaps zmm8, zmm0 lea r10, [r11 + 2*r11 + 0] //r10 has 3 * r11 vmovaps zmm9, zmm0 - vmovaps zmm10, zmm0 - mov rdi, r11 - vmovaps zmm11, zmm0 + vmovaps zmm10, zmm0 + mov rdi, r11 + vmovaps zmm11, zmm0 sal rdi, 2 //rdi has 4*r11 - vmovaps zmm12, zmm0 + vmovaps zmm12, zmm0 mov rcx, c //load address of c for prefetching - vmovaps zmm13, zmm0 - vmovaps zmm14, zmm0 + vmovaps zmm13, zmm0 + vmovaps zmm14, zmm0 mov r8, k64 - vmovaps zmm15, zmm0 + vmovaps zmm15, zmm0 vmovaps zmm16, zmm0 vmovaps zmm17, zmm0 mov r13, L2_PREFETCH_DIST*8*8 - vmovaps zmm18, zmm0 + vmovaps zmm18, zmm0 mov r14, L2_PREFETCH_DIST*8*32 - vmovaps zmm19, zmm0 - vmovaps zmm20, zmm0 - vmovaps zmm21, zmm0 - vmovaps zmm22, zmm0 + vmovaps zmm19, zmm0 + vmovaps zmm20, zmm0 + vmovaps zmm21, zmm0 + vmovaps zmm22, zmm0 - vmovaps zmm23, zmm0 + vmovaps zmm23, zmm0 sub r8, 30 + L2_PREFETCH_DIST //Check if we have over 40 operations to do. - vmovaps zmm24, zmm0 + vmovaps zmm24, zmm0 mov r8, 30 - vmovaps zmm25, zmm0 + vmovaps zmm25, zmm0 mov r9, 8*8 //amount to increment b* by each iteration - vmovaps zmm26, zmm0 + vmovaps zmm26, zmm0 mov r12, 32*8 //amount to increment a* by each iteration - vmovaps zmm27, zmm0 - vmovaps zmm28, zmm0 - vmovaps zmm29, zmm0 + vmovaps zmm27, zmm0 + vmovaps zmm28, zmm0 + vmovaps zmm29, zmm0 #ifdef MONITORS rdtsc mov midl, eax - mov midh, edx + mov midh, edx #endif jle CONSIDER_UNDER_40 sub rsi, 30 + L2_PREFETCH_DIST - + //First 30 iterations LOOPREFECHCL2: ONE_ITER_PC_L2(rcx) @@ -357,26 +361,26 @@ void bli_dgemm_knc_asm_30x8 LOOPMAIN: ONE_ITER_MAIN_LOOP(rcx, rsi) jne LOOPMAIN - + //Penultimate 22 iterations. //Break these off from the main loop to avoid prefetching extra shit. mov r14, a_next mov r13, b_next sub r14, r15 sub r13, rbx - + mov rsi, L2_PREFETCH_DIST-10 LOOPMAIN2: ONE_ITER_MAIN_LOOP(rcx, rsi) jne LOOPMAIN2 - - + + //Last 10 iterations mov r8, 10 LOOPREFETCHCL1: ONE_ITER_PC_L1(rcx) jne LOOPREFETCHCL1 - + jmp POSTACCUM @@ -403,14 +407,8 @@ void bli_dgemm_knc_asm_30x8 mov r9, c //load address of c for update mov r12, alpha //load address of alpha - // Check if C is row stride. If not, jump to the slow scattered update - mov r14, cs_c - dec r14 - jne SCATTEREDUPDATE - mov r14, beta - vbroadcastsd zmm31, 0[r14] - + vbroadcastsd zmm31, 0[r14] vmulpd zmm0, zmm0, 0[r12]{1to8} vmulpd zmm1, zmm1, 0[r12]{1to8} @@ -467,7 +465,7 @@ void bli_dgemm_knc_asm_30x8 vmovapd [r9+2*r11+0], zmm14 vmovapd [r9+r10+0], zmm15 add r9, rdi - + vmulpd zmm16, zmm16, 0[r12]{1to8} vmulpd zmm17, zmm17, 0[r12]{1to8} vmulpd zmm18, zmm18, 0[r12]{1to8} @@ -516,47 +514,6 @@ void bli_dgemm_knc_asm_30x8 vfmadd231pd zmm29, zmm31, [r9+r11+0] vmovapd [r9+0], zmm28 vmovapd [r9+r11+0], zmm29 - - jmp END - - SCATTEREDUPDATE: - mov r10, offsetPtr - vmovapd zmm31, 0[r10] - vpbroadcastd zmm30, cs_c - mov r13, beta - vpmulld zmm30, zmm31, zmm30 - - mov ebx, 255 - UPDATE_C_ROW_SCATTERED(zmm0, 0, r9) - UPDATE_C_ROW_SCATTERED(zmm1, 1, r9) - UPDATE_C_ROW_SCATTERED(zmm2, 2, r9) - UPDATE_C_ROW_SCATTERED(zmm3, 3, r9) - UPDATE_C_ROW_SCATTERED(zmm4, 4, r9) - UPDATE_C_ROW_SCATTERED(zmm5, 5, r9) - UPDATE_C_ROW_SCATTERED(zmm6, 6, r9) - UPDATE_C_ROW_SCATTERED(zmm7, 7, r9) - UPDATE_C_ROW_SCATTERED(zmm8, 8, r9) - UPDATE_C_ROW_SCATTERED(zmm9, 9, r9) - UPDATE_C_ROW_SCATTERED(zmm10, 10, r9) - UPDATE_C_ROW_SCATTERED(zmm11, 11, r9) - UPDATE_C_ROW_SCATTERED(zmm12, 12, r9) - UPDATE_C_ROW_SCATTERED(zmm13, 13, r9) - UPDATE_C_ROW_SCATTERED(zmm14, 14, r9) - UPDATE_C_ROW_SCATTERED(zmm15, 15, r9) - UPDATE_C_ROW_SCATTERED(zmm16, 16, r9) - UPDATE_C_ROW_SCATTERED(zmm17, 17, r9) - UPDATE_C_ROW_SCATTERED(zmm18, 18, r9) - UPDATE_C_ROW_SCATTERED(zmm19, 19, r9) - UPDATE_C_ROW_SCATTERED(zmm20, 20, r9) - UPDATE_C_ROW_SCATTERED(zmm21, 21, r9) - UPDATE_C_ROW_SCATTERED(zmm22, 22, r9) - UPDATE_C_ROW_SCATTERED(zmm23, 23, r9) - UPDATE_C_ROW_SCATTERED(zmm24, 24, r9) - UPDATE_C_ROW_SCATTERED(zmm25, 25, r9) - UPDATE_C_ROW_SCATTERED(zmm26, 26, r9) - UPDATE_C_ROW_SCATTERED(zmm27, 27, r9) - UPDATE_C_ROW_SCATTERED(zmm28, 28, r9) - UPDATE_C_ROW_SCATTERED(zmm29, 29, r9) END: #ifdef MONITORS @@ -566,6 +523,8 @@ void bli_dgemm_knc_asm_30x8 #endif } + GEMM_UKR_FLUSH_CT( d ); + #ifdef LOOPMON printf("looptime = \t%d\n", bloopl - tloopl); #endif diff --git a/kernels/knc/3/bli_sgemm_knc_asm_30x16.c b/kernels/knc/3/bli_sgemm_knc_asm_30x16.c index 866cb62ec1..18a8e5e2ee 100644 --- a/kernels/knc/3/bli_sgemm_knc_asm_30x16.c +++ b/kernels/knc/3/bli_sgemm_knc_asm_30x16.c @@ -256,6 +256,8 @@ int offsets[16] __attribute__((aligned(0x1000))) = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9 //#define LOOPMON void bli_sgemm_knc_asm_30x16 ( + dim_t m, + dim_t n, dim_t k, float* restrict alpha, float* restrict a, @@ -273,80 +275,82 @@ void bli_sgemm_knc_asm_30x16 uint64_t k64 = k; + GEMM_UKR_SETUP_CT( s, 30, 16, true ); + #ifdef MONITORS int toph, topl, both, botl, midl, midh, mid2l, mid2h; #endif #ifdef LOOPMON int tlooph, tloopl, blooph, bloopl; #endif - + __asm { #ifdef MONITORS rdtsc mov topl, eax - mov toph, edx + mov toph, edx #endif vpxord zmm0, zmm0, zmm0 vmovaps zmm1, zmm0 //clear out registers - vmovaps zmm2, zmm0 + vmovaps zmm2, zmm0 mov rsi, k64 //loop index - vmovaps zmm3, zmm0 + vmovaps zmm3, zmm0 mov r11, rs_c //load row stride - vmovaps zmm4, zmm0 + vmovaps zmm4, zmm0 sal r11, 2 //scale row stride - vmovaps zmm5, zmm0 + vmovaps zmm5, zmm0 mov r15, a //load address of a - vmovaps zmm6, zmm0 + vmovaps zmm6, zmm0 mov rbx, b //load address of b - vmovaps zmm7, zmm0 + vmovaps zmm7, zmm0 - vmovaps zmm8, zmm0 + vmovaps zmm8, zmm0 lea r10, [r11 + 2*r11 + 0] //r10 has 3 * r11 vmovaps zmm9, zmm0 - vmovaps zmm10, zmm0 - mov rdi, r11 - vmovaps zmm11, zmm0 + vmovaps zmm10, zmm0 + mov rdi, r11 + vmovaps zmm11, zmm0 sal rdi, 2 //rdi has 4*r11 - vmovaps zmm12, zmm0 + vmovaps zmm12, zmm0 mov rcx, c //load address of c for prefetching - vmovaps zmm13, zmm0 - vmovaps zmm14, zmm0 + vmovaps zmm13, zmm0 + vmovaps zmm14, zmm0 mov r8, k64 - vmovaps zmm15, zmm0 + vmovaps zmm15, zmm0 vmovaps zmm16, zmm0 vmovaps zmm17, zmm0 mov r13, L2_PREFETCH_DIST*4*16 - vmovaps zmm18, zmm0 + vmovaps zmm18, zmm0 mov r14, L2_PREFETCH_DIST*4*32 - vmovaps zmm19, zmm0 - vmovaps zmm20, zmm0 - vmovaps zmm21, zmm0 - vmovaps zmm22, zmm0 + vmovaps zmm19, zmm0 + vmovaps zmm20, zmm0 + vmovaps zmm21, zmm0 + vmovaps zmm22, zmm0 - vmovaps zmm23, zmm0 + vmovaps zmm23, zmm0 sub r8, 30 + L2_PREFETCH_DIST //Check if we have over 40 operations to do. - vmovaps zmm24, zmm0 + vmovaps zmm24, zmm0 mov r8, 30 - vmovaps zmm25, zmm0 + vmovaps zmm25, zmm0 mov r9, 16*4 //amount to increment b* by each iteration - vmovaps zmm26, zmm0 + vmovaps zmm26, zmm0 mov r12, 32*4 //amount to increment a* by each iteration - vmovaps zmm27, zmm0 - vmovaps zmm28, zmm0 - vmovaps zmm29, zmm0 + vmovaps zmm27, zmm0 + vmovaps zmm28, zmm0 + vmovaps zmm29, zmm0 #ifdef MONITORS rdtsc mov midl, eax - mov midh, edx + mov midh, edx #endif jle CONSIDER_UNDER_40 sub rsi, 30 + L2_PREFETCH_DIST - + //First 30 iterations LOOPREFECHCL2: ONE_ITER_PC_L2(rcx) @@ -357,26 +361,26 @@ void bli_sgemm_knc_asm_30x16 LOOPMAIN: ONE_ITER_MAIN_LOOP(rcx, rsi) jne LOOPMAIN - + //Penultimate 22 iterations. //Break these off from the main loop to avoid prefetching extra shit. mov r14, a_next mov r13, b_next sub r14, r15 sub r13, rbx - + mov rsi, L2_PREFETCH_DIST-10 LOOPMAIN2: ONE_ITER_MAIN_LOOP(rcx, rsi) jne LOOPMAIN2 - - + + //Last 10 iterations mov r8, 10 LOOPREFETCHCL1: ONE_ITER_PC_L1(rcx) jne LOOPREFETCHCL1 - + jmp POSTACCUM @@ -384,7 +388,7 @@ void bli_sgemm_knc_asm_30x16 //Used when <= 40 iterations CONSIDER_UNDER_40: mov rsi, k64 - test rsi, rsi + test rsi, rsi je POSTACCUM LOOP_UNDER_40: ONE_ITER_MAIN_LOOP(rcx, rsi) @@ -403,13 +407,8 @@ void bli_sgemm_knc_asm_30x16 mov r9, c //load address of c for update mov r12, alpha //load address of alpha - // Check if C is row stride. If not, jump to the slow scattered update - mov r14, cs_c - dec r14 - jne SCATTEREDUPDATE - mov r14, beta - vbroadcastss zmm31, 0[r14] + vbroadcastss zmm31, 0[r14] vmulps zmm0, zmm0, 0[r12]{1to16} @@ -467,7 +466,7 @@ void bli_sgemm_knc_asm_30x16 vmovaps [r9+2*r11+0], zmm14 vmovaps [r9+r10+0], zmm15 add r9, rdi - + vmulps zmm16, zmm16, 0[r12]{1to16} vmulps zmm17, zmm17, 0[r12]{1to16} vmulps zmm18, zmm18, 0[r12]{1to16} @@ -516,48 +515,6 @@ void bli_sgemm_knc_asm_30x16 vfmadd231ps zmm29, zmm31, [r9+r11+0] vmovaps [r9+0], zmm28 vmovaps [r9+r11+0], zmm29 - - jmp END - - SCATTEREDUPDATE: - - mov r10, offsetPtr - vmovaps zmm31, 0[r10] - vpbroadcastd zmm30, cs_c - mov r13, beta - vpmulld zmm30, zmm31, zmm30 - - mov ebx, 0xFFFF - UPDATE_C_ROW_SCATTERED(zmm0, 0, r9) - UPDATE_C_ROW_SCATTERED(zmm1, 1, r9) - UPDATE_C_ROW_SCATTERED(zmm2, 2, r9) - UPDATE_C_ROW_SCATTERED(zmm3, 3, r9) - UPDATE_C_ROW_SCATTERED(zmm4, 4, r9) - UPDATE_C_ROW_SCATTERED(zmm5, 5, r9) - UPDATE_C_ROW_SCATTERED(zmm6, 6, r9) - UPDATE_C_ROW_SCATTERED(zmm7, 7, r9) - UPDATE_C_ROW_SCATTERED(zmm8, 8, r9) - UPDATE_C_ROW_SCATTERED(zmm9, 9, r9) - UPDATE_C_ROW_SCATTERED(zmm10, 10, r9) - UPDATE_C_ROW_SCATTERED(zmm11, 11, r9) - UPDATE_C_ROW_SCATTERED(zmm12, 12, r9) - UPDATE_C_ROW_SCATTERED(zmm13, 13, r9) - UPDATE_C_ROW_SCATTERED(zmm14, 14, r9) - UPDATE_C_ROW_SCATTERED(zmm15, 15, r9) - UPDATE_C_ROW_SCATTERED(zmm16, 16, r9) - UPDATE_C_ROW_SCATTERED(zmm17, 17, r9) - UPDATE_C_ROW_SCATTERED(zmm18, 18, r9) - UPDATE_C_ROW_SCATTERED(zmm19, 19, r9) - UPDATE_C_ROW_SCATTERED(zmm20, 20, r9) - UPDATE_C_ROW_SCATTERED(zmm21, 21, r9) - UPDATE_C_ROW_SCATTERED(zmm22, 22, r9) - UPDATE_C_ROW_SCATTERED(zmm23, 23, r9) - UPDATE_C_ROW_SCATTERED(zmm24, 24, r9) - UPDATE_C_ROW_SCATTERED(zmm25, 25, r9) - UPDATE_C_ROW_SCATTERED(zmm26, 26, r9) - UPDATE_C_ROW_SCATTERED(zmm27, 27, r9) - UPDATE_C_ROW_SCATTERED(zmm28, 28, r9) - UPDATE_C_ROW_SCATTERED(zmm29, 29, r9) END: #ifdef MONITORS @@ -567,6 +524,8 @@ void bli_sgemm_knc_asm_30x16 #endif } + GEMM_UKR_FLUSH_CT( s ); + #ifdef LOOPMON printf("looptime = \t%d\n", bloopl - tloopl); #endif diff --git a/kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c b/kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c index e52cc9e0e0..5d24f6e86e 100644 --- a/kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c +++ b/kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c @@ -39,7 +39,9 @@ void bli_sgemm_penryn_asm_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -54,38 +56,40 @@ void bli_sgemm_penryn_asm_8x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ALIGNED( s, 8, 4, false, 16 ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r9) // load address of b_next. - + sub(imm(0-8*16), rax) // increment pointers to allow byte sub(imm(0-8*16), rbx) // offsets in the unrolled iterations. - + movaps(mem(rax, -8*16), xmm0) // initialize loop by pre-loading elements movaps(mem(rax, -7*16), xmm1) // of a and b. movaps(mem(rbx, -8*16), xmm2) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float) mov(rdi, r12) // make a copy of cs_c (in bytes) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(2, mem(r9, 0*4)) // prefetch b_next - + xorps(xmm3, xmm3) xorps(xmm4, xmm4) xorps(xmm5, xmm5) xorps(xmm6, xmm6) - + prefetch(2, mem(rcx, 6*4)) // prefetch c + 0*cs_c xorps(xmm8, xmm8) xorps(xmm9, xmm9) @@ -98,33 +102,33 @@ void bli_sgemm_penryn_asm_8x4 prefetch(2, mem(r10, rdi, 1, 6*4)) // prefetch c + 3*cs_c xorps(xmm14, xmm14) xorps(xmm15, xmm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.SLOOPKITER) // MAIN LOOP - + prefetch(0, mem(rax, (4*35+1)*8)) - + addps(xmm6, xmm10) // iteration 0 addps(xmm3, xmm14) movaps(xmm2, xmm3) pshufd(imm(0x39), xmm2, xmm7) mulps(xmm0, xmm2) mulps(xmm1, xmm3) - + addps(xmm4, xmm11) addps(xmm5, xmm15) movaps(xmm7, xmm5) pshufd(imm(0x39), xmm7, xmm6) mulps(xmm0, xmm7) mulps(xmm1, xmm5) - + addps(xmm2, xmm8) movaps(mem(rbx, -7*16), xmm2) addps(xmm3, xmm12) @@ -132,7 +136,7 @@ void bli_sgemm_penryn_asm_8x4 pshufd(imm(0x39), xmm6, xmm4) mulps(xmm0, xmm6) mulps(xmm1, xmm3) - + addps(xmm7, xmm9) addps(xmm5, xmm13) movaps(xmm4, xmm5) @@ -140,22 +144,22 @@ void bli_sgemm_penryn_asm_8x4 movaps(mem(rax, -6*16), xmm0) mulps(xmm1, xmm5) movaps(mem(rax, -5*16), xmm1) - - + + addps(xmm6, xmm10) // iteration 1 addps(xmm3, xmm14) movaps(xmm2, xmm3) pshufd(imm(0x39), xmm2, xmm7) mulps(xmm0, xmm2) mulps(xmm1, xmm3) - + addps(xmm4, xmm11) addps(xmm5, xmm15) movaps(xmm7, xmm5) pshufd(imm(0x39), xmm7, xmm6) mulps(xmm0, xmm7) mulps(xmm1, xmm5) - + addps(xmm2, xmm8) movaps(mem(rbx, -6*16), xmm2) addps(xmm3, xmm12) @@ -163,7 +167,7 @@ void bli_sgemm_penryn_asm_8x4 pshufd(imm(0x39), xmm6, xmm4) mulps(xmm0, xmm6) mulps(xmm1, xmm3) - + addps(xmm7, xmm9) addps(xmm5, xmm13) movaps(xmm4, xmm5) @@ -171,22 +175,22 @@ void bli_sgemm_penryn_asm_8x4 movaps(mem(rax, -4*16), xmm0) mulps(xmm1, xmm5) movaps(mem(rax, -3*16), xmm1) - - + + addps(xmm6, xmm10) // iteration 2 addps(xmm3, xmm14) movaps(xmm2, xmm3) pshufd(imm(0x39), xmm2, xmm7) mulps(xmm0, xmm2) mulps(xmm1, xmm3) - + addps(xmm4, xmm11) addps(xmm5, xmm15) movaps(xmm7, xmm5) pshufd(imm(0x39), xmm7, xmm6) mulps(xmm0, xmm7) mulps(xmm1, xmm5) - + addps(xmm2, xmm8) movaps(mem(rbx, -5*16), xmm2) addps(xmm3, xmm12) @@ -194,7 +198,7 @@ void bli_sgemm_penryn_asm_8x4 pshufd(imm(0x39), xmm6, xmm4) mulps(xmm0, xmm6) mulps(xmm1, xmm3) - + addps(xmm7, xmm9) addps(xmm5, xmm13) movaps(xmm4, xmm5) @@ -202,26 +206,26 @@ void bli_sgemm_penryn_asm_8x4 movaps(mem(rax, -2*16), xmm0) mulps(xmm1, xmm5) movaps(mem(rax, -1*16), xmm1) - - + + addps(xmm6, xmm10) // iteration 3 addps(xmm3, xmm14) movaps(xmm2, xmm3) pshufd(imm(0x39), xmm2, xmm7) mulps(xmm0, xmm2) mulps(xmm1, xmm3) - + sub(imm(0-4*8*4), rax) // a += 4*8 (unroll x mr) - + addps(xmm4, xmm11) addps(xmm5, xmm15) movaps(xmm7, xmm5) pshufd(imm(0x39), xmm7, xmm6) mulps(xmm0, xmm7) mulps(xmm1, xmm5) - + sub(imm(0-4*4*4), r9) // b_next += 4*4 (unroll x nr) - + addps(xmm2, xmm8) movaps(mem(rbx, -4*16), xmm2) addps(xmm3, xmm12) @@ -229,9 +233,9 @@ void bli_sgemm_penryn_asm_8x4 pshufd(imm(0x39), xmm6, xmm4) mulps(xmm0, xmm6) mulps(xmm1, xmm3) - + sub(imm(0-4*4*4), rbx) // b += 4*4 (unroll x nr) - + addps(xmm7, xmm9) addps(xmm5, xmm13) movaps(xmm4, xmm5) @@ -239,40 +243,40 @@ void bli_sgemm_penryn_asm_8x4 movaps(mem(rax, -8*16), xmm0) mulps(xmm1, xmm5) movaps(mem(rax, -7*16), xmm1) - + prefetch(2, mem(r9, 0*4)) // prefetch b_next[0] prefetch(2, mem(r9, 16*4)) // prefetch b_next[16] - - + + dec(rsi) // i -= 1; jne(.SLOOPKITER) // iterate again if i != 0. - - - + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.SLOOPKLEFT) // EDGE LOOP - + addps(xmm6, xmm10) // iteration 0 addps(xmm3, xmm14) movaps(xmm2, xmm3) pshufd(imm(0x39), xmm2, xmm7) mulps(xmm0, xmm2) mulps(xmm1, xmm3) - + addps(xmm4, xmm11) addps(xmm5, xmm15) movaps(xmm7, xmm5) pshufd(imm(0x39), xmm7, xmm6) mulps(xmm0, xmm7) mulps(xmm1, xmm5) - + addps(xmm2, xmm8) movaps(mem(rbx, -7*16), xmm2) addps(xmm3, xmm12) @@ -280,7 +284,7 @@ void bli_sgemm_penryn_asm_8x4 pshufd(imm(0x39), xmm6, xmm4) mulps(xmm0, xmm6) mulps(xmm1, xmm3) - + addps(xmm7, xmm9) addps(xmm5, xmm13) movaps(xmm4, xmm5) @@ -288,40 +292,40 @@ void bli_sgemm_penryn_asm_8x4 movaps(mem(rax, -6*16), xmm0) mulps(xmm1, xmm5) movaps(mem(rax, -5*16), xmm1) - + sub(imm(0-1*8*4), rax) // a += 8 (1 x mr) sub(imm(0-1*4*4), rbx) // b += 4 (1 x nr) - - + + dec(rsi) // i -= 1; jne(.SLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.SPOSTACCUM) - + addps(xmm6, xmm10) addps(xmm3, xmm14) addps(xmm4, xmm11) addps(xmm5, xmm15) - - + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta movss(mem(rax), xmm6) // load alpha to bottom 4 bytes of xmm6 movss(mem(rbx), xmm7) // load beta to bottom 4 bytes of xmm7 pshufd(imm(0x00), xmm6, xmm6) // populate xmm6 with four alphas pshufd(imm(0x00), xmm7, xmm7) // populate xmm7 with four betas - - + + mov(var(rs_c), rsi) // load rs_c mov(rsi, r8) // make a copy of rs_c - + lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float) lea(mem(rsi, rsi, 2), r11) // r11 = 3*(rs_c * sizeof(float)) - + lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - + // xmm8: xmm9: xmm10: xmm11: // ( ab00 ( ab01 ( ab02 ( ab03 // ab11 ab12 ab13 ab10 @@ -338,20 +342,20 @@ void bli_sgemm_penryn_asm_8x4 shufps(imm(0xd8), xmm11, xmm8) shufps(imm(0xd8), xmm10, xmm11) shufps(imm(0xd8), xmm4, xmm10) - + movaps(xmm8, xmm4) shufps(imm(0xd8), xmm10, xmm8) shufps(imm(0xd8), xmm4, xmm10) movaps(xmm9, xmm5) shufps(imm(0xd8), xmm11, xmm9) shufps(imm(0xd8), xmm5, xmm11) - + movaps(xmm13, xmm4) shufps(imm(0xd8), xmm12, xmm13) shufps(imm(0xd8), xmm15, xmm12) shufps(imm(0xd8), xmm14, xmm15) shufps(imm(0xd8), xmm4, xmm14) - + movaps(xmm12, xmm4) shufps(imm(0xd8), xmm14, xmm12) shufps(imm(0xd8), xmm4, xmm14) @@ -369,456 +373,118 @@ void bli_sgemm_penryn_asm_8x4 // ab50 ab51 ab52 ab53 // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - - - - // determine if - // c % 16 == 0, AND - // 8*cs_c % 16 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(1), r8) // set ZF if rs_c == 1. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(15), rcx) // set ZF if c & 16 is zero. - setz(bh) // bh = ( ZF == 1 ? 1 : 0 ); - test(imm(15), r12) // set ZF if (4*cs_c) & 16 is zero. - setz(al) // al = ( ZF == 1 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + // now avoid loading C if beta == 0 - + xorpd(xmm0, xmm0) // set xmm0 to zero. ucomisd(xmm0, xmm7) // check if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORED) // jump to column storage case - - - - label(.SGENSTORED) - - movlps(mem(rcx), xmm0) // load c00 ~ c30 - movhps(mem(rcx, rsi, 1), xmm0) - movlps(mem(rcx, rsi, 2), xmm1) - movhps(mem(rcx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm8) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm8, xmm0) // add the gemm result, - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - movlps(mem(rdx), xmm0) // load c40 ~ c70 - movhps(mem(rdx, rsi, 1), xmm0) - movlps(mem(rdx, rsi, 2), xmm1) - movhps(mem(rdx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm12) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm12, xmm0) // add the gemm result, - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - movlps(mem(rcx), xmm0) // load c01 ~ c31 - movhps(mem(rcx, rsi, 1), xmm0) - movlps(mem(rcx, rsi, 2), xmm1) - movhps(mem(rcx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm9) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm9, xmm0) // add the gemm result, - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - movlps(mem(rdx), xmm0) // load c41 ~ c71 - movhps(mem(rdx, rsi, 1), xmm0) - movlps(mem(rdx, rsi, 2), xmm1) - movhps(mem(rdx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm13) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm13, xmm0) // add the gemm result, - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - movlps(mem(rcx), xmm0) // load c02 ~ c32 - movhps(mem(rcx, rsi, 1), xmm0) - movlps(mem(rcx, rsi, 2), xmm1) - movhps(mem(rcx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm10) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm10, xmm0) // add the gemm result, - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - movlps(mem(rdx), xmm0) // load c42 ~ c72 - movhps(mem(rdx, rsi, 1), xmm0) - movlps(mem(rdx, rsi, 2), xmm1) - movhps(mem(rdx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm14) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm14, xmm0) // add the gemm result, - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - movlps(mem(rcx), xmm0) // load c03 ~ c33 - movhps(mem(rcx, rsi, 1), xmm0) - movlps(mem(rcx, rsi, 2), xmm1) - movhps(mem(rcx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm11) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm11, xmm0) // add the gemm result, - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - - - - movlps(mem(rdx), xmm0) // load c43 ~ c73 - movhps(mem(rdx, rsi, 1), xmm0) - movlps(mem(rdx, rsi, 2), xmm1) - movhps(mem(rdx, r11, 1), xmm1) - shufps(imm(0x88), xmm1, xmm0) - - mulps(xmm6, xmm15) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm15, xmm0) // add the gemm result, - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - movaps(mem(rcx), xmm0) // load c00 ~ c30, - mulps(xmm6, xmm8) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm8, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c40 ~ c70, - mulps(xmm6, xmm12) // scale by alpha, - mulps(xmm7, xmm1) // scale by beta, - addps(xmm12, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c01 ~ c31, - mulps(xmm6, xmm9) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm9, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c41 ~ c71, - mulps(xmm6, xmm13) // scale by alpha, - mulps(xmm7, xmm1) // scale by beta, - addps(xmm13, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c02 ~ c32, - mulps(xmm6, xmm10) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm10, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c42 ~ c72, - mulps(xmm6, xmm14) // scale by alpha, - mulps(xmm7, xmm1) // scale by beta, - addps(xmm14, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c03 ~ c33, - mulps(xmm6, xmm11) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm11, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - - - movaps(mem(rdx), xmm1) // load c43 ~ c73, - mulps(xmm6, xmm15) // scale by alpha, - mulps(xmm7, xmm1) // scale by beta, - addps(xmm15, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - - jmp(.SDONE) // jump to end. - - - - + + movaps(mem(rcx), xmm0) // load c00 ~ c30, + mulps(xmm6, xmm8) // scale by alpha, + mulps(xmm7, xmm0) // scale by beta, + addps(xmm8, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c40 ~ c70, + mulps(xmm6, xmm12) // scale by alpha, + mulps(xmm7, xmm1) // scale by beta, + addps(xmm12, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c01 ~ c31, + mulps(xmm6, xmm9) // scale by alpha, + mulps(xmm7, xmm0) // scale by beta, + addps(xmm9, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c41 ~ c71, + mulps(xmm6, xmm13) // scale by alpha, + mulps(xmm7, xmm1) // scale by beta, + addps(xmm13, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c02 ~ c32, + mulps(xmm6, xmm10) // scale by alpha, + mulps(xmm7, xmm0) // scale by beta, + addps(xmm10, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c42 ~ c72, + mulps(xmm6, xmm14) // scale by alpha, + mulps(xmm7, xmm1) // scale by beta, + addps(xmm14, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c03 ~ c33, + mulps(xmm6, xmm11) // scale by alpha, + mulps(xmm7, xmm0) // scale by beta, + addps(xmm11, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + + + movaps(mem(rdx), xmm1) // load c43 ~ c73, + mulps(xmm6, xmm15) // scale by alpha, + mulps(xmm7, xmm1) // scale by beta, + addps(xmm15, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + + jmp(.SDONE) // jump to end. + label(.SBETAZERO) - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORBZ) // jump to column storage case - - - - label(.SGENSTORBZ) - - mulps(xmm6, xmm8) // scale by alpha, - movaps(xmm8, xmm0) - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - mulps(xmm6, xmm12) // scale by alpha, - movaps(xmm12, xmm0) - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - mulps(xmm6, xmm9) // scale by alpha, - movaps(xmm9, xmm0) - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - mulps(xmm6, xmm13) // scale by alpha, - movaps(xmm13, xmm0) - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - mulps(xmm6, xmm10) // scale by alpha, - movaps(xmm10, xmm0) - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - add(rdi, rcx) - - - mulps(xmm6, xmm14) // scale by alpha, - movaps(xmm14, xmm0) - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - add(rdi, rdx) - - - mulps(xmm6, xmm11) // scale by alpha, - movaps(xmm11, xmm0) - - movss(xmm0, mem(rcx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rcx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rcx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rcx, r11, 1)) - - - - - mulps(xmm6, xmm15) // scale by alpha, - movaps(xmm15, xmm0) - - movss(xmm0, mem(rdx)) // and store back to memory. - pshufd(imm(0x39), xmm0, xmm1) - movss(xmm1, mem(rdx, rsi, 1)) - pshufd(imm(0x39), xmm1, xmm2) - movss(xmm2, mem(rdx, rsi, 2)) - pshufd(imm(0x39), xmm2, xmm3) - movss(xmm3, mem(rdx, r11, 1)) - - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORBZ) - - // skip loading c00 ~ c30, - mulps(xmm6, xmm8) // scale by alpha, - movaps(xmm8, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c40 ~ c70, - mulps(xmm6, xmm12) // scale by alpha, - movaps(xmm12, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c01 ~ c31, - mulps(xmm6, xmm9) // scale by alpha, - movaps(xmm9, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c41 ~ c71, - mulps(xmm6, xmm13) // scale by alpha, - movaps(xmm13, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c02 ~ c32, - mulps(xmm6, xmm10) // scale by alpha, - movaps(xmm10, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c42 ~ c72, - mulps(xmm6, xmm14) // scale by alpha, - movaps(xmm14, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c03 ~ c33, - mulps(xmm6, xmm11) // scale by alpha, - movaps(xmm11, mem(rcx)) // and store back to memory. - - // skip loading c43 ~ c73, - mulps(xmm6, xmm15) // scale by alpha, - movaps(xmm15, mem(rdx)) // and store back to memory. - - - - - - - - + + // skip loading c00 ~ c30, + mulps(xmm6, xmm8) // scale by alpha, + movaps(xmm8, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c40 ~ c70, + mulps(xmm6, xmm12) // scale by alpha, + movaps(xmm12, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c01 ~ c31, + mulps(xmm6, xmm9) // scale by alpha, + movaps(xmm9, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c41 ~ c71, + mulps(xmm6, xmm13) // scale by alpha, + movaps(xmm13, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c02 ~ c32, + mulps(xmm6, xmm10) // scale by alpha, + movaps(xmm10, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c42 ~ c72, + mulps(xmm6, xmm14) // scale by alpha, + movaps(xmm14, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c03 ~ c33, + mulps(xmm6, xmm11) // scale by alpha, + movaps(xmm11, mem(rcx)) // and store back to memory. + + // skip loading c43 ~ c73, + mulps(xmm6, xmm15) // scale by alpha, + movaps(xmm15, mem(rdx)) // and store back to memory. + label(.SDONE) - + end_asm( : // output operands (none) @@ -842,11 +508,15 @@ void bli_sgemm_penryn_asm_8x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } void bli_dgemm_penryn_asm_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -861,39 +531,41 @@ void bli_dgemm_penryn_asm_4x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT_ALIGNED( d, 4, 4, false, 16 ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r9) // load address of b_next. mov(var(a_next), r11) // load address of a_next. - + sub(imm(0-8*16), rax) // increment pointers to allow byte sub(imm(0-8*16), rbx) // offsets in the unrolled iterations. - + movaps(mem(rax, -8*16), xmm0) // initialize loop by pre-loading elements movaps(mem(rax, -7*16), xmm1) // of a and b. movaps(mem(rbx, -8*16), xmm2) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) mov(rdi, r12) // make a copy of cs_c (in bytes) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(2, mem(r9, 0*8)) // prefetch b_next - + xorpd(xmm3, xmm3) xorpd(xmm4, xmm4) xorpd(xmm5, xmm5) xorpd(xmm6, xmm6) - + prefetch(2, mem(rcx, 3*8)) // prefetch c + 0*cs_c xorpd(xmm8, xmm8) xorpd(xmm9, xmm9) @@ -906,22 +578,22 @@ void bli_dgemm_penryn_asm_4x4 prefetch(2, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c xorpd(xmm14, xmm14) xorpd(xmm15, xmm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - + prefetch(0, mem(rax, (4*35+1)*8)) //prefetch(0, mem(rax, (8*97+4)*8)) - + //prefetch(0, mem(r11, 67*4*8)) // prefetch a_next[0] - + addpd(xmm3, xmm11) // iteration 0 movaps(mem(rbx, -7*16), xmm3) addpd(xmm4, xmm15) @@ -929,13 +601,13 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm2, xmm7) mulpd(xmm0, xmm2) mulpd(xmm1, xmm4) - + addpd(xmm5, xmm10) addpd(xmm6, xmm14) movaps(xmm7, xmm6) mulpd(xmm0, xmm7) mulpd(xmm1, xmm6) - + addpd(xmm2, xmm9) movaps(mem(rbx, -6*16), xmm2) addpd(xmm4, xmm13) @@ -943,7 +615,7 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm3, xmm5) mulpd(xmm0, xmm3) mulpd(xmm1, xmm4) - + addpd(xmm7, xmm8) addpd(xmm6, xmm12) movaps(xmm5, xmm6) @@ -951,9 +623,9 @@ void bli_dgemm_penryn_asm_4x4 movaps(mem(rax, -6*16), xmm0) mulpd(xmm1, xmm6) movaps(mem(rax, -5*16), xmm1) - - - + + + addpd(xmm3, xmm11) // iteration 1 movaps(mem(rbx, -5*16), xmm3) addpd(xmm4, xmm15) @@ -961,13 +633,13 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm2, xmm7) mulpd(xmm0, xmm2) mulpd(xmm1, xmm4) - + addpd(xmm5, xmm10) addpd(xmm6, xmm14) movaps(xmm7, xmm6) mulpd(xmm0, xmm7) mulpd(xmm1, xmm6) - + addpd(xmm2, xmm9) movaps(mem(rbx, -4*16), xmm2) addpd(xmm4, xmm13) @@ -975,7 +647,7 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm3, xmm5) mulpd(xmm0, xmm3) mulpd(xmm1, xmm4) - + addpd(xmm7, xmm8) addpd(xmm6, xmm12) movaps(xmm5, xmm6) @@ -983,16 +655,16 @@ void bli_dgemm_penryn_asm_4x4 movaps(mem(rax, -4*16), xmm0) mulpd(xmm1, xmm6) movaps(mem(rax, -3*16), xmm1) - - + + prefetch(0, mem(rax, (4*37+1)*8)) //prefetch(0, mem(rax, (8*97+12)*8)) - + //prefetch(0, mem(r11, 69*4*8)) // prefetch a_next[8] //sub(imm(-4*4*8), r11) // a_next += 4*4 (unroll x mr) - - - + + + addpd(xmm3, xmm11) // iteration 2 movaps(mem(rbx, -3*16), xmm3) addpd(xmm4, xmm15) @@ -1000,13 +672,13 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm2, xmm7) mulpd(xmm0, xmm2) mulpd(xmm1, xmm4) - + addpd(xmm5, xmm10) addpd(xmm6, xmm14) movaps(xmm7, xmm6) mulpd(xmm0, xmm7) mulpd(xmm1, xmm6) - + addpd(xmm2, xmm9) movaps(mem(rbx, -2*16), xmm2) addpd(xmm4, xmm13) @@ -1014,8 +686,8 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm3, xmm5) mulpd(xmm0, xmm3) mulpd(xmm1, xmm4) - - + + addpd(xmm7, xmm8) addpd(xmm6, xmm12) movaps(xmm5, xmm6) @@ -1023,9 +695,9 @@ void bli_dgemm_penryn_asm_4x4 movaps(mem(rax, -2*16), xmm0) mulpd(xmm1, xmm6) movaps(mem(rax, -1*16), xmm1) - - - + + + addpd(xmm3, xmm11) // iteration 3 movaps(mem(rbx, -1*16), xmm3) addpd(xmm4, xmm15) @@ -1033,17 +705,17 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm2, xmm7) mulpd(xmm0, xmm2) mulpd(xmm1, xmm4) - + sub(imm(0-4*4*8), rax) // a += 4*4 (unroll x mr) - + addpd(xmm5, xmm10) addpd(xmm6, xmm14) movaps(xmm7, xmm6) mulpd(xmm0, xmm7) mulpd(xmm1, xmm6) - + sub(imm(0-4*4*8), r9) // b_next += 4*4 (unroll x nr) - + addpd(xmm2, xmm9) movaps(mem(rbx, 0*16), xmm2) addpd(xmm4, xmm13) @@ -1051,9 +723,9 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm3, xmm5) mulpd(xmm0, xmm3) mulpd(xmm1, xmm4) - + sub(imm(0-4*4*8), rbx) // b += 4*4 (unroll x nr) - + addpd(xmm7, xmm8) addpd(xmm6, xmm12) movaps(xmm5, xmm6) @@ -1061,29 +733,29 @@ void bli_dgemm_penryn_asm_4x4 movaps(mem(rax, -8*16), xmm0) mulpd(xmm1, xmm6) movaps(mem(rax, -7*16), xmm1) - + prefetch(2, mem(r9, 0*8)) // prefetch b_next[0] prefetch(2, mem(r9, 8*8)) // prefetch b_next[8] - + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - + + + //prefetch(2, mem(r9, -8*8)) // prefetch b_next[-8] - - - + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + addpd(xmm3, xmm11) // iteration 0 movaps(mem(rbx, -7*16), xmm3) addpd(xmm4, xmm15) @@ -1091,13 +763,13 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm2, xmm7) mulpd(xmm0, xmm2) mulpd(xmm1, xmm4) - + addpd(xmm5, xmm10) addpd(xmm6, xmm14) movaps(xmm7, xmm6) mulpd(xmm0, xmm7) mulpd(xmm1, xmm6) - + addpd(xmm2, xmm9) movaps(mem(rbx, -6*16), xmm2) addpd(xmm4, xmm13) @@ -1105,7 +777,7 @@ void bli_dgemm_penryn_asm_4x4 pshufd(imm(0x4e), xmm3, xmm5) mulpd(xmm0, xmm3) mulpd(xmm1, xmm4) - + addpd(xmm7, xmm8) addpd(xmm6, xmm12) movaps(xmm5, xmm6) @@ -1113,38 +785,38 @@ void bli_dgemm_penryn_asm_4x4 movaps(mem(rax, -6*16), xmm0) mulpd(xmm1, xmm6) movaps(mem(rax, -5*16), xmm1) - - + + sub(imm(0-4*1*8), rax) // a += 4 (1 x mr) sub(imm(0-4*1*8), rbx) // b += 4 (1 x nr) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + addpd(xmm3, xmm11) addpd(xmm4, xmm15) addpd(xmm5, xmm10) addpd(xmm6, xmm14) - - + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta movddup(mem(rax), xmm6) // load alpha and duplicate movddup(mem(rbx), xmm7) // load beta and duplicate - - + + mov(var(rs_c), rsi) // load rs_c mov(rsi, r8) // make a copy of rs_c - + lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) - + lea(mem(rcx, rsi, 2), rdx) // load address of c + 2*rs_c; - + // xmm8: xmm9: xmm10: xmm11: // ( ab01 ( ab00 ( ab03 ( ab02 // ab10 ) ab11 ) ab12 ) ab13 ) @@ -1155,15 +827,15 @@ void bli_dgemm_penryn_asm_4x4 movaps(xmm8, xmm0) movsd(xmm9, xmm8) movsd(xmm0, xmm9) - + movaps(xmm10, xmm0) movsd(xmm11, xmm10) movsd(xmm0, xmm11) - + movaps(xmm12, xmm0) movsd(xmm13, xmm12) movsd(xmm0, xmm13) - + movaps(xmm14, xmm0) movsd(xmm15, xmm14) movsd(xmm0, xmm15) @@ -1174,298 +846,118 @@ void bli_dgemm_penryn_asm_4x4 // xmm12: xmm13: xmm14: xmm15: // ( ab20 ( ab21 ( ab22 ( ab23 // ab30 ) ab31 ) ab32 ) ab33 ) - - - - // determine if - // c % 16 == 0, AND - // 8*cs_c % 16 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(1), r8) // set ZF if rs_c == 1. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(15), rcx) // set ZF if c & 16 is zero. - setz(bh) // bh = ( ZF == 1 ? 1 : 0 ); - test(imm(15), r12) // set ZF if (8*cs_c) & 16 is zero. - setz(al) // al = ( ZF == 1 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + // now avoid loading C if beta == 0 - + xorpd(xmm0, xmm0) // set xmm0 to zero. ucomisd(xmm0, xmm7) // check if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.DCOLSTORED) // jump to column storage case - - - - label(.DGENSTORED) - - movlpd(mem(rcx), xmm0) // load c00 and c10, - movhpd(mem(rcx, rsi, 1), xmm0) - mulpd(xmm6, xmm8) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm8, xmm0) // add the gemm result, - movlpd(xmm0, mem(rcx)) // and store back to memory. - movhpd(xmm0, mem(rcx, rsi, 1)) - add(rdi, rcx) - - movlpd(mem(rdx), xmm1) // load c20 and c30, - movhpd(mem(rdx, rsi, 1), xmm1) - mulpd(xmm6, xmm12) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm12, xmm1) // add the gemm result, - movlpd(xmm1, mem(rdx)) // and store back to memory. - movhpd(xmm1, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - - movlpd(mem(rcx), xmm0) // load c01 and c11, - movhpd(mem(rcx, rsi, 1), xmm0) - mulpd(xmm6, xmm9) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm9, xmm0) // add the gemm result, - movlpd(xmm0, mem(rcx)) // and store back to memory. - movhpd(xmm0, mem(rcx, rsi, 1)) - add(rdi, rcx) - - movlpd(mem(rdx), xmm1) // load c21 and c31, - movhpd(mem(rdx, rsi, 1), xmm1) - mulpd(xmm6, xmm13) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm13, xmm1) // add the gemm result, - movlpd(xmm1, mem(rdx)) // and store back to memory. - movhpd(xmm1, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - - movlpd(mem(rcx), xmm0) // load c02 and c12, - movhpd(mem(rcx, rsi, 1), xmm0) - mulpd(xmm6, xmm10) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm10, xmm0) // add the gemm result, - movlpd(xmm0, mem(rcx)) // and store back to memory. - movhpd(xmm0, mem(rcx, rsi, 1)) - add(rdi, rcx) - - movlpd(mem(rdx), xmm1) // load c22 and c32, - movhpd(mem(rdx, rsi, 1), xmm1) - mulpd(xmm6, xmm14) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm14, xmm1) // add the gemm result, - movlpd(xmm1, mem(rdx)) // and store back to memory. - movhpd(xmm1, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - - movlpd(mem(rcx), xmm0) // load c03 and c13, - movhpd(mem(rcx, rsi, 1), xmm0) - mulpd(xmm6, xmm11) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm11, xmm0) // add the gemm result, - movlpd(xmm0, mem(rcx)) // and store back to memory. - movhpd(xmm0, mem(rcx, rsi, 1)) - - - movlpd(mem(rdx), xmm1) // load c23 and c33, - movhpd(mem(rdx, rsi, 1), xmm1) - mulpd(xmm6, xmm15) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm15, xmm1) // add the gemm result, - movlpd(xmm1, mem(rdx)) // and store back to memory. - movhpd(xmm1, mem(rdx, rsi, 1)) - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORED) - - movaps(mem(rcx), xmm0) // load c00 and c10, - mulpd(xmm6, xmm8) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm8, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c20 and c30, - mulpd(xmm6, xmm12) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm12, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c01 and c11, - mulpd(xmm6, xmm9) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm9, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c21 and c31, - mulpd(xmm6, xmm13) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm13, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c02 and c12, - mulpd(xmm6, xmm10) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm10, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) - - movaps(mem(rdx), xmm1) // load c22 and c32, - mulpd(xmm6, xmm14) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm14, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - - movaps(mem(rcx), xmm0) // load c03 and c13, - mulpd(xmm6, xmm11) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm11, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - - - movaps(mem(rdx), xmm1) // load c23 and c33, - mulpd(xmm6, xmm15) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm15, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - - jmp(.DDONE) // jump to end. - - - - + + movaps(mem(rcx), xmm0) // load c00 and c10, + mulpd(xmm6, xmm8) // scale by alpha, + mulpd(xmm7, xmm0) // scale by beta, + addpd(xmm8, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c20 and c30, + mulpd(xmm6, xmm12) // scale by alpha, + mulpd(xmm7, xmm1) // scale by beta, + addpd(xmm12, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c01 and c11, + mulpd(xmm6, xmm9) // scale by alpha, + mulpd(xmm7, xmm0) // scale by beta, + addpd(xmm9, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c21 and c31, + mulpd(xmm6, xmm13) // scale by alpha, + mulpd(xmm7, xmm1) // scale by beta, + addpd(xmm13, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c02 and c12, + mulpd(xmm6, xmm10) // scale by alpha, + mulpd(xmm7, xmm0) // scale by beta, + addpd(xmm10, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) + + movaps(mem(rdx), xmm1) // load c22 and c32, + mulpd(xmm6, xmm14) // scale by alpha, + mulpd(xmm7, xmm1) // scale by beta, + addpd(xmm14, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + + movaps(mem(rcx), xmm0) // load c03 and c13, + mulpd(xmm6, xmm11) // scale by alpha, + mulpd(xmm7, xmm0) // scale by beta, + addpd(xmm11, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + + + movaps(mem(rdx), xmm1) // load c23 and c33, + mulpd(xmm6, xmm15) // scale by alpha, + mulpd(xmm7, xmm1) // scale by beta, + addpd(xmm15, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + + jmp(.DDONE) // jump to end. + label(.DBETAZERO) - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.DCOLSTORBZ) // jump to column storage case - - - - label(.DGENSTORBZ) - // skip loading c00 and c10, - mulpd(xmm6, xmm8) // scale by alpha, - movlpd(xmm8, mem(rcx)) // and store back to memory. - movhpd(xmm8, mem(rcx, rsi, 1)) - add(rdi, rcx) - // skip loading c20 and c30, - mulpd(xmm6, xmm12) // scale by alpha, - movlpd(xmm12, mem(rdx)) // and store back to memory. - movhpd(xmm12, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - // skip loading c01 and c11, - mulpd(xmm6, xmm9) // scale by alpha, - movlpd(xmm9, mem(rcx)) // and store back to memory. - movhpd(xmm9, mem(rcx, rsi, 1)) - add(rdi, rcx) - // skip loading c21 and c31, - mulpd(xmm6, xmm13) // scale by alpha, - movlpd(xmm13, mem(rdx)) // and store back to memory. - movhpd(xmm13, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - // skip loading c02 and c12, - mulpd(xmm6, xmm10) // scale by alpha, - movlpd(xmm10, mem(rcx)) // and store back to memory. - movhpd(xmm10, mem(rcx, rsi, 1)) - add(rdi, rcx) - // skip loading c22 and c32, - mulpd(xmm6, xmm14) // scale by alpha, - movlpd(xmm14, mem(rdx)) // and store back to memory. - movhpd(xmm14, mem(rdx, rsi, 1)) - add(rdi, rdx) - - - // skip loading c03 and c13, - mulpd(xmm6, xmm11) // scale by alpha, - movlpd(xmm11, mem(rcx)) // and store back to memory. - movhpd(xmm11, mem(rcx, rsi, 1)) - - // skip loading c23 and c33, - mulpd(xmm6, xmm15) // scale by alpha, - movlpd(xmm15, mem(rdx)) // and store back to memory. - movhpd(xmm15, mem(rdx, rsi, 1)) - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORBZ) - - // skip loading c00 and c10, - mulpd(xmm6, xmm8) // scale by alpha, - movaps(xmm8, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c20 and c30, - mulpd(xmm6, xmm12) // scale by alpha, - movaps(xmm12, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c01 and c11, - mulpd(xmm6, xmm9) // scale by alpha, - movaps(xmm9, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c21 and c31, - mulpd(xmm6, xmm13) // scale by alpha, - movaps(xmm13, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c02 and c12, - mulpd(xmm6, xmm10) // scale by alpha, - movaps(xmm10, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c22 and c32, - mulpd(xmm6, xmm14) // scale by alpha, - movaps(xmm14, mem(rdx)) // and store back to memory. - add(rdi, rdx) - - - // skip loading c03 and c13, - mulpd(xmm6, xmm11) // scale by alpha, - movaps(xmm11, mem(rcx)) // and store back to memory. - - // skip loading c23 and c33, - mulpd(xmm6, xmm15) // scale by alpha, - movaps(xmm15, mem(rdx)) // and store back to memory. - - - - - - - - + + // skip loading c00 and c10, + mulpd(xmm6, xmm8) // scale by alpha, + movaps(xmm8, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c20 and c30, + mulpd(xmm6, xmm12) // scale by alpha, + movaps(xmm12, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c01 and c11, + mulpd(xmm6, xmm9) // scale by alpha, + movaps(xmm9, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c21 and c31, + mulpd(xmm6, xmm13) // scale by alpha, + movaps(xmm13, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c02 and c12, + mulpd(xmm6, xmm10) // scale by alpha, + movaps(xmm10, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c22 and c32, + mulpd(xmm6, xmm14) // scale by alpha, + movaps(xmm14, mem(rdx)) // and store back to memory. + add(rdi, rdx) + + + // skip loading c03 and c13, + mulpd(xmm6, xmm11) // scale by alpha, + movaps(xmm11, mem(rcx)) // and store back to memory. + + // skip loading c23 and c33, + mulpd(xmm6, xmm15) // scale by alpha, + movaps(xmm15, mem(rdx)) // and store back to memory. + label(.DDONE) - + end_asm( : // output operands (none) @@ -1489,6 +981,8 @@ void bli_dgemm_penryn_asm_4x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/piledriver/3/bli_gemm_piledriver_asm_d8x3.c b/kernels/piledriver/3/bli_gemm_piledriver_asm_d8x3.c index 5963dabee6..709c2b40b2 100644 --- a/kernels/piledriver/3/bli_gemm_piledriver_asm_d8x3.c +++ b/kernels/piledriver/3/bli_gemm_piledriver_asm_d8x3.c @@ -42,7 +42,9 @@ void bli_sgemm_piledriver_asm_16x3 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -57,36 +59,38 @@ void bli_sgemm_piledriver_asm_16x3 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 8; - uint64_t k_left = k0 % 8; + uint64_t k_iter = k / 8; + uint64_t k_left = k % 8; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( s, 16, 3, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. mov(var(a_next), r14) // load address of a_next. - + prefetch(0, mem(rbx, 128)) // prefetch b prefetch(0, mem(rbx, 64+128)) // prefetch b prefetch(0, mem(rbx, 128+128)) // prefetch b - + add(imm(32*4), rax) add(imm(12*4), rbx) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float) lea(mem(rcx, rdi, 1), r10) // load address of c + 1*cs_c; lea(mem(rcx, rdi, 2), r11) // load address of c + 2*cs_c; - + vbroadcastss(mem(rbx, -12*4), xmm1) vbroadcastss(mem(rbx, -11*4), xmm2) vbroadcastss(mem(rbx, -10*4), xmm3) - + vxorps(xmm4, xmm4, xmm4) vxorps(xmm5, xmm5, xmm5) vxorps(xmm6, xmm6, xmm6) @@ -99,23 +103,23 @@ void bli_sgemm_piledriver_asm_16x3 vxorps(xmm13, xmm13, xmm13) vxorps(xmm14, xmm14, xmm14) vxorps(xmm15, xmm15, xmm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.SLOOPKITER) // MAIN LOOP - - + + je(.SCONSIDKLEFT) // if i == 0, jump to k_left code. - - + + prefetch(0, mem(rbx, 16+192)) // prefetch b - + // iteration 0 vmovaps(mem(rax, -32*4), xmm0) prefetch(0, mem(rax, 384)) @@ -136,7 +140,7 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, -8*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + // iteration 1 vmovaps(mem(rax, -16*4), xmm0) vbroadcastss(mem(rbx, -7*4), xmm3) @@ -158,7 +162,7 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, -5*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + // iteration 2 vmovaps(mem(rax, 0*4), xmm0) vbroadcastss(mem(rbx, -4*4), xmm3) @@ -180,7 +184,7 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, -2*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + // iteration 3 vmovaps(mem(rax, 16*4), xmm0) vbroadcastss(mem(rbx, -1*4), xmm3) @@ -202,10 +206,10 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, 1*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - - + + add(imm(4*16*4), rax) // a += 4*16 (unroll x mr) - + // iteration 4 vmovaps(mem(rax, -32*4), xmm0) vbroadcastss(mem(rbx, 2*4), xmm3) @@ -227,9 +231,9 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, 4*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + prefetch(0, mem(rbx, 80+192)) // prefetch b - + // iteration 5 vmovaps(mem(rax, -16*4), xmm0) vbroadcastss(mem(rbx, 5*4), xmm3) @@ -251,7 +255,7 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, 7*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + // iteration 6 vmovaps(mem(rax, 0*4), xmm0) vbroadcastss(mem(rbx, 8*4), xmm3) @@ -273,7 +277,7 @@ void bli_sgemm_piledriver_asm_16x3 vfmadd231ps(xmm2, xmm0, xmm14) vbroadcastss(mem(rbx, 10*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) - + // iteration 7 vmovaps(mem(rax, 16*4), xmm0) vbroadcastss(mem(rbx, 11*4), xmm3) @@ -298,34 +302,34 @@ void bli_sgemm_piledriver_asm_16x3 vbroadcastss(mem(rbx, -11*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) vbroadcastss(mem(rbx, -10*4), xmm3) - - - - + + + + dec(rsi) // i -= 1; jmp(.SLOOPKITER) // jump to beginning of loop. - - - - - - + + + + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.SLOOPKLEFT) // EDGE LOOP - - + + je(.SPOSTACCUM) // if i == 0, we're done. - - + + prefetch(0, mem(rbx, 16+192)) // prefetch b - + // iteration 0 vmovaps(mem(rax, -32*4), xmm0) prefetch(0, mem(rax, 384)) @@ -347,56 +351,56 @@ void bli_sgemm_piledriver_asm_16x3 vbroadcastss(mem(rbx, -8*4), xmm2) vfmadd231ps(xmm3, xmm0, xmm15) vbroadcastss(mem(rbx, -7*4), xmm3) - - + + add(imm(1*16*4), rax) // a += 4*16 (unroll x mr) add(imm(1*3*4), rbx) // a += 4*3 (unroll x nr) - - + + dec(rsi) // i -= 1; jmp(.SLOOPKLEFT) // jump to beginning of loop. - - - + + + label(.SPOSTACCUM) - - + + prefetchw0(mem(rcx, 0*8)) // prefetch c + 0*cs_c prefetchw0(mem(r10, 0*8)) // prefetch c + 1*cs_c prefetchw0(mem(r11, 0*8)) // prefetch c + 2*cs_c - - - // xmm4: xmm5: xmm6: + + + // xmm4: xmm5: xmm6: // ( ab00 ( ab01 ( ab02 - // ab10 ab11 ab12 + // ab10 ab11 ab12 // ab20 ab21 ab22 // ab30 ) ab31 ) ab32 ) - - // xmm7: xmm8: xmm9: + + // xmm7: xmm8: xmm9: // ( ab40 ( ab41 ( ab42 - // ab50 ab51 ab52 + // ab50 ab51 ab52 // ab60 ab61 ab62 // ab70 ) ab71 ) ab72 ) - + // xmm10: xmm11: xmm12: // ( ab80 ( ab01 ( ab02 - // ab90 ab11 ab12 + // ab90 ab11 ab12 // abA0 abA1 abA2 // abB0 ) abB1 ) abB2 ) - + // xmm13: xmm14: xmm15: // ( abC0 ( abC1 ( abC2 - // abD0 abD1 abD2 + // abD0 abD1 abD2 // abE0 abE1 abE2 // abF0 ) abF1 ) abF2 ) - - - + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rax), xmm0) // load alpha and duplicate vbroadcastss(mem(rbx), xmm2) // load beta and duplicate - + vmulps(xmm0, xmm4, xmm4) // scale by alpha vmulps(xmm0, xmm5, xmm5) vmulps(xmm0, xmm6, xmm6) @@ -409,32 +413,32 @@ void bli_sgemm_piledriver_asm_16x3 vmulps(xmm0, xmm13, xmm13) vmulps(xmm0, xmm14, xmm14) vmulps(xmm0, xmm15, xmm15) - - - + + + prefetch(0, mem(r14)) // prefetch a_next prefetch(0, mem(r14, 64)) // prefetch a_next - - - - + + + + mov(var(rs_c), rsi) // load rs_c lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - + lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - - - + + + // determine if // c % 32 == 0, AND // 4*cs_c % 32 == 0, AND // rs_c == 1 // ie: aligned, ldim aligned, and // column-stored - + cmp(imm(4), rsi) // set ZF if (4*rs_c) == 4. sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); test(imm(31), rcx) // set ZF if c & 32 is zero. @@ -443,448 +447,52 @@ void bli_sgemm_piledriver_asm_16x3 setz(al) // al = ( ZF == 0 ? 1 : 0 ); // and(bl,bh) followed by // and(bh,al) will reveal result - + prefetch(0, mem(r15)) // prefetch b_next prefetch(0, mem(r15, 64)) // prefetch b_next - + // now avoid loading C if beta == 0 - + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. vucomiss(xmm0, xmm2) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORED) // jump to column storage case - - - - label(.SGENSTORED) - - - vmovlps(mem(rcx), xmm0, xmm0) // load c00:c30 - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm4, xmm0, xmm0) - vmovss(xmm0, mem(rcx)) // store c00:c30 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovlps(mem(rcx), xmm0, xmm0) // load c40:c70 - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm7, xmm0, xmm0) - vmovss(xmm0, mem(rcx)) // store c40:c70 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovlps(mem(rcx), xmm0, xmm0) // load c80:cB0 - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm10, xmm0, xmm0) - vmovss(xmm0, mem(rcx)) // store c80:cB0 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovlps(mem(rcx), xmm0, xmm0) // load cC0:cF0 - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm13, xmm0, xmm0) - vmovss(xmm0, mem(rcx)) // store cC0:cF0 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovlps(mem(r10), xmm0, xmm0) // load c01:c31 - vmovhps(mem(r10, rsi, 1), xmm0, xmm0) - vmovlps(mem(r10, r12, 1), xmm1, xmm1) - vmovhps(mem(r10, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm5, xmm0, xmm0) - vmovss(xmm0, mem(r10)) // store c01:c31 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovlps(mem(r10), xmm0, xmm0) // load c41:c71 - vmovhps(mem(r10, rsi, 1), xmm0, xmm0) - vmovlps(mem(r10, r12, 1), xmm1, xmm1) - vmovhps(mem(r10, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm8, xmm0, xmm0) - vmovss(xmm0, mem(r10)) // store c41:c71 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovlps(mem(r10), xmm0, xmm0) // load c81:cB1 - vmovhps(mem(r10, rsi, 1), xmm0, xmm0) - vmovlps(mem(r10, r12, 1), xmm1, xmm1) - vmovhps(mem(r10, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm11, xmm0, xmm0) - vmovss(xmm0, mem(r10)) // store c81:cB1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovlps(mem(r10), xmm0, xmm0) // load cC1:cF1 - vmovhps(mem(r10, rsi, 1), xmm0, xmm0) - vmovlps(mem(r10, r12, 1), xmm1, xmm1) - vmovhps(mem(r10, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm14, xmm0, xmm0) - vmovss(xmm0, mem(r10)) // store cC1:cF1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovlps(mem(r11), xmm0, xmm0) // load c02:c32 - vmovhps(mem(r11, rsi, 1), xmm0, xmm0) - vmovlps(mem(r11, r12, 1), xmm1, xmm1) - vmovhps(mem(r11, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm6, xmm0, xmm0) - vmovss(xmm0, mem(r11)) // store c02:c32 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovlps(mem(r11), xmm0, xmm0) // load c42:c72 - vmovhps(mem(r11, rsi, 1), xmm0, xmm0) - vmovlps(mem(r11, r12, 1), xmm1, xmm1) - vmovhps(mem(r11, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm9, xmm0, xmm0) - vmovss(xmm0, mem(r11)) // store c42:c72 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovlps(mem(r11), xmm0, xmm0) // load c82:cB2 - vmovhps(mem(r11, rsi, 1), xmm0, xmm0) - vmovlps(mem(r11, r12, 1), xmm1, xmm1) - vmovhps(mem(r11, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm12, xmm0, xmm0) - vmovss(xmm0, mem(r11)) // store c82:cB2 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovlps(mem(r11), xmm0, xmm0) // load cC2:cF2 - vmovhps(mem(r11, rsi, 1), xmm0, xmm0) - vmovlps(mem(r11, r12, 1), xmm1, xmm1) - vmovhps(mem(r11, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmulps(xmm2, xmm0, xmm0) - vaddps(xmm15, xmm0, xmm0) - vmovss(xmm0, mem(r11)) // store cC2:cF1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - - vfmadd231ps(mem(rcx, 0*16), xmm2, xmm4) - vfmadd231ps(mem(rcx, 1*16), xmm2, xmm7) - vfmadd231ps(mem(rcx, 2*16), xmm2, xmm10) - vfmadd231ps(mem(rcx, 3*16), xmm2, xmm13) - - vmovups(xmm4, mem(rcx, 0*16)) - vmovups(xmm7, mem(rcx, 1*16)) - vmovups(xmm10, mem(rcx, 2*16)) - vmovups(xmm13, mem(rcx, 3*16)) - - vfmadd231ps(mem(r10, 0*16), xmm2, xmm5) - vfmadd231ps(mem(r10, 1*16), xmm2, xmm8) - vfmadd231ps(mem(r10, 2*16), xmm2, xmm11) - vfmadd231ps(mem(r10, 3*16), xmm2, xmm14) - - vmovups(xmm5, mem(r10, 0*16)) - vmovups(xmm8, mem(r10, 1*16)) - vmovups(xmm11, mem(r10, 2*16)) - vmovups(xmm14, mem(r10, 3*16)) - - vfmadd231ps(mem(r11, 0*16), xmm2, xmm6) - vfmadd231ps(mem(r11, 1*16), xmm2, xmm9) - vfmadd231ps(mem(r11, 2*16), xmm2, xmm12) - vfmadd231ps(mem(r11, 3*16), xmm2, xmm15) - - vmovups(xmm6, mem(r11, 0*16)) - vmovups(xmm9, mem(r11, 1*16)) - vmovups(xmm12, mem(r11, 2*16)) - vmovups(xmm15, mem(r11, 3*16)) - - - - jmp(.SDONE) // jump to end. - - - + + vfmadd231ps(mem(rcx, 0*16), xmm2, xmm4) + vfmadd231ps(mem(rcx, 1*16), xmm2, xmm7) + vfmadd231ps(mem(rcx, 2*16), xmm2, xmm10) + vfmadd231ps(mem(rcx, 3*16), xmm2, xmm13) + + vfmadd231ps(mem(r10, 0*16), xmm2, xmm5) + vfmadd231ps(mem(r10, 1*16), xmm2, xmm8) + vfmadd231ps(mem(r10, 2*16), xmm2, xmm11) + vfmadd231ps(mem(r10, 3*16), xmm2, xmm14) + + vfmadd231ps(mem(r11, 0*16), xmm2, xmm6) + vfmadd231ps(mem(r11, 1*16), xmm2, xmm9) + vfmadd231ps(mem(r11, 2*16), xmm2, xmm12) + vfmadd231ps(mem(r11, 3*16), xmm2, xmm15) + + // fall through + label(.SBETAZERO) - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.SCOLSTORBZ) // jump to column storage case - - - - label(.SGENSTORBZ) - - - vmovaps(xmm4, xmm0) - vmovss(xmm0, mem(rcx)) // store c00:c30 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovaps(xmm7, xmm0) - vmovss(xmm0, mem(rcx)) // store c40:c70 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovaps(xmm10, xmm0) - vmovss(xmm0, mem(rcx)) // store c80:cB0 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovaps(xmm13, xmm0) - vmovss(xmm0, mem(rcx)) // store cC0:cF0 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) // c += 4*rs_c; - - - vmovaps(xmm5, xmm0) - vmovss(xmm0, mem(r10)) // store c01:c31 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovaps(xmm8, xmm0) - vmovss(xmm0, mem(r10)) // store c41:c71 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovaps(xmm11, xmm0) - vmovss(xmm0, mem(r10)) // store c81:cB1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovaps(xmm14, xmm0) - vmovss(xmm0, mem(r10)) // store cC1:cF1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r10, r13, 1)) - lea(mem(r10, rsi, 4), r10) // c += 4*rs_c; - - - vmovaps(xmm6, xmm0) - vmovss(xmm0, mem(r11)) // store c02:c32 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovaps(xmm9, xmm0) - vmovss(xmm0, mem(r11)) // store c42:c72 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovaps(xmm12, xmm0) - vmovss(xmm0, mem(r11)) // store c82:cB2 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - vmovaps(xmm15, xmm0) - vmovss(xmm0, mem(r11)) // store cC2:cF1 - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, rsi, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm0) - vmovss(xmm0, mem(r11, r13, 1)) - lea(mem(r11, rsi, 4), r11) // c += 4*rs_c; - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORBZ) - - - vmovups(xmm4, mem(rcx, 0*16)) - vmovups(xmm7, mem(rcx, 1*16)) - vmovups(xmm10, mem(rcx, 2*16)) - vmovups(xmm13, mem(rcx, 3*16)) - - vmovups(xmm5, mem(r10, 0*16)) - vmovups(xmm8, mem(r10, 1*16)) - vmovups(xmm11, mem(r10, 2*16)) - vmovups(xmm14, mem(r10, 3*16)) - - vmovups(xmm6, mem(r11, 0*16)) - vmovups(xmm9, mem(r11, 1*16)) - vmovups(xmm12, mem(r11, 2*16)) - vmovups(xmm15, mem(r11, 3*16)) - - - - - - + + vmovups(xmm4, mem(rcx, 0*16)) + vmovups(xmm7, mem(rcx, 1*16)) + vmovups(xmm10, mem(rcx, 2*16)) + vmovups(xmm13, mem(rcx, 3*16)) + + vmovups(xmm5, mem(r10, 0*16)) + vmovups(xmm8, mem(r10, 1*16)) + vmovups(xmm11, mem(r10, 2*16)) + vmovups(xmm14, mem(r10, 3*16)) + + vmovups(xmm6, mem(r11, 0*16)) + vmovups(xmm9, mem(r11, 1*16)) + vmovups(xmm12, mem(r11, 2*16)) + vmovups(xmm15, mem(r11, 3*16)) + label(.SDONE) - + end_asm( : // output operands (none) @@ -901,7 +509,7 @@ void bli_sgemm_piledriver_asm_16x3 [b_next] "m" (b_next), // 9 [a_next] "m" (a_next) // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -909,11 +517,15 @@ void bli_sgemm_piledriver_asm_16x3 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } void bli_dgemm_piledriver_asm_8x3 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -928,36 +540,38 @@ void bli_dgemm_piledriver_asm_8x3 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 8; - uint64_t k_left = k0 % 8; + uint64_t k_iter = k / 8; + uint64_t k_left = k % 8; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( d, 8, 3, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. mov(var(a_next), r14) // load address of a_next. - + prefetch(0, mem(rbx, 128)) // prefetch b prefetch(0, mem(rbx, 64+128)) // prefetch b prefetch(0, mem(rbx, 128+128)) // prefetch b - + add(imm(16*8), rax) add(imm(12*8), rbx) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) lea(mem(rcx, rdi, 1), r10) // load address of c + 1*cs_c; lea(mem(rcx, rdi, 2), r11) // load address of c + 2*cs_c; - + vmovddup(mem(rbx, -12*8), xmm1) vmovddup(mem(rbx, -11*8), xmm2) vmovddup(mem(rbx, -10*8), xmm3) - + vxorpd(xmm4, xmm4, xmm4) vxorpd(xmm5, xmm5, xmm5) vxorpd(xmm6, xmm6, xmm6) @@ -970,24 +584,24 @@ void bli_dgemm_piledriver_asm_8x3 vxorpd(xmm13, xmm13, xmm13) vxorpd(xmm14, xmm14, xmm14) vxorpd(xmm15, xmm15, xmm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + je(.DCONSIDKLEFT) // if i == 0, jump to k_left code. - - + + prefetch(0, mem(rbx, -32+256)) // prefetch b prefetch(0, mem(rbx, 32+256)) // prefetch b - + // iteration 0 vmovaps(mem(rax, -8*16), xmm0) prefetch(0, mem(rax, 384)) // prefetch a @@ -1008,7 +622,7 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, -8*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + // iteration 1 vmovaps(mem(rax, -4*16), xmm0) prefetch(0, mem(rax, 64+384)) // prefetch a @@ -1030,7 +644,7 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, -5*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + // iteration 2 vmovaps(mem(rax, 0*16), xmm0) prefetch(0, mem(rax, 128+384)) // prefetch a @@ -1052,7 +666,7 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, -2*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + // iteration 3 vmovaps(mem(rax, 4*16), xmm0) prefetch(0, mem(rax, 192+384)) // prefetch a @@ -1075,7 +689,7 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, 1*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + // iteration 4 vmovaps(mem(rax, -8*16), xmm0) prefetch(0, mem(rax, 384)) // prefetch a @@ -1097,9 +711,9 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, 4*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + prefetch(0, mem(rbx, 96+256)) // prefetch b - + // iteration 5 vmovaps(mem(rax, -4*16), xmm0) prefetch(0, mem(rax, 64+384)) // prefetch a @@ -1121,8 +735,8 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, 7*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - - + + // iteration 6 vmovaps(mem(rax, 0*16), xmm0) prefetch(0, mem(rax, 128+384)) // prefetch a @@ -1144,7 +758,7 @@ void bli_dgemm_piledriver_asm_8x3 vfmadd231pd(xmm2, xmm0, xmm14) vmovddup(mem(rbx, 10*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) - + // iteration 7 vmovaps(mem(rax, 4*16), xmm0) prefetch(0, mem(rax, 192+384)) // prefetch a @@ -1169,31 +783,31 @@ void bli_dgemm_piledriver_asm_8x3 vmovddup(mem(rbx, -11*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) vmovddup(mem(rbx, -10*8), xmm3) - - - + + + dec(rsi) // i -= 1; jmp(.DLOOPKITER) // jump to beginning of loop. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done. // else, we prepare to // enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - - + + je(.DPOSTACCUM) // if i == 0, we're done. - + // iteration 0 vmovaps(mem(rax, -8*16), xmm0) prefetch(0, mem(rax, 512)) // prefetch a @@ -1215,48 +829,48 @@ void bli_dgemm_piledriver_asm_8x3 vmovddup(mem(rbx, -8*8), xmm2) vfmadd231pd(xmm3, xmm0, xmm15) vmovddup(mem(rbx, -7*8), xmm3) - - + + add(imm(1*8*8), rax) // a += 1*8 (1 x mr) add(imm(1*3*8), rbx) // b += 1*3 (1 x nr) - - + + dec(rsi) // i -= 1; jmp(.DLOOPKLEFT) // jump to beginning of loop. - - - + + + label(.DPOSTACCUM) - + prefetchw0(mem(rcx, 0*8)) // prefetch c + 0*cs_c prefetchw0(mem(r10, 0*8)) // prefetch c + 1*cs_c prefetchw0(mem(r11, 0*8)) // prefetch c + 2*cs_c - - - // xmm4: xmm5: xmm6: - // ( ab00 ( ab01 ( ab02 + + + // xmm4: xmm5: xmm6: + // ( ab00 ( ab01 ( ab02 // ab10 ) ab11 ) ab12 ) // - // xmm7: xmm8: xmm9: - // ( ab20 ( ab21 ( ab22 + // xmm7: xmm8: xmm9: + // ( ab20 ( ab21 ( ab22 // ab30 ) ab31 ) ab32 ) // - // xmm10: xmm11: xmm12: - // ( ab40 ( ab41 ( ab42 + // xmm10: xmm11: xmm12: + // ( ab40 ( ab41 ( ab42 // ab50 ) ab51 ) ab52 ) // - // xmm13: xmm14: xmm15: - // ( ab60 ( ab61 ( ab62 + // xmm13: xmm14: xmm15: + // ( ab60 ( ab61 ( ab62 // ab70 ) ab71 ) ab72 ) - - - - + + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vmovddup(mem(rax), xmm0) // load alpha and duplicate vmovddup(mem(rbx), xmm2) // load beta and duplicate - + vmulpd(xmm0, xmm4, xmm4) // scale by alpha vmulpd(xmm0, xmm5, xmm5) vmulpd(xmm0, xmm6, xmm6) @@ -1269,341 +883,72 @@ void bli_dgemm_piledriver_asm_8x3 vmulpd(xmm0, xmm13, xmm13) vmulpd(xmm0, xmm14, xmm14) vmulpd(xmm0, xmm15, xmm15) - - + + prefetch(0, mem(r14)) // prefetch a_next prefetch(0, mem(r14, 64)) // prefetch a_next - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) - - lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - - lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - - - - // determine if - // c % 32 == 0, AND - // 8*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (8*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + prefetch(0, mem(r15)) // prefetch b_next prefetch(0, mem(r15, 64)) // prefetch b_next - + // now avoid loading C if beta == 0 - + vxorpd(xmm0, xmm0, xmm0) // set xmm0 to zero. vucomisd(xmm0, xmm2) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - je(.DGENSTORED) // jump to column storage case - - - - label(.DCOLSTORED) - - // xmm4: xmm5: xmm6: - // ( ab00 ( ab01 ( ab02 - // ab10 ) ab11 ) ab12 ) - // - // xmm7: xmm8: xmm9: - // ( ab20 ( ab21 ( ab22 - // ab30 ) ab31 ) ab32 ) - // - // xmm10: xmm11: xmm12: - // ( ab40 ( ab41 ( ab42 - // ab50 ) ab51 ) ab52 ) - // - // xmm13: xmm14: xmm15: - // ( ab60 ( ab61 ( ab62 - // ab70 ) ab71 ) ab72 ) - - - vfmadd231pd(mem(rcx, 0*16), xmm2, xmm4) - vfmadd231pd(mem(rcx, 1*16), xmm2, xmm7) - vfmadd231pd(mem(rcx, 2*16), xmm2, xmm10) - vfmadd231pd(mem(rcx, 3*16), xmm2, xmm13) - - vfmadd231pd(mem(r10, 0*16), xmm2, xmm5) - vfmadd231pd(mem(r10, 1*16), xmm2, xmm8) - vfmadd231pd(mem(r10, 2*16), xmm2, xmm11) - vfmadd231pd(mem(r10, 3*16), xmm2, xmm14) - - vfmadd231pd(mem(r11, 0*16), xmm2, xmm6) - vfmadd231pd(mem(r11, 1*16), xmm2, xmm9) - vfmadd231pd(mem(r11, 2*16), xmm2, xmm12) - vfmadd231pd(mem(r11, 3*16), xmm2, xmm15) - - - vmovups(xmm4, mem(rcx, 0*16)) - vmovups(xmm7, mem(rcx, 1*16)) - vmovups(xmm10, mem(rcx, 2*16)) - vmovups(xmm13, mem(rcx, 3*16)) - - vmovups(xmm5, mem(r10, 0*16)) - vmovups(xmm8, mem(r10, 1*16)) - vmovups(xmm11, mem(r10, 2*16)) - vmovups(xmm14, mem(r10, 3*16)) - - vmovups(xmm6, mem(r11, 0*16)) - vmovups(xmm9, mem(r11, 1*16)) - vmovups(xmm12, mem(r11, 2*16)) - vmovups(xmm15, mem(r11, 3*16)) - - - - -/* - vmovupd(mem(rcx), xmm0) // load c00:c10 - vmovupd(mem(rcx, r12, 1), xmm1) // load c20:c30 - vfmadd231pd(xmm2, xmm0, xmm4) - vfmadd231pd(xmm2, xmm1, xmm7) - vmovupd(xmm4, mem(rcx)) // store c00:c10 - vmovupd(xmm7, mem(rcx, r12, 1)) // store c20:c30 - add(rdi, rcx) - - vmovupd(mem(rdx), xmm0) // load c40:c50 - vmovupd(mem(rdx, r12, 1), xmm1) // load c60:c70 - vfmadd213pd(xmm10, xmm2, xmm0) - vfmadd213pd(xmm13, xmm2, xmm1) - vmovupd(xmm0, mem(rdx)) // store c40:c50 - vmovupd(xmm1, mem(rdx, r12, 1)) // store c60:c70 - add(rdi, rdx) - - - vmovupd(mem(rcx), xmm0) // load c01:c11 - vmovupd(mem(rcx, r12, 1), xmm1) // load c21:c31 - vfmadd213pd(xmm5, xmm2, xmm0) - vfmadd213pd(xmm8, xmm2, xmm1) - vmovupd(xmm0, mem(rcx)) // store c01:c11 - vmovupd(xmm1, mem(rcx, r12, 1)) // store c21:c31 - add(rdi, rcx) - - vmovupd(mem(rdx), xmm0) // load c41:c51 - vmovupd(mem(rdx, r12, 1), xmm1) // load c61:c71 - vfmadd213pd(xmm11, xmm2, xmm0) - vfmadd213pd(xmm14, xmm2, xmm1) - vmovupd(xmm0, mem(rdx)) // store c41:c51 - vmovupd(xmm1, mem(rdx, r12, 1)) // store c61:c71 - add(rdi, rdx) - - - vmovupd(mem(rcx), xmm0) // load c02:c12 - vmovupd(mem(rcx, r12, 1), xmm1) // load c22:c32 - vfmadd213pd(xmm6, xmm2, xmm0) - vfmadd213pd(xmm9, xmm2, xmm1) - vmovupd(xmm0, mem(rcx)) // store c02:c12 - vmovupd(xmm1, mem(rcx, r12, 1)) // store c22:c32 - - vmovupd(mem(rdx), xmm0) // load c42:c52 - vmovupd(mem(rdx, r12, 1), xmm1) // load c62:c72 - vfmadd213pd(xmm12, xmm2, xmm0) - vfmadd213pd(xmm15, xmm2, xmm1) - vmovupd(xmm0, mem(rdx)) // store c42:c52 - vmovupd(xmm1, mem(rdx, r12, 1)) // store c62:c72 -*/ - - - - jmp(.DDONE) // jump to end. - - - - label(.DGENSTORED) - - - vmovlpd(mem(rcx), xmm0, xmm0) // load c00:c10 - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm4, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx)) // store c00:c10 - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c20:c30 - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm7, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx, r12, 1)) // store c20:c30 - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) - - vmovlpd(mem(rdx), xmm0, xmm0) // load c40:c50 - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm10, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx)) // store c40:c50 - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c60:c70 - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm13, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx, r12, 1)) // store c60:c70 - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) - - - vmovlpd(mem(rcx), xmm0, xmm0) // load c01:c11 - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm5, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx)) // store c01:c11 - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c21:c31 - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm8, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx, r12, 1)) // store c21:c31 - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) - - vmovlpd(mem(rdx), xmm0, xmm0) // load c41:c51 - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm11, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx)) // store c41:c51 - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c61:c71 - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm14, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx, r12, 1)) // store c61:c71 - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) - - - vmovlpd(mem(rcx), xmm0, xmm0) // load c02:c12 - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm6, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx)) // store c02:c12 - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c22:c32 - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm9, xmm0, xmm0) - vmovlpd(xmm0, mem(rcx, r12, 1)) // store c22:c32 - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) - - vmovlpd(mem(rdx), xmm0, xmm0) // load c42:c52 - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm12, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx)) // store c42:c52 - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c62:c72 - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) - vaddpd(xmm15, xmm0, xmm0) - vmovlpd(xmm0, mem(rdx, r12, 1)) // store c62:c72 - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) - - - - jmp(.DDONE) // jump to end. - - - + + // xmm4: xmm5: xmm6: + // ( ab00 ( ab01 ( ab02 + // ab10 ) ab11 ) ab12 ) + // + // xmm7: xmm8: xmm9: + // ( ab20 ( ab21 ( ab22 + // ab30 ) ab31 ) ab32 ) + // + // xmm10: xmm11: xmm12: + // ( ab40 ( ab41 ( ab42 + // ab50 ) ab51 ) ab52 ) + // + // xmm13: xmm14: xmm15: + // ( ab60 ( ab61 ( ab62 + // ab70 ) ab71 ) ab72 ) + + vfmadd231pd(mem(rcx, 0*16), xmm2, xmm4) + vfmadd231pd(mem(rcx, 1*16), xmm2, xmm7) + vfmadd231pd(mem(rcx, 2*16), xmm2, xmm10) + vfmadd231pd(mem(rcx, 3*16), xmm2, xmm13) + + vfmadd231pd(mem(r10, 0*16), xmm2, xmm5) + vfmadd231pd(mem(r10, 1*16), xmm2, xmm8) + vfmadd231pd(mem(r10, 2*16), xmm2, xmm11) + vfmadd231pd(mem(r10, 3*16), xmm2, xmm14) + + vfmadd231pd(mem(r11, 0*16), xmm2, xmm6) + vfmadd231pd(mem(r11, 1*16), xmm2, xmm9) + vfmadd231pd(mem(r11, 2*16), xmm2, xmm12) + vfmadd231pd(mem(r11, 3*16), xmm2, xmm15) + + // fall through + label(.DBETAZERO) - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.DCOLSTORBZ) // jump to column storage case - - - - label(.DGENSTORBZ) - - - vmovlpd(xmm4, mem(rcx)) - vmovhpd(xmm4, mem(rcx, rsi, 1)) - vmovlpd(xmm7, mem(rcx, r12, 1)) - vmovhpd(xmm7, mem(rcx, r13, 1)) - add(rdi, rcx) - vmovlpd(xmm10, mem(rdx)) - vmovhpd(xmm10, mem(rdx, rsi, 1)) - vmovlpd(xmm13, mem(rdx, r12, 1)) - vmovhpd(xmm13, mem(rdx, r13, 1)) - add(rdi, rdx) - - vmovlpd(xmm5, mem(rcx)) - vmovhpd(xmm5, mem(rcx, rsi, 1)) - vmovlpd(xmm8, mem(rcx, r12, 1)) - vmovhpd(xmm8, mem(rcx, r13, 1)) - add(rdi, rcx) - vmovlpd(xmm11, mem(rdx)) - vmovhpd(xmm11, mem(rdx, rsi, 1)) - vmovlpd(xmm14, mem(rdx, r12, 1)) - vmovhpd(xmm14, mem(rdx, r13, 1)) - add(rdi, rdx) - - vmovlpd(xmm6, mem(rcx)) - vmovhpd(xmm6, mem(rcx, rsi, 1)) - vmovlpd(xmm9, mem(rcx, r12, 1)) - vmovhpd(xmm9, mem(rcx, r13, 1)) - add(rdi, rcx) - vmovlpd(xmm12, mem(rdx)) - vmovhpd(xmm12, mem(rdx, rsi, 1)) - vmovlpd(xmm15, mem(rdx, r12, 1)) - vmovhpd(xmm15, mem(rdx, r13, 1)) - add(rdi, rdx) - - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORBZ) - - - vmovupd(xmm4, mem(rcx)) - vmovupd(xmm7, mem(rcx, r12, 1)) - add(rdi, rcx) - vmovupd(xmm10, mem(rdx)) - vmovupd(xmm13, mem(rdx, r12, 1)) - add(rdi, rdx) - - vmovupd(xmm5, mem(rcx)) - vmovupd(xmm8, mem(rcx, r12, 1)) - add(rdi, rcx) - vmovupd(xmm11, mem(rdx)) - vmovupd(xmm14, mem(rdx, r12, 1)) - add(rdi, rdx) - - vmovupd(xmm6, mem(rcx)) - vmovupd(xmm9, mem(rcx, r12, 1)) - add(rdi, rcx) - vmovupd(xmm12, mem(rdx)) - vmovupd(xmm15, mem(rdx, r12, 1)) - add(rdi, rdx) - - - - - + + vmovups(xmm4, mem(rcx, 0*16)) + vmovups(xmm7, mem(rcx, 1*16)) + vmovups(xmm10, mem(rcx, 2*16)) + vmovups(xmm13, mem(rcx, 3*16)) + + vmovups(xmm5, mem(r10, 0*16)) + vmovups(xmm8, mem(r10, 1*16)) + vmovups(xmm11, mem(r10, 2*16)) + vmovups(xmm14, mem(r10, 3*16)) + + vmovups(xmm6, mem(r11, 0*16)) + vmovups(xmm9, mem(r11, 1*16)) + vmovups(xmm12, mem(r11, 2*16)) + vmovups(xmm15, mem(r11, 3*16)) + label(.DDONE) - + end_asm( : // output operands (none) @@ -1620,7 +965,7 @@ void bli_dgemm_piledriver_asm_8x3 [b_next] "m" (b_next), // 9 [a_next] "m" (a_next) // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1628,11 +973,15 @@ void bli_dgemm_piledriver_asm_8x3 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } void bli_cgemm_piledriver_asm_4x2 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -1647,28 +996,30 @@ void bli_cgemm_piledriver_asm_4x2 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 8; - uint64_t k_left = k0 % 8; + uint64_t k_iter = k / 8; + uint64_t k_left = k % 8; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( c, 4, 2, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. mov(var(a_next), r14) // load address of a_next. - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(scomplex) lea(mem(rcx, rdi, 1), r10) // load address of c + 1*cs_c; - + add(imm(32*4), rax) add(imm(16*4), rbx) - - + + vxorps(xmm8, xmm8, xmm8) vxorps(xmm9, xmm9, xmm9) vxorps(xmm10, xmm10, xmm10) @@ -1678,24 +1029,24 @@ void bli_cgemm_piledriver_asm_4x2 vxorps(xmm14, xmm14, xmm14) vxorps(xmm15, xmm15, xmm15) //vzeroall() - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.CCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.CLOOPKITER) // MAIN LOOP - - + + je(.CCONSIDKLEFT) // if i == 0, jump to k_left code. - - + + prefetch(0, mem(rbx, 256)) prefetch(0, mem(rax, 512)) - + // iteration 0 vmovaps(mem(rax, -32*4), xmm0) vbroadcastss(mem(rbx, -16*4), xmm4) @@ -1711,7 +1062,7 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, -13*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + // iteration 1 vmovaps(mem(rax, -24*4), xmm0) vbroadcastss(mem(rbx, -12*4), xmm4) @@ -1727,10 +1078,10 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, -9*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 64+256)) prefetch(0, mem(rax, 64+512)) - + // iteration 2 vmovaps(mem(rax, -16*4), xmm0) vbroadcastss(mem(rbx, -8*4), xmm4) @@ -1746,7 +1097,7 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, -5*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + // iteration 3 vmovaps(mem(rax, -8*4), xmm0) vbroadcastss(mem(rbx, -4*4), xmm4) @@ -1762,10 +1113,10 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, -1*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 128+256)) prefetch(0, mem(rax, 128+512)) - + // iteration 4 vmovaps(mem(rax, 0*4), xmm0) vbroadcastss(mem(rbx, 0*4), xmm4) @@ -1781,7 +1132,7 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, 3*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + // iteration 5 vmovaps(mem(rax, 8*4), xmm0) vbroadcastss(mem(rbx, 4*4), xmm4) @@ -1797,10 +1148,10 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, 7*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 128+256)) prefetch(0, mem(rax, 128+512)) - + // iteration 6 vmovaps(mem(rax, 16*4), xmm0) vbroadcastss(mem(rbx, 8*4), xmm4) @@ -1816,7 +1167,7 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, 11*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - + // iteration 7 vmovaps(mem(rax, 24*4), xmm0) vbroadcastss(mem(rbx, 12*4), xmm4) @@ -1834,33 +1185,33 @@ void bli_cgemm_piledriver_asm_4x2 add(imm(8*2*8), rbx) // b += 8*2 (unroll x nr) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - - - + + + dec(rsi) // i -= 1; jmp(.CLOOPKITER) // jump to beginning of loop. - - - - - - + + + + + + label(.CCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.CPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.CLOOPKLEFT) // EDGE LOOP - - + + je(.CPOSTACCUM) // if i == 0, we're done. - + prefetch(0, mem(rbx, 256)) prefetch(0, mem(rax, 512)) - + // iteration 0 vmovaps(mem(rax, -32*4), xmm0) vbroadcastss(mem(rbx, -16*4), xmm4) @@ -1876,123 +1227,88 @@ void bli_cgemm_piledriver_asm_4x2 vbroadcastss(mem(rbx, -13*4), xmm7) vfmadd231ps(xmm0, xmm7, xmm11) vfmadd231ps(xmm1, xmm7, xmm15) - - + + add(imm(1*4*8), rax) // a += 1*2 (1 x mr) add(imm(1*2*8), rbx) // b += 1*2 (1 x nr) - - + + dec(rsi) // i -= 1; jmp(.CLOOPKLEFT) // jump to beginning of loop. - - - + + + label(.CPOSTACCUM) - - + + prefetchw0(mem(rcx, 0*8)) // prefetch c + 0*cs_c prefetchw0(mem(r10, 0*8)) // prefetch c + 1*cs_c - - + + vpermilps(imm(0xb1), xmm9, xmm9) vpermilps(imm(0xb1), xmm11, xmm11) vpermilps(imm(0xb1), xmm13, xmm13) vpermilps(imm(0xb1), xmm15, xmm15) - + vaddsubps(xmm9, xmm8, xmm8) vaddsubps(xmm11, xmm10, xmm10) vaddsubps(xmm13, xmm12, xmm12) vaddsubps(xmm15, xmm14, xmm14) - - + + // xmm8: xmm10: // ( ab00 ( ab01 // ab10 ab11 // ab20 ab21 // ab30 ) ab31 ) - + // xmm12: xmm14: // ( ab40 ( ab41 // ab50 ab51 // ab60 ab61 // ab70 ) ab71 ) - - + + prefetch(0, mem(r14)) // prefetch a_next prefetch(0, mem(r14, 64)) // prefetch a_next - - + + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vbroadcastss(mem(rax), xmm0) // load alpha_r and duplicate vbroadcastss(mem(rax, 4), xmm1) // load alpha_i and duplicate - + vpermilps(imm(0xb1), xmm8, xmm9) vpermilps(imm(0xb1), xmm10, xmm11) vpermilps(imm(0xb1), xmm12, xmm13) vpermilps(imm(0xb1), xmm14, xmm15) - + vmulps(xmm8, xmm0, xmm8) vmulps(xmm10, xmm0, xmm10) vmulps(xmm12, xmm0, xmm12) vmulps(xmm14, xmm0, xmm14) - + vmulps(xmm9, xmm1, xmm9) vmulps(xmm11, xmm1, xmm11) vmulps(xmm13, xmm1, xmm13) vmulps(xmm15, xmm1, xmm15) - + vaddsubps(xmm9, xmm8, xmm8) vaddsubps(xmm11, xmm10, xmm10) vaddsubps(xmm13, xmm12, xmm12) vaddsubps(xmm15, xmm14, xmm14) - - - - + + mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rbx), xmm6) // load beta_r and duplicate vbroadcastss(mem(rbx, 4), xmm7) // load beta_i and duplicate - - - - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(scomplex) - - - lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; - lea(mem(rsi, rsi, 2), r13) // r13 = 3*rs_c; - - - + prefetch(0, mem(r15)) // prefetch b_next prefetch(0, mem(r15, 64)) // prefetch b_next - - - - // determine if - // c % 32 == 0, AND - // 8*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(8), rsi) // set ZF if (8*rs_c) == 8. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (8*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + // now avoid loading C if beta == 0 - + vxorps(xmm0, xmm0, xmm0) // set xmm0 to zero. vucomiss(xmm0, xmm6) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -2000,158 +1316,49 @@ void bli_cgemm_piledriver_asm_4x2 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 0, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.CCOLSTORED) // jump to column storage case - - - - label(.CGENSTORED) - - - vmovlps(mem(rcx), xmm0, xmm0) // load c00:c10 - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm2, xmm2) // load c20:c30 - vmovhps(mem(rcx, r13, 1), xmm2, xmm2) - vpermilps(imm(0xb1), xmm0, xmm1) - vpermilps(imm(0xb1), xmm2, xmm3) - - vmulps(xmm6, xmm0, xmm0) - vmulps(xmm7, xmm1, xmm1) - vaddsubps(xmm1, xmm0, xmm0) - vaddps(xmm8, xmm0, xmm0) - vmovlps(xmm0, mem(rcx)) // store c00:c10 - vmovhps(xmm0, mem(rcx, rsi, 1)) - - vmulps(xmm6, xmm2, xmm2) - vmulps(xmm7, xmm3, xmm3) - vaddsubps(xmm3, xmm2, xmm2) - vaddps(xmm12, xmm2, xmm2) - vmovlps(xmm2, mem(rcx, r12, 1)) // store c20:c30 - vmovhps(xmm2, mem(rcx, r13, 1)) - - - - vmovlps(mem(r10), xmm0, xmm0) // load c01:c11 - vmovhps(mem(r10, rsi, 1), xmm0, xmm0) - vmovlps(mem(r10, r12, 1), xmm2, xmm2) // load c21:c31 - vmovhps(mem(r10, r13, 1), xmm2, xmm2) - vpermilps(imm(0xb1), xmm0, xmm1) - vpermilps(imm(0xb1), xmm2, xmm3) - - vmulps(xmm6, xmm0, xmm0) - vmulps(xmm7, xmm1, xmm1) - vaddsubps(xmm1, xmm0, xmm0) - vaddps(xmm10, xmm0, xmm0) - vmovlps(xmm0, mem(r10)) // store c01:c11 - vmovhps(xmm0, mem(r10, rsi, 1)) - - vmulps(xmm6, xmm2, xmm2) - vmulps(xmm7, xmm3, xmm3) - vaddsubps(xmm3, xmm2, xmm2) - vaddps(xmm14, xmm2, xmm2) - vmovlps(xmm2, mem(r10, r12, 1)) // store c21:c31 - vmovhps(xmm2, mem(r10, r13, 1)) - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORED) - - - vmovups(mem(rcx), xmm0) // load c00:c10 - vmovups(mem(rcx, 16), xmm2) // load c20:c30 - vpermilps(imm(0xb1), xmm0, xmm1) - vpermilps(imm(0xb1), xmm2, xmm3) - - vmulps(xmm6, xmm0, xmm0) - vmulps(xmm7, xmm1, xmm1) - vaddsubps(xmm1, xmm0, xmm0) - vaddps(xmm8, xmm0, xmm0) - vmovups(xmm0, mem(rcx)) // store c00:c10 - - vmulps(xmm6, xmm2, xmm2) - vmulps(xmm7, xmm3, xmm3) - vaddsubps(xmm3, xmm2, xmm2) - vaddps(xmm12, xmm2, xmm2) - vmovups(xmm2, mem(rcx, 16)) // store c20:c30 - - - - vmovups(mem(r10), xmm0) // load c01:c11 - vmovups(mem(r10, 16), xmm2) // load c21:c31 - vpermilps(imm(0xb1), xmm0, xmm1) - vpermilps(imm(0xb1), xmm2, xmm3) - - vmulps(xmm6, xmm0, xmm0) - vmulps(xmm7, xmm1, xmm1) - vaddsubps(xmm1, xmm0, xmm0) - vaddps(xmm10, xmm0, xmm0) - vmovups(xmm0, mem(r10)) // store c01:c11 - - vmulps(xmm6, xmm2, xmm2) - vmulps(xmm7, xmm3, xmm3) - vaddsubps(xmm3, xmm2, xmm2) - vaddps(xmm14, xmm2, xmm2) - vmovups(xmm2, mem(r10, 16)) // store c21:c31 - - - - jmp(.CDONE) // jump to end. - - - + + vmovups(mem(rcx), xmm0) // load c00:c10 + vmovups(mem(rcx, 16), xmm2) // load c20:c30 + vpermilps(imm(0xb1), xmm0, xmm1) + vpermilps(imm(0xb1), xmm2, xmm3) + + vmulps(xmm6, xmm0, xmm0) + vmulps(xmm7, xmm1, xmm1) + vaddsubps(xmm1, xmm0, xmm0) + vaddps(xmm8, xmm0, xmm0) + + vmulps(xmm6, xmm2, xmm2) + vmulps(xmm7, xmm3, xmm3) + vaddsubps(xmm3, xmm2, xmm2) + vaddps(xmm12, xmm2, xmm2) + + vmovups(mem(r10), xmm0) // load c01:c11 + vmovups(mem(r10, 16), xmm2) // load c21:c31 + vpermilps(imm(0xb1), xmm0, xmm1) + vpermilps(imm(0xb1), xmm2, xmm3) + + vmulps(xmm6, xmm0, xmm0) + vmulps(xmm7, xmm1, xmm1) + vaddsubps(xmm1, xmm0, xmm0) + vaddps(xmm10, xmm0, xmm0) + + vmulps(xmm6, xmm2, xmm2) + vmulps(xmm7, xmm3, xmm3) + vaddsubps(xmm3, xmm2, xmm2) + vaddps(xmm14, xmm2, xmm2) + + // fall through + label(.CBETAZERO) - // check if aligned/column-stored - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.CCOLSTORBZ) // jump to column storage case - - - - label(.CGENSTORBZ) - - - vmovlps(xmm8, mem(rcx)) // store c00:c10 - vmovhps(xmm8, mem(rcx, rsi, 1)) - - vmovlps(xmm12, mem(rcx, r12, 1)) // store c20:c30 - vmovhps(xmm12, mem(rcx, r13, 1)) - - vmovlps(xmm10, mem(r10)) // store c01:c11 - vmovhps(xmm10, mem(r10, rsi, 1)) - - vmovlps(xmm14, mem(r10, r12, 1)) // store c21:c31 - vmovhps(xmm14, mem(r10, r13, 1)) - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORBZ) - - - vmovups(xmm8, mem(rcx)) // store c00:c10 - vmovups(xmm12, mem(rcx, 16)) // store c20:c30 - - vmovups(xmm10, mem(r10)) // store c01:c11 - vmovups(xmm14, mem(r10, 16)) // store c21:c31 - - - - - + + vmovups(xmm8, mem(rcx)) // store c00:c10 + vmovups(xmm12, mem(rcx, 16)) // store c20:c30 + + vmovups(xmm10, mem(r10)) // store c01:c11 + vmovups(xmm14, mem(r10, 16)) // store c21:c31 + label(.CDONE) - + end_asm( : // output operands (none) @@ -2168,7 +1375,7 @@ void bli_cgemm_piledriver_asm_4x2 [b_next] "m" (b_next), // 9 [a_next] "m" (a_next) // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2176,11 +1383,15 @@ void bli_cgemm_piledriver_asm_4x2 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( c ); } void bli_zgemm_piledriver_asm_2x2 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -2195,28 +1406,30 @@ void bli_zgemm_piledriver_asm_2x2 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 8; - uint64_t k_left = k0 % 8; + uint64_t k_iter = k / 8; + uint64_t k_left = k % 8; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( z, 2, 2, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. mov(var(a_next), r14) // load address of a_next. - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(dcomplex) lea(mem(, rdi, 2), rdi) lea(mem(rcx, rdi, 1), r10) // load address of c + 1*cs_c; - + add(imm(16*8), rax) add(imm(16*8), rbx) - + vxorpd(xmm8, xmm8, xmm8) vxorpd(xmm9, xmm9, xmm9) vxorpd(xmm10, xmm10, xmm10) @@ -2225,25 +1438,25 @@ void bli_zgemm_piledriver_asm_2x2 vxorpd(xmm13, xmm13, xmm13) vxorpd(xmm14, xmm14, xmm14) vxorpd(xmm15, xmm15, xmm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.ZCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.ZLOOPKITER) // MAIN LOOP - - + + je(.ZCONSIDKLEFT) // if i == 0, jump to k_left code. - - + + prefetch(0, mem(rbx, 256)) - + prefetch(0, mem(rax, 512)) - + // iteration 0 vmovaps(mem(rax, -16*8), xmm0) vmovddup(mem(rbx, -16*8), xmm4) @@ -2261,7 +1474,7 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, -12*8), xmm0) vmovddup(mem(rbx, -12*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + // iteration 1 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, -10*8), xmm1) @@ -2277,11 +1490,11 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, -8*8), xmm0) vmovddup(mem(rbx, -8*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 64+256)) - + prefetch(0, mem(rax, 64+512)) - + // iteration 2 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, -6*8), xmm1) @@ -2297,7 +1510,7 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, -4*8), xmm0) vmovddup(mem(rbx, -4*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + // iteration 3 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, -2*8), xmm1) @@ -2313,11 +1526,11 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, 0*8), xmm0) vmovddup(mem(rbx, 0*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 128+256)) - + prefetch(0, mem(rax, 128+512)) - + // iteration 4 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, 2*8), xmm1) @@ -2333,7 +1546,7 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, 4*8), xmm0) vmovddup(mem(rbx, 4*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + // iteration 5 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, 6*8), xmm1) @@ -2349,11 +1562,11 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, 8*8), xmm0) vmovddup(mem(rbx, 8*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + prefetch(0, mem(rbx, 128+256)) - + prefetch(0, mem(rax, 128+512)) - + // iteration 6 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, 10*8), xmm1) @@ -2369,7 +1582,7 @@ void bli_zgemm_piledriver_asm_2x2 vmovaps(mem(rax, 12*8), xmm0) vmovddup(mem(rbx, 12*8), xmm4) vfmadd231pd(xmm1, xmm7, xmm15) - + // iteration 7 vfmadd231pd(xmm0, xmm4, xmm8) vmovaps(mem(rax, 14*8), xmm1) @@ -2385,34 +1598,34 @@ void bli_zgemm_piledriver_asm_2x2 add(imm(8*2*16), rbx) // b += 8*2 (unroll x nr) vfmadd231pd(xmm0, xmm7, xmm11) vfmadd231pd(xmm1, xmm7, xmm15) - - - + + + dec(rsi) // i -= 1; jmp(.ZLOOPKITER) // jump to beginning of loop. - - - - - - + + + + + + label(.ZCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.ZLOOPKLEFT) // EDGE LOOP - - + + je(.ZPOSTACCUM) // if i == 0, we're done. - + prefetch(0, mem(rbx, 256)) - + prefetch(0, mem(rax, 512)) - + // iteration 0 vmovaps(mem(rax, -16*8), xmm0) vmovddup(mem(rbx, -16*8), xmm4) @@ -2428,119 +1641,86 @@ void bli_zgemm_piledriver_asm_2x2 vmovddup(mem(rbx, -13*8), xmm7) vfmadd231pd(xmm0, xmm7, xmm11) vfmadd231pd(xmm1, xmm7, xmm15) - - + + add(imm(1*2*16), rax) // a += 1*2 (1 x mr) add(imm(1*2*16), rbx) // b += 1*2 (1 x nr) - - + + dec(rsi) // i -= 1; jmp(.ZLOOPKLEFT) // jump to beginning of loop. - - - + + + label(.ZPOSTACCUM) - - + + prefetchw0(mem(rcx, 0*8)) // prefetch c + 0*cs_c prefetchw0(mem(r10, 0*8)) // prefetch c + 1*cs_c - - + + vpermilpd(imm(0x1), xmm9, xmm9) vpermilpd(imm(0x1), xmm11, xmm11) vpermilpd(imm(0x1), xmm13, xmm13) vpermilpd(imm(0x1), xmm15, xmm15) - + vaddsubpd(xmm9, xmm8, xmm8) vaddsubpd(xmm11, xmm10, xmm10) vaddsubpd(xmm13, xmm12, xmm12) vaddsubpd(xmm15, xmm14, xmm14) - - + + // xmm8: xmm10: // ( ab00 ( ab01 // ab10 ) ab11 ) - + // xmm12: xmm14: // ( ab20 ( ab21 // ab30 ) ab31 ) - - + + prefetch(0, mem(r14)) // prefetch a_next prefetch(0, mem(r14, 64)) // prefetch a_next - - + + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vmovddup(mem(rax), xmm0) // load alpha_r and duplicate vmovddup(mem(rax, 8), xmm1) // load alpha_i and duplicate - + vpermilpd(imm(0x1), xmm8, xmm9) vpermilpd(imm(0x1), xmm10, xmm11) vpermilpd(imm(0x1), xmm12, xmm13) vpermilpd(imm(0x1), xmm14, xmm15) - + vmulpd(xmm8, xmm0, xmm8) vmulpd(xmm10, xmm0, xmm10) vmulpd(xmm12, xmm0, xmm12) vmulpd(xmm14, xmm0, xmm14) - + vmulpd(xmm9, xmm1, xmm9) vmulpd(xmm11, xmm1, xmm11) vmulpd(xmm13, xmm1, xmm13) vmulpd(xmm15, xmm1, xmm15) - + vaddsubpd(xmm9, xmm8, xmm8) vaddsubpd(xmm11, xmm10, xmm10) vaddsubpd(xmm13, xmm12, xmm12) vaddsubpd(xmm15, xmm14, xmm14) - - - - + + + + mov(var(beta), rbx) // load address of beta vmovddup(mem(rbx), xmm6) // load beta_r and duplicate vmovddup(mem(rbx, 8), xmm7) // load beta_i and duplicate - - - - - - - - mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(dcomplex) - lea(mem(, rsi, 2), rsi) - //lea(mem(rcx, rsi, 2), rdx) // load address of c + 2*rs_c; - - - - - + prefetch(0, mem(r15)) // prefetch b_next prefetch(0, mem(r15, 64)) // prefetch b_next - - - - // determine if - // c % 32 == 0, AND - // 16*cs_c % 32 == 0, AND - // rs_c == 1 - // ie: aligned, ldim aligned, and - // column-stored - - cmp(imm(16), rsi) // set ZF if (16*rs_c) == 16. - sete(bl) // bl = ( ZF == 1 ? 1 : 0 ); - test(imm(31), rcx) // set ZF if c & 32 is zero. - setz(bh) // bh = ( ZF == 0 ? 1 : 0 ); - test(imm(31), rdi) // set ZF if (16*cs_c) & 32 is zero. - setz(al) // al = ( ZF == 0 ? 1 : 0 ); - // and(bl,bh) followed by - // and(bh,al) will reveal result - + // now avoid loading C if beta == 0 - + vxorpd(xmm0, xmm0, xmm0) // set xmm0 to zero. vucomisd(xmm0, xmm6) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -2548,144 +1728,49 @@ void bli_zgemm_piledriver_asm_2x2 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 0, jump to beta == 0 case - - - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.ZCOLSTORED) // jump to column storage case - - - - label(.ZGENSTORED) - - - vmovups(mem(rcx), xmm0) // load c00 - vmovups(mem(rcx, rsi, 1), xmm2) // load c10 - vpermilpd(imm(0x1), xmm0, xmm1) - vpermilpd(imm(0x1), xmm2, xmm3) - - vmulpd(xmm6, xmm0, xmm0) - vmulpd(xmm7, xmm1, xmm1) - vaddsubpd(xmm1, xmm0, xmm0) - vaddpd(xmm8, xmm0, xmm0) - vmovups(xmm0, mem(rcx)) // store c00 - - vmulpd(xmm6, xmm2, xmm2) - vmulpd(xmm7, xmm3, xmm3) - vaddsubpd(xmm3, xmm2, xmm2) - vaddpd(xmm12, xmm2, xmm2) - vmovups(xmm2, mem(rcx, rsi, 1)) // store c10 - - - - vmovups(mem(r10), xmm0) // load c01 - vmovups(mem(r10, rsi, 1), xmm2) // load c11 - vpermilpd(imm(0x1), xmm0, xmm1) - vpermilpd(imm(0x1), xmm2, xmm3) - - vmulpd(xmm6, xmm0, xmm0) - vmulpd(xmm7, xmm1, xmm1) - vaddsubpd(xmm1, xmm0, xmm0) - vaddpd(xmm10, xmm0, xmm0) - vmovups(xmm0, mem(r10)) // store c01 - - vmulpd(xmm6, xmm2, xmm2) - vmulpd(xmm7, xmm3, xmm3) - vaddsubpd(xmm3, xmm2, xmm2) - vaddpd(xmm14, xmm2, xmm2) - vmovups(xmm2, mem(r10, rsi, 1)) // store c11 - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORED) - - - vmovups(mem(rcx), xmm0) // load c00 - vmovups(mem(rcx, 16), xmm2) // load c10 - vpermilpd(imm(0x1), xmm0, xmm1) - vpermilpd(imm(0x1), xmm2, xmm3) - - vmulpd(xmm6, xmm0, xmm0) - vmulpd(xmm7, xmm1, xmm1) - vaddsubpd(xmm1, xmm0, xmm0) - vaddpd(xmm8, xmm0, xmm0) - vmovups(xmm0, mem(rcx)) // store c00 - - vmulpd(xmm6, xmm2, xmm2) - vmulpd(xmm7, xmm3, xmm3) - vaddsubpd(xmm3, xmm2, xmm2) - vaddpd(xmm12, xmm2, xmm2) - vmovups(xmm2, mem(rcx, 16)) // store c10 - - - - vmovups(mem(r10), xmm0) // load c01 - vmovups(mem(r10, 16), xmm2) // load c11 - vpermilpd(imm(0x1), xmm0, xmm1) - vpermilpd(imm(0x1), xmm2, xmm3) - - vmulpd(xmm6, xmm0, xmm0) - vmulpd(xmm7, xmm1, xmm1) - vaddsubpd(xmm1, xmm0, xmm0) - vaddpd(xmm10, xmm0, xmm0) - vmovups(xmm0, mem(r10)) // store c01 - - vmulpd(xmm6, xmm2, xmm2) - vmulpd(xmm7, xmm3, xmm3) - vaddsubpd(xmm3, xmm2, xmm2) - vaddpd(xmm14, xmm2, xmm2) - vmovups(xmm2, mem(r10, 16)) // store c11 - - - - jmp(.ZDONE) // jump to end. - - - + + vmovups(mem(rcx), xmm0) // load c00 + vmovups(mem(rcx, 16), xmm2) // load c10 + vpermilpd(imm(0x1), xmm0, xmm1) + vpermilpd(imm(0x1), xmm2, xmm3) + + vmulpd(xmm6, xmm0, xmm0) + vmulpd(xmm7, xmm1, xmm1) + vaddsubpd(xmm1, xmm0, xmm0) + vaddpd(xmm8, xmm0, xmm0) + + vmulpd(xmm6, xmm2, xmm2) + vmulpd(xmm7, xmm3, xmm3) + vaddsubpd(xmm3, xmm2, xmm2) + vaddpd(xmm12, xmm2, xmm2) + + vmovups(mem(r10), xmm0) // load c01 + vmovups(mem(r10, 16), xmm2) // load c11 + vpermilpd(imm(0x1), xmm0, xmm1) + vpermilpd(imm(0x1), xmm2, xmm3) + + vmulpd(xmm6, xmm0, xmm0) + vmulpd(xmm7, xmm1, xmm1) + vaddsubpd(xmm1, xmm0, xmm0) + vaddpd(xmm10, xmm0, xmm0) + + vmulpd(xmm6, xmm2, xmm2) + vmulpd(xmm7, xmm3, xmm3) + vaddsubpd(xmm3, xmm2, xmm2) + vaddpd(xmm14, xmm2, xmm2) + + // fall through + label(.ZBETAZERO) - // check if aligned/column-stored - // check if aligned/column-stored - and(bl, bh) // set ZF if bl & bh == 1. - and(bh, al) // set ZF if bh & al == 1. - jne(.ZCOLSTORBZ) // jump to column storage case - - - - label(.ZGENSTORBZ) - - - vmovups(xmm8, mem(rcx)) // store c00 - vmovups(xmm12, mem(rcx, rsi, 1)) // store c10 - - vmovups(xmm10, mem(r10)) // store c01 - vmovups(xmm14, mem(r10, rsi, 1)) // store c11 - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORBZ) - - - vmovups(xmm8, mem(rcx)) // store c00 - vmovups(xmm12, mem(rcx, 16)) // store c10 - - vmovups(xmm10, mem(r10)) // store c01 - vmovups(xmm14, mem(r10, 16)) // store c11 - - - - - + + vmovups(xmm8, mem(rcx)) // store c00 + vmovups(xmm12, mem(rcx, 16)) // store c10 + + vmovups(xmm10, mem(r10)) // store c01 + vmovups(xmm14, mem(r10, 16)) // store c11 + label(.ZDONE) - + end_asm( : // output operands (none) @@ -2702,7 +1787,7 @@ void bli_zgemm_piledriver_asm_2x2 [b_next] "m" (b_next), // 9 [a_next] "m" (a_next) // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2710,6 +1795,8 @@ void bli_zgemm_piledriver_asm_2x2 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/power10/3/bli_dgemm_power10_mma.c b/kernels/power10/3/bli_dgemm_power10_mma.c index 3968249869..84e7d16d34 100644 --- a/kernels/power10/3/bli_dgemm_power10_mma.c +++ b/kernels/power10/3/bli_dgemm_power10_mma.c @@ -37,7 +37,7 @@ #define D_ASSEMBLE_VEC_PAIR \ __builtin_mma_assemble_pair (&colA_1, ca[1], ca[0]); \ - __builtin_mma_assemble_pair (&colA_2, ca[3], ca[2]); + __builtin_mma_assemble_pair (&colA_2, ca[3], ca[2]); #define D_ACCUMULATE \ __builtin_mma_xvf64gerpp (&acc0, colA_1, rb[0]); \ @@ -47,7 +47,7 @@ __builtin_mma_xvf64gerpp (&acc4, colA_2, rb[0]); \ __builtin_mma_xvf64gerpp (&acc5, colA_2, rb[1]); \ __builtin_mma_xvf64gerpp (&acc6, colA_2, rb[2]); \ - __builtin_mma_xvf64gerpp (&acc7, colA_2, rb[3]); + __builtin_mma_xvf64gerpp (&acc7, colA_2, rb[3]); #define D_INCREMENT \ A0+=8; \ @@ -57,17 +57,19 @@ LOAD_VECTORS \ D_ASSEMBLE_VEC_PAIR \ D_INCREMENT \ - D_ACCUMULATE + D_ACCUMULATE void bli_dgemm_power10_mma_8x8 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict c, inc_t rs_c0, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -76,11 +78,13 @@ void bli_dgemm_power10_mma_8x8 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. // (1 is subtracted from k0 because 1 iteration of the k loop is pulled out) - uint64_t k_iter = (k0-1) / 4; - uint64_t k_left = (k0-1) % 4; + uint64_t k_iter = (k-1) / 4; + uint64_t k_left = (k-1) % 4; uint64_t rs_c = rs_c0; + GEMM_UKR_SETUP_CT( d, 8, 8, true ); + double* restrict A0 = a; double* restrict B0 = b; double* restrict C0 = c; @@ -92,23 +96,23 @@ void bli_dgemm_power10_mma_8x8 dv4sf_t *rowC; /* 8 accumulator registers that will be used to store the result. - + Each accumulator register is mapped to 4 vector registers. Illustration: - + acc0 = [ vs0 vs1 vs3 vs4 ] - These registers are used to store the result of an outer product + These registers are used to store the result of an outer product instruction (general outer product instruction syntax: xv???ger??). */ - __vector_quad acc0, acc1, acc2, acc3, + __vector_quad acc0, acc1, acc2, acc3, acc4, acc5, acc6, acc7; - /* 2 vector pairs are necessary for a double precision outer product + /* 2 vector pairs are necessary for a double precision outer product instruction. */ - __vector_pair colA_1, + __vector_pair colA_1, colA_2; /* Prefetch C so that it stays in cache */ @@ -123,17 +127,17 @@ void bli_dgemm_power10_mma_8x8 /* Load elements into vector registers */ vec_t *ca = (vec_t *) A0; - vec_t *rb = (vec_t *) B0; + vec_t *rb = (vec_t *) B0; - /* Each accumulator represents a matrix of size + /* Each accumulator represents a matrix of size 4 x ( 16 / (datatype size in bytes) ) (vector register size = 16B) - Thus in the case of double, the accumulate registers represent a 4x2 + Thus in the case of double, the accumulate registers represent a 4x2 matrix. However, a vector register can hold at most 2 doubles. Thus, if - we performed an outer product using 2 vector register, we can only get a + we performed an outer product using 2 vector register, we can only get a 2x2 matrix. Therefore, we must create a vector register pair in order to get the desired 4x2 matrix. - + */ D_ASSEMBLE_VEC_PAIR @@ -158,7 +162,7 @@ void bli_dgemm_power10_mma_8x8 D_AB_PRODUCT D_AB_PRODUCT } - + // edge loop for (int k = 0; k 0; kk--) { + vector double va00 = vec_splats( *(double *)( pa+0 ) ); + vector double va10 = vec_splats( *(double *)( pa+d1 ) ); + vector double va20 = vec_splats( *(double *)( pa+d2 ) ); + vector double va30 = vec_splats( *(double *)( pa+d3 ) ); + vector double va40 = vec_splats( *(double *)( pa+d4 ) ); + vector double va50 = vec_splats( *(double *)( pa+d5 ) ); + vector double va60 = vec_splats( *(double *)( pa+d6 ) ); + vector double va70 = vec_splats( *(double *)( pa+d7 ) ); + pa += 8*sizeof(double); + + vector double vb00_01 = *(vector double *)( pb+0 ); + vector double vb02_03 = *(vector double *)( pb+d2 ); + pb += 4*sizeof(double); + + vc00_01 = vec_madd(va00, vb00_01, vc00_01); + vc02_03 = vec_madd(va00, vb02_03, vc02_03); + vc10_11 = vec_madd(va10, vb00_01, vc10_11); + vc12_13 = vec_madd(va10, vb02_03, vc12_13); + vc20_21 = vec_madd(va20, vb00_01, vc20_21); + vc22_23 = vec_madd(va20, vb02_03, vc22_23); + vc30_31 = vec_madd(va30, vb00_01, vc30_31); + vc32_33 = vec_madd(va30, vb02_03, vc32_33); + vc40_41 = vec_madd(va40, vb00_01, vc40_41); + vc42_43 = vec_madd(va40, vb02_03, vc42_43); + vc50_51 = vec_madd(va50, vb00_01, vc50_51); + vc52_53 = vec_madd(va50, vb02_03, vc52_53); + vc60_61 = vec_madd(va60, vb00_01, vc60_61); + vc62_63 = vec_madd(va60, vb02_03, vc62_63); + vc70_71 = vec_madd(va70, vb00_01, vc70_71); + vc72_73 = vec_madd(va70, vb02_03, vc72_73); + } + + vector double valpha = vec_splats( *alpha ); + vector double vbeta = (vector double) { *beta, *beta }; + + vector double *pc = (vector double *)c; + + vc00_01 = vec_mul(valpha, vc00_01); + vc02_03 = vec_mul(valpha, vc02_03); + pc[0] = vec_madd( pc[0], vbeta, vc00_01); + pc[1] = vec_madd( pc[1], vbeta, vc02_03); + pc += rs_c/2; + + vc10_11 = vec_mul(valpha, vc10_11); + vc12_13 = vec_mul(valpha, vc12_13); + pc[0] = vec_madd( pc[0], vbeta, vc10_11); + pc[1] = vec_madd( pc[1], vbeta, vc12_13); + pc += rs_c/2; + + vc20_21 = vec_mul(valpha, vc20_21); + vc22_23 = vec_mul(valpha, vc22_23); + pc[0] = vec_madd( pc[0], vbeta, vc20_21); + pc[1] = vec_madd( pc[1], vbeta, vc22_23); + pc += rs_c/2; + + vc30_31 = vec_mul(valpha, vc30_31); + vc32_33 = vec_mul(valpha, vc32_33); + pc[0] = vec_madd( pc[0], vbeta, vc30_31); + pc[1] = vec_madd( pc[1], vbeta, vc32_33); + pc += rs_c/2; + + vc40_41 = vec_mul(valpha, vc40_41); + vc42_43 = vec_mul(valpha, vc42_43); + pc[0] = vec_madd( pc[0], vbeta, vc40_41); + pc[1] = vec_madd( pc[1], vbeta, vc42_43); + pc += rs_c/2; + + vc50_51 = vec_mul(valpha, vc50_51); + vc52_53 = vec_mul(valpha, vc52_53); + pc[0] = vec_madd( pc[0], vbeta, vc50_51); + pc[1] = vec_madd( pc[1], vbeta, vc52_53); + pc += rs_c/2; + + vc60_61 = vec_mul(valpha, vc60_61); + vc62_63 = vec_mul(valpha, vc62_63); + pc[0] = vec_madd( pc[0], vbeta, vc60_61); + pc[1] = vec_madd( pc[1], vbeta, vc62_63); + pc += rs_c/2; + + vc70_71 = vec_mul(valpha, vc70_71); + vc72_73 = vec_mul(valpha, vc72_73); + pc[0] = vec_madd( pc[0], vbeta, vc70_71); + pc[1] = vec_madd( pc[1], vbeta, vc72_73); + pc += rs_c/2; + } + else + { + GEMM_UKR_SETUP_CT( d, 8, 4, false ); + // Optimized code for case where C columns are contiguous (column-major C) vector double vzero = vec_splats( 0.0 ); @@ -301,168 +433,8 @@ void bli_dgemm_power7_int_8x4 pc[1] = vec_madd( pc[1], vbeta, vc23_33); pc[2] = vec_madd( pc[2], vbeta, vc43_53); pc[3] = vec_madd( pc[3], vbeta, vc63_73); - } - else -#endif -#if 1 - if ( cs_c == 1 ) { - // Optimized code for case where C rows are contiguous (i.e. C is row-major) - - vector double vzero = vec_splats( 0.0 ); - - vector double vc00_01 = vzero; - vector double vc02_03 = vzero; - vector double vc10_11 = vzero; - vector double vc12_13 = vzero; - vector double vc20_21 = vzero; - vector double vc22_23 = vzero; - vector double vc30_31 = vzero; - vector double vc32_33 = vzero; - vector double vc40_41 = vzero; - vector double vc42_43 = vzero; - vector double vc50_51 = vzero; - vector double vc52_53 = vzero; - vector double vc60_61 = vzero; - vector double vc62_63 = vzero; - vector double vc70_71 = vzero; - vector double vc72_73 = vzero; - - unsigned long long pa = (unsigned long long)a; - unsigned long long pb = (unsigned long long)b; - -#if 0 - unsigned long long d1 = 1*sizeof(double); - unsigned long long d2 = 2*sizeof(double); - unsigned long long d3 = 3*sizeof(double); - unsigned long long d4 = 4*sizeof(double); - unsigned long long d6 = 6*sizeof(double); -#else - // ppc64 linux abi: r14-r31 Nonvolatile registers used for local variables - register unsigned long long d1 __asm ("r21") = 1*sizeof(double); - register unsigned long long d2 __asm ("r22") = 2*sizeof(double); - register unsigned long long d3 __asm ("r23") = 3*sizeof(double); - register unsigned long long d4 __asm ("r24") = 4*sizeof(double); - register unsigned long long d5 __asm ("r25") = 5*sizeof(double); - register unsigned long long d6 __asm ("r26") = 6*sizeof(double); - register unsigned long long d7 __asm ("r27") = 7*sizeof(double); - - __asm__ volatile (";" : "=r" (d1) : "r" (d1) ); - __asm__ volatile (";" : "=r" (d2) : "r" (d2) ); - __asm__ volatile (";" : "=r" (d3) : "r" (d3) ); - __asm__ volatile (";" : "=r" (d4) : "r" (d4) ); - __asm__ volatile (";" : "=r" (d5) : "r" (d5) ); - __asm__ volatile (";" : "=r" (d6) : "r" (d6) ); - __asm__ volatile (";" : "=r" (d7) : "r" (d7) ); -#endif - - int kk; - for (kk=k; kk > 0; kk--) { - vector double va00 = vec_splats( *(double *)( pa+0 ) ); - vector double va10 = vec_splats( *(double *)( pa+d1 ) ); - vector double va20 = vec_splats( *(double *)( pa+d2 ) ); - vector double va30 = vec_splats( *(double *)( pa+d3 ) ); - vector double va40 = vec_splats( *(double *)( pa+d4 ) ); - vector double va50 = vec_splats( *(double *)( pa+d5 ) ); - vector double va60 = vec_splats( *(double *)( pa+d6 ) ); - vector double va70 = vec_splats( *(double *)( pa+d7 ) ); - pa += 8*sizeof(double); - - vector double vb00_01 = *(vector double *)( pb+0 ); - vector double vb02_03 = *(vector double *)( pb+d2 ); - pb += 4*sizeof(double); - - vc00_01 = vec_madd(va00, vb00_01, vc00_01); - vc02_03 = vec_madd(va00, vb02_03, vc02_03); - vc10_11 = vec_madd(va10, vb00_01, vc10_11); - vc12_13 = vec_madd(va10, vb02_03, vc12_13); - vc20_21 = vec_madd(va20, vb00_01, vc20_21); - vc22_23 = vec_madd(va20, vb02_03, vc22_23); - vc30_31 = vec_madd(va30, vb00_01, vc30_31); - vc32_33 = vec_madd(va30, vb02_03, vc32_33); - vc40_41 = vec_madd(va40, vb00_01, vc40_41); - vc42_43 = vec_madd(va40, vb02_03, vc42_43); - vc50_51 = vec_madd(va50, vb00_01, vc50_51); - vc52_53 = vec_madd(va50, vb02_03, vc52_53); - vc60_61 = vec_madd(va60, vb00_01, vc60_61); - vc62_63 = vec_madd(va60, vb02_03, vc62_63); - vc70_71 = vec_madd(va70, vb00_01, vc70_71); - vc72_73 = vec_madd(va70, vb02_03, vc72_73); - } - - vector double valpha = vec_splats( *alpha ); - vector double vbeta = (vector double) { *beta, *beta }; - - vector double *pc = (vector double *)c; - - vc00_01 = vec_mul(valpha, vc00_01); - vc02_03 = vec_mul(valpha, vc02_03); - pc[0] = vec_madd( pc[0], vbeta, vc00_01); - pc[1] = vec_madd( pc[1], vbeta, vc02_03); - pc += rs_c/2; - - vc10_11 = vec_mul(valpha, vc10_11); - vc12_13 = vec_mul(valpha, vc12_13); - pc[0] = vec_madd( pc[0], vbeta, vc10_11); - pc[1] = vec_madd( pc[1], vbeta, vc12_13); - pc += rs_c/2; - - vc20_21 = vec_mul(valpha, vc20_21); - vc22_23 = vec_mul(valpha, vc22_23); - pc[0] = vec_madd( pc[0], vbeta, vc20_21); - pc[1] = vec_madd( pc[1], vbeta, vc22_23); - pc += rs_c/2; - - vc30_31 = vec_mul(valpha, vc30_31); - vc32_33 = vec_mul(valpha, vc32_33); - pc[0] = vec_madd( pc[0], vbeta, vc30_31); - pc[1] = vec_madd( pc[1], vbeta, vc32_33); - pc += rs_c/2; - - vc40_41 = vec_mul(valpha, vc40_41); - vc42_43 = vec_mul(valpha, vc42_43); - pc[0] = vec_madd( pc[0], vbeta, vc40_41); - pc[1] = vec_madd( pc[1], vbeta, vc42_43); - pc += rs_c/2; - - vc50_51 = vec_mul(valpha, vc50_51); - vc52_53 = vec_mul(valpha, vc52_53); - pc[0] = vec_madd( pc[0], vbeta, vc50_51); - pc[1] = vec_madd( pc[1], vbeta, vc52_53); - pc += rs_c/2; - - vc60_61 = vec_mul(valpha, vc60_61); - vc62_63 = vec_mul(valpha, vc62_63); - pc[0] = vec_madd( pc[0], vbeta, vc60_61); - pc[1] = vec_madd( pc[1], vbeta, vc62_63); - pc += rs_c/2; - - vc70_71 = vec_mul(valpha, vc70_71); - vc72_73 = vec_mul(valpha, vc72_73); - pc[0] = vec_madd( pc[0], vbeta, vc70_71); - pc[1] = vec_madd( pc[1], vbeta, vc72_73); - pc += rs_c/2; - } - else -#endif - { /* General case. Just do it right. */ -#if 1 || defined(UTEST) - const long MR = BLIS_DEFAULT_MR_D, NR = BLIS_DEFAULT_NR_D; - const long LDA = MR, LDB = NR; - int i, j, kk; - double c00; - - for (i=0; i < MR; i++) { - for (j=0; j < NR; j++) { - c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta; - for (kk=0; kk < k; kk++) - c00 += *alpha * (a[COLMAJ_INDEX(i,kk,LDA)] * b[ROWMAJ_INDEX(kk,j,LDB)]); - c[BLIS_INDEX(i,j,rs_c,cs_c)] = c00; - } - } -#else - //BLIS_DGEMM_UKERNEL_REF(k, alpha, a, b, beta, c, rs_c, cs_c, data); -#endif + GEMM_UKR_FLUSH_CT( d ); } } @@ -477,30 +449,26 @@ void bli_dgemm_power7_int_8x4 */ void bli_cgemm_power7_int_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, scomplex* restrict beta, - scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + scomplex* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. - uint64_t k = k0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - #if 1 || defined(UTEST) const long MR = BLIS_DEFAULT_MR_C, NR = BLIS_DEFAULT_NR_C; const long LDA = MR, LDB = NR; int i, j, kk; scomplex c00; - for (i=0; i < MR; i++) { - for (j=0; j < NR; j++) { + for (i=0; i < m; i++) { + for (j=0; j < n; j++) { scomplex tmpc, tmpa, tmpb, tmp; //c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta; tmpc = c[BLIS_INDEX(i,j,rs_c,cs_c)]; @@ -534,30 +502,26 @@ void bli_cgemm_power7_int_8x4 */ void bli_zgemm_power7_int_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, scomplex* restrict beta, - scomplex* restrict c, inc_t rs_c0, inc_t cs_c0, + scomplex* restrict c, inc_t rs_c, inc_t cs_c, auxinfo_t* restrict data, cntx_t* restrict cntx ) { - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. - uint64_t k = k0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - #if 1 || defined(UTEST) const long MR = BLIS_DEFAULT_MR_Z, NR = BLIS_DEFAULT_NR_Z; const long LDA = MR, LDB = NR; int i, j, kk; dcomplex c00; - for (i=0; i < MR; i++) { - for (j=0; j < NR; j++) { + for (i=0; i < m; i++) { + for (j=0; j < n; j++) { dcomplex tmpc, tmpa, tmpb, tmp; //c00 = c[BLIS_INDEX(i,j,rs_c,cs_c)] * *beta; tmpc = c[BLIS_INDEX(i,j,rs_c,cs_c)]; diff --git a/kernels/power7/3/test/bli_gemm_power7_int_8x4.h b/kernels/power7/3/test/bli_gemm_power7_int_8x4.h index ef1930907e..50984a67df 100644 --- a/kernels/power7/3/test/bli_gemm_power7_int_8x4.h +++ b/kernels/power7/3/test/bli_gemm_power7_int_8x4.h @@ -43,6 +43,8 @@ void bli_sgemm_opt_8x4 ( + dim_t m, + dim_t n, dim_t k, float* restrict alpha, float* restrict a, @@ -55,6 +57,8 @@ void bli_sgemm_opt_8x4 void bli_dgemm_opt_8x4 ( + dim_t m, + dim_t n, dim_t k, double* restrict alpha, double* restrict a, @@ -67,6 +71,8 @@ void bli_dgemm_opt_8x4 void bli_cgemm_opt_8x4 ( + dim_t m, + dim_t n, dim_t k, scomplex* restrict alpha, scomplex* restrict a, @@ -79,6 +85,8 @@ void bli_cgemm_opt_8x4 void bli_zgemm_opt_8x4 ( + dim_t m, + dim_t n, dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, diff --git a/kernels/power9/3/bli_gemm_power9_asm_d12x6.c b/kernels/power9/3/bli_gemm_power9_asm_d12x6.c index ec09f8e380..1f1c1d209b 100644 --- a/kernels/power9/3/bli_gemm_power9_asm_d12x6.c +++ b/kernels/power9/3/bli_gemm_power9_asm_d12x6.c @@ -37,7 +37,9 @@ void bli_dgemm_power9_asm_12x6 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -50,12 +52,14 @@ void bli_dgemm_power9_asm_12x6 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 16; - uint64_t k_left = k0 % 16; + uint64_t k_iter = k / 16; + uint64_t k_left = k % 16; - uint64_t rs_c = rs_c0; + uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( d, 12, 6, false ); + __asm__ volatile ( " \n\t" @@ -86,9 +90,9 @@ void bli_dgemm_power9_asm_12x6 "add %%r20, %%r19, %%r10 \n\t" // col 4 of C "add %%r21, %%r20, %%r10 \n\t" // col 5 of C " \n\t" - DZERO_OUT_VREG + DZERO_OUT_VREG " \n\t" - DPRELOAD + DPRELOAD " \n\t" "addi %%r8, %%r8, 96 \n\t" // move to next col/row of A/B "addi %%r7, %%r7, 96 \n\t" @@ -98,10 +102,10 @@ void bli_dgemm_power9_asm_12x6 "cmpwi %%r11, 0 \n\t" // if k_iter == 0, "beq DCONSIDERKLEFT \n\t" // then jmp to k_left "mtctr %%r11 \n\t" // else, do k_iter loop - " \n\t" + " \n\t" "DLOOPKITER: \n\t" // k_iter loop " \n\t" - A_B_PRODUCT_16 // compute A*B + A_B_PRODUCT_16 // compute A*B " \n\t" "bdnz DLOOPKITER \n\t" " \n\t" @@ -111,54 +115,26 @@ void bli_dgemm_power9_asm_12x6 "beq DPOSTACCUM \n\t" // then jmp to post accum "mtctr %%r12 \n\t" // else, do k_left loop " \n\t" - "DLOOPKLEFT: \n\t" // k_left loop + "DLOOPKLEFT: \n\t" // k_left loop " \n\t" A_B_PRODUCT_1 " \n\t" - "bdnz DLOOPKLEFT \n\t" + "bdnz DLOOPKLEFT \n\t" " \n\t" - "DPOSTACCUM: \n\t" + "DPOSTACCUM: \n\t" " \n\t" - DSCALE_ALPHA + DSCALE_ALPHA " \n\t" "cmpdi %%r26, 0 \n\t" // if beta == 0, "beq DBETAZERO \n\t" // then jmp to BZ " \n\t" - "cmpwi %%r9, 8 \n\t" // if rs_c == 8 - "beq DCOLSTOREDBNZ \n\t" // then jmp to col store - " \n\t" - "DGENSTOREDBNZ: \n\t" // BNZ gen stored case - " \n\t" - DGEN_LOAD_OFS_C - " \n\t" - DGEN_SCALE_BETA - " \n\t" - "b DGENSTORED \n\t" - " \n\t" - "DCOLSTOREDBNZ: \n\t" // BNZ col stored case - " \n\t" - DCOL_SCALE_BETA - " \n\t" - "b DCOLSTORED \n\t" + DCOL_SCALE_BETA " \n\t" "DBETAZERO: \n\t" // BZ case - " \n\t" - "cmpwi %%r9, 8 \n\t" // if rs_c == 8, - "beq DCOLSTORED \n\t" // C is col stored - " \n\t" - "DGENSTORED: \n\t" // BZ gen stored case - " \n\t" - DGEN_LOAD_OFS_C - " \n\t" - DGEN_STORE - " \n\t" - "b DDONE \n\t" - " \n\t" - "DCOLSTORED: \n\t" // BZ col stored case " \n\t" DCOL_STORE " \n\t" - "DDONE: \n\t" + "DDONE: \n\t" " \n\t" : // output operands (none) : // input operands @@ -176,8 +152,8 @@ void bli_dgemm_power9_asm_12x6 : // register clobber list /* unclobberable regs: r2, r3, r4, r5, r6, r13, r14, r15, r30, r31 */ "r0", "r7", "r8", "r9", - "r10", "r11", "r12", "r16", "r17", "r18", "r19", - "r20", "r21", "r22", "r23", "r24", "r25", "r26", "r27", "r28", "r29" + "r10", "r11", "r12", "r16", "r17", "r18", "r19", + "r20", "r21", "r22", "r23", "r24", "r25", "r26", "r27", "r28", "r29" #if XLC ,"f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9" @@ -198,4 +174,6 @@ void bli_dgemm_power9_asm_12x6 #endif ); + + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c b/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c index a56ef16e5e..0cf1dd20c3 100644 --- a/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c +++ b/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c @@ -42,7 +42,9 @@ void bli_sgemm_sandybridge_asm_8x8 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -57,27 +59,29 @@ void bli_sgemm_sandybridge_asm_8x8 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - begin_asm() - - + GEMM_UKR_SETUP_CT( s, 8, 8, false ); + + begin_asm() + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(var(b_next), r15) // load address of b_next. - + vmovaps(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovsldup(mem(rbx, 0*32), ymm2) // elements of a and b. vpermilps(imm(0x4e), ymm2, ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c - lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float) + lea(mem( rdi, 4), rdi) // cs_c *= sizeof(float) lea(mem(rcx, rdi, 4), r10) // load address of c + 4*cs_c; - + lea(mem(rdi, rdi, 2), r14) // r14 = 3*cs_c; prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*cs_c @@ -87,7 +91,7 @@ void bli_sgemm_sandybridge_asm_8x8 prefetch(0, mem(r10, rdi, 1, 7*8)) // prefetch c + 5*cs_c prefetch(0, mem(r10, rdi, 2, 7*8)) // prefetch c + 6*cs_c prefetch(0, mem(r10, r14, 1, 7*8)) // prefetch c + 7*cs_c - + vxorps(ymm8, ymm8, ymm8) vxorps(ymm9, ymm9, ymm9) vxorps(ymm10, ymm10, ymm10) @@ -96,18 +100,18 @@ void bli_sgemm_sandybridge_asm_8x8 vxorps(ymm13, ymm13, ymm13) vxorps(ymm14, ymm14, ymm14) vxorps(ymm15, ymm15, ymm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.SLOOPKITER) // MAIN LOOP - - + + // iteration 0 prefetch(0, mem(rax, 16*32)) vmulps(ymm0, ymm2, ymm6) @@ -117,14 +121,14 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm15, ymm6, ymm15) vaddps(ymm13, ymm7, ymm13) - + vmovaps(mem(rax, 1*32), ymm1) vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vaddps(ymm11, ymm6, ymm11) vaddps(ymm9, ymm7, ymm9) - + vmulps(ymm0, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 1*32), ymm2) @@ -132,13 +136,13 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm14, ymm6, ymm14) vaddps(ymm12, ymm7, ymm12) - + vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vaddps(ymm10, ymm6, ymm10) vaddps(ymm8, ymm7, ymm8) - + // iteration 1 vmulps(ymm1, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) @@ -147,14 +151,14 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm15, ymm6, ymm15) vaddps(ymm13, ymm7, ymm13) - + vmovaps(mem(rax, 2*32), ymm0) vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddps(ymm11, ymm6, ymm11) vaddps(ymm9, ymm7, ymm9) - + vmulps(ymm1, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 2*32), ymm2) @@ -162,14 +166,14 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm14, ymm6, ymm14) vaddps(ymm12, ymm7, ymm12) - + vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddps(ymm10, ymm6, ymm10) vaddps(ymm8, ymm7, ymm8) - - + + // iteration 2 prefetch(0, mem(rax, 18*32)) vmulps(ymm0, ymm2, ymm6) @@ -179,7 +183,7 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm15, ymm6, ymm15) vaddps(ymm13, ymm7, ymm13) - + vmovaps(mem(rax, 3*32), ymm1) add(imm(4*8*4), rax) // a += 4*8 (unroll x mr) vpermilps(imm(0x4e), ymm2, ymm3) @@ -187,7 +191,7 @@ void bli_sgemm_sandybridge_asm_8x8 vmulps(ymm0, ymm5, ymm7) vaddps(ymm11, ymm6, ymm11) vaddps(ymm9, ymm7, ymm9) - + vmulps(ymm0, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 3*32), ymm2) @@ -195,14 +199,14 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm14, ymm6, ymm14) vaddps(ymm12, ymm7, ymm12) - + vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vaddps(ymm10, ymm6, ymm10) vaddps(ymm8, ymm7, ymm8) - - + + // iteration 3 vmulps(ymm1, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) @@ -212,14 +216,14 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm15, ymm6, ymm15) vaddps(ymm13, ymm7, ymm13) - + vmovaps(mem(rax, 0*32), ymm0) vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddps(ymm11, ymm6, ymm11) vaddps(ymm9, ymm7, ymm9) - + vmulps(ymm1, ymm2, ymm6) vperm2f128(imm(0x03), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 0*32), ymm2) @@ -227,35 +231,35 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x03), ymm3, ymm3, ymm5) vaddps(ymm14, ymm6, ymm14) vaddps(ymm12, ymm7, ymm12) - + vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddps(ymm10, ymm6, ymm10) vaddps(ymm8, ymm7, ymm8) - - - - + + + + dec(rsi) // i -= 1; jne(.SLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.SCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.SLOOPKLEFT) // EDGE LOOP - - + + prefetch(0, mem(rax, 16*32)) vmulps(ymm0, ymm2, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) @@ -264,7 +268,7 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm15, ymm6, ymm15) vaddps(ymm13, ymm7, ymm13) - + vmovaps(mem(rax, 1*32), ymm1) add(imm(8*1*4), rax) // a += 8 (1 x mr) vpermilps(imm(0x4e), ymm2, ymm3) @@ -272,7 +276,7 @@ void bli_sgemm_sandybridge_asm_8x8 vmulps(ymm0, ymm5, ymm7) vaddps(ymm11, ymm6, ymm11) vaddps(ymm9, ymm7, ymm9) - + vmulps(ymm0, ymm2, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 1*32), ymm2) @@ -281,122 +285,122 @@ void bli_sgemm_sandybridge_asm_8x8 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm14, ymm6, ymm14) vaddps(ymm12, ymm7, ymm12) - + vpermilps(imm(0x4e), ymm2, ymm3) vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(ymm1, ymm0) vaddps(ymm10, ymm6, ymm10) vaddps(ymm8, ymm7, ymm8) - - - + + + dec(rsi) // i -= 1; jne(.SLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.SPOSTACCUM) - + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab22 ab20 ab26 ab24 // ab32 ab30 ab36 ab34 // ab44 ab46 ab40 ab42 - // ab54 ab56 ab50 ab52 + // ab54 ab56 ab50 ab52 // ab66 ab64 ab62 ab60 // ab76 ) ab74 ) ab72 ) ab70 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab23 ab21 ab27 ab25 // ab33 ab31 ab37 ab35 // ab45 ab47 ab41 ab43 - // ab55 ab57 ab51 ab53 + // ab55 ab57 ab51 ab53 // ab67 ab65 ab63 ab61 // ab77 ) ab75 ) ab73 ) ab71 ) - + vmovaps(ymm15, ymm7) vshufps(imm(0xe4), ymm13, ymm15, ymm15) vshufps(imm(0xe4), ymm7, ymm13, ymm13) - + vmovaps(ymm11, ymm7) vshufps(imm(0xe4), ymm9, ymm11, ymm11) vshufps(imm(0xe4), ymm7, ymm9, ymm9) - + vmovaps(ymm14, ymm7) vshufps(imm(0xe4), ymm12, ymm14, ymm14) vshufps(imm(0xe4), ymm7, ymm12, ymm12) - + vmovaps(ymm10, ymm7) vshufps(imm(0xe4), ymm8, ymm10, ymm10) vshufps(imm(0xe4), ymm7, ymm8, ymm8) - + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab20 ab22 ab24 ab26 // ab30 ab32 ab34 ab36 // ab44 ab46 ab40 ab42 - // ab54 ab56 ab50 ab52 + // ab54 ab56 ab50 ab52 // ab64 ab66 ab60 ab62 // ab74 ) ab76 ) ab70 ) ab72 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab21 ab23 ab25 ab27 // ab31 ab33 ab35 ab37 // ab45 ab47 ab41 ab43 - // ab55 ab57 ab51 ab53 + // ab55 ab57 ab51 ab53 // ab65 ab67 ab61 ab63 // ab75 ) ab77 ) ab71 ) ab73 ) - + vmovaps(ymm15, ymm7) vperm2f128(imm(0x30), ymm11, ymm15, ymm15) vperm2f128(imm(0x12), ymm11, ymm7, ymm11) - + vmovaps(ymm13, ymm7) vperm2f128(imm(0x30), ymm9, ymm13, ymm13) vperm2f128(imm(0x12), ymm9, ymm7, ymm9) - + vmovaps(ymm14, ymm7) vperm2f128(imm(0x30), ymm10, ymm14, ymm14) vperm2f128(imm(0x12), ymm10, ymm7, ymm10) - + vmovaps(ymm12, ymm7) vperm2f128(imm(0x30), ymm8, ymm12, ymm12) vperm2f128(imm(0x12), ymm8, ymm7, ymm8) - + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab02 ( ab04 ( ab06 - // ab10 ab12 ab14 ab16 + // ab10 ab12 ab14 ab16 // ab20 ab22 ab24 ab26 // ab30 ab32 ab34 ab36 // ab40 ab42 ab44 ab46 - // ab50 ab52 ab54 ab56 + // ab50 ab52 ab54 ab56 // ab60 ab62 ab64 ab66 // ab70 ) ab72 ) ab74 ) ab76 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab01 ( ab03 ( ab05 ( ab07 - // ab11 ab13 ab15 ab17 + // ab11 ab13 ab15 ab17 // ab21 ab23 ab25 ab27 // ab31 ab33 ab35 ab37 // ab41 ab43 ab45 ab47 - // ab51 ab53 ab55 ab57 + // ab51 ab53 ab55 ab57 // ab61 ab63 ab65 ab67 // ab71 ) ab73 ) ab75 ) ab77 ) - - - + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rax), ymm0) // load alpha and duplicate vbroadcastss(mem(rbx), ymm4) // load beta and duplicate - + vmulps(ymm0, ymm8, ymm8) // scale by alpha vmulps(ymm0, ymm9, ymm9) vmulps(ymm0, ymm10, ymm10) @@ -405,601 +409,101 @@ void bli_sgemm_sandybridge_asm_8x8 vmulps(ymm0, ymm13, ymm13) vmulps(ymm0, ymm14, ymm14) vmulps(ymm0, ymm15, ymm15) - - - - - - + + mov(var(rs_c), rsi) // load rs_c - lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float) - + lea(mem( rsi, 4), rsi) // rsi = rs_c * sizeof(float) + lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - - lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; + + lea(mem( rsi, 2), r12) // r12 = 2*rs_c; lea(mem(r12, rsi, 1), r13) // r13 = 3*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm4) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. - jz(.SCOLSTORED) // jump to column storage case - - - - label(.SGENSTORED) - - // update c00:c70 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm15, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c01:c71 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm14, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c02:c72 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm13, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c03:c73 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm12, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c04:c74 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm11, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c05:c75 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm10, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c06:c76 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm9, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c07:c77 - vmovlps(mem(rcx), xmm0, xmm0) - vmovhps(mem(rcx, rsi, 1), xmm0, xmm0) - vmovlps(mem(rcx, r12, 1), xmm1, xmm1) - vmovhps(mem(rcx, r13, 1), xmm1, xmm1) - vshufps(imm(0x88), xmm1, xmm0, xmm0) - vmovlps(mem(rdx), xmm2, xmm2) - vmovhps(mem(rdx, rsi, 1), xmm2, xmm2) - vmovlps(mem(rdx, r12, 1), xmm3, xmm3) - vmovhps(mem(rdx, r13, 1), xmm3, xmm3) - vshufps(imm(0x88), xmm3, xmm2, xmm2) - vperm2f128(imm(0x20), ymm2, ymm0, ymm0) - - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm8, ymm0, ymm0) // add the gemm result, - - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORED) - - - vmovups(mem(rcx), ymm0) // load c00:c70, - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm15, ymm0, ymm0) // add the gemm result, - vmovups(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm1) // load c01:c71, - vmulps(ymm4, ymm1, ymm1) // scale by beta, - vaddps(ymm14, ymm1, ymm1) // add the gemm result, - vmovups(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm0) // load c02:c72, - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm13, ymm0, ymm0) // add the gemm result, - vmovups(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm1) // load c03:c73, - vmulps(ymm4, ymm1, ymm1) // scale by beta, - vaddps(ymm12, ymm1, ymm1) // add the gemm result, - vmovups(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm0) // load c04:c74, - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm11, ymm0, ymm0) // add the gemm result, - vmovups(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm1) // load c05:c75, - vmulps(ymm4, ymm1, ymm1) // scale by beta, - vaddps(ymm10, ymm1, ymm1) // add the gemm result, - vmovups(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm0) // load c06:c76, - vmulps(ymm4, ymm0, ymm0) // scale by beta, - vaddps(ymm9, ymm0, ymm0) // add the gemm result, - vmovups(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(mem(rcx), ymm1) // load c07:c77, - vmulps(ymm4, ymm1, ymm1) // scale by beta, - vaddps(ymm8, ymm1, ymm1) // add the gemm result, - vmovups(ymm1, mem(rcx)) // and store back to memory. - - - jmp(.SDONE) // jump to end. - - - - + + vmovups(mem(rcx), ymm0) // load c00:c70, + vmulps(ymm4, ymm0, ymm0) // scale by beta, + vaddps(ymm15, ymm0, ymm0) // add the gemm result, + vmovups(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm1) // load c01:c71, + vmulps(ymm4, ymm1, ymm1) // scale by beta, + vaddps(ymm14, ymm1, ymm1) // add the gemm result, + vmovups(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm0) // load c02:c72, + vmulps(ymm4, ymm0, ymm0) // scale by beta, + vaddps(ymm13, ymm0, ymm0) // add the gemm result, + vmovups(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm1) // load c03:c73, + vmulps(ymm4, ymm1, ymm1) // scale by beta, + vaddps(ymm12, ymm1, ymm1) // add the gemm result, + vmovups(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm0) // load c04:c74, + vmulps(ymm4, ymm0, ymm0) // scale by beta, + vaddps(ymm11, ymm0, ymm0) // add the gemm result, + vmovups(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm1) // load c05:c75, + vmulps(ymm4, ymm1, ymm1) // scale by beta, + vaddps(ymm10, ymm1, ymm1) // add the gemm result, + vmovups(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm0) // load c06:c76, + vmulps(ymm4, ymm0, ymm0) // scale by beta, + vaddps(ymm9, ymm0, ymm0) // add the gemm result, + vmovups(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(mem(rcx), ymm1) // load c07:c77, + vmulps(ymm4, ymm1, ymm1) // scale by beta, + vaddps(ymm8, ymm1, ymm1) // add the gemm result, + vmovups(ymm1, mem(rcx)) // and store back to memory. + + jmp(.SDONE) // jump to end. + label(.SBETAZERO) - - cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. - jz(.SCOLSTORBZ) // jump to column storage case - - - - label(.SGENSTORBZ) - - // update c00:c70 - vmovups(ymm15, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c01:c71 - vmovups(ymm14, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c02:c72 - vmovups(ymm13, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c03:c73 - vmovups(ymm12, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c04:c74 - vmovups(ymm11, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c05:c75 - vmovups(ymm10, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c06:c76 - vmovups(ymm9, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - add(rdi, rcx) // c += cs_c; - add(rdi, rdx) // c += cs_c; - - - // update c07:c77 - vmovups(ymm8, ymm0) - vextractf128(imm(1), ymm0, xmm2) - vmovss(xmm0, mem(rcx)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, rsi, 1)) - vpermilps(imm(0x39), xmm1, xmm0) - vmovss(xmm0, mem(rcx, r12, 1)) - vpermilps(imm(0x39), xmm0, xmm1) - vmovss(xmm1, mem(rcx, r13, 1)) - vmovss(xmm2, mem(rdx)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, rsi, 1)) - vpermilps(imm(0x39), xmm3, xmm2) - vmovss(xmm2, mem(rdx, r12, 1)) - vpermilps(imm(0x39), xmm2, xmm3) - vmovss(xmm3, mem(rdx, r13, 1)) - - - jmp(.SDONE) // jump to end. - - - - label(.SCOLSTORBZ) - - - vmovups(ymm15, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm14, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm13, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm12, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm11, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm10, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm9, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovups(ymm8, mem(rcx)) // and store back to memory. - - - - - + + vmovups(ymm15, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm14, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm13, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm12, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm11, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm10, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm9, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovups(ymm8, mem(rcx)) // and store back to memory. + label(.SDONE) - + vzeroupper() - + end_asm( : // output operands (none) @@ -1016,7 +520,7 @@ void bli_sgemm_sandybridge_asm_8x8 [b_next] "m" (b_next), // 9 [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1024,11 +528,17 @@ void bli_sgemm_sandybridge_asm_8x8 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + \ +if ( _use_ct ) bli_sxpbys_mxn ( m , n , _ct , _rs_ct , _cs_ct , _beta , _c , _rs_c , _cs_c ) ; + ; } void bli_dgemm_sandybridge_asm_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -1043,34 +553,36 @@ void bli_dgemm_sandybridge_asm_8x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( d, 8, 4, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. //mov(var(a_next), r14) // load address of a_next. sub(imm(4*64), r15) - + vmovapd(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovapd(mem(rbx, 0*32), ymm2) // elements of a and b. vpermilpd(imm(0x5), ymm2, ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(double) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*cs_c prefetch(0, mem(r10, 3*8)) // prefetch c + 2*cs_c prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c - + vxorpd(ymm8, ymm8, ymm8) vxorpd(ymm9, ymm9, ymm9) vxorpd(ymm10, ymm10, ymm10) @@ -1079,19 +591,19 @@ void bli_dgemm_sandybridge_asm_8x4 vxorpd(ymm13, ymm13, ymm13) vxorpd(ymm14, ymm14, ymm14) vxorpd(ymm15, ymm15, ymm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - + add(imm(4*4*8), r15) // b_next += 4*4 (unroll x nr) - + // iteration 0 vmovapd(mem(rax, 1*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -1100,7 +612,7 @@ void bli_dgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm15, ymm6, ymm15) vaddpd(ymm13, ymm7, ymm13) - + prefetch(0, mem(rax, 16*32)) vmulpd(ymm1, ymm2, ymm6) vmovapd(mem(rbx, 1*32), ymm2) @@ -1108,20 +620,20 @@ void bli_dgemm_sandybridge_asm_8x4 vpermilpd(imm(0x5), ymm2, ymm3) vaddpd(ymm14, ymm6, ymm14) vaddpd(ymm12, ymm7, ymm12) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 2*32), ymm0) vaddpd(ymm11, ymm6, ymm11) vaddpd(ymm9, ymm7, ymm9) prefetch(0, mem(r15, 0*32)) // prefetch b_next[0*4] - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddpd(ymm10, ymm6, ymm10) vaddpd(ymm8, ymm7, ymm8) - - + + // iteration 1 vmovapd(mem(rax, 3*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -1130,7 +642,7 @@ void bli_dgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm15, ymm6, ymm15) vaddpd(ymm13, ymm7, ymm13) - + prefetch(0, mem(rax, 18*32)) vmulpd(ymm1, ymm2, ymm6) vmovapd(mem(rbx, 2*32), ymm2) @@ -1138,19 +650,19 @@ void bli_dgemm_sandybridge_asm_8x4 vpermilpd(imm(0x5), ymm2, ymm3) vaddpd(ymm14, ymm6, ymm14) vaddpd(ymm12, ymm7, ymm12) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 4*32), ymm0) vaddpd(ymm11, ymm6, ymm11) vaddpd(ymm9, ymm7, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddpd(ymm10, ymm6, ymm10) vaddpd(ymm8, ymm7, ymm8) - - + + // iteration 2 vmovapd(mem(rax, 5*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -1159,7 +671,7 @@ void bli_dgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm15, ymm6, ymm15) vaddpd(ymm13, ymm7, ymm13) - + prefetch(0, mem(rax, 20*32)) vmulpd(ymm1, ymm2, ymm6) vmovapd(mem(rbx, 3*32), ymm2) @@ -1168,20 +680,20 @@ void bli_dgemm_sandybridge_asm_8x4 vpermilpd(imm(0x5), ymm2, ymm3) vaddpd(ymm14, ymm6, ymm14) vaddpd(ymm12, ymm7, ymm12) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 6*32), ymm0) vaddpd(ymm11, ymm6, ymm11) vaddpd(ymm9, ymm7, ymm9) prefetch(0, mem(r15, 2*32)) // prefetch b_next[2*4] - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddpd(ymm10, ymm6, ymm10) vaddpd(ymm8, ymm7, ymm8) - - + + // iteration 3 vmovapd(mem(rax, 7*32), ymm1) add(imm(4*8*8), rax) // a += 4*8 (unroll x mr) @@ -1191,7 +703,7 @@ void bli_dgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm15, ymm6, ymm15) vaddpd(ymm13, ymm7, ymm13) - + //prefetch(0, mem(rax, 22*32)) prefetch(0, mem(rax, 14*32)) vmulpd(ymm1, ymm2, ymm6) @@ -1200,41 +712,41 @@ void bli_dgemm_sandybridge_asm_8x4 vpermilpd(imm(0x5), ymm2, ymm3) vaddpd(ymm14, ymm6, ymm14) vaddpd(ymm12, ymm7, ymm12) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 0*32), ymm0) vaddpd(ymm11, ymm6, ymm11) vaddpd(ymm9, ymm7, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddpd(ymm10, ymm6, ymm10) vaddpd(ymm8, ymm7, ymm8) - - - + + + //add(imm(4*8*8), rax) // a += 4*8 (unroll x mr) //add(imm(4*4*8), rbx) // b += 4*4 (unroll x nr) - + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + vmovapd(mem(rax, 1*32), ymm1) add(imm(8*1*8), rax) // a += 8 (1 x mr) vmulpd(ymm0, ymm2, ymm6) @@ -1243,7 +755,7 @@ void bli_dgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm15, ymm6, ymm15) vaddpd(ymm13, ymm7, ymm13) - + prefetch(0, mem(rax, 14*32)) vmulpd(ymm1, ymm2, ymm6) vmovapd(mem(rbx, 1*32), ymm2) @@ -1252,101 +764,101 @@ void bli_dgemm_sandybridge_asm_8x4 vpermilpd(imm(0x5), ymm2, ymm3) vaddpd(ymm14, ymm6, ymm14) vaddpd(ymm12, ymm7, ymm12) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 0*32), ymm0) vaddpd(ymm11, ymm6, ymm11) vaddpd(ymm9, ymm7, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddpd(ymm10, ymm6, ymm10) vaddpd(ymm8, ymm7, ymm8) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - + + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab11 ab10 ab13 ab12 + // ab11 ab10 ab13 ab12 // ab22 ab23 ab20 ab21 // ab33 ) ab32 ) ab31 ) ab30 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab51 ab50 ab53 ab52 + // ab51 ab50 ab53 ab52 // ab62 ab63 ab60 ab61 // ab73 ) ab72 ) ab71 ) ab70 ) - + vmovapd(ymm15, ymm7) vshufpd(imm(0xa), ymm15, ymm13, ymm15) vshufpd(imm(0xa), ymm13, ymm7, ymm13) - + vmovapd(ymm11, ymm7) vshufpd(imm(0xa), ymm11, ymm9, ymm11) vshufpd(imm(0xa), ymm9, ymm7, ymm9) - + vmovapd(ymm14, ymm7) vshufpd(imm(0xa), ymm14, ymm12, ymm14) vshufpd(imm(0xa), ymm12, ymm7, ymm12) - + vmovapd(ymm10, ymm7) vshufpd(imm(0xa), ymm10, ymm8, ymm10) vshufpd(imm(0xa), ymm8, ymm7, ymm8) - + // ymm15: ymm13: ymm11: ymm9: // ( ab01 ( ab00 ( ab03 ( ab02 - // ab11 ab10 ab13 ab12 + // ab11 ab10 ab13 ab12 // ab23 ab22 ab21 ab20 // ab33 ) ab32 ) ab31 ) ab30 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab41 ( ab40 ( ab43 ( ab42 - // ab51 ab50 ab53 ab52 + // ab51 ab50 ab53 ab52 // ab63 ab62 ab61 ab60 // ab73 ) ab72 ) ab71 ) ab70 ) - + vmovapd(ymm15, ymm7) vperm2f128(imm(0x30), ymm15, ymm11, ymm15) vperm2f128(imm(0x12), ymm7, ymm11, ymm11) - + vmovapd(ymm13, ymm7) vperm2f128(imm(0x30), ymm13, ymm9, ymm13) vperm2f128(imm(0x12), ymm7, ymm9, ymm9) - + vmovapd(ymm14, ymm7) vperm2f128(imm(0x30), ymm14, ymm10, ymm14) vperm2f128(imm(0x12), ymm7, ymm10, ymm10) - + vmovapd(ymm12, ymm7) vperm2f128(imm(0x30), ymm12, ymm8, ymm12) vperm2f128(imm(0x12), ymm7, ymm8, ymm8) - + // ymm9: ymm11: ymm13: ymm15: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 + // ab10 ab11 ab12 ab13 // ab20 ab21 ab22 ab23 // ab30 ) ab31 ) ab32 ) ab33 ) - + // ymm8: ymm10: ymm12: ymm14: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab50 ab51 ab52 ab53 + // ab50 ab51 ab52 ab53 // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - - + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm2) // load beta and duplicate - + vmulpd(ymm0, ymm8, ymm8) // scale by alpha vmulpd(ymm0, ymm9, ymm9) vmulpd(ymm0, ymm10, ymm10) @@ -1355,326 +867,107 @@ void bli_dgemm_sandybridge_asm_8x4 vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm0, ymm14, ymm14) vmulpd(ymm0, ymm15, ymm15) - - - - - - + + mov(var(rs_c), rsi) // load rs_c lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(double) - + lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - + lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; lea(mem(r12, rsi, 1), r13) // r13 = 3*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm2) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.DCOLSTORED) // jump to column storage case - - - - label(.DGENSTORED) - // update c00:c33 - - vextractf128(imm(1), ymm9, xmm1) - vmovlpd(mem(rcx), xmm0, xmm0) // load c00 and c10, - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm9, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c20 and c30, - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm11, xmm1) - vmovlpd(mem(rcx), xmm0, xmm0) // load c01 and c11, - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm11, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c21 and c31, - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm13, xmm1) - vmovlpd(mem(rcx), xmm0, xmm0) // load c02 and c12, - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm13, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c22 and c32, - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm15, xmm1) - vmovlpd(mem(rcx), xmm0, xmm0) // load c03 and c13, - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm15, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, rsi, 1)) - vmovlpd(mem(rcx, r12, 1), xmm0, xmm0) // load c23 and c33, - vmovhpd(mem(rcx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rcx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rcx, r13, 1)) - - // update c40:c73 - - vextractf128(imm(1), ymm8, xmm1) - vmovlpd(mem(rdx), xmm0, xmm0) // load c40 and c50, - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm8, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c60 and c70, - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm10, xmm1) - vmovlpd(mem(rdx), xmm0, xmm0) // load c41 and c51, - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm10, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c61 and c71, - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm12, xmm1) - vmovlpd(mem(rdx), xmm0, xmm0) // load c42 and c52, - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm12, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c62 and c72, - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm14, xmm1) - vmovlpd(mem(rdx), xmm0, xmm0) // load c43 and c53, - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm14, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, rsi, 1)) - vmovlpd(mem(rdx, r12, 1), xmm0, xmm0) // load c63 and c73, - vmovhpd(mem(rdx, r13, 1), xmm0, xmm0) - vmulpd(xmm2, xmm0, xmm0) // scale by beta, - vaddpd(xmm1, xmm0, xmm0) // add the gemm result, - vmovlpd(xmm0, mem(rdx, r12, 1)) // and store back to memory. - vmovhpd(xmm0, mem(rdx, r13, 1)) - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORED) - // update c00:c33 - - vmovupd(mem(rcx), ymm0) // load c00:c30, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm9, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovupd(mem(rcx), ymm0) // load c01:c31, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm11, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovupd(mem(rcx), ymm0) // load c02:c32, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm13, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovupd(mem(rcx), ymm0) // load c03:c33, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm15, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rcx)) // and store back to memory. - - // update c40:c73 - - vmovupd(mem(rdx), ymm0) // load c40:c70, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm8, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rdx)) // and store back to memory. - add(rdi, rdx) // c += cs_c; - - vmovupd(mem(rdx), ymm0) // load c41:c71, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm10, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rdx)) // and store back to memory. - add(rdi, rdx) // c += cs_c; - - vmovupd(mem(rdx), ymm0) // load c42:c72, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm12, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rdx)) // and store back to memory. - add(rdi, rdx) // c += cs_c; - - vmovupd(mem(rdx), ymm0) // load c43:c73, - vmulpd(ymm2, ymm0, ymm0) // scale by beta, - vaddpd(ymm14, ymm0, ymm0) // add the gemm result, - vmovupd(ymm0, mem(rdx)) // and store back to memory. - - - jmp(.DDONE) // jump to end. - - - - + + // update c00:c33 + + vmovupd(mem(rcx), ymm0) // load c00:c30, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm9, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovupd(mem(rcx), ymm0) // load c01:c31, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm11, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovupd(mem(rcx), ymm0) // load c02:c32, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm13, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovupd(mem(rcx), ymm0) // load c03:c33, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm15, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rcx)) // and store back to memory. + + // update c40:c73 + + vmovupd(mem(rdx), ymm0) // load c40:c70, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm8, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rdx)) // and store back to memory. + add(rdi, rdx) // c += cs_c; + + vmovupd(mem(rdx), ymm0) // load c41:c71, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm10, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rdx)) // and store back to memory. + add(rdi, rdx) // c += cs_c; + + vmovupd(mem(rdx), ymm0) // load c42:c72, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm12, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rdx)) // and store back to memory. + add(rdi, rdx) // c += cs_c; + + vmovupd(mem(rdx), ymm0) // load c43:c73, + vmulpd(ymm2, ymm0, ymm0) // scale by beta, + vaddpd(ymm14, ymm0, ymm0) // add the gemm result, + vmovupd(ymm0, mem(rdx)) // and store back to memory. + + jmp(.DDONE) // jump to end. + label(.DBETAZERO) - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.DCOLSTORBZ) // jump to column storage case - - - - label(.DGENSTORBZ) - // update c00:c33 - - vextractf128(imm(1), ymm9, xmm1) - vmovlpd(xmm9, mem(rcx)) // store to c00:c30 - vmovhpd(xmm9, mem(rcx, rsi, 1)) - vmovlpd(xmm1, mem(rcx, r12, 1)) - vmovhpd(xmm1, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm11, xmm1) - vmovlpd(xmm11, mem(rcx)) // store to c01:c31 - vmovhpd(xmm11, mem(rcx, rsi, 1)) - vmovlpd(xmm1, mem(rcx, r12, 1)) - vmovhpd(xmm1, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm13, xmm1) - vmovlpd(xmm13, mem(rcx)) // store to c02:c32 - vmovhpd(xmm13, mem(rcx, rsi, 1)) - vmovlpd(xmm1, mem(rcx, r12, 1)) - vmovhpd(xmm1, mem(rcx, r13, 1)) - add(rdi, rcx) // c += cs_c; - - vextractf128(imm(1), ymm15, xmm1) - vmovlpd(xmm15, mem(rcx)) // store to c03:c33 - vmovhpd(xmm15, mem(rcx, rsi, 1)) - vmovlpd(xmm1, mem(rcx, r12, 1)) - vmovhpd(xmm1, mem(rcx, r13, 1)) - - // update c40:c73 - - vextractf128(imm(1), ymm8, xmm1) - vmovlpd(xmm8, mem(rdx)) // store to c40:c70 - vmovhpd(xmm8, mem(rdx, rsi, 1)) - vmovlpd(xmm1, mem(rdx, r12, 1)) - vmovhpd(xmm1, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm10, xmm1) - vmovlpd(xmm10, mem(rdx)) // store to c41:c71 - vmovhpd(xmm10, mem(rdx, rsi, 1)) - vmovlpd(xmm1, mem(rdx, r12, 1)) - vmovhpd(xmm1, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm12, xmm1) - vmovlpd(xmm12, mem(rdx)) // store to c42:c72 - vmovhpd(xmm12, mem(rdx, rsi, 1)) - vmovlpd(xmm1, mem(rdx, r12, 1)) - vmovhpd(xmm1, mem(rdx, r13, 1)) - add(rdi, rdx) // c += cs_c; - - vextractf128(imm(1), ymm14, xmm1) - vmovlpd(xmm14, mem(rdx)) // store to c43:c73 - vmovhpd(xmm14, mem(rdx, rsi, 1)) - vmovlpd(xmm1, mem(rdx, r12, 1)) - vmovhpd(xmm1, mem(rdx, r13, 1)) - - - jmp(.DDONE) // jump to end. - - - - label(.DCOLSTORBZ) - // update c00:c33 - - vmovupd(ymm9, mem(rcx)) // store c00:c30 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm11, mem(rcx)) // store c01:c31 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm13, mem(rcx)) // store c02:c32 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm15, mem(rcx)) // store c03:c33 - - // update c40:c73 - - vmovupd(ymm8, mem(rdx)) // store c40:c70 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm10, mem(rdx)) // store c41:c71 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm12, mem(rdx)) // store c42:c72 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm14, mem(rdx)) // store c43:c73 - - - - - + + // update c00:c33 + + vmovupd(ymm9, mem(rcx)) // store c00:c30 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm11, mem(rcx)) // store c01:c31 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm13, mem(rcx)) // store c02:c32 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm15, mem(rcx)) // store c03:c33 + + // update c40:c73 + + vmovupd(ymm8, mem(rdx)) // store c40:c70 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm10, mem(rdx)) // store c41:c71 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm12, mem(rdx)) // store c42:c72 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm14, mem(rdx)) // store c43:c73 + label(.DDONE) - + vzeroupper() - - vzeroupper() - + end_asm( : // output operands (none) @@ -1691,7 +984,7 @@ void bli_dgemm_sandybridge_asm_8x4 [b_next] "m" (b_next)/*, // 9 [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1699,11 +992,15 @@ void bli_dgemm_sandybridge_asm_8x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } void bli_cgemm_sandybridge_asm_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -1718,34 +1015,36 @@ void bli_cgemm_sandybridge_asm_8x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( c, 8, 4, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. mov(var(b_next), r15) // load address of b_next. //mov(var(a_next), r14) // load address of a_next. sub(imm(4*64), r15) - + vmovaps(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovsldup(mem(rbx, 0*32), ymm2) vpermilps(imm(0x4e), ymm2, ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(scomplex) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*cs_c prefetch(0, mem(r10, 3*8)) // prefetch c + 2*cs_c prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c - + vxorps(ymm8, ymm8, ymm8) vxorps(ymm9, ymm9, ymm9) vxorps(ymm10, ymm10, ymm10) @@ -1754,19 +1053,19 @@ void bli_cgemm_sandybridge_asm_8x4 vxorps(ymm13, ymm13, ymm13) vxorps(ymm14, ymm14, ymm14) vxorps(ymm15, ymm15, ymm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.CCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.CLOOPKITER) // MAIN LOOP - + add(imm(4*4*8), r15) // b_next += 4*4 (unroll x nr) - + // iteration 0 prefetch(0, mem(rax, 8*32)) vmovaps(mem(rax, 1*32), ymm1) @@ -1776,20 +1075,20 @@ void bli_cgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm15, ymm15) vaddps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovshdup(mem(rbx, 0*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddps(ymm6, ymm14, ymm14) vaddps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vpermilps(imm(0xb1), ymm0, ymm0) vaddps(ymm6, ymm11, ymm11) vaddps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulps(ymm1, ymm5, ymm7) @@ -1797,32 +1096,32 @@ void bli_cgemm_sandybridge_asm_8x4 vaddps(ymm6, ymm10, ymm10) vaddps(ymm7, ymm8, ymm8) prefetch(0, mem(r15, 0*32)) // prefetch b_next[0*4] - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 1*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 2*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + // iteration 1 prefetch(0, mem(rax, 10*32)) vmovaps(mem(rax, 3*32), ymm1) @@ -1832,52 +1131,52 @@ void bli_cgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm15, ymm15) vaddps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovshdup(mem(rbx, 1*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddps(ymm6, ymm14, ymm14) vaddps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vpermilps(imm(0xb1), ymm0, ymm0) vaddps(ymm6, ymm11, ymm11) vaddps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulps(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm10, ymm10) vaddps(ymm7, ymm8, ymm8) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 2*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 4*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + // iteration 2 prefetch(0, mem(rax, 12*32)) vmovaps(mem(rax, 5*32), ymm1) @@ -1887,20 +1186,20 @@ void bli_cgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm15, ymm15) vaddps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovshdup(mem(rbx, 2*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddps(ymm6, ymm14, ymm14) vaddps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vpermilps(imm(0xb1), ymm0, ymm0) vaddps(ymm6, ymm11, ymm11) vaddps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulps(ymm1, ymm5, ymm7) @@ -1908,32 +1207,32 @@ void bli_cgemm_sandybridge_asm_8x4 vaddps(ymm6, ymm10, ymm10) vaddps(ymm7, ymm8, ymm8) prefetch(0, mem(r15, 2*32)) // prefetch b_next[2*4] - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 3*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 6*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + // iteration 3 prefetch(0, mem(rax, 14*32)) vmovaps(mem(rax, 7*32), ymm1) @@ -1943,74 +1242,74 @@ void bli_cgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm15, ymm15) vaddps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovshdup(mem(rbx, 3*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddps(ymm6, ymm14, ymm14) vaddps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vpermilps(imm(0xb1), ymm0, ymm0) vaddps(ymm6, ymm11, ymm11) vaddps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulps(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm10, ymm10) vaddps(ymm7, ymm8, ymm8) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 4*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 8*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + add(imm(8*4*8), rax) // a += 8*4 (unroll x mr) add(imm(4*4*8), rbx) // b += 4*4 (unroll x nr) - - + + dec(rsi) // i -= 1; jne(.CLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.CCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.CPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.CLOOPKLEFT) // EDGE LOOP - + // iteration 0 prefetch(0, mem(rax, 8*32)) vmovaps(mem(rax, 1*32), ymm1) @@ -2020,228 +1319,228 @@ void bli_cgemm_sandybridge_asm_8x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm15, ymm15) vaddps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovshdup(mem(rbx, 0*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddps(ymm6, ymm14, ymm14) vaddps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vpermilps(imm(0xb1), ymm0, ymm0) vaddps(ymm6, ymm11, ymm11) vaddps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulps(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddps(ymm6, ymm10, ymm10) vaddps(ymm7, ymm8, ymm8) - + vpermilps(imm(0xb1), ymm1, ymm1) vmulps(ymm0, ymm2, ymm6) vmulps(ymm0, ymm3, ymm7) vaddsubps(ymm6, ymm15, ymm15) vaddsubps(ymm7, ymm13, ymm13) - + vmulps(ymm1, ymm2, ymm6) vmovsldup(mem(rbx, 1*32), ymm2) vmulps(ymm1, ymm3, ymm7) vpermilps(imm(0x4e), ymm2, ymm3) vaddsubps(ymm6, ymm14, ymm14) vaddsubps(ymm7, ymm12, ymm12) - + vmulps(ymm0, ymm4, ymm6) vmulps(ymm0, ymm5, ymm7) vmovaps(mem(rax, 2*32), ymm0) vaddsubps(ymm6, ymm11, ymm11) vaddsubps(ymm7, ymm9, ymm9) - + vmulps(ymm1, ymm4, ymm6) vmulps(ymm1, ymm5, ymm7) vaddsubps(ymm6, ymm10, ymm10) vaddsubps(ymm7, ymm8, ymm8) - - + + add(imm(8*1*8), rax) // a += 8 (1 x mr) add(imm(4*1*8), rbx) // b += 4 (1 x nr) - - + + dec(rsi) // i -= 1; jne(.CLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.CPOSTACCUM) - + // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab21 ab20 ab23 ab22 - // ab31 ab30 ab33 ab32 - // ab42 ab43 ab40 ab41 - // ab52 ab53 ab50 ab51 - // ab63 ab62 ab61 ab60 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab21 ab20 ab23 ab22 + // ab31 ab30 ab33 ab32 + // ab42 ab43 ab40 ab41 + // ab52 ab53 ab50 ab51 + // ab63 ab62 ab61 ab60 // ab73 ) ab72 ) ab71 ) ab70 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba1 aba0 aba3 aba2 - // abb1 abb0 abb3 abb2 - // abc2 abc3 abc0 abc1 - // abd2 abd3 abd0 abd1 - // abe3 abe2 abe1 abe0 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba1 aba0 aba3 aba2 + // abb1 abb0 abb3 abb2 + // abc2 abc3 abc0 abc1 + // abd2 abd3 abd0 abd1 + // abe3 abe2 abe1 abe0 // abf3 abf2 abf1 abf0 ) - + vmovaps(ymm15, ymm7) vshufps(imm(0xe4), ymm13, ymm15, ymm15) vshufps(imm(0xe4), ymm7, ymm13, ymm13) - + vmovaps(ymm11, ymm7) vshufps(imm(0xe4), ymm9, ymm11, ymm11) vshufps(imm(0xe4), ymm7, ymm9, ymm9) - + vmovaps(ymm14, ymm7) vshufps(imm(0xe4), ymm12, ymm14, ymm14) vshufps(imm(0xe4), ymm7, ymm12, ymm12) - + vmovaps(ymm10, ymm7) vshufps(imm(0xe4), ymm8, ymm10, ymm10) vshufps(imm(0xe4), ymm7, ymm8, ymm8) - + // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab20 ab21 ab22 ab23 - // ab30 ab31 ab32 ab33 - // ab42 ab43 ab40 ab41 - // ab52 ab53 ab50 ab51 - // ab62 ab63 ab60 ab61 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab20 ab21 ab22 ab23 + // ab30 ab31 ab32 ab33 + // ab42 ab43 ab40 ab41 + // ab52 ab53 ab50 ab51 + // ab62 ab63 ab60 ab61 // ab72 ) ab73 ) ab70 ) ab71 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba0 aba1 aba2 aba3 - // abb0 abb1 abb2 abb3 - // abc2 abc3 abc0 abc1 - // abd2 abd3 abd0 abd1 - // abe2 abe3 abe0 abe1 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba0 aba1 aba2 aba3 + // abb0 abb1 abb2 abb3 + // abc2 abc3 abc0 abc1 + // abd2 abd3 abd0 abd1 + // abe2 abe3 abe0 abe1 // abf2 ) abf3 ) abf0 ) abf1 ) - + vmovaps(ymm15, ymm7) vperm2f128(imm(0x12), ymm15, ymm11, ymm15) vperm2f128(imm(0x30), ymm7, ymm11, ymm11) - + vmovaps(ymm13, ymm7) vperm2f128(imm(0x12), ymm13, ymm9, ymm13) vperm2f128(imm(0x30), ymm7, ymm9, ymm9) - + vmovaps(ymm14, ymm7) vperm2f128(imm(0x12), ymm14, ymm10, ymm14) vperm2f128(imm(0x30), ymm7, ymm10, ymm10) - + vmovaps(ymm12, ymm7) vperm2f128(imm(0x12), ymm12, ymm8, ymm12) vperm2f128(imm(0x30), ymm7, ymm8, ymm8) - + // ymm15: ymm13: ymm11: ymm9: - // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 - // ab20 ab21 ab22 ab23 - // ab30 ab31 ab32 ab33 - // ab40 ab41 ab42 ab43 - // ab50 ab51 ab52 ab53 - // ab60 ab61 ab62 ab63 + // ( ab00 ( ab01 ( ab02 ( ab03 + // ab10 ab11 ab12 ab13 + // ab20 ab21 ab22 ab23 + // ab30 ab31 ab32 ab33 + // ab40 ab41 ab42 ab43 + // ab50 ab51 ab52 ab53 + // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - + // ymm14: ymm12: ymm10: ymm8: - // ( ab80 ( ab81 ( ab82 ( ab83 - // ab90 ab91 ab92 ab93 - // aba0 aba1 aba2 aba3 - // abb0 abb1 abb2 abb3 - // abc0 abc1 abc2 abc3 - // abd0 abd1 abd2 abd3 - // abe0 abe1 abe2 abe3 + // ( ab80 ( ab81 ( ab82 ( ab83 + // ab90 ab91 ab92 ab93 + // aba0 aba1 aba2 aba3 + // abb0 abb1 abb2 abb3 + // abc0 abc1 abc2 abc3 + // abd0 abd1 abd2 abd3 + // abe0 abe1 abe2 abe3 // abf0 ) abf1 ) abf2 ) abf3 ) - - - - + + + + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vbroadcastss(mem(rax), ymm7) // load alpha_r and duplicate vbroadcastss(mem(rax, 4), ymm6) // load alpha_i and duplicate - + vpermilps(imm(0xb1), ymm15, ymm3) vmulps(ymm7, ymm15, ymm15) vmulps(ymm6, ymm3, ymm3) vaddsubps(ymm3, ymm15, ymm15) - + vpermilps(imm(0xb1), ymm14, ymm2) vmulps(ymm7, ymm14, ymm14) vmulps(ymm6, ymm2, ymm2) vaddsubps(ymm2, ymm14, ymm14) - + vpermilps(imm(0xb1), ymm13, ymm1) vmulps(ymm7, ymm13, ymm13) vmulps(ymm6, ymm1, ymm1) vaddsubps(ymm1, ymm13, ymm13) - + vpermilps(imm(0xb1), ymm12, ymm0) vmulps(ymm7, ymm12, ymm12) vmulps(ymm6, ymm0, ymm0) vaddsubps(ymm0, ymm12, ymm12) - + vpermilps(imm(0xb1), ymm11, ymm3) vmulps(ymm7, ymm11, ymm11) vmulps(ymm6, ymm3, ymm3) vaddsubps(ymm3, ymm11, ymm11) - + vpermilps(imm(0xb1), ymm10, ymm2) vmulps(ymm7, ymm10, ymm10) vmulps(ymm6, ymm2, ymm2) vaddsubps(ymm2, ymm10, ymm10) - + vpermilps(imm(0xb1), ymm9, ymm1) vmulps(ymm7, ymm9, ymm9) vmulps(ymm6, ymm1, ymm1) vaddsubps(ymm1, ymm9, ymm9) - + vpermilps(imm(0xb1), ymm8, ymm0) vmulps(ymm7, ymm8, ymm8) vmulps(ymm6, ymm0, ymm0) vaddsubps(ymm0, ymm8, ymm8) - - - - + + + + mov(var(beta), rbx) // load address of beta vbroadcastss(mem(rbx), ymm7) // load beta_r and duplicate vbroadcastss(mem(rbx, 4), ymm6) // load beta_i and duplicate - - - - - - - + + + + + + + mov(var(rs_c), rsi) // load rs_c lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(scomplex) - + lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - + lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; lea(mem(r12, rsi, 1), r13) // r13 = 3*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm7) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -2249,393 +1548,127 @@ void bli_cgemm_sandybridge_asm_8x4 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 0, jump to beta == 0 case - - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.CCOLSTORED) // jump to column storage case - - - - label(.CGENSTORED) - - // update c00:c70 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c00,10) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c20,30) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c40,50) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c60,70) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c00,c10) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c20,c30) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c40,c50) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c60,c70) - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c80,90) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca0,b0) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc0,d0) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce0,f0) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c80,c90) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca0,cb0) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc0,cd0) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce0,cf0) - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c01,11) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c21,31) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c41,51) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c61,71) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c01,c11) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c21,c31) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c41,c51) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c61,c71) - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c81,91) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca1,b1) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc1,d1) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce1,f1) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c81,c91) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca1,cb1) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc1,cd1) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce1,cf1) - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c02,12) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c22,32) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c42,52) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c62,72) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c02,c12) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c22,c32) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c42,c52) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c62,c72) - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c82,92) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca2,b2) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc2,d2) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce2,f2) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c82,c92) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca2,cb2) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc2,cd2) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce2,cf2) - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - - vmovlpd(mem(rcx), xmm0, xmm0) // load (c03,13) into xmm0[0:1] - vmovhpd(mem(rcx, rsi, 1), xmm0, xmm0) // load (c23,33) into xmm0[2:3] - vmovlpd(mem(rcx, r12, 1), xmm2, xmm2) // load (c43,53) into xmm2[0:1] - vmovhpd(mem(rcx, r13, 1), xmm2, xmm2) // load (c63,73) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rcx)) // store (c03,c13) - vmovhpd(xmm0, mem(rcx, rsi, 1)) // store (c23,c33) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c43,c53) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c63,c73) - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - - vmovlpd(mem(rdx), xmm0, xmm0) // load (c83,93) into xmm0[0:1] - vmovhpd(mem(rdx, rsi, 1), xmm0, xmm0) // load (ca3,b3) into xmm0[2:3] - vmovlpd(mem(rdx, r12, 1), xmm2, xmm2) // load (cc3,d3) into xmm2[0:1] - vmovhpd(mem(rdx, r13, 1), xmm2, xmm2) // load (ce3,f3) into xmm2[2:3] - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:3],xmm2) - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm0, mem(rdx)) // store (c83,c93) - vmovhpd(xmm0, mem(rdx, rsi, 1)) // store (ca3,cb3) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc3,cd3) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce3,cf3) - add(rdi, rdx) // c += cs_c; - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORED) - - // update c00:c70 - - vmovups(mem(rcx), ymm0) // load c00:c70 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rcx)) // store c00:c70 - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - - vmovups(mem(rdx), ymm0) // load c80:f0 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rdx)) // store c80:cf0 - add(rdi, rdx) // c += cs_c; - - // update c00:c70 - - vmovups(mem(rcx), ymm0) // load c01:c71 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rcx)) // store c01:c71 - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - - vmovups(mem(rdx), ymm0) // load c81:f1 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rdx)) // store c81:cf1 - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - - vmovups(mem(rcx), ymm0) // load c02:c72 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rcx)) // store c02:c72 - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - - vmovups(mem(rdx), ymm0) // load c82:f2 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rdx)) // store c82:cf2 - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - - vmovups(mem(rcx), ymm0) // load c03:c73 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rcx)) // store c03:c73 - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - - vmovups(mem(rdx), ymm0) // load c83:f3 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vmovups(ymm0, mem(rdx)) // store c83:cf3 - add(rdi, rdx) // c += cs_c; - - - - jmp(.CDONE) // jump to end. - - - + + // update c00:c70 + + vmovups(mem(rcx), ymm0) // load c00:c70 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rcx)) // store c00:c70 + add(rdi, rcx) // c += cs_c; + + // update c80:cf0 + + vmovups(mem(rdx), ymm0) // load c80:f0 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rdx)) // store c80:cf0 + add(rdi, rdx) // c += cs_c; + + // update c00:c70 + + vmovups(mem(rcx), ymm0) // load c01:c71 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rcx)) // store c01:c71 + add(rdi, rcx) // c += cs_c; + + // update c81:cf1 + + vmovups(mem(rdx), ymm0) // load c81:f1 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rdx)) // store c81:cf1 + add(rdi, rdx) // c += cs_c; + + // update c02:c72 + + vmovups(mem(rcx), ymm0) // load c02:c72 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rcx)) // store c02:c72 + add(rdi, rcx) // c += cs_c; + + // update c82:cf2 + + vmovups(mem(rdx), ymm0) // load c82:f2 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rdx)) // store c82:cf2 + add(rdi, rdx) // c += cs_c; + + // update c03:c73 + + vmovups(mem(rcx), ymm0) // load c03:c73 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rcx)) // store c03:c73 + add(rdi, rcx) // c += cs_c; + + // update c83:cf3 + + vmovups(mem(rdx), ymm0) // load c83:f3 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 + vmovups(ymm0, mem(rdx)) // store c83:cf3 + add(rdi, rdx) // c += cs_c; + + jmp(.CDONE) // jump to end. + label(.CBETAZERO) - - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.CCOLSTORBZ) // jump to column storage case - - - - label(.CGENSTORBZ) - - // update c00:c70 - - vextractf128(imm(1), ymm15, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm15, mem(rcx)) // store (c00,c10) - vmovhpd(xmm15, mem(rcx, rsi, 1)) // store (c20,c30) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c40,c50) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c60,c70) - add(rdi, rcx) // c += cs_c; - - // update c80:cf0 - - vextractf128(imm(1), ymm14, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm14, mem(rdx)) // store (c80,c90) - vmovhpd(xmm14, mem(rdx, rsi, 1)) // store (ca0,cb0) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc0,cd0) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce0,cf0) - add(rdi, rdx) // c += cs_c; - - // update c01:c71 - - vextractf128(imm(1), ymm13, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm13, mem(rcx)) // store (c01,c11) - vmovhpd(xmm13, mem(rcx, rsi, 1)) // store (c21,c31) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c41,c51) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c61,c71) - add(rdi, rcx) // c += cs_c; - - // update c81:cf1 - - vextractf128(imm(1), ymm12, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm12, mem(rdx)) // store (c81,c91) - vmovhpd(xmm12, mem(rdx, rsi, 1)) // store (ca1,cb1) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc1,cd1) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce1,cf1) - add(rdi, rdx) // c += cs_c; - - // update c02:c72 - - vextractf128(imm(1), ymm11, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm11, mem(rcx)) // store (c02,c12) - vmovhpd(xmm11, mem(rcx, rsi, 1)) // store (c22,c32) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c42,c52) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c62,c72) - add(rdi, rcx) // c += cs_c; - - // update c82:cf2 - - vextractf128(imm(1), ymm10, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm10, mem(rdx)) // store (c82,c92) - vmovhpd(xmm10, mem(rdx, rsi, 1)) // store (ca2,cb2) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc2,cd2) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce2,cf2) - add(rdi, rdx) // c += cs_c; - - // update c03:c73 - - vextractf128(imm(1), ymm9, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm9, mem(rcx)) // store (c03,c13) - vmovhpd(xmm9, mem(rcx, rsi, 1)) // store (c23,c33) - vmovlpd(xmm2, mem(rcx, r12, 1)) // store (c43,c53) - vmovhpd(xmm2, mem(rcx, r13, 1)) // store (c63,c73) - add(rdi, rcx) // c += cs_c; - - // update c83:cf3 - - vextractf128(imm(1), ymm8, xmm2) // xmm2 := ymm0[4:7] - vmovlpd(xmm8, mem(rdx)) // store (c83,c93) - vmovhpd(xmm8, mem(rdx, rsi, 1)) // store (ca3,cb3) - vmovlpd(xmm2, mem(rdx, r12, 1)) // store (cc3,cd3) - vmovhpd(xmm2, mem(rdx, r13, 1)) // store (ce3,cf3) - add(rdi, rdx) // c += cs_c; - - - - jmp(.CDONE) // jump to end. - - - - label(.CCOLSTORBZ) - - - vmovups(ymm15, mem(rcx)) // store c00:c70 - add(rdi, rcx) // c += cs_c; - - vmovups(ymm14, mem(rdx)) // store c80:cf0 - add(rdi, rdx) // c += cs_c; - - vmovups(ymm13, mem(rcx)) // store c01:c71 - add(rdi, rcx) // c += cs_c; - - vmovups(ymm12, mem(rdx)) // store c81:cf1 - add(rdi, rdx) // c += cs_c; - - vmovups(ymm11, mem(rcx)) // store c02:c72 - add(rdi, rcx) // c += cs_c; - - vmovups(ymm10, mem(rdx)) // store c82:cf2 - add(rdi, rdx) // c += cs_c; - - vmovups(ymm9, mem(rcx)) // store c03:c73 - add(rdi, rcx) // c += cs_c; - - vmovups(ymm8, mem(rdx)) // store c83:cf3 - add(rdi, rdx) // c += cs_c; - - - - - + + vmovups(ymm15, mem(rcx)) // store c00:c70 + add(rdi, rcx) // c += cs_c; + + vmovups(ymm14, mem(rdx)) // store c80:cf0 + add(rdi, rdx) // c += cs_c; + + vmovups(ymm13, mem(rcx)) // store c01:c71 + add(rdi, rcx) // c += cs_c; + + vmovups(ymm12, mem(rdx)) // store c81:cf1 + add(rdi, rdx) // c += cs_c; + + vmovups(ymm11, mem(rcx)) // store c02:c72 + add(rdi, rcx) // c += cs_c; + + vmovups(ymm10, mem(rdx)) // store c82:cf2 + add(rdi, rdx) // c += cs_c; + + vmovups(ymm9, mem(rcx)) // store c03:c73 + add(rdi, rcx) // c += cs_c; + + vmovups(ymm8, mem(rdx)) // store c83:cf3 + add(rdi, rdx) // c += cs_c; + label(.CDONE) - + vzeroupper() - - vzeroupper() - + end_asm( : // output operands (none) @@ -2652,7 +1685,7 @@ void bli_cgemm_sandybridge_asm_8x4 [b_next] "m" (b_next)/*, // 9 [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2660,13 +1693,17 @@ void bli_cgemm_sandybridge_asm_8x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( c ); } void bli_zgemm_sandybridge_asm_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -2681,34 +1718,36 @@ void bli_zgemm_sandybridge_asm_4x4 // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_iter = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; + GEMM_UKR_SETUP_CT( z, 4, 4, false ); + begin_asm() - - + + mov(var(a), rax) // load address of a. mov(var(b), rbx) // load address of b. //mov(var(b_next), r15) // load address of b_next. //mov(var(a_next), r14) // load address of a_next. - + vmovapd(mem(rax, 0*32), ymm0) // initialize loop by pre-loading vmovddup(mem(rbx, 0+0*32), ymm2) vmovddup(mem(rbx, 0+1*32), ymm3) - + mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c lea(mem(, rdi, 8), rdi) // cs_c *= sizeof(dcomplex) lea(mem(, rdi, 2), rdi) lea(mem(rcx, rdi, 2), r10) // load address of c + 2*cs_c; - + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*cs_c prefetch(0, mem(r10, 3*8)) // prefetch c + 2*cs_c prefetch(0, mem(r10, rdi, 1, 3*8)) // prefetch c + 3*cs_c - + vxorpd(ymm8, ymm8, ymm8) vxorpd(ymm9, ymm9, ymm9) vxorpd(ymm10, ymm10, ymm10) @@ -2717,18 +1756,18 @@ void bli_zgemm_sandybridge_asm_4x4 vxorpd(ymm13, ymm13, ymm13) vxorpd(ymm14, ymm14, ymm14) vxorpd(ymm15, ymm15, ymm15) - - - + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.ZCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.ZLOOPKITER) // MAIN LOOP - - + + // iteration 0 vmovapd(mem(rax, 1*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -2737,7 +1776,7 @@ void bli_zgemm_sandybridge_asm_4x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm15, ymm15) vaddpd(ymm7, ymm11, ymm11) - + prefetch(0, mem(rax, 16*32)) vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 8+0*32), ymm2) @@ -2745,45 +1784,45 @@ void bli_zgemm_sandybridge_asm_4x4 vmovddup(mem(rbx, 8+1*32), ymm3) vaddpd(ymm6, ymm14, ymm14) vaddpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vpermilpd(imm(0x5), ymm0, ymm0) vaddpd(ymm6, ymm13, ymm13) vaddpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulpd(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm12, ymm12) vaddpd(ymm7, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm1, ymm1) vmulpd(ymm0, ymm2, ymm6) vmulpd(ymm0, ymm3, ymm7) vaddsubpd(ymm6, ymm15, ymm15) vaddsubpd(ymm7, ymm11, ymm11) - + vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+2*32), ymm2) vmulpd(ymm1, ymm3, ymm7) vmovddup(mem(rbx, 0+3*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 2*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + // iteration 1 vmovapd(mem(rax, 3*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -2792,7 +1831,7 @@ void bli_zgemm_sandybridge_asm_4x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm15, ymm15) vaddpd(ymm7, ymm11, ymm11) - + prefetch(0, mem(rax, 18*32)) vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 8+2*32), ymm2) @@ -2800,45 +1839,45 @@ void bli_zgemm_sandybridge_asm_4x4 vmovddup(mem(rbx, 8+3*32), ymm3) vaddpd(ymm6, ymm14, ymm14) vaddpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vpermilpd(imm(0x5), ymm0, ymm0) vaddpd(ymm6, ymm13, ymm13) vaddpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulpd(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm12, ymm12) vaddpd(ymm7, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm1, ymm1) vmulpd(ymm0, ymm2, ymm6) vmulpd(ymm0, ymm3, ymm7) vaddsubpd(ymm6, ymm15, ymm15) vaddsubpd(ymm7, ymm11, ymm11) - + vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+4*32), ymm2) vmulpd(ymm1, ymm3, ymm7) vmovddup(mem(rbx, 0+5*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 4*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + // iteration 2 vmovapd(mem(rax, 5*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -2847,7 +1886,7 @@ void bli_zgemm_sandybridge_asm_4x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm15, ymm15) vaddpd(ymm7, ymm11, ymm11) - + prefetch(0, mem(rax, 20*32)) vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 8+4*32), ymm2) @@ -2855,45 +1894,45 @@ void bli_zgemm_sandybridge_asm_4x4 vmovddup(mem(rbx, 8+5*32), ymm3) vaddpd(ymm6, ymm14, ymm14) vaddpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vpermilpd(imm(0x5), ymm0, ymm0) vaddpd(ymm6, ymm13, ymm13) vaddpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulpd(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm12, ymm12) vaddpd(ymm7, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm1, ymm1) vmulpd(ymm0, ymm2, ymm6) vmulpd(ymm0, ymm3, ymm7) vaddsubpd(ymm6, ymm15, ymm15) vaddsubpd(ymm7, ymm11, ymm11) - + vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+6*32), ymm2) vmulpd(ymm1, ymm3, ymm7) vmovddup(mem(rbx, 0+7*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 6*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + // iteration 3 vmovapd(mem(rax, 7*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -2902,7 +1941,7 @@ void bli_zgemm_sandybridge_asm_4x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm15, ymm15) vaddpd(ymm7, ymm11, ymm11) - + prefetch(0, mem(rax, 22*32)) vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 8+6*32), ymm2) @@ -2910,67 +1949,67 @@ void bli_zgemm_sandybridge_asm_4x4 vmovddup(mem(rbx, 8+7*32), ymm3) vaddpd(ymm6, ymm14, ymm14) vaddpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vpermilpd(imm(0x5), ymm0, ymm0) vaddpd(ymm6, ymm13, ymm13) vaddpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulpd(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm12, ymm12) vaddpd(ymm7, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm1, ymm1) vmulpd(ymm0, ymm2, ymm6) vmulpd(ymm0, ymm3, ymm7) vaddsubpd(ymm6, ymm15, ymm15) vaddsubpd(ymm7, ymm11, ymm11) - + vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+8*32), ymm2) vmulpd(ymm1, ymm3, ymm7) vmovddup(mem(rbx, 0+9*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 8*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + add(imm(4*4*16), rbx) // b += 4*4 (unroll x nr) add(imm(4*4*16), rax) // a += 4*4 (unroll x mr) - - + + dec(rsi) // i -= 1; jne(.ZLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.ZCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.ZLOOPKLEFT) // EDGE LOOP - + // iteration 0 vmovapd(mem(rax, 1*32), ymm1) vmulpd(ymm0, ymm2, ymm6) @@ -2979,7 +2018,7 @@ void bli_zgemm_sandybridge_asm_4x4 vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm15, ymm15) vaddpd(ymm7, ymm11, ymm11) - + prefetch(0, mem(rax, 16*32)) vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 8+0*32), ymm2) @@ -2987,166 +2026,166 @@ void bli_zgemm_sandybridge_asm_4x4 vmovddup(mem(rbx, 8+1*32), ymm3) vaddpd(ymm6, ymm14, ymm14) vaddpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vpermilpd(imm(0x5), ymm0, ymm0) vaddpd(ymm6, ymm13, ymm13) vaddpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmulpd(ymm1, ymm5, ymm7) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vaddpd(ymm6, ymm12, ymm12) vaddpd(ymm7, ymm8, ymm8) - + vpermilpd(imm(0x5), ymm1, ymm1) vmulpd(ymm0, ymm2, ymm6) vmulpd(ymm0, ymm3, ymm7) vaddsubpd(ymm6, ymm15, ymm15) vaddsubpd(ymm7, ymm11, ymm11) - + vmulpd(ymm1, ymm2, ymm6) vmovddup(mem(rbx, 0+2*32), ymm2) vmulpd(ymm1, ymm3, ymm7) vmovddup(mem(rbx, 0+3*32), ymm3) vaddsubpd(ymm6, ymm14, ymm14) vaddsubpd(ymm7, ymm10, ymm10) - + vmulpd(ymm0, ymm4, ymm6) vmulpd(ymm0, ymm5, ymm7) vmovapd(mem(rax, 2*32), ymm0) vaddsubpd(ymm6, ymm13, ymm13) vaddsubpd(ymm7, ymm9, ymm9) - + vmulpd(ymm1, ymm4, ymm6) vmulpd(ymm1, ymm5, ymm7) vaddsubpd(ymm6, ymm12, ymm12) vaddsubpd(ymm7, ymm8, ymm8) - - + + add(imm(4*1*16), rax) // a += 4 (1 x mr) add(imm(4*1*16), rbx) // b += 4 (1 x nr) - - + + dec(rsi) // i -= 1; jne(.ZLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.ZPOSTACCUM) - + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 + // ab10 ab11 ab12 ab13 // ab21 ab20 ab23 ab22 // ab31 ) ab30 ) ab33 ) ab32 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab50 ab51 ab52 ab53 + // ab50 ab51 ab52 ab53 // ab61 ab60 ab63 ab62 // ab71 ) ab70 ) ab73 ) ab72 ) - - + + vmovapd(ymm15, ymm7) vperm2f128(imm(0x12), ymm15, ymm13, ymm15) vperm2f128(imm(0x30), ymm7, ymm13, ymm13) - + vmovapd(ymm11, ymm7) vperm2f128(imm(0x12), ymm11, ymm9, ymm11) vperm2f128(imm(0x30), ymm7, ymm9, ymm9) - + vmovapd(ymm14, ymm7) vperm2f128(imm(0x12), ymm14, ymm12, ymm14) vperm2f128(imm(0x30), ymm7, ymm12, ymm12) - + vmovapd(ymm10, ymm7) vperm2f128(imm(0x12), ymm10, ymm8, ymm10) vperm2f128(imm(0x30), ymm7, ymm8, ymm8) - - + + // ymm15: ymm13: ymm11: ymm9: // ( ab00 ( ab01 ( ab02 ( ab03 - // ab10 ab11 ab12 ab13 + // ab10 ab11 ab12 ab13 // ab20 ab21 ab22 ab23 // ab30 ) ab31 ) ab32 ) ab33 ) - + // ymm14: ymm12: ymm10: ymm8: // ( ab40 ( ab41 ( ab42 ( ab43 - // ab50 ab51 ab52 ab53 + // ab50 ab51 ab52 ab53 // ab60 ab61 ab62 ab63 // ab70 ) ab71 ) ab72 ) ab73 ) - - + + // scale by alpha - + mov(var(alpha), rax) // load address of alpha vbroadcastsd(mem(rax), ymm7) // load alpha_r and duplicate vbroadcastsd(mem(rax, 8), ymm6) // load alpha_i and duplicate - + vpermilpd(imm(0x5), ymm15, ymm3) vmulpd(ymm7, ymm15, ymm15) vmulpd(ymm6, ymm3, ymm3) vaddsubpd(ymm3, ymm15, ymm15) - + vpermilpd(imm(0x5), ymm14, ymm2) vmulpd(ymm7, ymm14, ymm14) vmulpd(ymm6, ymm2, ymm2) vaddsubpd(ymm2, ymm14, ymm14) - + vpermilpd(imm(0x5), ymm13, ymm1) vmulpd(ymm7, ymm13, ymm13) vmulpd(ymm6, ymm1, ymm1) vaddsubpd(ymm1, ymm13, ymm13) - + vpermilpd(imm(0x5), ymm12, ymm0) vmulpd(ymm7, ymm12, ymm12) vmulpd(ymm6, ymm0, ymm0) vaddsubpd(ymm0, ymm12, ymm12) - + vpermilpd(imm(0x5), ymm11, ymm3) vmulpd(ymm7, ymm11, ymm11) vmulpd(ymm6, ymm3, ymm3) vaddsubpd(ymm3, ymm11, ymm11) - + vpermilpd(imm(0x5), ymm10, ymm2) vmulpd(ymm7, ymm10, ymm10) vmulpd(ymm6, ymm2, ymm2) vaddsubpd(ymm2, ymm10, ymm10) - + vpermilpd(imm(0x5), ymm9, ymm1) vmulpd(ymm7, ymm9, ymm9) vmulpd(ymm6, ymm1, ymm1) vaddsubpd(ymm1, ymm9, ymm9) - + vpermilpd(imm(0x5), ymm8, ymm0) vmulpd(ymm7, ymm8, ymm8) vmulpd(ymm6, ymm0, ymm0) vaddsubpd(ymm0, ymm8, ymm8) - - - - + + + + mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rbx), ymm7) // load beta_r and duplicate vbroadcastsd(mem(rbx, 8), ymm6) // load beta_i and duplicate - - - - - - - + + + + + + + mov(var(rs_c), rsi) // load rs_c lea(mem(, rsi, 8), rsi) // rsi = rs_c * sizeof(dcomplex) lea(mem(, rsi, 2), rsi) lea(mem(rcx, rsi, 2), rdx) // load address of c + 2*rs_c; - - + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm7) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -3154,338 +2193,125 @@ void bli_zgemm_sandybridge_asm_4x4 sete(r9b) // r9b = ( ZF == 1 ? 1 : 0 ); and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 0, jump to beta == 0 case - - - cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16. - jz(.ZCOLSTORED) // jump to column storage case - - - - label(.ZGENSTORED) - // update c00:c30 - - vmovupd(mem(rcx), xmm0) // load (c00,c10) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c20,c30) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c00,c10) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c20,c30) - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vmovupd(mem(rdx), xmm0) // load (c40,c50) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c60,c70) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c40,c50) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c60,c70) - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vmovupd(mem(rcx), xmm0) // load (c01,c11) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c21,c31) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c01,c11) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c21,c31) - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vmovupd(mem(rdx), xmm0) // load (c41,c51) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c61,c71) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c41,c51) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c61,c71) - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vmovupd(mem(rcx), xmm0) // load (c02,c12) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c22,c32) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c02,c12) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c22,c32) - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vmovupd(mem(rdx), xmm0) // load (c42,c52) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c62,c72) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c42,c52) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c62,c72) - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vmovupd(mem(rcx), xmm0) // load (c03,c13) into xmm0 - vmovupd(mem(rcx, rsi, 1), xmm2) // load (c23,c33) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rcx)) // store (c03,c13) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c23,c33) - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vmovupd(mem(rdx), xmm0) // load (c43,c53) into xmm0 - vmovupd(mem(rdx, rsi, 1), xmm2) // load (c63,c73) into xmm2 - vinsertf128(imm(1), xmm2, ymm0, ymm0) // ymm0 := (ymm0[0:1],xmm2) - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vextractf128(imm(1), ymm0, xmm2) // xmm2 := ymm0[2:3] - vmovupd(xmm0, mem(rdx)) // store (c43,c53) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c63,c73) - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORED) - // update c00:c30 - - vmovupd(mem(rcx), ymm0) // load c00:c30 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rcx)) // store c00:c30 - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vmovupd(mem(rdx), ymm0) // load c40:c70 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rdx)) // store c40:c70 - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vmovupd(mem(rcx), ymm0) // load c01:c31 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rcx)) // store c01:c31 - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vmovupd(mem(rdx), ymm0) // load c41:c71 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rdx)) // store c41:c71 - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vmovupd(mem(rcx), ymm0) // load c02:c32 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rcx)) // store c02:c32 - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vmovupd(mem(rdx), ymm0) // load c42:c72 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rdx)) // store c42:c72 - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vmovupd(mem(rcx), ymm0) // load c03:c33 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rcx)) // store c03:c33 - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vmovupd(mem(rdx), ymm0) // load c43:c73 into ymm0 - vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta - vmulpd(ymm7, ymm0, ymm0) - vmulpd(ymm6, ymm2, ymm2) - vaddsubpd(ymm2, ymm0, ymm0) - vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vmovupd(ymm0, mem(rdx)) // store c43:c73 - - - - jmp(.ZDONE) // jump to end. - - - + + // update c00:c30 + + vmovupd(mem(rcx), ymm0) // load c00:c30 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rcx)) // store c00:c30 + add(rdi, rcx) // c += cs_c; + + // update c40:c70 + + vmovupd(mem(rdx), ymm0) // load c40:c70 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rdx)) // store c40:c70 + add(rdi, rdx) // c += cs_c; + + // update c01:c31 + + vmovupd(mem(rcx), ymm0) // load c01:c31 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rcx)) // store c01:c31 + add(rdi, rcx) // c += cs_c; + + // update c41:c71 + + vmovupd(mem(rdx), ymm0) // load c41:c71 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rdx)) // store c41:c71 + add(rdi, rdx) // c += cs_c; + + // update c02:c32 + + vmovupd(mem(rcx), ymm0) // load c02:c32 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rcx)) // store c02:c32 + add(rdi, rcx) // c += cs_c; + + // update c42:c72 + + vmovupd(mem(rdx), ymm0) // load c42:c72 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rdx)) // store c42:c72 + add(rdi, rdx) // c += cs_c; + + // update c03:c33 + + vmovupd(mem(rcx), ymm0) // load c03:c33 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rcx)) // store c03:c33 + add(rdi, rcx) // c += cs_c; + + // update c43:c73 + + vmovupd(mem(rdx), ymm0) // load c43:c73 into ymm0 + vpermilpd(imm(0x5), ymm0, ymm2) // scale ymm0 by beta + vmulpd(ymm7, ymm0, ymm0) + vmulpd(ymm6, ymm2, ymm2) + vaddsubpd(ymm2, ymm0, ymm0) + vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 + vmovupd(ymm0, mem(rdx)) // store c43:c73 + + jmp(.ZDONE) // jump to end. + label(.ZBETAZERO) - - cmp(imm(16), rsi) // set ZF if (16*cs_c) == 16. - jz(.ZCOLSTORBZ) // jump to column storage case - - - - label(.ZGENSTORBZ) - // update c00:c30 - - vextractf128(imm(1), ymm15, xmm2) - vmovupd(xmm15, mem(rcx)) // store (c00,c10) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c20,c30) - add(rdi, rcx) // c += cs_c; - - // update c40:c70 - - vextractf128(imm(1), ymm14, xmm2) - vmovupd(xmm14, mem(rdx)) // store (c40,c50) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c60,c70) - add(rdi, rdx) // c += cs_c; - - // update c01:c31 - - vextractf128(imm(1), ymm13, xmm2) - vmovupd(xmm13, mem(rcx)) // store (c01,c11) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c21,c31) - add(rdi, rcx) // c += cs_c; - - // update c41:c71 - - vextractf128(imm(1), ymm12, xmm2) - vmovupd(xmm12, mem(rdx)) // store (c41,c51) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c61,c71) - add(rdi, rdx) // c += cs_c; - - // update c02:c32 - - vextractf128(imm(1), ymm11, xmm2) - vmovupd(xmm11, mem(rcx)) // store (c02,c12) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c22,c32) - add(rdi, rcx) // c += cs_c; - - // update c42:c72 - - vextractf128(imm(1), ymm10, xmm2) - vmovupd(xmm10, mem(rdx)) // store (c42,c52) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c62,c72) - add(rdi, rdx) // c += cs_c; - - // update c03:c33 - - vextractf128(imm(1), ymm9, xmm2) - vmovupd(xmm9, mem(rcx)) // store (c03,c13) - vmovupd(xmm2, mem(rcx, rsi, 1)) // store (c23,c33) - add(rdi, rcx) // c += cs_c; - - // update c43:c73 - - vextractf128(imm(1), ymm8, xmm2) - vmovupd(xmm8, mem(rdx)) // store (c43,c53) - vmovupd(xmm2, mem(rdx, rsi, 1)) // store (c63,c73) - - - - jmp(.ZDONE) // jump to end. - - - - label(.ZCOLSTORBZ) - - - vmovupd(ymm15, mem(rcx)) // store c00:c30 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm14, mem(rdx)) // store c40:c70 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm13, mem(rcx)) // store c01:c31 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm12, mem(rdx)) // store c41:c71 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm11, mem(rcx)) // store c02:c32 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm10, mem(rdx)) // store c42:c72 - add(rdi, rdx) // c += cs_c; - - vmovupd(ymm9, mem(rcx)) // store c03:c33 - add(rdi, rcx) // c += cs_c; - - vmovupd(ymm8, mem(rdx)) // store c43:c73 - - - - - + + vmovupd(ymm15, mem(rcx)) // store c00:c30 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm14, mem(rdx)) // store c40:c70 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm13, mem(rcx)) // store c01:c31 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm12, mem(rdx)) // store c41:c71 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm11, mem(rcx)) // store c02:c32 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm10, mem(rdx)) // store c42:c72 + add(rdi, rdx) // c += cs_c; + + vmovupd(ymm9, mem(rcx)) // store c03:c33 + add(rdi, rcx) // c += cs_c; + + vmovupd(ymm8, mem(rdx)) // store c43:c73 + label(.ZDONE) - + vzeroupper() - - vzeroupper() - + end_asm( : // output operands (none) @@ -3502,7 +2328,7 @@ void bli_zgemm_sandybridge_asm_4x4 [b_next] "m" (b_next), // 9 [a_next] "m" (a_next)*/ // 10 : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -3510,6 +2336,8 @@ void bli_zgemm_sandybridge_asm_4x4 "xmm12", "xmm13", "xmm14", "xmm15", "memory" ) + + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/sandybridge/3/bli_gemm_sandybridge_int_d8x4.c b/kernels/sandybridge/3/bli_gemm_sandybridge_int_d8x4.c index 6a1bb04f54..7d093afade 100644 --- a/kernels/sandybridge/3/bli_gemm_sandybridge_int_d8x4.c +++ b/kernels/sandybridge/3/bli_gemm_sandybridge_int_d8x4.c @@ -32,14 +32,18 @@ */ -#include +#include +#include +#include "bli_misc_macro_defs.h" #include "blis.h" #if 0 void bli_sgemm_sandybridge_int_8x8 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, float* restrict alpha, float* restrict a, float* restrict b, @@ -52,11 +56,11 @@ void bli_sgemm_sandybridge_int_8x8 } #endif - - void bli_dgemm_sandybridge_int_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, double* restrict alpha, double* restrict a, double* restrict b, @@ -66,19 +70,22 @@ void bli_dgemm_sandybridge_int_8x4 cntx_t* restrict cntx ) { + //void* a_next = bli_auxinfo_next_a( data ); void* b_next = bli_auxinfo_next_b( data ); // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_iter = k0 / 2; - uint64_t k_left = k0 % 2; + uint64_t k_iter = k / 2; + uint64_t k_left = k % 2; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; uint64_t i; - double *c00, *c01, *c02, *c03; - double *c40, *c41, *c42, *c43; + GEMM_UKR_SETUP_CT( d, 8, 4, false ); + + double *c00, *c01, *c02, *c03; + double *c40, *c41, *c42, *c43; // Quad registers. __m256d va0_3, va4_7; @@ -87,23 +94,20 @@ void bli_dgemm_sandybridge_int_8x4 __m256d vb; __m256d vB0; - __m256d va0_3b_0, va4_7b_0; - __m256d va0_3b_1, va4_7b_1; - __m256d va0_3b_2, va4_7b_2; - __m256d va0_3b_3, va4_7b_3; - - __m256d va0_3b0, va4_7b0; - __m256d va0_3b1, va4_7b1; - __m256d va0_3b2, va4_7b2; - __m256d va0_3b3, va4_7b3; + __m256d va0_3b_0, va4_7b_0; + __m256d va0_3b_1, va4_7b_1; + __m256d va0_3b_2, va4_7b_2; + __m256d va0_3b_3, va4_7b_3; + __m256d va0_3b0, va4_7b0; + __m256d va0_3b1, va4_7b1; + __m256d va0_3b2, va4_7b2; + __m256d va0_3b3, va4_7b3; - __m256d valpha, vbeta, vtmp; + __m256d valpha, vbeta, vtmp; __m256d vc0_3_0, vc0_3_1, vc0_3_2, vc0_3_3; __m256d vc4_7_0, vc4_7_1, vc4_7_2, vc4_7_3; - __m128d aa, bb; - __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"(a) ); __asm__ volatile( "prefetcht2 0(%0) \n\t" : :"r"(b_next) ); __asm__ volatile( "prefetcht0 0(%0) \n\t" : :"r"(c) ); @@ -133,7 +137,7 @@ void bli_dgemm_sandybridge_int_8x4 // Load va4_7 va4_7 = _mm256_load_pd( a + 4 ); - // Load vb (b0,b1,b2,b3) + // Load vb (b0,b1,b2,b3) vb0 = _mm256_load_pd( b ); for( i = 0; i < k_iter; ++i ) @@ -166,7 +170,7 @@ void bli_dgemm_sandybridge_int_8x4 vb2 = _mm256_permute2f128_pd( vb1, vb1, 0x1 ); // Load vb (b0,b1,b2,b3) (Prefetch) - vB0 = _mm256_load_pd( b + 4 ); + vB0 = _mm256_load_pd( b + 4 ); vtmp = _mm256_mul_pd( va0_3, vb2 ); va0_3b_2 = _mm256_add_pd( va0_3b_2, vtmp ); @@ -186,7 +190,7 @@ void bli_dgemm_sandybridge_int_8x4 // Iteration 1. __asm__ volatile( "prefetcht0 512(%0) \n\t" : :"r"(a) ); - + // Load va0_3 (Next iteration) va0_3 = _mm256_load_pd( a + 16 ); @@ -218,7 +222,7 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b_2 = _mm256_add_pd( va4_7b_2, vtmp ); // Load vb0(Next iteration) - vb0 = _mm256_load_pd( b + 8 ); + vb0 = _mm256_load_pd( b + 8 ); vtmp = _mm256_mul_pd( vA0_3, vb3 ); va0_3b_3 = _mm256_add_pd( va0_3b_3, vtmp ); @@ -240,7 +244,7 @@ void bli_dgemm_sandybridge_int_8x4 // Load va4_7 va4_7 = _mm256_load_pd( a + 4 ); - // Load vb (b0,b1,b2,b3) + // Load vb (b0,b1,b2,b3) vb = _mm256_load_pd( b ); vtmp = _mm256_mul_pd( va0_3, vb ); @@ -309,131 +313,73 @@ void bli_dgemm_sandybridge_int_8x4 va4_7b1 = _mm256_permute2f128_pd( vtmpa_4_7b_1, vtmpa_4_7b_3, 0x30 ); va4_7b2 = _mm256_permute2f128_pd( vtmpa_4_7b_3, vtmpa_4_7b_1, 0x30 ); - if( rs_c == 1 ) + __m128d vzero = _mm_setzero_pd( ); + + if( _mm_comieq_sd( _mm256_castpd256_pd128(vbeta), vzero ) ) { // Calculate address - c00 = ( c + 0*rs_c + 0*cs_c ); - // Load - //vc0_3_0 = _mm256_load_pd( c + 0*rs_c + 0*cs_c ); - vc0_3_0 = _mm256_load_pd( c00 ); + c00 = ( c + 0 + 0*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b0); - // Scale by beta - vc0_3_0 = _mm256_mul_pd( vbeta, vc0_3_0 ); - // Add gemm result - vc0_3_0 = _mm256_add_pd( vc0_3_0, vtmp ); // Store back to memory - _mm256_store_pd( c00, vc0_3_0 ); - + _mm256_store_pd( c00, vtmp ); + // Calculate address - c40 = ( c + 4*rs_c + 0*cs_c ); - // Load - //vc4_7_0 = _mm256_load_pd( c + 4*rs_c + 0*cs_c ); - vc4_7_0 = _mm256_load_pd( c40 ); + c40 = ( c + 4 + 0*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b0); - // Scale by beta - vc4_7_0 = _mm256_mul_pd( vbeta, vc4_7_0 ); - // Add gemm result - vc4_7_0 = _mm256_add_pd( vc4_7_0, vtmp ); // Store back to memory - _mm256_store_pd( c40, vc4_7_0 ); - + _mm256_store_pd( c40, vtmp ); + // Calculate address - c01 = ( c + 0*rs_c + 1*cs_c ); - // Load - //vc0_3_1 = _mm256_load_pd( c + 0*rs_c + 1*cs_c ); - vc0_3_1 = _mm256_load_pd( c01 ); + c01 = ( c + 0 + 1*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b1); - // Scale by beta - vc0_3_1 = _mm256_mul_pd( vbeta, vc0_3_1 ); - // Add gemm result - vc0_3_1 = _mm256_add_pd( vc0_3_1, vtmp ); // Store back to memory - _mm256_store_pd( c01, vc0_3_1 ); - + _mm256_store_pd( c01, vtmp ); + // Calculate address - c41 = ( c + 4*rs_c + 1*cs_c ); - // Load - //vc4_7_1 = _mm256_load_pd( c + 4*rs_c + 1*cs_c ); - vc4_7_1 = _mm256_load_pd( c41 ); + c41 = ( c + 4 + 1*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b1); - // Scale by beta - vc4_7_1 = _mm256_mul_pd( vbeta, vc4_7_1 ); - // Add gemm result - vc4_7_1 = _mm256_add_pd( vc4_7_1, vtmp ); // Store back to memory - _mm256_store_pd( c41, vc4_7_1 ); - + _mm256_store_pd( c41, vtmp ); + // Calculate address - c02 = ( c + 0*rs_c + 2*cs_c ); - // Load - //vc0_3_2 = _mm256_load_pd( c + 0*rs_c + 2*cs_c ); - vc0_3_2 = _mm256_load_pd( c02 ); + c02 = ( c + 0 + 2*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b2); - // Scale by beta - vc0_3_2 = _mm256_mul_pd( vbeta, vc0_3_2 ); - // Add gemm result - vc0_3_2 = _mm256_add_pd( vc0_3_2, vtmp ); // Store back to memory - _mm256_store_pd( c02, vc0_3_2 ); - + _mm256_store_pd( c02, vtmp ); + // Calculate address - c42 = ( c + 4*rs_c + 2*cs_c ); - // Load - //vc4_7_2 = _mm256_load_pd( c + 4*rs_c + 2*cs_c ); - vc4_7_2 = _mm256_load_pd( c42 ); + c42 = ( c + 4 + 2*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b2); - // Scale by beta - vc4_7_2 = _mm256_mul_pd( vbeta, vc4_7_2 ); - // Add gemm result - vc4_7_2 = _mm256_add_pd( vc4_7_2, vtmp ); // Store back to memory - _mm256_store_pd( c42, vc4_7_2 ); - + _mm256_store_pd( c42, vtmp ); + // Calculate address - c03 = ( c + 0*rs_c + 3*cs_c ); - // Load - //vc0_3_3 = _mm256_load_pd( c + 0*rs_c + 3*cs_c ); - vc0_3_3 = _mm256_load_pd( c03 ); + c03 = ( c + 0 + 3*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b3); - // Scale by beta - vc0_3_3 = _mm256_mul_pd( vbeta, vc0_3_3 ); - // Add gemm result - vc0_3_3 = _mm256_add_pd( vc0_3_3, vtmp ); // Store back to memory - _mm256_store_pd( c03, vc0_3_3 ); - + _mm256_store_pd( c03, vtmp ); + // Calculate address - c43 = ( c + 4*rs_c + 3*cs_c ); - // Load - //vc4_7_3 = _mm256_load_pd( c + 4*rs_c + 3*cs_c ); - vc4_7_3 = _mm256_load_pd( c43 ); + c43 = ( c + 4 + 3*cs_c ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b3); - // Scale by beta - vc4_7_3 = _mm256_mul_pd( vbeta, vc4_7_3 ); - // Add gemm result - vc4_7_3 = _mm256_add_pd( vc4_7_3, vtmp ); // Store back to memory - _mm256_store_pd( c43, vc4_7_3 ); - + _mm256_store_pd( c43, vtmp ); } else { // Calculate address - c00 = ( c + 0*rs_c + 0*cs_c ); + c00 = ( c + 0 + 0*cs_c ); // Load - //vc0_3_0 = _mm256_load_pd( c + 0*rs_c + 0*cs_c ); - vc0_3_0 = _mm256_set_pd( *(c + 3*rs_c + 0*cs_c ), - *(c + 2*rs_c + 0*cs_c ), - *(c + 1*rs_c + 0*cs_c ), - *(c + 0*rs_c + 0*cs_c ) ); + //vc0_3_0 = _mm256_load_pd( c + 0 + 0*cs_c ); + vc0_3_0 = _mm256_load_pd( c00 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b0); // Scale by beta @@ -441,24 +387,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc0_3_0 = _mm256_add_pd( vc0_3_0, vtmp ); // Store back to memory - //_mm256_store_pd( c00, vc0_3_0 ); - - aa = _mm256_extractf128_pd( vc0_3_0, 0 ) ; - bb = _mm256_extractf128_pd( vc0_3_0, 1 ) ; - - _mm_storel_pd( c + 0*rs_c + 0*cs_c, aa ); - _mm_storeh_pd( c + 1*rs_c + 0*cs_c, aa ); - _mm_storel_pd( c + 2*rs_c + 0*cs_c, bb ); - _mm_storeh_pd( c + 3*rs_c + 0*cs_c, bb ); + _mm256_store_pd( c00, vc0_3_0 ); // Calculate address - c40 = ( c + 4*rs_c + 0*cs_c ); + c40 = ( c + 4 + 0*cs_c ); // Load - //vc4_7_0 = _mm256_load_pd( c + 4*rs_c + 0*cs_c ); - vc4_7_0 = _mm256_set_pd( *(c + 7*rs_c + 0*cs_c ), - *(c + 6*rs_c + 0*cs_c ), - *(c + 5*rs_c + 0*cs_c ), - *(c + 4*rs_c + 0*cs_c ) ); + //vc4_7_0 = _mm256_load_pd( c + 4 + 0*cs_c ); + vc4_7_0 = _mm256_load_pd( c40 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b0); // Scale by beta @@ -466,24 +401,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc4_7_0 = _mm256_add_pd( vc4_7_0, vtmp ); // Store back to memory - //_mm256_store_pd( c40, vc4_7_0 ); - - aa = _mm256_extractf128_pd( vc4_7_0, 0 ) ; - bb = _mm256_extractf128_pd( vc4_7_0, 1 ) ; - - _mm_storel_pd( c + 4*rs_c + 0*cs_c, aa ); - _mm_storeh_pd( c + 5*rs_c + 0*cs_c, aa ); - _mm_storel_pd( c + 6*rs_c + 0*cs_c, bb ); - _mm_storeh_pd( c + 7*rs_c + 0*cs_c, bb ); + _mm256_store_pd( c40, vc4_7_0 ); // Calculate address - c01 = ( c + 0*rs_c + 1*cs_c ); + c01 = ( c + 0 + 1*cs_c ); // Load - //vc0_3_1 = _mm256_load_pd( c + 0*rs_c + 1*cs_c ); - vc0_3_1 = _mm256_set_pd( *(c + 3*rs_c + 1*cs_c ), - *(c + 2*rs_c + 1*cs_c ), - *(c + 1*rs_c + 1*cs_c ), - *(c + 0*rs_c + 1*cs_c ) ); + //vc0_3_1 = _mm256_load_pd( c + 0 + 1*cs_c ); + vc0_3_1 = _mm256_load_pd( c01 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b1); // Scale by beta @@ -491,24 +415,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc0_3_1 = _mm256_add_pd( vc0_3_1, vtmp ); // Store back to memory - //_mm256_store_pd( c01, vc0_3_1 ); - - aa = _mm256_extractf128_pd( vc0_3_1, 0 ) ; - bb = _mm256_extractf128_pd( vc0_3_1, 1 ) ; - - _mm_storel_pd( c + 0*rs_c + 1*cs_c, aa ); - _mm_storeh_pd( c + 1*rs_c + 1*cs_c, aa ); - _mm_storel_pd( c + 2*rs_c + 1*cs_c, bb ); - _mm_storeh_pd( c + 3*rs_c + 1*cs_c, bb ); + _mm256_store_pd( c01, vc0_3_1 ); // Calculate address - c41 = ( c + 4*rs_c + 1*cs_c ); + c41 = ( c + 4 + 1*cs_c ); // Load - //vc4_7_1 = _mm256_load_pd( c + 4*rs_c + 1*cs_c ); - vc4_7_1 = _mm256_set_pd( *(c + 7*rs_c + 1*cs_c ), - *(c + 6*rs_c + 1*cs_c ), - *(c + 5*rs_c + 1*cs_c ), - *(c + 4*rs_c + 1*cs_c ) ); + //vc4_7_1 = _mm256_load_pd( c + 4 + 1*cs_c ); + vc4_7_1 = _mm256_load_pd( c41 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b1); // Scale by beta @@ -516,24 +429,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc4_7_1 = _mm256_add_pd( vc4_7_1, vtmp ); // Store back to memory - //_mm256_store_pd( c41, vc4_7_1 ); - - aa = _mm256_extractf128_pd( vc4_7_1, 0 ) ; - bb = _mm256_extractf128_pd( vc4_7_1, 1 ) ; - - _mm_storel_pd( c + 4*rs_c + 1*cs_c, aa ); - _mm_storeh_pd( c + 5*rs_c + 1*cs_c, aa ); - _mm_storel_pd( c + 6*rs_c + 1*cs_c, bb ); - _mm_storeh_pd( c + 7*rs_c + 1*cs_c, bb ); + _mm256_store_pd( c41, vc4_7_1 ); // Calculate address - c02 = ( c + 0*rs_c + 2*cs_c ); + c02 = ( c + 0 + 2*cs_c ); // Load - //vc0_3_2 = _mm256_load_pd( c + 0*rs_c + 2*cs_c ); - vc0_3_2 = _mm256_set_pd( *(c + 3*rs_c + 2*cs_c ), - *(c + 2*rs_c + 2*cs_c ), - *(c + 1*rs_c + 2*cs_c ), - *(c + 0*rs_c + 2*cs_c ) ); + //vc0_3_2 = _mm256_load_pd( c + 0 + 2*cs_c ); + vc0_3_2 = _mm256_load_pd( c02 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b2); // Scale by beta @@ -541,24 +443,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc0_3_2 = _mm256_add_pd( vc0_3_2, vtmp ); // Store back to memory - //_mm256_store_pd( c02, vc0_3_2 ); - - aa = _mm256_extractf128_pd( vc0_3_2, 0 ) ; - bb = _mm256_extractf128_pd( vc0_3_2, 1 ) ; - - _mm_storel_pd( c + 0*rs_c + 2*cs_c, aa ); - _mm_storeh_pd( c + 1*rs_c + 2*cs_c, aa ); - _mm_storel_pd( c + 2*rs_c + 2*cs_c, bb ); - _mm_storeh_pd( c + 3*rs_c + 2*cs_c, bb ); + _mm256_store_pd( c02, vc0_3_2 ); // Calculate address - c42 = ( c + 4*rs_c + 2*cs_c ); + c42 = ( c + 4 + 2*cs_c ); // Load - //vc4_7_2 = _mm256_load_pd( c + 4*rs_c + 2*cs_c ); - vc4_7_2 = _mm256_set_pd( *(c + 7*rs_c + 2*cs_c ), - *(c + 6*rs_c + 2*cs_c ), - *(c + 5*rs_c + 2*cs_c ), - *(c + 4*rs_c + 2*cs_c ) ); + //vc4_7_2 = _mm256_load_pd( c + 4 + 2*cs_c ); + vc4_7_2 = _mm256_load_pd( c42 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b2); // Scale by beta @@ -566,24 +457,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc4_7_2 = _mm256_add_pd( vc4_7_2, vtmp ); // Store back to memory - //_mm256_store_pd( c42, vc4_7_2 ); - - aa = _mm256_extractf128_pd( vc4_7_2, 0 ) ; - bb = _mm256_extractf128_pd( vc4_7_2, 1 ) ; - - _mm_storel_pd( c + 4*rs_c + 2*cs_c, aa ); - _mm_storeh_pd( c + 5*rs_c + 2*cs_c, aa ); - _mm_storel_pd( c + 6*rs_c + 2*cs_c, bb ); - _mm_storeh_pd( c + 7*rs_c + 2*cs_c, bb ); + _mm256_store_pd( c42, vc4_7_2 ); // Calculate address - c03 = ( c + 0*rs_c + 3*cs_c ); + c03 = ( c + 0 + 3*cs_c ); // Load - //vc0_3_3 = _mm256_load_pd( c + 0*rs_c + 3*cs_c ); - vc0_3_3 = _mm256_set_pd( *(c + 3*rs_c + 3*cs_c ), - *(c + 2*rs_c + 3*cs_c ), - *(c + 1*rs_c + 3*cs_c ), - *(c + 0*rs_c + 3*cs_c ) ); + //vc0_3_3 = _mm256_load_pd( c + 0 + 3*cs_c ); + vc0_3_3 = _mm256_load_pd( c03 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va0_3b3); // Scale by beta @@ -591,24 +471,13 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc0_3_3 = _mm256_add_pd( vc0_3_3, vtmp ); // Store back to memory - //_mm256_store_pd( c03, vc0_3_3 ); - - aa = _mm256_extractf128_pd( vc0_3_3, 0 ) ; - bb = _mm256_extractf128_pd( vc0_3_3, 1 ) ; - - _mm_storel_pd( c + 0*rs_c + 3*cs_c, aa ); - _mm_storeh_pd( c + 1*rs_c + 3*cs_c, aa ); - _mm_storel_pd( c + 2*rs_c + 3*cs_c, bb ); - _mm_storeh_pd( c + 3*rs_c + 3*cs_c, bb ); + _mm256_store_pd( c03, vc0_3_3 ); // Calculate address - c43 = ( c + 4*rs_c + 3*cs_c ); + c43 = ( c + 4 + 3*cs_c ); // Load - //vc4_7_3 = _mm256_load_pd( c + 4*rs_c + 3*cs_c ); - vc4_7_3 = _mm256_set_pd( *(c + 7*rs_c + 3*cs_c ), - *(c + 6*rs_c + 3*cs_c ), - *(c + 5*rs_c + 3*cs_c ), - *(c + 4*rs_c + 3*cs_c ) ); + //vc4_7_3 = _mm256_load_pd( c + 4 + 3*cs_c ); + vc4_7_3 = _mm256_load_pd( c43 ); // Scale by alpha vtmp = _mm256_mul_pd( valpha, va4_7b3); // Scale by beta @@ -616,17 +485,10 @@ void bli_dgemm_sandybridge_int_8x4 // Add gemm result vc4_7_3 = _mm256_add_pd( vc4_7_3, vtmp ); // Store back to memory - //_mm256_store_pd( c43, vc4_7_3 ); - - aa = _mm256_extractf128_pd( vc4_7_3, 0 ) ; - bb = _mm256_extractf128_pd( vc4_7_3, 1 ) ; - - _mm_storel_pd( c + 4*rs_c + 3*cs_c, aa ); - _mm_storeh_pd( c + 5*rs_c + 3*cs_c, aa ); - _mm_storel_pd( c + 6*rs_c + 3*cs_c, bb ); - _mm_storeh_pd( c + 7*rs_c + 3*cs_c, bb ); + _mm256_store_pd( c43, vc4_7_3 ); } + GEMM_UKR_FLUSH_CT( d ); } @@ -634,7 +496,9 @@ void bli_dgemm_sandybridge_int_8x4 #if 0 void bli_cgemm_sandybridge_int_8x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -652,7 +516,9 @@ void bli_cgemm_sandybridge_int_8x4 #if 0 void bli_zgemm_sandybridge_int_4x4 ( - dim_t k0, + dim_t m, + dim_t n, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, diff --git a/ref_kernels/3/bb/bli_gemmbb_ref.c b/ref_kernels/3/bb/bli_gemmbb_ref.c index b45718d454..4c75c064ce 100644 --- a/ref_kernels/3/bb/bli_gemmbb_ref.c +++ b/ref_kernels/3/bb/bli_gemmbb_ref.c @@ -42,6 +42,8 @@ \ void PASTEMAC3(ch,opname,arch,suf) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -59,9 +61,6 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \ const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ \ const inc_t cs_a = packmr; \ \ diff --git a/ref_kernels/3/bb/bli_gemmtrsmbb_ref.c b/ref_kernels/3/bb/bli_gemmtrsmbb_ref.c index 681b740b52..158f2da3f7 100644 --- a/ref_kernels/3/bb/bli_gemmtrsmbb_ref.c +++ b/ref_kernels/3/bb/bli_gemmtrsmbb_ref.c @@ -87,6 +87,8 @@ PASTEMAC(d,fprintm)( stdout, "gemmtrsm_ukr: b11", mr, 2*nr, \ /* upper: b11 = alpha * b11 - a12 * b21; */ \ gemm_ukr \ ( \ + mr, \ + nr, \ k, \ minus_one, \ a1x, \ diff --git a/ref_kernels/3/bli_gemm_ref.c b/ref_kernels/3/bli_gemm_ref.c index 931fe994b3..51ff9df4bd 100644 --- a/ref_kernels/3/bli_gemm_ref.c +++ b/ref_kernels/3/bli_gemm_ref.c @@ -44,6 +44,8 @@ \ void PASTEMAC3(ch,opname,arch,suf) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -107,8 +109,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ if ( PASTEMAC(ch,eq0)( *beta ) ) \ { \ - for ( dim_t i = 0; i < mr; ++i ) \ - for ( dim_t j = 0; j < nr; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ PASTEMAC(ch,copys) \ ( \ ab[ i*rs_ab + j*cs_ab ], \ @@ -117,8 +119,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ } \ else \ { \ - for ( dim_t i = 0; i < mr; ++i ) \ - for ( dim_t j = 0; j < nr; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ PASTEMAC(ch,xpbys) \ ( \ ab[ i*rs_ab + j*cs_ab ], \ @@ -133,8 +135,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ if ( PASTEMAC(ch,eq0)( *beta ) ) \ { \ - for ( dim_t j = 0; j < nr; ++j ) \ - for ( dim_t i = 0; i < mr; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ PASTEMAC(ch,copys) \ ( \ ab[ i*rs_ab + j*cs_ab ], \ @@ -143,8 +145,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ } \ else \ { \ - for ( dim_t j = 0; j < nr; ++j ) \ - for ( dim_t i = 0; i < mr; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ PASTEMAC(ch,xpbys) \ ( \ ab[ i*rs_ab + j*cs_ab ], \ @@ -171,6 +173,8 @@ GENTFUNC( dcomplex, z, gemm, BLIS_CNAME_INFIX, BLIS_REF_SUFFIX, 4, 4 ) \ void PASTEMAC3(ch,opname,arch,suf) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -188,9 +192,6 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ const inc_t packmr = bli_cntx_get_blksz_max_dt( dt, BLIS_MR, cntx ); \ const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ -\ - const dim_t m = mr; \ - const dim_t n = nr; \ \ const inc_t cs_a = packmr; \ \ diff --git a/ref_kernels/3/bli_gemmtrsm_ref.c b/ref_kernels/3/bli_gemmtrsm_ref.c index 2b756963e4..14a53d0868 100644 --- a/ref_kernels/3/bli_gemmtrsm_ref.c +++ b/ref_kernels/3/bli_gemmtrsm_ref.c @@ -52,6 +52,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ { \ const num_t dt = PASTEMAC(ch,type); \ \ + const inc_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const inc_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ const inc_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ \ const inc_t rs_b = packnr; \ @@ -68,6 +70,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ /* upper: b11 = alpha * b11 - a12 * b21; */ \ gemm_ukr \ ( \ + mr, \ + nr, \ k, \ minus_one, \ a1x, \ diff --git a/ref_kernels/ind/bli_gemm1m_ref.c b/ref_kernels/ind/bli_gemm1m_ref.c index 6d2464de94..cd532bcd08 100644 --- a/ref_kernels/ind/bli_gemm1m_ref.c +++ b/ref_kernels/ind/bli_gemm1m_ref.c @@ -39,6 +39,8 @@ \ void PASTEMAC3(ch,opname,arch,suf) \ ( \ + dim_t m, \ + dim_t n, \ dim_t k, \ ctype* restrict alpha, \ ctype* restrict a, \ @@ -59,6 +61,9 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ +\ + const dim_t mr_r = bli_cntx_get_blksz_def_dt( dt_r, BLIS_MR, cntx ); \ + const dim_t nr_r = bli_cntx_get_blksz_def_dt( dt_r, BLIS_NR, cntx ); \ \ const dim_t k2 = 2 * k; \ \ @@ -118,6 +123,11 @@ void PASTEMAC3(ch,opname,arch,suf) \ else if ( bli_is_gen_stored( rs_c, cs_c ) ) using_ct = TRUE; \ else using_ct = FALSE; \ \ +\ + /* If we are not packing a full micro-tile, then we must write to + ct and then accumulate to c afterwards. */ \ + if ( mr != m || nr != n ) using_ct = TRUE; \ +\ \ if ( using_ct ) \ { \ @@ -149,6 +159,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ /* c = beta * c + alpha_r * a * b; */ \ rgemm_ukr \ ( \ + mr_r, \ + nr_r, \ k2, \ alpha_r, \ a_r, \ @@ -164,8 +176,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ /* Accumulate the final result in ct back to c. */ \ if ( PASTEMAC(ch,eq1)( *beta ) ) \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( j = 0; j < n; ++j ) \ + for ( i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,adds)( *(ct + i*rs_ct + j*cs_ct), \ *(c + i*rs_c + j*cs_c ) ); \ @@ -173,8 +185,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ } \ else if ( PASTEMAC(ch,eq0)( *beta ) ) \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( j = 0; j < n; ++j ) \ + for ( i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,copys)( *(ct + i*rs_ct + j*cs_ct), \ *(c + i*rs_c + j*cs_c ) ); \ @@ -182,8 +194,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ } \ else \ { \ - for ( j = 0; j < nr; ++j ) \ - for ( i = 0; i < mr; ++i ) \ + for ( j = 0; j < n; ++j ) \ + for ( i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,xpbys)( *(ct + i*rs_ct + j*cs_ct), \ *beta, \ @@ -215,6 +227,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ /* c = beta * c + alpha_r * a * b; */ \ rgemm_ukr \ ( \ + mr_r, \ + nr_r, \ k2, \ alpha_r, \ a_r, \ diff --git a/ref_kernels/ind/bli_gemmtrsm1m_ref.c b/ref_kernels/ind/bli_gemmtrsm1m_ref.c index 5cfaee9ec6..e35a1bf13d 100644 --- a/ref_kernels/ind/bli_gemmtrsm1m_ref.c +++ b/ref_kernels/ind/bli_gemmtrsm1m_ref.c @@ -78,7 +78,7 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ const dim_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ \ - const pack_t schema_b = bli_auxinfo_schema_b( data ); \ + const pack_t schema_b = bli_cntx_schema_b_panel( cntx ); \ \ const dim_t k2 = 2 * k; \ \ @@ -153,6 +153,8 @@ void PASTEMAC3(ch,opname,arch,suf) \ upper: bt = -1.0 * a12 * b21; */ \ rgemm_ukr \ ( \ + mr_r, \ + nr_r, \ k2, \ minus_one_r, \ a1x_r, \ diff --git a/test/obj_t_makeover/complex_math.hpp b/test/obj_t_makeover/complex_math.hpp new file mode 100644 index 0000000000..9c68e730aa --- /dev/null +++ b/test/obj_t_makeover/complex_math.hpp @@ -0,0 +1,267 @@ +#include +#include +#include + +#include "blis.h" + +template +struct is_complex : std::false_type {}; + +template <> +struct is_complex : std::true_type {}; + +template <> +struct is_complex : std::true_type {}; + +template +struct is_real : std::integral_constant::value> {}; + +template struct make_complex; + +template <> struct make_complex { using type = scomplex; }; +template <> struct make_complex { using type = dcomplex; }; +template <> struct make_complex { using type = scomplex; }; +template <> struct make_complex { using type = dcomplex; }; + +template +using make_complex_t = typename make_complex::type; + +template struct make_real; + +template <> struct make_real { using type = float; }; +template <> struct make_real { using type = double; }; +template <> struct make_real { using type = float; }; +template <> struct make_real { using type = double; }; + +template +using make_real_t = typename make_real::type; + +template +struct make_complex_if : std::conditional,make_real_t> {}; + +template +using make_complex_if_t = typename make_complex_if::type; + +template +struct real_imag_part +{ + real_imag_part& operator=(T) { return *this; } + + operator T() const { return T(); } +}; + +template +std::enable_if_t::type>::value,T&> real(T& x) { return x; } + +template +std::enable_if_t::value,real_imag_part> imag(T x) { return {}; } + +inline float& real(scomplex& x) { return x.real; } + +inline float& imag(scomplex& x) { return x.imag; } + +inline double& real(dcomplex& x) { return x.real; } + +inline double& imag(dcomplex& x) { return x.imag; } + +inline const float& real(const scomplex& x) { return x.real; } + +inline const float& imag(const scomplex& x) { return x.imag; } + +inline const double& real(const dcomplex& x) { return x.real; } + +inline const double& imag(const dcomplex& x) { return x.imag; } + +template +std::enable_if_t::value,T> conj(T x) { return x; } + +template +std::enable_if_t::value,T> conj(const T& x) { return {x.real, -x.imag}; } + +template +struct convert_impl; + +template +struct convert_impl::value && is_real::value>> +{ + void operator()(T x, U& y) const { y = x; } +}; + +template +struct convert_impl::value && is_complex::value>> +{ + void operator()(T x, U& y) const { y.real = x; y.imag = 0; } +}; + +template +struct convert_impl::value && is_real::value>> +{ + void operator()(T x, U& y) const { y = x.real; } +}; + +template +struct convert_impl::value && is_complex::value>> +{ + void operator()(T x, U& y) const { y.real = x.real; y.imag = x.imag; } +}; + +template +U convert(T x) +{ + U y; + convert_impl{}(x,y); + return y; +} + +template +auto convert_prec(T x) -> make_complex_if_t::value> +{ + return convert::value>>(x); +} + +#define COMPLEX_MATH_OPS(rtype, ctype) \ +\ +inline bool operator==(rtype x, ctype y) \ +{ \ + return x == y.real && y.imag == 0; \ +} \ +\ +inline bool operator==(ctype x, rtype y) \ +{ \ + return y == x.real && x.imag == 0; \ +} \ +\ +inline bool operator==(ctype x, ctype y) \ +{ \ + return x.real == y.real && \ + x.imag == y.imag; \ + } \ + \ +inline ctype operator-(ctype x) \ +{ \ + return {-x.real, -x.imag}; \ +} \ +\ +inline ctype operator+(rtype x, ctype y) \ +{ \ + return {x+y.real, y.imag}; \ +} \ +\ +inline ctype operator+(ctype x, rtype y) \ +{ \ + return {y+x.real, x.imag}; \ +} \ +\ +inline ctype operator+(ctype x, ctype y) \ +{ \ + return {x.real+y.real, x.imag+y.imag}; \ +} \ +\ +inline ctype operator-(rtype x, ctype y) \ +{ \ + return {x-y.real, -y.imag}; \ +} \ +\ +inline ctype operator-(ctype x, rtype y) \ +{ \ + return {x.real-y, x.imag}; \ +} \ +\ +inline ctype operator-(ctype x, ctype y) \ +{ \ + return {x.real-y.real, x.imag-y.imag}; \ +} \ +\ +inline ctype operator*(rtype x, ctype y) \ +{ \ + return {x*y.real, x*y.imag}; \ +} \ +\ +inline ctype operator*(ctype x, rtype y) \ +{ \ + return {y*x.real, y*x.imag}; \ +} \ +\ +inline ctype operator*(ctype x, ctype y) \ +{ \ + return {x.real*y.real - x.imag*y.imag, \ + x.real*y.imag + x.imag*y.real}; \ +} \ +\ +inline ctype operator/(rtype x, ctype y) \ +{ \ + auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \ + auto n = std::ilogb(scale); \ + auto yrs = std::scalbn(y.real, -n); \ + auto yis = std::scalbn(y.imag, -n); \ + auto denom = y.real*yrs + y.imag*yis; \ + return {x*yrs/denom, -x*yis/denom}; \ +} \ +\ +inline ctype operator/(ctype x, rtype y) \ +{ \ + return {x.real/y, x.imag/y}; \ +} \ +\ +inline ctype operator/(ctype x, ctype y) \ +{ \ + auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \ + auto n = std::ilogb(scale); \ + auto yrs = std::scalbn(y.real, -n); \ + auto yis = std::scalbn(y.imag, -n); \ + auto denom = y.real*yrs + y.imag*yis; \ + return {(x.real*yrs + x.imag*yis)/denom, \ + (x.imag*yrs - x.real*yis)/denom}; \ +} \ +\ +inline ctype& operator+=(ctype& x, rtype y) \ +{ \ + x.real += y; \ + return x; \ +} \ +\ +inline ctype& operator+=(ctype& x, ctype y) \ +{ \ + x.real += y.real; x.imag += y.imag; \ + return x; \ +} \ +\ +inline ctype& operator-=(ctype& x, rtype y) \ +{ \ + x.real -= y; \ + return x; \ +} \ +\ +inline ctype& operator-=(ctype& x, ctype y) \ +{ \ + x.real -= y.real; x.imag -= y.imag; \ + return x; \ +} \ +\ +inline ctype& operator*=(ctype& x, rtype y) \ +{ \ + x.real *= y; x.imag *= y; \ + return x; \ +} \ +\ +inline ctype& operator*=(ctype& x, ctype y) \ +{ \ + x = x * y; \ + return x; \ +} \ +\ +inline ctype& operator/=(ctype& x, rtype y) \ +{ \ + x.real /= y; x.imag /= y; \ + return x; \ +} \ +\ +inline ctype& operator/=(ctype& x, ctype y) \ +{ \ + x = x / y; \ + return x; \ +} + +COMPLEX_MATH_OPS(float, scomplex); +COMPLEX_MATH_OPS(double, dcomplex); + diff --git a/test/obj_t_makeover/syrk_diagonal_example.c b/test/obj_t_makeover/syrk_diagonal_example.c new file mode 100644 index 0000000000..7c604c8612 --- /dev/null +++ b/test/obj_t_makeover/syrk_diagonal_example.c @@ -0,0 +1,186 @@ +#include "syrk_diagonal_ref.h" + +/* + * Structure which includes all additional information beyond what is + * already stored in the obj_t structure. + * + * This structure is **read-only** during the operation! + */ +typedef struct packm_diag_params_t +{ + packm_blk_var1_params_t super; + void* d; + inc_t incd; +} packm_diag_params_t; + +/* + * Declare the pack kernel type and set up and array of + * packing kernels, one for each data type. + */ +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +void PASTEMAC(ch,op) \ + ( \ + struc_t struca, \ + diag_t diaga, \ + uplo_t uploa, \ + conj_t conja, \ + pack_t schema, \ + bool invdiag, \ + dim_t panel_dim, \ + dim_t panel_len, \ + dim_t panel_dim_max, \ + dim_t panel_len_max, \ + dim_t panel_dim_off, \ + dim_t panel_len_off, \ + void* restrict kappa, \ + void* restrict a, inc_t inca, inc_t lda, \ + void* restrict p, inc_t ldp, \ + inc_t is_p, \ + cntx_t* cntx, \ + void* params \ + ) \ +{ \ + packm_diag_params_t* params_cast = params; \ + ctype* restrict a_cast = a; \ + ctype* restrict p_cast = p; \ + ctype* restrict d_cast = params_cast->d; \ + inc_t incd = params_cast->incd; \ + ctype kappa_cast = *( ctype* )kappa; \ +\ + if ( schema != BLIS_PACKED_ROW_PANELS && \ + schema != BLIS_PACKED_COL_PANELS ) \ + bli_abort(); \ +\ + /* Apply the offset */ \ + d_cast += panel_len_off * incd; \ +\ + if ( conja ) \ + { \ + for ( dim_t j = 0; j < panel_len; j++ ) \ + { \ + ctype kappa_d; \ + PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ +\ + for (dim_t i = 0;i < panel_dim;i++) \ + PASTEMAC(ch,scal2js)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ +\ + for (dim_t i = panel_dim;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + } \ + } \ + else \ + { \ + for ( dim_t j = 0; j < panel_len; j++ ) \ + { \ + ctype kappa_d; \ + PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ +\ + for (dim_t i = 0;i < panel_dim;i++) \ + PASTEMAC(ch,scal2s)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ +\ + for (dim_t i = panel_dim;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + } \ + } \ +\ + for (dim_t j = panel_len;j < panel_len_max;j++) \ + for (dim_t i = 0;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ +} + +INSERT_GENTFUNC_BASIC0(packm_diag_ukr); + +static packm_ker_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr ); + +/* + * Modify the object A to include information about the diagonal D, + * and imbue it with special function pointers which will take care + * of the actual work of forming (D * A^T) + */ +void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a ) +{ + memset( params, 0, sizeof(*params) ); + + // Assumes D is a column vector + params->d = bli_obj_buffer_at_off( d ); + params->incd = bli_obj_row_stride( d ); + + for ( int i = BLIS_DT_LO; i <= BLIS_DT_HI; i++ ) + params->super.ukr_fn[i][i] = packm_diag_ukrs[i]; + + // Attach the parameters to the A object. + bli_obj_set_pack_params( params, a ); +} + +/* + * Implements C := alpha * A * D * A^T + beta * C + * + * where D is a diagonal matrix with elements taken from the "d" vector. + */ +void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ) +{ + obj_t ad; // this is (D * A^T) + packm_diag_params_t params; + + bli_obj_alias_to( a, &ad ); + bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T + attach_diagonal_factor( ¶ms, d, &ad ); + + // Does C := alpha * A * B + beta * C using B = (D + A^T) + bli_gemmtnat( alpha, a, &ad, beta, c, NULL, NULL ); +} + +int main() +{ + obj_t a; + obj_t d; + obj_t c; + obj_t c_copy; + obj_t norm; + + dim_t m = 10; + dim_t k = 10; + + for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) + for ( int upper = 0; upper <= 1; upper++ ) + for ( int transa = 0; transa <= 1; transa++ ) + for ( int transc = 0; transc <= 1; transc++ ) + { + num_t dt = dt_; + uplo_t uplo = upper ? BLIS_UPPER : BLIS_LOWER; + + bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a ); + bli_obj_create( dt, k, 1, 1, 1, &d ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy ); + bli_obj_set_uplo( uplo , &c ); + bli_obj_set_uplo( uplo , &c_copy ); + bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); + + bli_randm( &a ); + bli_randm( &d ); + bli_randm( &c ); + bli_copym( &c, &c_copy ); + + syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c ); + syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy ); + + bli_subm( &c_copy, &c ); + bli_normfm( &c, &norm ); + + double normr, normi; + bli_getsc( &norm, &normr, &normi ); + + printf("dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n", + dt, upper, transa, transc, normr); + + bli_obj_free( &a ); + bli_obj_free( &d ); + bli_obj_free( &c ); + bli_obj_free( &c_copy ); + bli_obj_free( &norm ); + } +} diff --git a/test/obj_t_makeover/syrk_diagonal_example.cxx b/test/obj_t_makeover/syrk_diagonal_example.cxx new file mode 100644 index 0000000000..1c269d5c48 --- /dev/null +++ b/test/obj_t_makeover/syrk_diagonal_example.cxx @@ -0,0 +1,220 @@ +#include "syrk_diagonal_ref.h" + +/* + * Forward-declare the pack kernel type and set up and array of + * packing kernels, one for each data type. + */ +template +void packm_diag_ukr + ( + struc_t /*struca*/, + diag_t /*diaga*/, + uplo_t /*uploa*/, + conj_t conja, + pack_t schema, + bool /*invdiag*/, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + dim_t /*panel_dim_off*/, + dim_t panel_len_off, + void* restrict kappa, + void* restrict a, inc_t inca, inc_t lda, + void* restrict p, inc_t ldp, + inc_t /*is_p*/, + cntx_t* /*cntx*/, + void* params + ); + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &packm_diag_ukr; + +INSERT_GENTFUNC_BASIC0(packm_diag_ukr); + +static packm_ker_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr ); + +/* + * Structure which includes all additional information beyond what is + * already stored in the obj_t structure. + * + * This structure is **read-only** during the operation! + */ +struct packm_diag_params_t : packm_blk_var1_params_t +{ + void* d; + inc_t incd; + + packm_diag_params_t() {} + + packm_diag_params_t( void* d, inc_t incd ) + : d(d), incd(incd) + { + for ( int i = BLIS_DT_LO; i <= BLIS_DT_HI; i++ ) + ukr_fn[i][i] = packm_diag_ukrs[i]; + } +}; + +/* + * Selecting a different kernel based on the current architecture is + * currently not possible, but is something we plan to support. + */ +template +void packm_diag_ukr + ( + struc_t /*struca*/, + diag_t /*diaga*/, + uplo_t /*uploa*/, + conj_t conja, + pack_t schema, + bool /*invdiag*/, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + dim_t /*panel_dim_off*/, + dim_t panel_len_off, + void* restrict kappa, + void* restrict a, inc_t inca, inc_t lda, + void* restrict p, inc_t ldp, + inc_t /*is_p*/, + cntx_t* /*cntx*/, + void* params + ) +{ + auto params_cast = ( packm_diag_params_t* )params; + T* restrict a_cast = ( T* )a; + T* restrict p_cast = ( T* )p; + T* restrict d_cast = ( T* )params_cast->d; + auto incd = params_cast->incd; + auto kappa_cast = *( T* )kappa; + + if ( schema != BLIS_PACKED_ROW_PANELS && + schema != BLIS_PACKED_COL_PANELS ) + bli_abort(); + + /* Apply the offset */ + d_cast += panel_len_off * incd; + + if ( conja ) + { + for ( dim_t j = 0; j < panel_len; j++ ) + { + auto kappa_d = kappa_cast * d_cast[ j*incd ]; + + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i + j*ldp ] = kappa_d * conj( a_cast[ i*inca + j*lda ] ); + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); + } + } + else + { + for ( dim_t j = 0; j < panel_len; j++ ) + { + auto kappa_d = kappa_cast * d_cast[ j*incd ]; + + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i + j*ldp ] = kappa_d * a_cast[ i*inca + j*lda ]; + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); + } + } + + for (dim_t j = panel_len;j < panel_len_max;j++) + for (dim_t i = 0;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); +} + +/* + * Modify the object A to include information about the diagonal D, + * and imbue it with special function pointers which will take care + * of the actual work of forming (D * A^T) + */ +void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a ) +{ + // Assumes D is a column vector + new (params) packm_diag_params_t + ( + bli_obj_buffer_at_off( d ), + bli_obj_row_stride( d ) + ); + + // Attach the parameters to the A object. + bli_obj_set_pack_params( params, a ); +} + +/* + * Implements C := alpha * A * D * A^T + beta * C + * + * where D is a diagonal matrix with elements taken from the "d" vector. + */ +void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ) +{ + obj_t ad; // this is (D * A^T) + packm_diag_params_t params; + + bli_obj_alias_to( a, &ad ); + bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T + attach_diagonal_factor( ¶ms, d, &ad ); + + // Does C := alpha * A * B + beta * C using B = (D + A^T) + bli_gemmtnat( alpha, a, &ad, beta, c, NULL, NULL ); +} + +int main() +{ + obj_t a; + obj_t d; + obj_t c; + obj_t c_copy; + obj_t norm; + + auto m = 10; + auto k = 10; + + for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) + for ( int upper = 0; upper <= 1; upper++ ) + for ( int transa = 0; transa <= 1; transa++ ) + for ( int transc = 0; transc <= 1; transc++ ) + { + auto dt = ( num_t )dt_; + auto uplo = upper ? BLIS_UPPER : BLIS_LOWER; + + bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a ); + bli_obj_create( dt, k, 1, 1, 1, &d ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy ); + bli_obj_set_uplo( uplo , &c ); + bli_obj_set_uplo( uplo , &c_copy ); + bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); + + bli_randm( &a ); + bli_randm( &d ); + bli_randm( &c ); + bli_copym( &c, &c_copy ); + + syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c ); + syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy ); + + bli_subm( &c_copy, &c ); + bli_normfm( &c, &norm ); + + double normr, normi; + bli_getsc( &norm, &normr, &normi ); + + printf("dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n", + dt, upper, transa, transc, normr); + + bli_obj_free( &a ); + bli_obj_free( &d ); + bli_obj_free( &c ); + bli_obj_free( &c_copy ); + bli_obj_free( &norm ); + } +} diff --git a/test/obj_t_makeover/syrk_diagonal_example.tgz b/test/obj_t_makeover/syrk_diagonal_example.tgz new file mode 100644 index 0000000000000000000000000000000000000000..79ac14959103c7a43631814e16d4d9d23dd21efa GIT binary patch literal 6904 zcmV9#Hqc# zommw{LN;THAMO5$ z4eynThG(8;P6DKOGbhW=`BmBg2Vm3b3Ou4QO#l8)o^Soo4dxza57UAqBz(Ga{;<;- z^g67^K4i(<&E^SbSLeQa&SasKIYGc2&+{`MOF{ZHN`R5@#U^|ig&q`eGUlg<#R(Gb zcuSUMQNld{X25V3O+Jk?NRu_}03-sRp8IK@3K&?ghsFC*7%Y1|=FJm7JQEmVo=~0N zK*2#LTgKeuQ(2hBj(ahqhK*ws)sjJR)ZjT#PTBWTbOhItPUN1acYq9@Hg1 zpw*BI|HAi89o~J#uNT}69H1JhLpXN9laL1tNn;LhCFL2L$MotXiL+$U6w4tjwn~7f zHj6^efTfH(z;(~b973@}?fiBgVl#g_d2{;X$2ZSA-@=y<@G8M_5!QCk?I`ZDpF7wV z_U`z@>B;M3$W+PLGE4r#LtKMqnxBrH=Tn1qHhr4e^Eil%Z2UB{-6;HIWaFoqjqNwG zQ}>+D%=}3d1uTF3@Dl4UP&YHm13q9Q^wDe`NWjAlu#9KUVgpQ@fMH@Zoo<4y2qq*l z;m|QYNcV*k$4+xLWNE3@?}4|XL~JwjQ6*xdMD>w%q+=7z4YS2K>oMcAQmxTk96~Z1 z#qwId*9^0sgdV$2nz2LqwJ!^h)Ef<=#}Tmf95QDmY+QhiSHgNl*qp5Xk*5j@=2SW6 zsrOPrQZQzv<+9Z0vf@M5vlMAv%D|$s;wcPlF^1SThwO(}FF%am9Y6o?lNaOnZ~qFf zZ%$r)U|U;e&(?!KfBQ;-ca6~)_+w`hC7IRDD;@OMaU2Irp_!*s)O5XpEF{^reOMH4 z(0r=V2IZt2bbZ$bm8cOdIbc2k3x@mfbCBbN&+ToK(n9H+uY!@9+3GQMWa>J>ZrZgB zrrWX^e`lY1)Q!K`3KKIJQ6N^(>JzQLCKXcm{q5~;&CsgQ@F^$4v!y)$o%u+C!?uAG z7}ZEgta4;x%zoK`DuW)AI@DwvQOL6)($`qIDs&^?)h|qPyzSr;4PQzBiZ_FhPs|wi z`-KiALOvAO*m(qzB#!Bo$j3(j1yOc6&kCCm;O z*ih~&N-i)7D@PO5dI2EE|2XZEtx2{cERetq&DRK1C>kZ!swvtQ<~^9u`VCaNw7+M3 z20dhH(4oyz|5qMOt)58g64wxPFkr`NI-kLk;{~n-9GD$~`7C6Y9BGi20+xY^7lZ|; zCaiBVpH88xoD8QK&?tz|gr(t~gD26|7^ioSOW0h5p(9mX-;mBA1XFsX>bLhpR{Q%OEX0Tet5e1+BI3-L?187V2m8UCKlJFwpH9hr%(MN2a;&}Z5L<$Duhg$5*e!zfw+6&S#DvS|uLu8y}we;TAr|ZWJ%gln@k(Y>!ap zGoW-d+&5A#WJ+nYdD*x$-Y}&!;vh+nxxDB+h7l(W2k{D}M)JJ|och!e@bX<05Fbq? zT%&|*mT+~6Ypo2zIX2b91q;x;uhZp<8Gj&4Q2V`T;>UOIPTsR;;^WoZzn;98JTk@7 zO}LYxz5=@i*0~q9YQguw*Jo^Ge2?IJh612Oy9o~<85P%ob>CawIy^J_Bl+_^^n6Z_ zYtZA~7akJ6I%KFFX)D7L%>3}T*RM}b-@km$nyO!2Vr50xm6qtxRoF{ljVJ{p#-qi^ zqBM}6IEhdnjUgvM4zVs^z=5V!he7}>@@7^|m#d6QAF7NB-AEs06sXl8hF0Gqj=g0}<<{{Wc{5eSR~Y3*^GU!-QzbV2_{HL9JWD|k zsH8-Y$X~yJDNaB+Os$7r_Kf`*Xs6*UUd1Rh-d*vINZ{@4=kTFjHo{T%356=Iii%=M znGYx~(}eT74ma{@$F0n(6<3wlD5DZUaTc>~U^4lwcemST-ILb&e|GQJ0kqoxv-@=T zcu)C%_8#w{|L5uO(XjRZe2(Xj?Eks7b4OvORSYgt%@rK(2MC+bYH(KwZ^v3!#ac(j zMsA9YoD?ga3JsZ_;ZpFLfExV{xAZy)cfi-+SF4qUFWYf-J6BUtEWGV_R$z_ zpA}x8*o12I_jo2$ql3oq&8&6Jv<{Be!SM&`;IFQKzh(!AY1Q1)%R%#Do-T;jHj&)I zWh}P^bW*I7bnE~{4XD&4%x08|FD7iDU;~|+ohN9GVg|f|;V3e2Q}bKUzN1d(!C=7t z%E3-O150$WWVp6fCgs%UftS+m?M~-G?GABlZqxq6vl_>ZHY`#=qWAyglY1RJ+@3{j z8fFY@k|GQwobjv<6vvj67@#7-02LyJp#HC1W0l`x`v_X{faate;WxNX$^P{(MUME^ z&JEM~FL-&f$Px!8%1bBlqj{=$sC@x2-DL?FL0Q5%GBy3x_kga+Y#wA|c%5m|s7Hka zN$YzuLja4qnA%h+9j+@K8l_d7k&JdQJEU5gfCZ$1LfKG)01$OB8FLg1+%m$@vh-KX zq0SI|w2cPaB;K{?$hxgX2WsKA=)${D7v|`4lVyWwl3ZD>k0ogqrVqGA%2k z;Y1XCItB6YFUvF@Gb@a)l!UzNiuGsUpM`QhnJLMw;Rv%}hVVh>21~~OId=lW1Mmq< zOligw7#)#&fx0^F?udO6550m{Fy(`Mu)xo<0Om12Lt0W`7OiawO~5aC(BI9sNPu}0 z3Mp{1?|}9TSRzyN=Ew)>pBu-e2@33&v+Udw-7lfS3e=VFm4N^ij7#NVCytP5lSN}} zzl8vHL4_77SY0;E%Z93~&>haDQXM3XMrZk<>r~} zaiJa{c?o?K!vc|qnHfrvgAyJo#9B&~Brhro4D+y!siin)D>j&8?-U*Dqz0G>R(MF7 z$50Q&=26oy76fR`Ak`6c33&u5$JR`EcE!2aks~djT>6Mw;JhHQ)&PW~3bZ}i!NG(M zaia@JHq;37&^$uMMuY<-Ke{4g*c`~rf(b`!&wtsqi18t8K#dWy%im1D5L z^7%7@OZ>C*jB)RbtF932hr#rJMumjOb-;ED(uhT_J5P+t=*7el;1b`c&X=f*nypwP zc0?Fu&{bFG5e!+S0}$%`Ecqz(+i2`CFdWV_X)KsHj(ll~? zC@2Psp37z&B>~26vIM>L&pD2++z^7?2r?!~1hXW~ZZI^!rovr&tvpPF9^~5NXgav( zPC{!rBQlz%35#;N$X;Py#RoV~z z%!hdc=1nHHgguO-n)x9ln7t+%O9*JHrSSqywS=<+oxhC3;37??XOIZ~Nv?d*a}Oth z6J9(NEEQ^z*eV?yHEaSC3_BKdRsyp=D!ND;d7vW_F8X6f3dJ`$)38xu`Dx3D`Z~UW zW&PkmE}%(f&!I;EA;$nCm+{Nd0aPyS~^Bev}x}!R#VE1xGWaI?ZDb1ZWNS? zC}}7jDkv2(s8Xkp#0GqA6##(yVe?d$X$mD&fvyEM8c~=!3`6nHEAvl>t(tzw;sKEY zoN!Tg9tgjw@;HFra)SN5o6IvJx@m%bvA<)nEu*BK*lKqF@`z`B-liD2it6DPX_x%Z zTxei!*dszKrP6#9y1B41G*1#?LSx={V4W&n=eu~MIE4)ZodaaZ5s|63iNdfFxN?@3 zIf0nqneP?BG3prY2Zv4n9fNw550a8P0$v=bt~ee#j7Lc z6P}U@fFn3V!JqnP^Mv;^u;fhB&`0|X%G>w`bSTLushlWq@W7kG9 zLpH)G@fv1D0h;f!`%)mFvrFb6&Io-d-Ycz7^#+$1Yu;GwAd6w9PB^i$IwD6)FSqvO zNOcOO9aY-L{cNLN@{$;(^XlUQg1CUdd{Eox2%vtlEOM6G1S8(|JrWR3v@Hkl4^ zs6xgHvS6_ok#skm2RC_`etSkl#^~Jc;0Ie_klgL;2QGx=W~n<*Gnlcc;}i!Q&R|SK zv&FVyLbQiF+QS|9dbp$MWJi0Nqdm>hp5|yzbF`;9+S42jdF^SA_NhJpvEwcl_gKf8 z`@i<~hU)ysaCi5~BRv1{Xf$fie|(N-v-2N6M#+_vczy44#NTHeWDxg+4lpf*1rj?w zDn}LK&Lb%%9avbpxN6W7S6B50#!-mE)m4KYU0!7XlrFCt^yvC31EBW$DwSD(d6f=v zD#rg0dh+tB0)1c3iI5V|pjWuIY8}z)+A0NfoAV|2b#!FK_KtvoF2FspK{e=9wHXtP z{Msw4D3d%Tw>&l?!)={8to+NTM-Fk-Xo~s)6+YVkE$%D%T^;^vKJ>+4S>x#`q3;!vvFYkQk2tFI!G=gAY)8(Da0}O2grs5_L zZk(YnL#p?QgK>0GSnzW=iF@v6T=ftlky*i#qFmxAO?|M$MZ_6)Xf)&5IXNX@`Usqm zLae|dj*_fZ{jKW%eX4$`JV=ihvX#5EVe`AI_9?N%Y5k!!Dh+vErwSI=fW=iXuMVtu zPho}rb(a%X3RG?+Jk>|u)&`HSjs+LJQFo=F*l}Eap`cN@JZx91X4js>((u;>$s!i0 zgHwkl7kZTwR>s<1SG@X6RYi#FPWc?_`g=^olsBEMD(r03kg)Cq)rRB%5tls44o;7( z;m)w&;L87{h^?=>t36$${Y42xV#Y4SK@(cRnPsX z9x0X$?vTyAA)7fv6x3>OMzb@c^(VCcgw~(HzHooS7k?LBdkI~830->$-F;j_*WM}D z-YM7KDc9a9*WM}D-YM7KDc9a9*WM}D-YIwIJLQ_MyE3Z$gWP(hw#MW|R|YHU!rC-v zZg~xs$IndsO)!u~F-h|cOgw@`JReIW+u>QEQ;OzxjZEX3M8(aR{ zU69oVlefX-Z7_KoOx^~Qx54CJ>QjySz2Cc^@L3)IJ$yVGJw9C;4n$Y z`rXC%5@3BOqs%0Ou$1N+m|X+2s~0~jB1H?F_&^?s^i+behGDo@UXUUNN7^s8{rn|nY0QP78LPk$7kjhoCW^NVqQ zqnPJ|X10rIq_V5Z`QP7zF)o6xgi;bfeob~U^kwW!GoGjo6>YobcW4)4U}~TZ2N?~P z74F3x(k(((yov>!ZUauYg41olS%Tc?%6+{Zk09r`DWw~2jAL2@T{$uOP;f&^OnzI` zzIB+~b(ma%$<>&u@+EwIwHvBpmFwpFsI*5>zv~LRSuAsL-6J-+D%-DBAG0mzdUKXv zv7K9(@K+H)yK3v|cV+#qT`-71ZT!~8Z2R2$sj>R+y^b~Ozv1p^n6LkaBV7LtA3bi@ zf1l&Y7kDt*Q@Zl^TYrj^|9I@i!92yk9ln6cCu9$wKUDuA+`Uev`>pEut)B9-{y$O2 z+WWsBKN9Evcb`1nd$b4R|MA{#JN`e*qsRY&iz0B8@Aw=z!C92}+4;;&MB5$BC_hV& zI-MHrQNJxB4&6rv;w_6*Ue_m1y?1D4XJ}TmT6K7%UPVUbRNDi-fHSn^K!w`gg?QlO=wk3<0z-xwEU8l+GY_6z20 z)Fs&GEBXVon>&LtR<@qXH+KVNtVht;1Dg`i$orgt^_)jkHEgFE*6`dG2k4(ip{EC2 zp*LYACkGykH+9Ft0m9qGUWb}&gBti;G5QMC*2s`QZ5AWQJUF$f=xT9N?ID6FHXWl9 zv|ZMu9t#gr&gDVm`)7|N&E6%?l_r0RL*yE#1NC43fLLbES*Cn~*WpKcvd3#`*EwlL^j5UqD7KxW}0Qh;E)x08&;$WQB6&S$~B2jHrpjq zHf>#@Jk`9qe0}|e_R6GN5m0FRY6O~s4TVU~&W7ZG9;fq=r+1gky188A74(+`|5`NN zW_{6wo5t$wT&sM{-w)}|8l^GhCN^z1HEOUNO>H^QQCPuxhOwd&_v++-$FGl1|1y62?t_((M~0De z1zj>fE8R;bjsv~RgKVso>1qXbRc~KaTT|P0zpC*|Z73q!-h}_xi|wUE zyW9}#27JF>ZYyFprF8>-%PCJ2EKqN6Mt6TRZuAxM{-$U*VTT~Rzuc4^{adi3Uu4II zboVx6M^7W~ZHjjz@#rb>z_(lCUCR+kH0w*2NYd@%;A9$@2I+PVjLGZh8LZ?VX>mN~ zM|%@E;mw?dMK8`IwL}$NITwl)th?LYHxY!YupdNc6N~T_A(lz1U@&(QT3UR=5~N>+ z6ig{YdOVEKg}jvPfhxA8Zs;v>ZCDYJBppy8)JzB>zSeVa^N};qBWJL@1>Qe)&sYl+ zZlU42s<(>6!BP+a-q(XegEz1}-$&S%bLLh-l>#b`Un%=B;d&Gmy8Za;$Ft|el y1{*bC(15tpsaVRi4`MUD`vHb^1x9RD*m{4Wby2oY`?Swj_xwK>uoF1|$N&I`%W-S~ literal 0 HcmV?d00001 diff --git a/test/obj_t_makeover/syrk_diagonal_example2.c b/test/obj_t_makeover/syrk_diagonal_example2.c new file mode 100644 index 0000000000..4937f45138 --- /dev/null +++ b/test/obj_t_makeover/syrk_diagonal_example2.c @@ -0,0 +1,351 @@ +#include "syrk_diagonal_ref.h" + +/* + * Structure which includes all additional information beyond what is + * already stored in the obj_t structure. + * + * This structure is **read-only** during the operation! + */ +typedef struct packm_diag_params_t +{ + void* d; + inc_t incd; +} packm_diag_params_t; + +typedef void (*packm_diag_ukr_vft) + ( + bool conja, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + void* restrict kappa, + void* restrict d, inc_t incd, + void* restrict a, inc_t inca, inc_t lda, + void* restrict p, inc_t ldp + ); + +/* + * Declare the pack kernel type and set up and array of + * packing kernels, one for each data type. + */ +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +void PASTEMAC(ch,op) \ + ( \ + bool conja, \ + dim_t panel_dim, \ + dim_t panel_len, \ + dim_t panel_dim_max, \ + dim_t panel_len_max, \ + void* restrict kappa, \ + void* restrict d, inc_t incd, \ + void* restrict a, inc_t inca, inc_t lda, \ + void* restrict p, inc_t ldp \ + ) \ +{ \ + ctype* restrict a_cast = a; \ + ctype* restrict p_cast = p; \ + ctype* restrict d_cast = d; \ + ctype kappa_cast = *( ctype* )kappa; \ +\ + if ( conja ) \ + { \ + for ( dim_t j = 0; j < panel_len; j++ ) \ + { \ + ctype kappa_d; \ + PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ +\ + for (dim_t i = 0;i < panel_dim;i++) \ + PASTEMAC(ch,scal2js)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ +\ + for (dim_t i = panel_dim;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + } \ + } \ + else \ + { \ + for ( dim_t j = 0; j < panel_len; j++ ) \ + { \ + ctype kappa_d; \ + PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ +\ + for (dim_t i = 0;i < panel_dim;i++) \ + PASTEMAC(ch,scal2s)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ +\ + for (dim_t i = panel_dim;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + } \ + } \ +\ + for (dim_t j = panel_len;j < panel_len_max;j++) \ + for (dim_t i = 0;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ +} + +INSERT_GENTFUNC_BASIC0(packm_diag_ukr); + +static packm_diag_ukr_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr ); + +void packm_diag + ( + obj_t* a, + obj_t* p, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ +#if 1 + + // We begin by copying the fields of A. + bli_obj_alias_to( a, p ); + + // Get information about data types. + num_t dt = bli_obj_dt( a ); + num_t dt_tar = bli_obj_target_dt( a ); + num_t dt_scalar = bli_obj_scalar_dt( a ); + dim_t dt_size = bli_dt_size( dt ); + + if ( dt_scalar != dt || dt_tar != dt ) + bli_abort(); + + // Extract various fields from the control tree. + bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl ); + bszid_t bmult_id_n = bli_cntl_packm_params_bmid_n( cntl ); + pack_t schema = bli_cntl_packm_params_pack_schema( cntl ); + dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx ); + dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx ); + dim_t bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx ); + + if ( schema != BLIS_PACKED_ROW_PANELS && + schema != BLIS_PACKED_COL_PANELS ) + bli_abort(); + + // Store the pack schema to the object. + bli_obj_set_pack_schema( schema, p ); + + // Clear the conjugation field from the object since matrix packing + // in BLIS is deemed to take care of all conjugation necessary. + bli_obj_set_conj( BLIS_NO_CONJUGATE, p ); + + // If we are packing micropanels, mark P as dense. + bli_obj_set_uplo( BLIS_DENSE, p ); + + // Reset the view offsets to (0,0). + bli_obj_set_offs( 0, 0, p ); + + // Compute the dimensions padded by the dimension multiples. These + // dimensions will be the dimensions of the packed matrices, including + // zero-padding, and will be used by the macro- and micro-kernels. + // We compute them by starting with the effective dimensions of A (now + // in P) and aligning them to the dimension multiples (typically equal + // to register blocksizes). This does waste a little bit of space for + // level-2 operations, but that's okay with us. + dim_t m_p = bli_obj_length( p ); + dim_t n_p = bli_obj_width( p ); + dim_t m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def ); + dim_t n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def ); + + // Save the padded dimensions into the packed object. It is important + // to save these dimensions since they represent the actual dimensions + // of the zero-padded matrix. + bli_obj_set_padded_dims( m_p_pad, n_p_pad, p ); + + // The "panel stride" of a micropanel packed object is interpreted as + // the distance between the (0,0) element of panel k and the (0,0) + // element of panel k+1. We use the padded width computed above to + // allow for zero-padding (if necessary/desired) along the far end + // of each micropanel (ie: the right edge of the matrix). Zero-padding + // can also occur along the long edge of the last micropanel if the m + // dimension of the matrix is not a whole multiple of MR. + inc_t ps_p = bmult_m_pack * n_p_pad; + + /* Compute the total number of iterations we'll need. */ + dim_t n_iter = m_p_pad / bmult_m_def; + + // Store the strides and panel dimension in P. + bli_obj_set_strides( 1, bmult_m_pack, p ); + bli_obj_set_imag_stride( 1, p ); + bli_obj_set_panel_dim( bmult_m_def, p ); + bli_obj_set_panel_stride( ps_p, p ); + bli_obj_set_panel_length( bmult_m_def, p ); + bli_obj_set_panel_width( n_p, p ); + + // Compute the size of the packed buffer. + siz_t size_p = ps_p * n_iter * dt_size; + if ( size_p == 0 ) return; + + // Update the buffer address in p to point to the buffer associated + // with the mem_t entry acquired from the memory broker (now cached in + // the control tree node). + char* p_cast = (char*)bli_packm_alloc( size_p, rntm, cntl, thread ); + bli_obj_set_buffer( p_cast, p ); + +#else + + // Every thread initializes p and determines the size of memory + // block needed (which gets embedded into the otherwise "blank" mem_t + // entry in the control tree node). Return early if no packing is required. + if ( !bli_packm_init( a, p, cntx, rntm, cntl, thread ) ) + return; + + num_t dt = bli_obj_dt( a ); + dim_t dt_size = bli_dt_size( dt ); + + bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl ); + dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt, bmult_id_m, cntx ); + dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt, bmult_id_m, cntx ); + + dim_t m_p = bli_obj_length( p ); + dim_t n_p = bli_obj_width( p ); + dim_t m_p_pad = bli_obj_padded_length( p ); + dim_t n_p_pad = bli_obj_padded_width( p ); + dim_t n_iter = m_p_pad / bmult_m_def; + + char* p_cast = bli_obj_buffer( p ); + inc_t ps_p = bli_obj_panel_stride( p ); + +#endif + + char* a_cast = bli_obj_buffer_at_off( a ); + inc_t inca = bli_obj_row_stride( a ); + inc_t lda = bli_obj_col_stride( a ); + dim_t panel_len_off = bli_obj_col_off( a ); + conj_t conja = bli_obj_conj_status( a ); + + packm_diag_params_t* params = bli_obj_pack_params( a ); + char* d_cast = params->d; + inc_t incd = params->incd; + + obj_t kappa_local; + char* kappa_cast = bli_packm_scalar( &kappa_local, p ); + + packm_diag_ukr_vft packm_ker_cast = packm_diag_ukrs[ dt ]; + + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ + const dim_t nt = bli_thread_n_way( thread ); + const dim_t tid = bli_thread_work_id( thread ); + + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ + dim_t it_start, it_end, it_inc; + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); + + /* Iterate over every logical micropanel in the source matrix. */ + for ( dim_t it = 0; it < n_iter; it += 1 ) + { + dim_t panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def ); + + char* d_begin = d_cast + panel_len_off*incd*dt_size; + char* a_begin = a_cast + it* bmult_m_def*inca*dt_size; + char* p_begin = p_cast + it* ps_p*dt_size; + + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) + { + packm_ker_cast( conja, + panel_dim_i, + n_p, + bmult_m_def, + n_p_pad, + kappa_cast, + d_begin, incd, + a_begin, inca, lda, + p_begin, bmult_m_pack ); + } + } +} + +/* + * Modify the object A to include information about the diagonal D, + * and imbue it with special function pointers which will take care + * of the actual work of forming (D * A^T) + */ +void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a ) +{ + // Assumes D is a column vector + params->d = bli_obj_buffer_at_off( d ); + params->incd = bli_obj_row_stride( d ); + + // Set the custom pack function. + bli_obj_set_pack_fn( packm_diag, a ); + + // Attach the parameters to the A object. + bli_obj_set_pack_params( params, a ); +} + +/* + * Implements C := alpha * A * D * A^T + beta * C + * + * where D is a diagonal matrix with elements taken from the "d" vector. + */ +void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ) +{ + obj_t ad; // this is (D * A^T) + packm_diag_params_t params; + + bli_obj_alias_to( a, &ad ); + bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T + attach_diagonal_factor( ¶ms, d, &ad ); + + // Does C := alpha * A * B + beta * C using B = (D + A^T) + bli_gemmt( alpha, a, &ad, beta, c ); +} + +int main() +{ + obj_t a; + obj_t d; + obj_t c; + obj_t c_copy; + obj_t norm; + + dim_t m = 10; + dim_t k = 10; + + for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) + for ( int upper = 0; upper <= 1; upper++ ) + for ( int transa = 0; transa <= 1; transa++ ) + for ( int transc = 0; transc <= 1; transc++ ) + { + num_t dt = dt_; + uplo_t uplo = upper ? BLIS_UPPER : BLIS_LOWER; + + bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a ); + bli_obj_create( dt, k, 1, 1, 1, &d ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy ); + bli_obj_set_uplo( uplo , &c ); + bli_obj_set_uplo( uplo , &c_copy ); + bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); + + bli_randm( &a ); + bli_randm( &d ); + bli_randm( &c ); + bli_copym( &c, &c_copy ); + + syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c ); + syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy ); + + bli_subm( &c_copy, &c ); + bli_normfm( &c, &norm ); + + double normr, normi; + bli_getsc( &norm, &normr, &normi ); + + printf("dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n", + dt, upper, transa, transc, normr); + + bli_obj_free( &a ); + bli_obj_free( &d ); + bli_obj_free( &c ); + bli_obj_free( &c_copy ); + bli_obj_free( &norm ); + } +} diff --git a/test/obj_t_makeover/syrk_diagonal_example2.cxx b/test/obj_t_makeover/syrk_diagonal_example2.cxx new file mode 100644 index 0000000000..8312a07ee8 --- /dev/null +++ b/test/obj_t_makeover/syrk_diagonal_example2.cxx @@ -0,0 +1,338 @@ +#include "syrk_diagonal_ref.h" + +/* + * Forward-declare the pack kernel type and set up and array of + * packing kernels, one for each data type. + */ +template +void packm_diag_ukr + ( + bool conja, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + void* restrict kappa, + void* restrict d, inc_t incd, + void* restrict a, inc_t inca, inc_t lda, + void* restrict p, inc_t ldp + ); + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &packm_diag_ukr; + +INSERT_GENTFUNC_BASIC0(packm_diag_ukr); + +using packm_diag_ukr_vft = decltype(&packm_diag_ukr); +static packm_diag_ukr_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr ); + +/* + * Structure which includes all additional information beyond what is + * already stored in the obj_t structure. + * + * This structure is **read-only** during the operation! + */ +struct packm_diag_params_t +{ + void* d; + inc_t incd; + + packm_diag_params_t() {} + + packm_diag_params_t( void* d, inc_t incd ) + : d(d), incd(incd) {} +}; + +/* + * Selecting a different kernel based on the current architecture is + * currently not possible, but is something we plan to support. + */ +template +void packm_diag_ukr + ( + bool conja, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + void* restrict kappa, + void* restrict d, inc_t incd, + void* restrict a, inc_t inca, inc_t lda, + void* restrict p, inc_t ldp + ) +{ + T* restrict a_cast = ( T* )a; + T* restrict p_cast = ( T* )p; + T* restrict d_cast = ( T* )d; + auto kappa_cast = *( T* )kappa; + + if ( conja ) + { + for ( dim_t j = 0; j < panel_len; j++ ) + { + auto kappa_d = kappa_cast * d_cast[ j*incd ]; + + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i + j*ldp ] = kappa_d * conj( a_cast[ i*inca + j*lda ] ); + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); + } + } + else + { + for ( dim_t j = 0; j < panel_len; j++ ) + { + auto kappa_d = kappa_cast * d_cast[ j*incd ]; + + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i + j*ldp ] = kappa_d * a_cast[ i*inca + j*lda ]; + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); + } + } + + for (dim_t j = panel_len;j < panel_len_max;j++) + for (dim_t i = 0;i < panel_dim_max;i++) + p_cast[ i + j*ldp ] = convert(0.0); +} + +void packm_diag + ( + obj_t* a, + obj_t* p, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + // We begin by copying the fields of A. + bli_obj_alias_to( a, p ); + + // Get information about data types. + num_t dt = bli_obj_dt( a ); + num_t dt_tar = bli_obj_target_dt( a ); + num_t dt_scalar = bli_obj_scalar_dt( a ); + dim_t dt_size = bli_dt_size( dt ); + + if ( dt_scalar != dt || dt_tar != dt ) + bli_abort(); + + // Extract various fields from the control tree. + bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl ); + bszid_t bmult_id_n = bli_cntl_packm_params_bmid_n( cntl ); + pack_t schema = bli_cntl_packm_params_pack_schema( cntl ); + dim_t bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx ); + dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx ); + dim_t bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx ); + + if ( schema != BLIS_PACKED_ROW_PANELS && + schema != BLIS_PACKED_COL_PANELS ) + bli_abort(); + + // Store the pack schema to the object. + bli_obj_set_pack_schema( schema, p ); + + // Clear the conjugation field from the object since matrix packing + // in BLIS is deemed to take care of all conjugation necessary. + bli_obj_set_conj( BLIS_NO_CONJUGATE, p ); + + // If we are packing micropanels, mark P as dense. + bli_obj_set_uplo( BLIS_DENSE, p ); + + // Reset the view offsets to (0,0). + bli_obj_set_offs( 0, 0, p ); + + // Compute the dimensions padded by the dimension multiples. These + // dimensions will be the dimensions of the packed matrices, including + // zero-padding, and will be used by the macro- and micro-kernels. + // We compute them by starting with the effective dimensions of A (now + // in P) and aligning them to the dimension multiples (typically equal + // to register blocksizes). This does waste a little bit of space for + // level-2 operations, but that's okay with us. + dim_t m_p = bli_obj_length( p ); + dim_t n_p = bli_obj_width( p ); + dim_t m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def ); + dim_t n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def ); + + // Save the padded dimensions into the packed object. It is important + // to save these dimensions since they represent the actual dimensions + // of the zero-padded matrix. + bli_obj_set_padded_dims( m_p_pad, n_p_pad, p ); + + // The "panel stride" of a micropanel packed object is interpreted as + // the distance between the (0,0) element of panel k and the (0,0) + // element of panel k+1. We use the padded width computed above to + // allow for zero-padding (if necessary/desired) along the far end + // of each micropanel (ie: the right edge of the matrix). Zero-padding + // can also occur along the long edge of the last micropanel if the m + // dimension of the matrix is not a whole multiple of MR. + inc_t ps_p = bmult_m_pack * n_p_pad; + + /* Compute the total number of iterations we'll need. */ + dim_t n_iter = m_p_pad / bmult_m_def; + + // Store the strides and panel dimension in P. + bli_obj_set_strides( 1, bmult_m_pack, p ); + bli_obj_set_imag_stride( 1, p ); + bli_obj_set_panel_dim( bmult_m_def, p ); + bli_obj_set_panel_stride( ps_p, p ); + bli_obj_set_panel_length( bmult_m_def, p ); + bli_obj_set_panel_width( n_p, p ); + + // Compute the size of the packed buffer. + siz_t size_p = ps_p * n_iter * dt_size; + if ( size_p == 0 ) return; + + // Update the buffer address in p to point to the buffer associated + // with the mem_t entry acquired from the memory broker (now cached in + // the control tree node). + char* p_cast = (char*)bli_packm_alloc( size_p, rntm, cntl, thread ); + bli_obj_set_buffer( p_cast, p ); + + char* a_cast = (char*)bli_obj_buffer_at_off( a ); + inc_t inca = bli_obj_row_stride( a ); + inc_t lda = bli_obj_col_stride( a ); + dim_t panel_len_off = bli_obj_col_off( a ); + conj_t conja = bli_obj_conj_status( a ); + + auto params = (packm_diag_params_t*)bli_obj_pack_params( a ); + char* d_cast = (char*)params->d; + inc_t incd = params->incd; + + obj_t kappa_local; + char* kappa_cast = (char*)bli_packm_scalar( &kappa_local, p ); + + auto packm_ker_cast = packm_diag_ukrs[ dt ]; + + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ + const dim_t nt = bli_thread_n_way( thread ); + const dim_t tid = bli_thread_work_id( thread ); + + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ + dim_t it_start, it_end, it_inc; + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); + + /* Iterate over every logical micropanel in the source matrix. */ + for ( dim_t it = 0; it < n_iter; it += 1 ) + { + dim_t panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def ); + + char* d_begin = d_cast + panel_len_off*incd*dt_size; + char* a_begin = a_cast + it* bmult_m_def*inca*dt_size; + char* p_begin = p_cast + it* ps_p*dt_size; + + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) + { + packm_ker_cast( conja, + panel_dim_i, + n_p, + bmult_m_def, + n_p_pad, + kappa_cast, + d_begin, incd, + a_begin, inca, lda, + p_begin, bmult_m_pack ); + } + } +} + +/* + * Modify the object A to include information about the diagonal D, + * and imbue it with special function pointers which will take care + * of the actual work of forming (D * A^T) + */ +void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a ) +{ + // Assumes D is a column vector + new (params) packm_diag_params_t + ( + bli_obj_buffer_at_off( d ), + bli_obj_row_stride( d ) + ); + + // Set the custom pack function. + bli_obj_set_pack_fn( packm_diag, a ); + + // Attach the parameters to the A object. + bli_obj_set_pack_params( params, a ); +} + +/* + * Implements C := alpha * A * D * A^T + beta * C + * + * where D is a diagonal matrix with elements taken from the "d" vector. + */ +void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ) +{ + obj_t ad; // this is (D * A^T) + packm_diag_params_t params; + + bli_obj_alias_to( a, &ad ); + bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T + attach_diagonal_factor( ¶ms, d, &ad ); + + // Does C := alpha * A * B + beta * C using B = (D + A^T) + bli_gemmt( alpha, a, &ad, beta, c ); +} + +int main() +{ + obj_t a; + obj_t d; + obj_t c; + obj_t c_copy; + obj_t norm; + + auto m = 10; + auto k = 10; + + for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) + for ( int upper = 0; upper <= 1; upper++ ) + for ( int transa = 0; transa <= 1; transa++ ) + for ( int transc = 0; transc <= 1; transc++ ) + { + auto dt = ( num_t )dt_; + auto uplo = upper ? BLIS_UPPER : BLIS_LOWER; + + bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a ); + bli_obj_create( dt, k, 1, 1, 1, &d ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy ); + bli_obj_set_uplo( uplo , &c ); + bli_obj_set_uplo( uplo , &c_copy ); + bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); + + bli_randm( &a ); + bli_randm( &d ); + bli_randm( &c ); + bli_copym( &c, &c_copy ); + + syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c ); + syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy ); + + bli_subm( &c_copy, &c ); + bli_normfm( &c, &norm ); + + double normr, normi; + bli_getsc( &norm, &normr, &normi ); + + printf("dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n", + dt, upper, transa, transc, normr); + + bli_obj_free( &a ); + bli_obj_free( &d ); + bli_obj_free( &c ); + bli_obj_free( &c_copy ); + bli_obj_free( &norm ); + } +} diff --git a/test/obj_t_makeover/syrk_diagonal_ref.cxx b/test/obj_t_makeover/syrk_diagonal_ref.cxx new file mode 100644 index 0000000000..1d7c5d96e5 --- /dev/null +++ b/test/obj_t_makeover/syrk_diagonal_ref.cxx @@ -0,0 +1,102 @@ +#include "syrk_diagonal_ref.h" +#include "complex_math.hpp" + +typedef void (*syrk_diag_ref_vft) + ( + uplo_t uplo, + dim_t m, + dim_t k, + void* alpha, + void* a, inc_t rs_a, inc_t cs_a, + void* d, inc_t incd, + void* beta, + void* c, inc_t rs_c, inc_t cs_c + ); + +template +void syrk_diag_ref + ( + uplo_t uplo, + dim_t m, + dim_t k, + void* alpha, + void* a, inc_t rs_a, inc_t cs_a, + void* d, inc_t incd, + void* beta, + void* c, inc_t rs_c, inc_t cs_c + ) +{ + auto alpha_cast = *( T* )alpha; + auto beta_cast = *( T* )beta; + auto a_cast = ( T* )a; + auto d_cast = ( T* )d; + auto c_cast = ( T* )c; + + for ( dim_t i = 0; i < m; i++ ) + { + dim_t j_min = uplo == BLIS_UPPER ? i : 0; + dim_t j_max = uplo == BLIS_UPPER ? m : i+1; + + for ( dim_t j = j_min; j < j_max; j++ ) + { + auto ada = convert(0.0); + + for ( dim_t p = 0; p < k; p++ ) + { + ada += a_cast[ i*rs_a + p*cs_a ] * + d_cast[ p*incd ] * + a_cast[ j*rs_a + p*cs_a ]; + } + + if ( beta_cast == convert(0.0) ) + { + c_cast[ i*rs_c + j*cs_c ] = alpha_cast * ada; + } + else + { + c_cast[ i*rs_c + j*cs_c ] = alpha_cast * ada + + beta_cast * c_cast[ i*rs_c + j*cs_c ]; + } + } + } +} + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &syrk_diag_ref; + +INSERT_GENTFUNC_BASIC0(syrk_diag_ref); + +static syrk_diag_ref_vft GENARRAY( syrk_diag_ref_impl, syrk_diag_ref ); + +void syrk_diag_ref( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ) +{ + num_t dt = bli_obj_dt( a ); + + dim_t m = bli_obj_length_after_trans( a ); + dim_t k = bli_obj_width_after_trans( a ); + + inc_t rs_a = bli_obj_row_stride( a ); + inc_t cs_a = bli_obj_col_stride( a ); + inc_t rs_c = bli_obj_row_stride( c ); + inc_t cs_c = bli_obj_col_stride( c ); + inc_t incd = bli_obj_row_stride( d ); + + if ( bli_obj_has_trans( a ) ) + bli_swap_incs( &rs_a, &cs_a ); + + if ( bli_obj_has_trans( c ) ) + bli_swap_incs( &rs_c, &cs_c ); + + syrk_diag_ref_impl[ dt ] + ( + bli_obj_uplo( c ), + m, k, + bli_obj_buffer_for_1x1( dt, alpha ), + bli_obj_buffer_at_off( a ), rs_a, cs_a, + bli_obj_buffer_at_off( d ), incd, + bli_obj_buffer_for_1x1( dt, beta ), + bli_obj_buffer_at_off( c ), rs_c, cs_c + ); +} + diff --git a/test/obj_t_makeover/syrk_diagonal_ref.h b/test/obj_t_makeover/syrk_diagonal_ref.h new file mode 100644 index 0000000000..a6864caec8 --- /dev/null +++ b/test/obj_t_makeover/syrk_diagonal_ref.h @@ -0,0 +1,8 @@ +#include "blis.h" + +#ifdef __cplusplus +#include "complex_math.hpp" +extern "C" +#endif +void syrk_diag_ref( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ); + diff --git a/test/tensor_contraction/complex_math.hpp b/test/tensor_contraction/complex_math.hpp new file mode 100644 index 0000000000..9c68e730aa --- /dev/null +++ b/test/tensor_contraction/complex_math.hpp @@ -0,0 +1,267 @@ +#include +#include +#include + +#include "blis.h" + +template +struct is_complex : std::false_type {}; + +template <> +struct is_complex : std::true_type {}; + +template <> +struct is_complex : std::true_type {}; + +template +struct is_real : std::integral_constant::value> {}; + +template struct make_complex; + +template <> struct make_complex { using type = scomplex; }; +template <> struct make_complex { using type = dcomplex; }; +template <> struct make_complex { using type = scomplex; }; +template <> struct make_complex { using type = dcomplex; }; + +template +using make_complex_t = typename make_complex::type; + +template struct make_real; + +template <> struct make_real { using type = float; }; +template <> struct make_real { using type = double; }; +template <> struct make_real { using type = float; }; +template <> struct make_real { using type = double; }; + +template +using make_real_t = typename make_real::type; + +template +struct make_complex_if : std::conditional,make_real_t> {}; + +template +using make_complex_if_t = typename make_complex_if::type; + +template +struct real_imag_part +{ + real_imag_part& operator=(T) { return *this; } + + operator T() const { return T(); } +}; + +template +std::enable_if_t::type>::value,T&> real(T& x) { return x; } + +template +std::enable_if_t::value,real_imag_part> imag(T x) { return {}; } + +inline float& real(scomplex& x) { return x.real; } + +inline float& imag(scomplex& x) { return x.imag; } + +inline double& real(dcomplex& x) { return x.real; } + +inline double& imag(dcomplex& x) { return x.imag; } + +inline const float& real(const scomplex& x) { return x.real; } + +inline const float& imag(const scomplex& x) { return x.imag; } + +inline const double& real(const dcomplex& x) { return x.real; } + +inline const double& imag(const dcomplex& x) { return x.imag; } + +template +std::enable_if_t::value,T> conj(T x) { return x; } + +template +std::enable_if_t::value,T> conj(const T& x) { return {x.real, -x.imag}; } + +template +struct convert_impl; + +template +struct convert_impl::value && is_real::value>> +{ + void operator()(T x, U& y) const { y = x; } +}; + +template +struct convert_impl::value && is_complex::value>> +{ + void operator()(T x, U& y) const { y.real = x; y.imag = 0; } +}; + +template +struct convert_impl::value && is_real::value>> +{ + void operator()(T x, U& y) const { y = x.real; } +}; + +template +struct convert_impl::value && is_complex::value>> +{ + void operator()(T x, U& y) const { y.real = x.real; y.imag = x.imag; } +}; + +template +U convert(T x) +{ + U y; + convert_impl{}(x,y); + return y; +} + +template +auto convert_prec(T x) -> make_complex_if_t::value> +{ + return convert::value>>(x); +} + +#define COMPLEX_MATH_OPS(rtype, ctype) \ +\ +inline bool operator==(rtype x, ctype y) \ +{ \ + return x == y.real && y.imag == 0; \ +} \ +\ +inline bool operator==(ctype x, rtype y) \ +{ \ + return y == x.real && x.imag == 0; \ +} \ +\ +inline bool operator==(ctype x, ctype y) \ +{ \ + return x.real == y.real && \ + x.imag == y.imag; \ + } \ + \ +inline ctype operator-(ctype x) \ +{ \ + return {-x.real, -x.imag}; \ +} \ +\ +inline ctype operator+(rtype x, ctype y) \ +{ \ + return {x+y.real, y.imag}; \ +} \ +\ +inline ctype operator+(ctype x, rtype y) \ +{ \ + return {y+x.real, x.imag}; \ +} \ +\ +inline ctype operator+(ctype x, ctype y) \ +{ \ + return {x.real+y.real, x.imag+y.imag}; \ +} \ +\ +inline ctype operator-(rtype x, ctype y) \ +{ \ + return {x-y.real, -y.imag}; \ +} \ +\ +inline ctype operator-(ctype x, rtype y) \ +{ \ + return {x.real-y, x.imag}; \ +} \ +\ +inline ctype operator-(ctype x, ctype y) \ +{ \ + return {x.real-y.real, x.imag-y.imag}; \ +} \ +\ +inline ctype operator*(rtype x, ctype y) \ +{ \ + return {x*y.real, x*y.imag}; \ +} \ +\ +inline ctype operator*(ctype x, rtype y) \ +{ \ + return {y*x.real, y*x.imag}; \ +} \ +\ +inline ctype operator*(ctype x, ctype y) \ +{ \ + return {x.real*y.real - x.imag*y.imag, \ + x.real*y.imag + x.imag*y.real}; \ +} \ +\ +inline ctype operator/(rtype x, ctype y) \ +{ \ + auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \ + auto n = std::ilogb(scale); \ + auto yrs = std::scalbn(y.real, -n); \ + auto yis = std::scalbn(y.imag, -n); \ + auto denom = y.real*yrs + y.imag*yis; \ + return {x*yrs/denom, -x*yis/denom}; \ +} \ +\ +inline ctype operator/(ctype x, rtype y) \ +{ \ + return {x.real/y, x.imag/y}; \ +} \ +\ +inline ctype operator/(ctype x, ctype y) \ +{ \ + auto scale = std::max(std::abs(y.real), std::abs(y.imag)); \ + auto n = std::ilogb(scale); \ + auto yrs = std::scalbn(y.real, -n); \ + auto yis = std::scalbn(y.imag, -n); \ + auto denom = y.real*yrs + y.imag*yis; \ + return {(x.real*yrs + x.imag*yis)/denom, \ + (x.imag*yrs - x.real*yis)/denom}; \ +} \ +\ +inline ctype& operator+=(ctype& x, rtype y) \ +{ \ + x.real += y; \ + return x; \ +} \ +\ +inline ctype& operator+=(ctype& x, ctype y) \ +{ \ + x.real += y.real; x.imag += y.imag; \ + return x; \ +} \ +\ +inline ctype& operator-=(ctype& x, rtype y) \ +{ \ + x.real -= y; \ + return x; \ +} \ +\ +inline ctype& operator-=(ctype& x, ctype y) \ +{ \ + x.real -= y.real; x.imag -= y.imag; \ + return x; \ +} \ +\ +inline ctype& operator*=(ctype& x, rtype y) \ +{ \ + x.real *= y; x.imag *= y; \ + return x; \ +} \ +\ +inline ctype& operator*=(ctype& x, ctype y) \ +{ \ + x = x * y; \ + return x; \ +} \ +\ +inline ctype& operator/=(ctype& x, rtype y) \ +{ \ + x.real /= y; x.imag /= y; \ + return x; \ +} \ +\ +inline ctype& operator/=(ctype& x, ctype y) \ +{ \ + x = x / y; \ + return x; \ +} + +COMPLEX_MATH_OPS(float, scomplex); +COMPLEX_MATH_OPS(double, dcomplex); + diff --git a/test/tensor_contraction/tcontract_example.cxx b/test/tensor_contraction/tcontract_example.cxx new file mode 100644 index 0000000000..82afc219e0 --- /dev/null +++ b/test/tensor_contraction/tcontract_example.cxx @@ -0,0 +1,987 @@ + +#include "tcontract_ref.hpp" + +#include +#include + +static constexpr dim_t BS_K = 8; + +struct packm_tensor_params_t +{ + gint_t ndim_m, ndim_n; + const dim_t *len_m, *len_n; + const inc_t *stride_m, *stride_n; + + packm_tensor_params_t() {} + + packm_tensor_params_t( gint_t ndim_m, const dim_t* len_m, const inc_t* stride_m, + gint_t ndim_n, const dim_t* len_n, const inc_t* stride_n ) + : ndim_m(ndim_m), ndim_n(ndim_n), + len_m(len_m), len_n(len_n), + stride_m(stride_m), stride_n(stride_n) {} +}; + +using gemm_tensor_params_t = packm_tensor_params_t; + +template +void packm_ckx_nb + ( + bool conja, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + void* kappa, + void* a, inc_t inca, inc_t* bsa, inc_t* scata, + void* p, inc_t ldp + ) +{ + T* restrict a_cast = ( T* )a; + T* restrict p_cast = ( T* )p; + auto kappa_cast = *( T* )kappa; + + if ( conja ) + { + for ( auto j0 = 0; j0 < panel_len; j0 += BS_K, bsa += BS_K, scata += BS_K ) + { + auto lda = *bsa; + auto panel_len_j = std::min( panel_len-j0, BS_K ); + + if ( lda ) + { + T* restrict aj = a_cast + *scata; + + for ( auto j = 0; j < panel_len_j; j++ ) + { + for ( auto i = 0; i < panel_dim; i++ ) + p_cast[ i ] = kappa_cast * conj( aj[ i*inca + j*lda ] ); + + for ( auto i = panel_dim; i < panel_dim_max; i++ ) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + else + { + for ( auto j = 0; j < panel_len_j; j++) + { + for ( auto i = 0; i < panel_dim; i++) + p_cast[ i ] = kappa_cast * conj( a_cast[ i*inca + scata[j] ] ); + + for ( auto i = panel_dim; i < panel_dim_max; i++) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + } + } + else + { + for ( auto j0 = 0; j0 < panel_len; j0 += BS_K, bsa += BS_K, scata += BS_K ) + { + auto lda = *bsa; + auto panel_len_j = std::min( panel_len-j0, BS_K ); + + if ( lda ) + { + T* restrict aj = a_cast + *scata; + + for ( auto j = 0; j < panel_len_j; j++ ) + { + for ( auto i = 0; i < panel_dim; i++ ) + p_cast[ i ] = kappa_cast * aj[ i*inca + j*lda ]; + + for ( auto i = panel_dim; i < panel_dim_max; i++ ) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + else + { + for ( auto j = 0; j < panel_len_j; j++ ) + { + for ( auto i = 0; i < panel_dim; i++ ) + p_cast[ i ] = kappa_cast * a_cast[ i*inca + scata[j] ]; + + for ( auto i = panel_dim; i < panel_dim_max; i++ ) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + } + } + + for ( auto j = panel_len; j < panel_len_max; j++) + { + for ( auto i = 0; i < panel_dim_max; i++) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } +} + +template +void packm_ckx_ss + ( + bool conja, + dim_t panel_dim, + dim_t panel_len, + dim_t panel_dim_max, + dim_t panel_len_max, + void* kappa, + void* a, inc_t* inca, inc_t* scata, + void* p, inc_t ldp + ) +{ + T* restrict a_cast = ( T* )a; + T* restrict p_cast = ( T* )p; + auto kappa_cast = *( T* )kappa; + + if ( conja ) + { + for (dim_t j = 0;j < panel_len;j++) + { + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i ] = kappa_cast * conj( a_cast[ inca[i] + scata[j] ] ); + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + else + { + for (dim_t j = 0;j < panel_len;j++) + { + for (dim_t i = 0;i < panel_dim;i++) + p_cast[ i ] = kappa_cast * a_cast[ inca[i] + scata[j] ]; + + for (dim_t i = panel_dim;i < panel_dim_max;i++) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } + } + + for (dim_t j = panel_len;j < panel_len_max;j++) + { + for (dim_t i = 0;i < panel_dim_max;i++) + p_cast[ i ] = convert(0.0); + + p_cast += ldp; + } +} + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &packm_ckx_nb; + +INSERT_GENTFUNC_BASIC0(packm_ckx_nb); + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &packm_ckx_ss; + +INSERT_GENTFUNC_BASIC0(packm_ckx_ss); + +static decltype(&packm_ckx_nb) GENARRAY( packm_ckx_nb_ukrs, packm_ckx_nb ); +static decltype(&packm_ckx_ss) GENARRAY( packm_ckx_ss_ukrs, packm_ckx_ss ); + +static void fill_scatter + ( + gint_t ndim, + const dim_t* restrict len, + const inc_t* restrict stride, + dim_t BS, + inc_t off, + dim_t size, + inc_t* restrict scat, + inc_t* restrict bs + ) +{ + if ( size == 0 ) return; + + if ( ndim == 0 ) + { + *scat = 0; + *bs = 0; + return; + } + + if ( ndim == 1 ) + { + auto l = *len; + auto s = *stride; + for ( auto i = 0; i < l; i++ ) + { + scat[i] = i*s; + bs[i] = s; + } + } + + dim_t tot_len = 1; + for ( auto i = 0; i < ndim; i++ ) + tot_len *= len[i]; + + assert(off >= 0); + assert(size >= 0); + assert(off+size <= tot_len); + + auto len0 = len[0]; + auto stride0 = stride[0]; + auto off0 = off % len0; + auto off1 = off / len0; + auto size1 = ( size + off0 + len0 - 1) / len0; + + inc_t pos1 = 0; + inc_t idx = 0; + for_each( ndim-1, len+1, off1, size1, pos1, stride+1, + [&] + { + auto pos = pos1 + off0 * stride0; + auto len_i = std::min( len0-off0, size-idx ); + for ( auto i = 0; i < len_i; i++ ) + { + scat[idx++] = pos; + pos += stride0; + } + off0 = 0; + }); + assert(idx == size); + + for ( idx = 0; idx < size; idx += BS ) + { + auto len_i = std::min( BS, size-idx ); + auto s = stride0; + + for ( auto i = idx; i < idx+len_i-1; i++) + { + if (scat[i+1]-scat[i] != s) + { + s = 0; + break; + } + } + + bs[idx] = s; + } +} + +void packm_tensor + ( + obj_t* a, + obj_t* p, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + // We begin by copying the fields of A. + bli_obj_alias_to( a, p ); + + // Get information about data types. + auto dt = bli_obj_dt( a ); + auto dt_tar = bli_obj_target_dt( a ); + auto dt_scalar = bli_obj_scalar_dt( a ); + auto dt_size = bli_dt_size( dt ); + + if ( dt_scalar != dt || dt_tar != dt ) + bli_abort(); + + // Extract various fields from the control tree. + auto bmult_id_m = bli_cntl_packm_params_bmid_m( cntl ); + auto bmult_id_n = bli_cntl_packm_params_bmid_n( cntl ); + auto schema = bli_cntl_packm_params_pack_schema( cntl ); + auto bmult_m_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_m, cntx ); + auto bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx ); + auto bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx ); + + if ( schema != BLIS_PACKED_ROW_PANELS && + schema != BLIS_PACKED_COL_PANELS ) + bli_abort(); + + // Store the pack schema to the object. + bli_obj_set_pack_schema( schema, p ); + + // Clear the conjugation field from the object since matrix packing + // in BLIS is deemed to take care of all conjugation necessary. + bli_obj_set_conj( BLIS_NO_CONJUGATE, p ); + + // If we are packing micropanels, mark P as dense. + bli_obj_set_uplo( BLIS_DENSE, p ); + + // Reset the view offsets to (0,0). + bli_obj_set_offs( 0, 0, p ); + + // Compute the dimensions padded by the dimension multiples. These + // dimensions will be the dimensions of the packed matrices, including + // zero-padding, and will be used by the macro- and micro-kernels. + // We compute them by starting with the effective dimensions of A (now + // in P) and aligning them to the dimension multiples (typically equal + // to register blocksizes). This does waste a little bit of space for + // level-2 operations, but that's okay with us. + auto m_p = bli_obj_length( p ); + auto n_p = bli_obj_width( p ); + auto m_p_pad = bli_align_dim_to_mult( m_p, bmult_m_def ); + auto n_p_pad = bli_align_dim_to_mult( n_p, bmult_n_def ); + + // Save the padded dimensions into the packed object. It is important + // to save these dimensions since they represent the actual dimensions + // of the zero-padded matrix. + bli_obj_set_padded_dims( m_p_pad, n_p_pad, p ); + + // The "panel stride" of a micropanel packed object is interpreted as + // the distance between the (0,0) element of panel k and the (0,0) + // element of panel k+1. We use the padded width computed above to + // allow for zero-padding (if necessary/desired) along the far end + // of each micropanel (ie: the right edge of the matrix). Zero-padding + // can also occur along the long edge of the last micropanel if the m + // dimension of the matrix is not a whole multiple of MR. + auto ps_p = bmult_m_pack * n_p_pad; + + /* Compute the total number of iterations we'll need. */ + auto n_iter = m_p_pad / bmult_m_def; + + // Store the strides and panel dimension in P. + bli_obj_set_strides( 1, bmult_m_pack, p ); + bli_obj_set_imag_stride( 1, p ); + bli_obj_set_panel_dim( bmult_m_def, p ); + bli_obj_set_panel_stride( ps_p, p ); + bli_obj_set_panel_length( bmult_m_def, p ); + bli_obj_set_panel_width( n_p, p ); + + // Compute the size of the packed buffer. + auto size_p = ps_p * n_iter * dt_size; + if ( size_p == 0 ) return; + + // Compute the size of the scatter and block-scatter vectors to the total. + // It is never necessary to add padding for alignment because: + // 1) ps_p is always even + // 2) dt_size is a power of two >= 4 + // 3) the alignment of the scatter vectors is at most 8 + auto scat_size = 2 * (m_p + n_p) * sizeof(inc_t); + + // Update the buffer address in p to point to the buffer associated + // with the mem_t entry acquired from the memory broker (now cached in + // the control tree node). + auto p_cast = (char*)bli_packm_alloc( size_p + scat_size, rntm, cntl, thread ); + bli_obj_set_buffer( p_cast, p ); + + // Get the addresses of the scatter and block-scatter vectors. These are + // placed directly after the packed matrix buffer. + auto rscat = (inc_t*)(p_cast + size_p); + auto rbs = rscat + m_p; + auto cscat = rbs + m_p; + auto cbs = cscat + n_p; + + auto a_cast = (char*)bli_obj_buffer_at_off( a ); + auto panel_dim_off = bli_obj_row_off( a ); + auto panel_len_off = bli_obj_col_off( a ); + auto conja = bli_obj_conj_status( a ); + + auto params = (packm_tensor_params_t*)bli_obj_pack_params( a ); + auto ndim_m = params->ndim_m; + auto ndim_n = params->ndim_n; + auto len_m = params->len_m; + auto len_n = params->len_n; + auto stride_m = params->stride_m; + auto stride_n = params->stride_n; + + obj_t kappa_local; + auto kappa_cast = (char*)bli_packm_scalar( &kappa_local, p ); + + auto packm_nb_ker = packm_ckx_nb_ukrs[ dt ]; + auto packm_ss_ker = packm_ckx_ss_ukrs[ dt ]; + + a_cast -= ( panel_dim_off * stride_m[0] + + panel_len_off * stride_n[0] ) * dt_size; + + /* Fill in the scatter and block-scatter vectors. This is done single-threaded for now. */ + if ( bli_thread_am_ochief( thread ) ) + { + fill_scatter + ( + ndim_m, + len_m, + stride_m, + bmult_m_def, + panel_dim_off, + m_p, + rscat, + rbs + ); + + fill_scatter + ( + ndim_n, + len_n, + stride_n, + BS_K, + panel_len_off, + n_p, + cscat, + cbs + ); + } + + /* Wait for the scatter vectors to be done. */ + bli_thread_barrier( thread ); + + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ + auto nt = bli_thread_n_way( thread ); + auto tid = bli_thread_work_id( thread ); + + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ + dim_t it_start, it_end, it_inc; + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); + + /* Iterate over every logical micropanel in the source matrix. */ + for ( auto it = 0; it < n_iter; it += 1 ) + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) + { + auto panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def ); + + auto p_begin = p_cast + it*ps_p*dt_size; + auto inca = rbs[ it*bmult_m_def ]; + + if ( inca ) + { + auto a_begin = a_cast + rscat[ it*bmult_m_def ]*dt_size; + + packm_nb_ker( conja, + panel_dim_i, + n_p, + bmult_m_def, + n_p_pad, + kappa_cast, + a_begin, inca, cbs, cscat, + p_begin, bmult_m_pack ); + } + else + { + auto a_begin = a_cast; + auto rscat_use = rscat + it*bmult_m_def; + + packm_ss_ker( conja, + panel_dim_i, + n_p, + bmult_m_def, + n_p_pad, + kappa_cast, + a_begin, rscat_use, cscat, + p_begin, bmult_m_pack ); + } + } +} + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +void PASTEMAC(ch,op) \ + ( \ + dim_t m, \ + dim_t n, \ + void* x, inc_t rs_x, inc_t cs_x, \ + void* b, \ + void* y, inc_t* rs_y, inc_t* cs_y \ + ) \ +{ \ + ctype* restrict x_cast = (ctype*)x; \ + ctype b_cast = *(ctype*)b; \ + ctype* restrict y_cast = (ctype*)y; \ +\ + if ( PASTEMAC(ch,eq0)( b_cast ) ) \ + { \ + for ( auto i = 0; i < m; i++ ) \ + for ( auto j = 0; j < n; j++ ) \ + PASTEMAC(ch,copys)( x_cast[ i*rs_x + j*cs_x ], y_cast[ rs_y[i] + cs_y[j] ] ); \ + } \ + else \ + { \ + for ( auto i = 0; i < m; i++ ) \ + for ( auto j = 0; j < n; j++ ) \ + PASTEMAC(ch,xpbys)( x_cast[ i*rs_x + j*cs_x ], b_cast, y_cast[ rs_y[i] + cs_y[j] ] ); \ + } \ +} + +INSERT_GENTFUNC_BASIC0(scatter_mxn); + +static decltype(&bli_sscatter_mxn) GENARRAY(scatter_mxn, scatter_mxn); + +void gemm_tensor + ( + obj_t* a, + obj_t* b, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + auto dt = bli_obj_dt( c ); + auto dt_size = bli_dt_size( dt ); + + auto m = bli_obj_length( c ); + auto n = bli_obj_width( c ); + auto k = bli_obj_width( a ); + + auto a_cast = (char*)bli_obj_buffer_at_off( a ); + auto pd_a = bli_obj_panel_dim( a ); + auto ps_a = bli_obj_panel_stride( a ); + + auto b_cast = (char*)bli_obj_buffer_at_off( b ); + auto pd_b = bli_obj_panel_dim( b ); + auto ps_b = bli_obj_panel_stride( b ); + + auto c_cast = (char*)bli_obj_buffer_at_off( c ); + auto rs_c0 = bli_obj_row_stride( c ); + auto cs_c0 = bli_obj_col_stride( c ); + auto off_m = bli_obj_row_off( c ); + auto off_n = bli_obj_col_off( c ); + + auto params = (gemm_tensor_params_t*)bli_obj_ker_params( c ); + auto ndim_m = params->ndim_m; + auto ndim_n = params->ndim_n; + auto len_m = params->len_m; + auto len_n = params->len_n; + auto stride_m = params->stride_m; + auto stride_n = params->stride_n; + + if ( rs_c0 != stride_m[0] || cs_c0 != stride_n[0] ) + { + std::swap( ndim_m, ndim_n ); + std::swap( len_m, len_n ); + std::swap( stride_m, stride_n ); + } + + /* If any dimension is zero, return immediately. */ + if ( bli_zero_dim3( m, n, k ) ) return; + + c_cast -= ( off_m * stride_m[0] + + off_n * stride_n[0] ) * dt_size; + + // Detach and multiply the scalars attached to A and B. + // NOTE: We know that the internal scalars of A and B are already of the + // target datatypes because the necessary typecasting would have already + // taken place during bli_packm_init(). + obj_t scalar_a; + obj_t scalar_b; + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + // NOTE: We know that scalar_b is of type dt due to the above code + // that casts the scalars of A and B to dt via scalar_a and scalar_b, + // and we know that the internal scalar in C is already of the type dt + // due to the casting in the implementation of bli_obj_scalar_attach(). + auto alpha_cast = (char*)bli_obj_internal_scalar_buffer( &scalar_b ); + auto beta_cast = (char*)bli_obj_internal_scalar_buffer( c ); + + /* Alias some constants to simpler names. */ + auto MR = pd_a; + auto NR = pd_b; + + /* Query the context for the micro-kernel address and cast it to its + function pointer type. */ + auto gemm_ukr = (gemm_ukr_vft)bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); + + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ + char ct[ BLIS_STACK_BUF_MAX_SIZE ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); + auto col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); + auto rs_ct = ( col_pref ? 1 : NR ); + auto cs_ct = ( col_pref ? MR : 1 ); + auto zero = (char*)bli_obj_buffer_for_const( dt, &BLIS_ZERO ); + + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ + + auto scat_size = 2 * (m + n) * sizeof(inc_t); + auto rscat_c = (inc_t*)bli_packm_alloc_ex( scat_size, BLIS_BUFFER_FOR_GEN_USE, rntm, cntl, thread ); + auto rbs_c = rscat_c + m; + auto cscat_c = rbs_c + m; + auto cbs_c = cscat_c + n; + + /* Fill in the scatter and block-scatter vectors. This is done single-threaded for now. */ + if ( bli_thread_am_ochief( thread ) ) + { + fill_scatter + ( + ndim_m, + len_m, + stride_m, + MR, + off_m, + m, + rscat_c, + rbs_c + ); + + fill_scatter + ( + ndim_n, + len_n, + stride_n, + NR, + off_n, + n, + cscat_c, + cbs_c + ); + } + + /* Wait for the scatter vectors to be done. */ + bli_thread_barrier( thread ); + + /* Compute number of primary and leftover components of the m and n + dimensions. */ + auto n_iter = n / NR; + auto n_left = n % NR; + + auto m_iter = m / MR; + auto m_left = m % MR; + + if ( n_left ) ++n_iter; + if ( m_left ) ++m_iter; + + /* Determine some increments used to step through A, B, and C. */ + auto rstep_a = ps_a * dt_size; + auto cstep_b = ps_b * dt_size; + + /* Save the virtual microkernel address and the params. */ + auxinfo_t aux; + bli_auxinfo_set_ukr( (void*)gemm_ukr, &aux ); + bli_auxinfo_set_params( params, &aux ); + + /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + loop around the microkernel. Here we query the thrinfo_t node for the + 1st (ir) loop around the microkernel. */ + auto caucus = bli_thrinfo_sub_node( thread ); + + /* Query the number of threads and thread ids for each loop. */ + auto jr_nt = bli_thread_n_way( thread ); + auto jr_tid = bli_thread_work_id( thread ); + auto ir_nt = bli_thread_n_way( caucus ); + auto ir_tid = bli_thread_work_id( caucus ); + + /* Determine the thread range and increment for the 2nd and 1st loops. + NOTE: The definition of bli_thread_range_jrir() will depend on whether + slab or round-robin partitioning was requested at configure-time. */ + dim_t jr_start, jr_end; + dim_t ir_start, ir_end; + dim_t jr_inc, ir_inc; + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); + bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); + + /* Loop over the n dimension (NR columns at a time). */ + for ( auto j = jr_start; j < jr_end; j += jr_inc ) + { + auto b1 = b_cast + j * cstep_b; + + auto n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); + + /* Initialize our next panel of B to be the current panel of B. */ + auto b2 = b1; + + /* Loop over the m dimension (MR rows at a time). */ + for ( auto i = ir_start; i < ir_end; i += ir_inc ) + { + auto a1 = a_cast + i * rstep_a; + auto rscat_c1 = rscat_c + i * MR; + auto rbs_c1 = rbs_c + i * MR; + auto cscat_c1 = cscat_c + j * NR; + auto cbs_c1 = cbs_c + j * NR; + + auto m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); + + /* Compute the addresses of the next panels of A and B. */ + auto a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) + { + a2 = a_cast; + b2 = bli_gemm_get_next_b_upanel( b1, cstep_b, jr_inc ); + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) + b2 = b_cast; + } + + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ + bli_auxinfo_set_next_a( a2, &aux ); + bli_auxinfo_set_next_b( b2, &aux ); + + auto rs_c = *rbs_c1; + auto cs_c = *cbs_c1; + + if ( rs_c && cs_c ) + { + auto c11 = c_cast + ( *rscat_c1 + *cscat_c1 ) * dt_size; + + /* Invoke the gemm micro-kernel. */ + gemm_ukr + ( + m_cur, + n_cur, + k, + alpha_cast, + a1, + b1, + beta_cast, + c11, rs_c, cs_c, + &aux, + cntx + ); + } + else + { + /* Invoke the gemm micro-kernel. */ + gemm_ukr + ( + MR, + NR, + k, + alpha_cast, + a1, + b1, + zero, + &ct, rs_ct, cs_ct, + &aux, + cntx + ); + + /* Scatter to C. */ + scatter_mxn[ dt ] + ( + m_cur, n_cur, + &ct, rs_ct, cs_ct, + beta_cast, + c_cast, rscat_c1, cscat_c1 + ); + } + } + } +} + +static bool has_unit_stride( const std::vector& stride ) +{ + for ( auto s : stride ) + if ( s == 1 ) + return true; + return false; +} + +void tcontract( num_t dt, const std::vector& m, const std::vector& n, const std::vector& k, + const void* alpha, const void* a, std::vector rs_a, std::vector cs_a, + const void* b, std::vector rs_b, std::vector cs_b, + const void* beta, void* c, std::vector rs_c, std::vector cs_c ) +{ + if ( rs_a.size() != m.size() || + rs_b.size() != k.size() || + rs_c.size() != m.size() ) + bli_check_error_code( BLIS_INVALID_ROW_STRIDE ); + + if ( cs_a.size() != k.size() || + cs_b.size() != n.size() || + cs_c.size() != n.size() ) + bli_check_error_code( BLIS_INVALID_COL_STRIDE ); + + dim_t m_mat = 1; + dim_t n_mat = 1; + dim_t k_mat = 1; + for ( auto& i : m ) m_mat *= i; + for ( auto& i : n ) n_mat *= i; + for ( auto& i : k ) k_mat *= i; + + auto& stride_m = has_unit_stride( rs_c ) ? rs_c : rs_a; + for ( int i = 1;i < m.size(); i++ ) + for ( int j = 0;j < m.size()-i; j++ ) + if ( stride_m[j] > stride_m[j+1] ) + { + std::swap( rs_a[j], rs_a[j+1] ); + std::swap( rs_c[j], rs_c[j+1] ); + } + + auto& stride_n = has_unit_stride( cs_c ) ? cs_c : cs_b; + for ( int i = 1;i < n.size(); i++ ) + for ( int j = 0;j < n.size()-i; j++ ) + if ( stride_n[j] > stride_n[j+1] ) + { + std::swap( cs_b[j], cs_b[j+1] ); + std::swap( cs_c[j], cs_c[j+1] ); + } + + auto& stride_k = has_unit_stride( cs_a ) ? cs_a : rs_b; + for ( int i = 1;i < k.size(); i++ ) + for ( int j = 0;j < k.size()-i; j++ ) + if ( stride_k[j] > stride_k[j+1] ) + { + std::swap( cs_a[j], cs_a[j+1] ); + std::swap( rs_b[j], rs_b[j+1] ); + } + + if ( rs_a.empty() ) rs_a.push_back( 1 ); + if ( cs_a.empty() ) cs_a.push_back( 1 ); + if ( rs_b.empty() ) rs_b.push_back( 1 ); + if ( cs_b.empty() ) cs_b.push_back( 1 ); + if ( rs_c.empty() ) rs_c.push_back( 1 ); + if ( cs_c.empty() ) cs_c.push_back( 1 ); + + obj_t a_o, b_o, c_o; + bli_obj_create_with_attached_buffer( dt, m_mat, k_mat, const_cast(a), rs_a[0], cs_a[0], &a_o ); + bli_obj_create_with_attached_buffer( dt, k_mat, n_mat, const_cast(b), rs_b[0], cs_b[0], &b_o ); + bli_obj_create_with_attached_buffer( dt, m_mat, n_mat, c , rs_c[0], cs_c[0], &c_o ); + + packm_tensor_params_t params_a( m.size(), m.data(), rs_a.data(), + k.size(), k.data(), cs_a.data() ); + packm_tensor_params_t params_b( n.size(), n.data(), cs_b.data(), + k.size(), k.data(), rs_b.data() ); + gemm_tensor_params_t params_c( m.size(), m.data(), rs_c.data(), + n.size(), n.data(), cs_c.data() ); + + bli_obj_set_pack_fn( packm_tensor, &a_o ); + bli_obj_set_pack_fn( packm_tensor, &b_o ); + bli_obj_set_ker_fn( gemm_tensor, &c_o ); + bli_obj_set_pack_params( ¶ms_a, &a_o ); + bli_obj_set_pack_params( ¶ms_b, &b_o ); + bli_obj_set_ker_params( ¶ms_c, &c_o ); + + obj_t alpha_o, beta_o; + bli_obj_create_1x1_with_attached_buffer( dt, const_cast(alpha), &alpha_o ); + bli_obj_create_1x1_with_attached_buffer( dt, const_cast(beta), &beta_o ); + + rntm_t rntm; + bli_rntm_init_from_global( &rntm ); + bli_rntm_disable_l3_sup( &rntm ); + + bli_gemm_ex( &alpha_o, &a_o, &b_o, &beta_o, &c_o, NULL, &rntm ); +} + +int main() +{ + auto N = 5; + + gint_t ndim_a = 4; + gint_t ndim_b = 4; + gint_t ndim_c = 4; + + std::vector len_a(ndim_a, N); + std::vector len_b(ndim_b, N); + std::vector len_c(ndim_c, N); + + std::vector stride_a(ndim_a, 1); + std::vector stride_b(ndim_b, 1); + std::vector stride_c(ndim_c, 1); + for ( gint_t i = 1; i < ndim_a; i++ ) + stride_a[i] = stride_a[i-1] * len_a[i - 1]; + for ( gint_t i = 1; i < ndim_b; i++ ) + stride_b[i] = stride_b[i-1] * len_b[i - 1]; + for ( gint_t i = 1; i < ndim_c; i++ ) + stride_c[i] = stride_c[i-1] * len_c[i - 1]; + + std::vector dim_a(ndim_a); + std::vector dim_b(ndim_b); + std::vector dim_c(ndim_c); + std::iota(dim_a.begin(), dim_a.end(), 0); + std::iota(dim_b.begin(), dim_b.end(), 0); + std::iota(dim_c.begin(), dim_c.end(), 0); + + for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) + do + do + do + { + auto dt = ( num_t )dt_; + + auto ndim_m = (ndim_a + ndim_c - ndim_b)/2; + auto ndim_k = (ndim_a + ndim_b - ndim_c)/2; + + std::vector m(len_a.begin(), len_a.begin()+ndim_m); + std::vector n(len_b.begin()+ndim_k, len_b.end()); + std::vector k(len_b.begin(), len_b.begin()+ndim_k); + + std::vector rs_a(stride_a.begin(), stride_a.begin()+ndim_m); + std::vector cs_a(stride_a.begin()+ndim_m, stride_a.end()); + std::vector rs_b(stride_b.begin(), stride_b.begin()+ndim_k); + std::vector cs_b(stride_b.begin()+ndim_k, stride_b.end()); + std::vector rs_c(stride_c.begin(), stride_c.begin()+ndim_m); + std::vector cs_c(stride_c.begin()+ndim_m, stride_c.end()); + + dim_t m_tot = 1; + dim_t n_tot = 1; + dim_t k_tot = 1; + for ( auto i : m ) m_tot *= i; + for ( auto i : n ) n_tot *= i; + for ( auto i : k ) k_tot *= i; + + obj_t a, b, c, c_ref, norm; + + bli_obj_create( dt, m_tot*k_tot, 1, 1, 1, &a ); + bli_obj_create( dt, k_tot*n_tot, 1, 1, 1, &b ); + bli_obj_create( dt, m_tot*n_tot, 1, 1, 1, &c ); + bli_obj_create( dt, m_tot*n_tot, 1, 1, 1, &c_ref ); + bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); + + bli_randv( &a ); + bli_randv( &b ); + bli_randv( &c ); + bli_copyv( &c, &c_ref ); + + tcontract( dt, m, n, k, + bli_obj_buffer_for_const( dt, &BLIS_ONE ), + bli_obj_buffer( &a ), rs_a, cs_a, + bli_obj_buffer( &b ), rs_b, cs_b, + bli_obj_buffer_for_const( dt, &BLIS_ZERO ), + bli_obj_buffer( &c ), rs_c, cs_c ); + + tcontract_ref( dt, m, n, k, + bli_obj_buffer_for_const( dt, &BLIS_ONE ), + bli_obj_buffer( &a ), rs_a, cs_a, + bli_obj_buffer( &b ), rs_b, cs_b, + bli_obj_buffer_for_const( dt, &BLIS_ZERO ), + bli_obj_buffer( &c_ref ), rs_c, cs_c ); + + bli_subv( &c_ref, &c ); + bli_normfv( &c, &norm ); + + double normr, normi; + bli_getsc( &norm, &normr, &normi ); + + printf("dt: %d, dim_a: [%d,%d,%d,%d], dim_b: [%d,%d,%d,%d], dim_c: [%d,%d,%d,%d], norm: %g\n", + dt, dim_a[0], dim_a[1], dim_a[2], dim_a[3], + dim_b[0], dim_b[1], dim_b[2], dim_b[3], + dim_c[0], dim_c[1], dim_c[2], dim_c[3], + normr / std::sqrt( bli_obj_vector_dim( &c ) ) ); + + bli_obj_free( &a ); + bli_obj_free( &b ); + bli_obj_free( &c ); + bli_obj_free( &c_ref ); + } + while (std::next_permutation(dim_a.begin(), dim_a.end())); + while (std::next_permutation(dim_b.begin(), dim_b.end())); + while (std::next_permutation(dim_c.begin(), dim_c.end())); +} diff --git a/test/tensor_contraction/tcontract_ref.cxx b/test/tensor_contraction/tcontract_ref.cxx new file mode 100644 index 0000000000..b4cd07f903 --- /dev/null +++ b/test/tensor_contraction/tcontract_ref.cxx @@ -0,0 +1,67 @@ +#include "tcontract_ref.hpp" + +template +void tcontract_ref( const std::vector& m, const std::vector& n, const std::vector& k, + const void* alpha, const void* a, const std::vector& rs_a, const std::vector& cs_a, + const void* b, const std::vector& rs_b, const std::vector& cs_b, + const void* beta, void* c, const std::vector& rs_c, const std::vector& cs_c ) +{ + auto alpha_cast = *( T* )alpha; + auto beta_cast = *( T* )beta; + auto a_cast = ( T* )a; + auto b_cast = ( T* )b; + auto c_cast = ( T* )c; + + for_each(m.size(), m.data(), a_cast, rs_a.data(), c_cast, rs_c.data(), + [&] + { + for_each(n.size(), n.data(), b_cast, cs_b.data(), c_cast, cs_c.data(), + [&] + { + auto ab = convert(0.0); + + for_each(k.size(), k.data(), a_cast, cs_a.data(), b_cast, rs_b.data(), + [&] + { + ab += (*a_cast) * (*b_cast); + }); + + if ( beta_cast == convert(0.0) ) + { + *c_cast = alpha_cast * ab; + } + else + { + *c_cast = alpha_cast * ab + beta_cast * (*c_cast); + } + }); + + assert(b_cast == b); + }); + + assert(a_cast == a); + assert(c_cast == c); +} + +#undef GENTFUNC +#define GENTFUNC(ctype,ch,op) \ +static auto PASTEMAC(ch,op) = &tcontract_ref; + +INSERT_GENTFUNC_BASIC0(tcontract_ref); + +static decltype(&tcontract_ref) GENARRAY( tcontract_ref_impl, tcontract_ref ); + +void tcontract_ref( num_t dt, const std::vector& m, const std::vector& n, const std::vector& k, + const void* alpha, const void* a, const std::vector& rs_a, const std::vector& cs_a, + const void* b, const std::vector& rs_b, const std::vector& cs_b, + const void* beta, void* c, const std::vector& rs_c, const std::vector& cs_c ) +{ + tcontract_ref_impl[ dt ] + ( + m, n, k, + alpha, a, rs_a, cs_a, + b, rs_b, cs_b, + beta, c, rs_c, cs_c + ); +} + diff --git a/test/tensor_contraction/tcontract_ref.hpp b/test/tensor_contraction/tcontract_ref.hpp new file mode 100644 index 0000000000..99d4380dce --- /dev/null +++ b/test/tensor_contraction/tcontract_ref.hpp @@ -0,0 +1,100 @@ +#include "blis.h" +#include "complex_math.hpp" + +#include +#include +#include + +inline void increment(inc_t, gint_t) {} + +template +void increment(inc_t n, gint_t i, T& off, const inc_t* s, Args&... args) +{ + off += s[i]*n; + increment(n, i, args...); +} + +template +void for_each_impl(gint_t ndim, const dim_t* n, + dim_t off, dim_t len, + Body& body, + Args&... args) +{ + std::array i = {}; + assert( ndim <= i.size() ); + + if ( off ) + { + for ( gint_t k = 0; k < ndim; k++ ) + { + i[k] = off % n[k]; + off /= n[k]; + increment(i[k], k, args...); + } + } + + for ( dim_t pos = 0; pos < len; pos++ ) + { + body(); + + for ( gint_t k = 0; k < ndim; k++ ) + { + if ( i[k] == n[k]-1 ) + { + increment(-i[k], k, args...); + i[k] = 0; + } + else + { + increment(1, k, args...); + i[k]++; + break; + } + } + } +} + +template +void for_each(gint_t ndim, const dim_t* n, + dim_t off, dim_t len, + T& a, const inc_t* s_a, + Body&& body) +{ + for_each_impl( ndim, n, off, len, body, a, s_a ); +} + +template +void for_each(gint_t ndim, const dim_t* n, + dim_t off, dim_t len, + T& a, const inc_t* s_a, + T& b, const inc_t* s_b, + Body&& body) +{ + for_each_impl( ndim, n, off, len, body, a, s_a, b, s_b ); +} + +template +void for_each(gint_t ndim, const dim_t* n, + T& a, const inc_t* s_a, + Body&& body) +{ + dim_t len = 1; + for ( gint_t i = 0;i < ndim;i++ ) len *= n[i]; + for_each_impl( ndim, n, 0, len, body, a, s_a ); +} + +template +void for_each(gint_t ndim, const dim_t* n, + T& a, const inc_t* s_a, + T& b, const inc_t* s_b, + Body&& body) +{ + dim_t len = 1; + for ( gint_t i = 0;i < ndim;i++ ) len *= n[i]; + for_each_impl( ndim, n, 0, len, body, a, s_a, b, s_b ); +} + +void tcontract_ref( num_t dt, const std::vector& m, const std::vector& n, const std::vector& k, + const void* alpha, const void* a, const std::vector& rs_a, const std::vector& cs_a, + const void* b, const std::vector& rs_b, const std::vector& cs_b, + const void* beta, void* c, const std::vector& rs_c, const std::vector& cs_c ); From 5f5741f453c3fda85fe915a648145244e231e77a Mon Sep 17 00:00:00 2001 From: Devin Matthews Date: Tue, 7 Dec 2021 13:56:01 -0600 Subject: [PATCH 02/12] Pick up missed changes for gemm 1m/md. --- frame/3/gemm/bli_gemm_md.h | 61 +++-------------------------- frame/3/gemm/ind/bli_gemm_ind_opt.h | 2 + 2 files changed, 8 insertions(+), 55 deletions(-) diff --git a/frame/3/gemm/bli_gemm_md.h b/frame/3/gemm/bli_gemm_md.h index 8fcf6bd21d..751e271eaf 100644 --- a/frame/3/gemm/bli_gemm_md.h +++ b/frame/3/gemm/bli_gemm_md.h @@ -154,7 +154,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast num_t* dt_comp, num_t dt_a, num_t dt_b, - num_t dt_c, + num_t* dt_c, dim_t* m, dim_t* n, dim_t* k, @@ -164,7 +164,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast inc_t* rs_c, inc_t* cs_c ) { - if ( bli_is_real( dt_c ) && + if ( bli_is_real( *dt_c ) && bli_is_complex( dt_a ) && bli_is_complex( dt_b ) ) { @@ -177,7 +177,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast *ps_a *= 2; *ps_b *= 2; } - else if ( bli_is_complex( dt_c ) && + else if ( bli_is_complex( *dt_c ) && bli_is_real( dt_a ) && bli_is_complex( dt_b ) ) { @@ -197,6 +197,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast // to the real virtual microkernel slots of the context) instead of // the complex macrokernel and c2r virtual microkernel. *dt_comp = bli_dt_proj_to_real( *dt_comp ); + *dt_c = bli_dt_proj_to_real( *dt_c ); *n *= 2; *pd_b *= 2; *ps_b *= 2; *rs_c *= 2; @@ -211,7 +212,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast *ps_a /= 2; } } - else if ( bli_is_complex( dt_c ) && + else if ( bli_is_complex( *dt_c ) && bli_is_complex( dt_a ) && bli_is_real( dt_b ) ) { @@ -231,6 +232,7 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast // to the real virtual microkernel slots of the context) instead of // the complex macrokernel and c2r virtual microkernel. *dt_comp = bli_dt_proj_to_real( *dt_comp ); + *dt_c = bli_dt_proj_to_real( *dt_c ); *m *= 2; *pd_a *= 2; *ps_a *= 2; *cs_c *= 2; @@ -274,54 +276,3 @@ BLIS_INLINE void bli_gemm_md_ker_var2_recast #endif } -// ----------------------------------------------------------------------------- - -// -// Prototype object-based interfaces. -// - -#undef GENPROT -#define GENPROT( opname ) \ -\ -void PASTEMAC0(opname) \ - ( \ - obj_t* a, \ - obj_t* b, \ - obj_t* c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - cntl_t* cntl, \ - thrinfo_t* thread \ - ); - -GENPROT( gemm_ker_var2_md ) - -// -// Prototype BLAS-like interfaces with void pointer operands. -// - -#undef GENTPROT2 -#define GENTPROT2( ctype_c, ctype_e, chc, che, varname ) \ -\ -void PASTEMAC2(chc,che,varname) \ - ( \ - pack_t schema_a, \ - pack_t schema_b, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - void* alpha, \ - void* a, inc_t cs_a, inc_t is_a, \ - dim_t pd_a, inc_t ps_a, \ - void* b, inc_t rs_b, inc_t is_b, \ - dim_t pd_b, inc_t ps_b, \ - void* beta, \ - void* c, inc_t rs_c, inc_t cs_c, \ - cntx_t* cntx, \ - rntm_t* rntm, \ - thrinfo_t* thread \ - ); - -INSERT_GENTPROT2_BASIC0( gemm_ker_var2_md ) -INSERT_GENTPROT2_MIXDP0( gemm_ker_var2_md ) - diff --git a/frame/3/gemm/ind/bli_gemm_ind_opt.h b/frame/3/gemm/ind/bli_gemm_ind_opt.h index 7528c4f03e..77213d0fa3 100644 --- a/frame/3/gemm/ind/bli_gemm_ind_opt.h +++ b/frame/3/gemm/ind/bli_gemm_ind_opt.h @@ -35,6 +35,7 @@ BLIS_INLINE void bli_gemm_ind_recast_1m_params ( num_t* dt_exec, + num_t* dt_c, pack_t schema_a, obj_t* c, dim_t* m, @@ -57,6 +58,7 @@ BLIS_INLINE void bli_gemm_ind_recast_1m_params !bli_is_gen_stored( *rs_c, *cs_c ) ) { *dt_exec = bli_dt_proj_to_real( *dt_exec ); + *dt_c = bli_dt_proj_to_real( *dt_c ); if ( bli_is_1e_packed( schema_a ) ) { From b5e3fa28da759053069fdd08d0a6ca8f8d6bcafa Mon Sep 17 00:00:00 2001 From: Devin Matthews Date: Wed, 8 Dec 2021 10:52:44 -0600 Subject: [PATCH 03/12] All working now. --- frame/1m/packm/bli_packm_alloc.c | 19 +++++++++++++++++++ frame/1m/packm/bli_packm_alloc.h | 9 +++++++++ ref_kernels/ind/bli_gemmtrsm1m_ref.c | 2 +- 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/frame/1m/packm/bli_packm_alloc.c b/frame/1m/packm/bli_packm_alloc.c index df6750d7ac..5015984848 100644 --- a/frame/1m/packm/bli_packm_alloc.c +++ b/frame/1m/packm/bli_packm_alloc.c @@ -46,6 +46,25 @@ void* bli_packm_alloc // Query the pack buffer type from the control tree node. packbuf_t pack_buf_type = bli_cntl_packm_params_pack_buf_type( cntl ); + return bli_packm_alloc_ex + ( + size_needed, + pack_buf_type, + rntm, + cntl, + thread + ); +} + +void* bli_packm_alloc_ex + ( + siz_t size_needed, + packbuf_t pack_buf_type, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ // Query the address of the mem_t entry within the control tree node. mem_t* cntl_mem_p = bli_cntl_pack_mem( cntl ); diff --git a/frame/1m/packm/bli_packm_alloc.h b/frame/1m/packm/bli_packm_alloc.h index b433be350a..f9870562d4 100644 --- a/frame/1m/packm/bli_packm_alloc.h +++ b/frame/1m/packm/bli_packm_alloc.h @@ -40,3 +40,12 @@ BLIS_EXPORT_BLIS void* bli_packm_alloc thrinfo_t* thread ); +BLIS_EXPORT_BLIS void* bli_packm_alloc_ex + ( + siz_t size_needed, + packbuf_t pack_buf_type, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ); + diff --git a/ref_kernels/ind/bli_gemmtrsm1m_ref.c b/ref_kernels/ind/bli_gemmtrsm1m_ref.c index e35a1bf13d..cf904a1924 100644 --- a/ref_kernels/ind/bli_gemmtrsm1m_ref.c +++ b/ref_kernels/ind/bli_gemmtrsm1m_ref.c @@ -78,7 +78,7 @@ void PASTEMAC3(ch,opname,arch,suf) \ \ const dim_t packnr = bli_cntx_get_blksz_max_dt( dt, BLIS_NR, cntx ); \ \ - const pack_t schema_b = bli_cntx_schema_b_panel( cntx ); \ + const pack_t schema_b = bli_auxinfo_schema_b( data ); \ \ const dim_t k2 = 2 * k; \ \ From ea6cc1fb52b0e73451c0dc8b3c8464b0e2a1286c Mon Sep 17 00:00:00 2001 From: Devin Matthews Date: Wed, 8 Dec 2021 10:57:37 -0600 Subject: [PATCH 04/12] Move syrk_diagonal example. --- test/{obj_t_makeover => syrk_diagonal}/complex_math.hpp | 0 test/{obj_t_makeover => syrk_diagonal}/syrk_diagonal_example.c | 0 test/{obj_t_makeover => syrk_diagonal}/syrk_diagonal_example.cxx | 0 test/{obj_t_makeover => syrk_diagonal}/syrk_diagonal_example2.c | 0 test/{obj_t_makeover => syrk_diagonal}/syrk_diagonal_example2.cxx | 0 test/{obj_t_makeover => syrk_diagonal}/syrk_diagonal_ref.cxx | 0 test/{obj_t_makeover => syrk_diagonal}/syrk_diagonal_ref.h | 0 7 files changed, 0 insertions(+), 0 deletions(-) rename test/{obj_t_makeover => syrk_diagonal}/complex_math.hpp (100%) rename test/{obj_t_makeover => syrk_diagonal}/syrk_diagonal_example.c (100%) rename test/{obj_t_makeover => syrk_diagonal}/syrk_diagonal_example.cxx (100%) rename test/{obj_t_makeover => syrk_diagonal}/syrk_diagonal_example2.c (100%) rename test/{obj_t_makeover => syrk_diagonal}/syrk_diagonal_example2.cxx (100%) rename test/{obj_t_makeover => syrk_diagonal}/syrk_diagonal_ref.cxx (100%) rename test/{obj_t_makeover => syrk_diagonal}/syrk_diagonal_ref.h (100%) diff --git a/test/obj_t_makeover/complex_math.hpp b/test/syrk_diagonal/complex_math.hpp similarity index 100% rename from test/obj_t_makeover/complex_math.hpp rename to test/syrk_diagonal/complex_math.hpp diff --git a/test/obj_t_makeover/syrk_diagonal_example.c b/test/syrk_diagonal/syrk_diagonal_example.c similarity index 100% rename from test/obj_t_makeover/syrk_diagonal_example.c rename to test/syrk_diagonal/syrk_diagonal_example.c diff --git a/test/obj_t_makeover/syrk_diagonal_example.cxx b/test/syrk_diagonal/syrk_diagonal_example.cxx similarity index 100% rename from test/obj_t_makeover/syrk_diagonal_example.cxx rename to test/syrk_diagonal/syrk_diagonal_example.cxx diff --git a/test/obj_t_makeover/syrk_diagonal_example2.c b/test/syrk_diagonal/syrk_diagonal_example2.c similarity index 100% rename from test/obj_t_makeover/syrk_diagonal_example2.c rename to test/syrk_diagonal/syrk_diagonal_example2.c diff --git a/test/obj_t_makeover/syrk_diagonal_example2.cxx b/test/syrk_diagonal/syrk_diagonal_example2.cxx similarity index 100% rename from test/obj_t_makeover/syrk_diagonal_example2.cxx rename to test/syrk_diagonal/syrk_diagonal_example2.cxx diff --git a/test/obj_t_makeover/syrk_diagonal_ref.cxx b/test/syrk_diagonal/syrk_diagonal_ref.cxx similarity index 100% rename from test/obj_t_makeover/syrk_diagonal_ref.cxx rename to test/syrk_diagonal/syrk_diagonal_ref.cxx diff --git a/test/obj_t_makeover/syrk_diagonal_ref.h b/test/syrk_diagonal/syrk_diagonal_ref.h similarity index 100% rename from test/obj_t_makeover/syrk_diagonal_ref.h rename to test/syrk_diagonal/syrk_diagonal_ref.h From 0c913237ff77635d5d97d8c889ceca9e5257a64b Mon Sep 17 00:00:00 2001 From: Devin Matthews Date: Wed, 8 Dec 2021 10:57:59 -0600 Subject: [PATCH 05/12] Cleanup. [ci skip] --- test/obj_t_makeover/syrk_diagonal_example.tgz | Bin 6904 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 test/obj_t_makeover/syrk_diagonal_example.tgz diff --git a/test/obj_t_makeover/syrk_diagonal_example.tgz b/test/obj_t_makeover/syrk_diagonal_example.tgz deleted file mode 100644 index 79ac14959103c7a43631814e16d4d9d23dd21efa..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6904 zcmV9#Hqc# zommw{LN;THAMO5$ z4eynThG(8;P6DKOGbhW=`BmBg2Vm3b3Ou4QO#l8)o^Soo4dxza57UAqBz(Ga{;<;- z^g67^K4i(<&E^SbSLeQa&SasKIYGc2&+{`MOF{ZHN`R5@#U^|ig&q`eGUlg<#R(Gb zcuSUMQNld{X25V3O+Jk?NRu_}03-sRp8IK@3K&?ghsFC*7%Y1|=FJm7JQEmVo=~0N zK*2#LTgKeuQ(2hBj(ahqhK*ws)sjJR)ZjT#PTBWTbOhItPUN1acYq9@Hg1 zpw*BI|HAi89o~J#uNT}69H1JhLpXN9laL1tNn;LhCFL2L$MotXiL+$U6w4tjwn~7f zHj6^efTfH(z;(~b973@}?fiBgVl#g_d2{;X$2ZSA-@=y<@G8M_5!QCk?I`ZDpF7wV z_U`z@>B;M3$W+PLGE4r#LtKMqnxBrH=Tn1qHhr4e^Eil%Z2UB{-6;HIWaFoqjqNwG zQ}>+D%=}3d1uTF3@Dl4UP&YHm13q9Q^wDe`NWjAlu#9KUVgpQ@fMH@Zoo<4y2qq*l z;m|QYNcV*k$4+xLWNE3@?}4|XL~JwjQ6*xdMD>w%q+=7z4YS2K>oMcAQmxTk96~Z1 z#qwId*9^0sgdV$2nz2LqwJ!^h)Ef<=#}Tmf95QDmY+QhiSHgNl*qp5Xk*5j@=2SW6 zsrOPrQZQzv<+9Z0vf@M5vlMAv%D|$s;wcPlF^1SThwO(}FF%am9Y6o?lNaOnZ~qFf zZ%$r)U|U;e&(?!KfBQ;-ca6~)_+w`hC7IRDD;@OMaU2Irp_!*s)O5XpEF{^reOMH4 z(0r=V2IZt2bbZ$bm8cOdIbc2k3x@mfbCBbN&+ToK(n9H+uY!@9+3GQMWa>J>ZrZgB zrrWX^e`lY1)Q!K`3KKIJQ6N^(>JzQLCKXcm{q5~;&CsgQ@F^$4v!y)$o%u+C!?uAG z7}ZEgta4;x%zoK`DuW)AI@DwvQOL6)($`qIDs&^?)h|qPyzSr;4PQzBiZ_FhPs|wi z`-KiALOvAO*m(qzB#!Bo$j3(j1yOc6&kCCm;O z*ih~&N-i)7D@PO5dI2EE|2XZEtx2{cERetq&DRK1C>kZ!swvtQ<~^9u`VCaNw7+M3 z20dhH(4oyz|5qMOt)58g64wxPFkr`NI-kLk;{~n-9GD$~`7C6Y9BGi20+xY^7lZ|; zCaiBVpH88xoD8QK&?tz|gr(t~gD26|7^ioSOW0h5p(9mX-;mBA1XFsX>bLhpR{Q%OEX0Tet5e1+BI3-L?187V2m8UCKlJFwpH9hr%(MN2a;&}Z5L<$Duhg$5*e!zfw+6&S#DvS|uLu8y}we;TAr|ZWJ%gln@k(Y>!ap zGoW-d+&5A#WJ+nYdD*x$-Y}&!;vh+nxxDB+h7l(W2k{D}M)JJ|och!e@bX<05Fbq? zT%&|*mT+~6Ypo2zIX2b91q;x;uhZp<8Gj&4Q2V`T;>UOIPTsR;;^WoZzn;98JTk@7 zO}LYxz5=@i*0~q9YQguw*Jo^Ge2?IJh612Oy9o~<85P%ob>CawIy^J_Bl+_^^n6Z_ zYtZA~7akJ6I%KFFX)D7L%>3}T*RM}b-@km$nyO!2Vr50xm6qtxRoF{ljVJ{p#-qi^ zqBM}6IEhdnjUgvM4zVs^z=5V!he7}>@@7^|m#d6QAF7NB-AEs06sXl8hF0Gqj=g0}<<{{Wc{5eSR~Y3*^GU!-QzbV2_{HL9JWD|k zsH8-Y$X~yJDNaB+Os$7r_Kf`*Xs6*UUd1Rh-d*vINZ{@4=kTFjHo{T%356=Iii%=M znGYx~(}eT74ma{@$F0n(6<3wlD5DZUaTc>~U^4lwcemST-ILb&e|GQJ0kqoxv-@=T zcu)C%_8#w{|L5uO(XjRZe2(Xj?Eks7b4OvORSYgt%@rK(2MC+bYH(KwZ^v3!#ac(j zMsA9YoD?ga3JsZ_;ZpFLfExV{xAZy)cfi-+SF4qUFWYf-J6BUtEWGV_R$z_ zpA}x8*o12I_jo2$ql3oq&8&6Jv<{Be!SM&`;IFQKzh(!AY1Q1)%R%#Do-T;jHj&)I zWh}P^bW*I7bnE~{4XD&4%x08|FD7iDU;~|+ohN9GVg|f|;V3e2Q}bKUzN1d(!C=7t z%E3-O150$WWVp6fCgs%UftS+m?M~-G?GABlZqxq6vl_>ZHY`#=qWAyglY1RJ+@3{j z8fFY@k|GQwobjv<6vvj67@#7-02LyJp#HC1W0l`x`v_X{faate;WxNX$^P{(MUME^ z&JEM~FL-&f$Px!8%1bBlqj{=$sC@x2-DL?FL0Q5%GBy3x_kga+Y#wA|c%5m|s7Hka zN$YzuLja4qnA%h+9j+@K8l_d7k&JdQJEU5gfCZ$1LfKG)01$OB8FLg1+%m$@vh-KX zq0SI|w2cPaB;K{?$hxgX2WsKA=)${D7v|`4lVyWwl3ZD>k0ogqrVqGA%2k z;Y1XCItB6YFUvF@Gb@a)l!UzNiuGsUpM`QhnJLMw;Rv%}hVVh>21~~OId=lW1Mmq< zOligw7#)#&fx0^F?udO6550m{Fy(`Mu)xo<0Om12Lt0W`7OiawO~5aC(BI9sNPu}0 z3Mp{1?|}9TSRzyN=Ew)>pBu-e2@33&v+Udw-7lfS3e=VFm4N^ij7#NVCytP5lSN}} zzl8vHL4_77SY0;E%Z93~&>haDQXM3XMrZk<>r~} zaiJa{c?o?K!vc|qnHfrvgAyJo#9B&~Brhro4D+y!siin)D>j&8?-U*Dqz0G>R(MF7 z$50Q&=26oy76fR`Ak`6c33&u5$JR`EcE!2aks~djT>6Mw;JhHQ)&PW~3bZ}i!NG(M zaia@JHq;37&^$uMMuY<-Ke{4g*c`~rf(b`!&wtsqi18t8K#dWy%im1D5L z^7%7@OZ>C*jB)RbtF932hr#rJMumjOb-;ED(uhT_J5P+t=*7el;1b`c&X=f*nypwP zc0?Fu&{bFG5e!+S0}$%`Ecqz(+i2`CFdWV_X)KsHj(ll~? zC@2Psp37z&B>~26vIM>L&pD2++z^7?2r?!~1hXW~ZZI^!rovr&tvpPF9^~5NXgav( zPC{!rBQlz%35#;N$X;Py#RoV~z z%!hdc=1nHHgguO-n)x9ln7t+%O9*JHrSSqywS=<+oxhC3;37??XOIZ~Nv?d*a}Oth z6J9(NEEQ^z*eV?yHEaSC3_BKdRsyp=D!ND;d7vW_F8X6f3dJ`$)38xu`Dx3D`Z~UW zW&PkmE}%(f&!I;EA;$nCm+{Nd0aPyS~^Bev}x}!R#VE1xGWaI?ZDb1ZWNS? zC}}7jDkv2(s8Xkp#0GqA6##(yVe?d$X$mD&fvyEM8c~=!3`6nHEAvl>t(tzw;sKEY zoN!Tg9tgjw@;HFra)SN5o6IvJx@m%bvA<)nEu*BK*lKqF@`z`B-liD2it6DPX_x%Z zTxei!*dszKrP6#9y1B41G*1#?LSx={V4W&n=eu~MIE4)ZodaaZ5s|63iNdfFxN?@3 zIf0nqneP?BG3prY2Zv4n9fNw550a8P0$v=bt~ee#j7Lc z6P}U@fFn3V!JqnP^Mv;^u;fhB&`0|X%G>w`bSTLushlWq@W7kG9 zLpH)G@fv1D0h;f!`%)mFvrFb6&Io-d-Ycz7^#+$1Yu;GwAd6w9PB^i$IwD6)FSqvO zNOcOO9aY-L{cNLN@{$;(^XlUQg1CUdd{Eox2%vtlEOM6G1S8(|JrWR3v@Hkl4^ zs6xgHvS6_ok#skm2RC_`etSkl#^~Jc;0Ie_klgL;2QGx=W~n<*Gnlcc;}i!Q&R|SK zv&FVyLbQiF+QS|9dbp$MWJi0Nqdm>hp5|yzbF`;9+S42jdF^SA_NhJpvEwcl_gKf8 z`@i<~hU)ysaCi5~BRv1{Xf$fie|(N-v-2N6M#+_vczy44#NTHeWDxg+4lpf*1rj?w zDn}LK&Lb%%9avbpxN6W7S6B50#!-mE)m4KYU0!7XlrFCt^yvC31EBW$DwSD(d6f=v zD#rg0dh+tB0)1c3iI5V|pjWuIY8}z)+A0NfoAV|2b#!FK_KtvoF2FspK{e=9wHXtP z{Msw4D3d%Tw>&l?!)={8to+NTM-Fk-Xo~s)6+YVkE$%D%T^;^vKJ>+4S>x#`q3;!vvFYkQk2tFI!G=gAY)8(Da0}O2grs5_L zZk(YnL#p?QgK>0GSnzW=iF@v6T=ftlky*i#qFmxAO?|M$MZ_6)Xf)&5IXNX@`Usqm zLae|dj*_fZ{jKW%eX4$`JV=ihvX#5EVe`AI_9?N%Y5k!!Dh+vErwSI=fW=iXuMVtu zPho}rb(a%X3RG?+Jk>|u)&`HSjs+LJQFo=F*l}Eap`cN@JZx91X4js>((u;>$s!i0 zgHwkl7kZTwR>s<1SG@X6RYi#FPWc?_`g=^olsBEMD(r03kg)Cq)rRB%5tls44o;7( z;m)w&;L87{h^?=>t36$${Y42xV#Y4SK@(cRnPsX z9x0X$?vTyAA)7fv6x3>OMzb@c^(VCcgw~(HzHooS7k?LBdkI~830->$-F;j_*WM}D z-YM7KDc9a9*WM}D-YM7KDc9a9*WM}D-YIwIJLQ_MyE3Z$gWP(hw#MW|R|YHU!rC-v zZg~xs$IndsO)!u~F-h|cOgw@`JReIW+u>QEQ;OzxjZEX3M8(aR{ zU69oVlefX-Z7_KoOx^~Qx54CJ>QjySz2Cc^@L3)IJ$yVGJw9C;4n$Y z`rXC%5@3BOqs%0Ou$1N+m|X+2s~0~jB1H?F_&^?s^i+behGDo@UXUUNN7^s8{rn|nY0QP78LPk$7kjhoCW^NVqQ zqnPJ|X10rIq_V5Z`QP7zF)o6xgi;bfeob~U^kwW!GoGjo6>YobcW4)4U}~TZ2N?~P z74F3x(k(((yov>!ZUauYg41olS%Tc?%6+{Zk09r`DWw~2jAL2@T{$uOP;f&^OnzI` zzIB+~b(ma%$<>&u@+EwIwHvBpmFwpFsI*5>zv~LRSuAsL-6J-+D%-DBAG0mzdUKXv zv7K9(@K+H)yK3v|cV+#qT`-71ZT!~8Z2R2$sj>R+y^b~Ozv1p^n6LkaBV7LtA3bi@ zf1l&Y7kDt*Q@Zl^TYrj^|9I@i!92yk9ln6cCu9$wKUDuA+`Uev`>pEut)B9-{y$O2 z+WWsBKN9Evcb`1nd$b4R|MA{#JN`e*qsRY&iz0B8@Aw=z!C92}+4;;&MB5$BC_hV& zI-MHrQNJxB4&6rv;w_6*Ue_m1y?1D4XJ}TmT6K7%UPVUbRNDi-fHSn^K!w`gg?QlO=wk3<0z-xwEU8l+GY_6z20 z)Fs&GEBXVon>&LtR<@qXH+KVNtVht;1Dg`i$orgt^_)jkHEgFE*6`dG2k4(ip{EC2 zp*LYACkGykH+9Ft0m9qGUWb}&gBti;G5QMC*2s`QZ5AWQJUF$f=xT9N?ID6FHXWl9 zv|ZMu9t#gr&gDVm`)7|N&E6%?l_r0RL*yE#1NC43fLLbES*Cn~*WpKcvd3#`*EwlL^j5UqD7KxW}0Qh;E)x08&;$WQB6&S$~B2jHrpjq zHf>#@Jk`9qe0}|e_R6GN5m0FRY6O~s4TVU~&W7ZG9;fq=r+1gky188A74(+`|5`NN zW_{6wo5t$wT&sM{-w)}|8l^GhCN^z1HEOUNO>H^QQCPuxhOwd&_v++-$FGl1|1y62?t_((M~0De z1zj>fE8R;bjsv~RgKVso>1qXbRc~KaTT|P0zpC*|Z73q!-h}_xi|wUE zyW9}#27JF>ZYyFprF8>-%PCJ2EKqN6Mt6TRZuAxM{-$U*VTT~Rzuc4^{adi3Uu4II zboVx6M^7W~ZHjjz@#rb>z_(lCUCR+kH0w*2NYd@%;A9$@2I+PVjLGZh8LZ?VX>mN~ zM|%@E;mw?dMK8`IwL}$NITwl)th?LYHxY!YupdNc6N~T_A(lz1U@&(QT3UR=5~N>+ z6ig{YdOVEKg}jvPfhxA8Zs;v>ZCDYJBppy8)JzB>zSeVa^N};qBWJL@1>Qe)&sYl+ zZlU42s<(>6!BP+a-q(XegEz1}-$&S%bLLh-l>#b`Un%=B;d&Gmy8Za;$Ft|el y1{*bC(15tpsaVRi4`MUD`vHb^1x9RD*m{4Wby2oY`?Swj_xwK>uoF1|$N&I`%W-S~ From e5c72e01ba1c4dcf64dcd47fe88679f84d05c3c2 Mon Sep 17 00:00:00 2001 From: Devin Matthews Date: Wed, 8 Dec 2021 11:09:34 -0600 Subject: [PATCH 06/12] Trigger Travis build. --- test/tensor_contraction/tcontract_example.cxx | 1 + 1 file changed, 1 insertion(+) diff --git a/test/tensor_contraction/tcontract_example.cxx b/test/tensor_contraction/tcontract_example.cxx index 82afc219e0..0b935c54d4 100644 --- a/test/tensor_contraction/tcontract_example.cxx +++ b/test/tensor_contraction/tcontract_example.cxx @@ -985,3 +985,4 @@ int main() while (std::next_permutation(dim_b.begin(), dim_b.end())); while (std::next_permutation(dim_c.begin(), dim_c.end())); } + From 4638cc6f665348c852cc4e54915eddae3b07dd8b Mon Sep 17 00:00:00 2001 From: Devin Matthews Date: Wed, 8 Dec 2021 11:34:12 -0600 Subject: [PATCH 07/12] Missed some kernels. [ci skip] until I fix the ARM ones. --- kernels/knl/3/bli_dgemm_knl_asm_24x8.c | 85 ++------- kernels/knl/3/bli_sgemm_knl_asm_24x16.c | 85 ++------- .../3/bli_gemm_sandybridge_int_d8x4.c | 1 - kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c | 84 +++------ kernels/skx/3/bli_dgemm_skx_asm_16x14.c | 162 +++++------------- kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c | 88 +++------- 6 files changed, 117 insertions(+), 388 deletions(-) diff --git a/kernels/knl/3/bli_dgemm_knl_asm_24x8.c b/kernels/knl/3/bli_dgemm_knl_asm_24x8.c index b794e7c059..a7f860ae02 100644 --- a/kernels/knl/3/bli_dgemm_knl_asm_24x8.c +++ b/kernels/knl/3/bli_dgemm_knl_asm_24x8.c @@ -185,6 +185,8 @@ static int32_t offsets[32] __attribute__((aligned(64))) = //#define LOOPMON void bli_dgemm_knl_asm_24x8 ( + dim_t m, + dim_t n, dim_t k_, double* restrict alpha, double* restrict a, @@ -201,10 +203,12 @@ void bli_dgemm_knl_asm_24x8 const double * a_next = bli_auxinfo_next_a( data ); const double * b_next = bli_auxinfo_next_b( data ); - const int32_t * offsetPtr = &offsets[0]; - const int64_t k = k_; - const int64_t rs_c = rs_c_; - const int64_t cs_c = cs_c_; + int32_t * offsetPtr = &offsets[0]; + int64_t k = k_; + int64_t rs_c = rs_c_; + int64_t cs_c = cs_c_; + + GEMM_UKR_SETUP_CT( d, 24, 8, true ); #ifdef MONITORS int toph, topl, both, botl, midl, midh, mid2l, mid2h; @@ -565,10 +569,7 @@ void bli_dgemm_knl_asm_24x8 // Check if C is row stride. If not, jump to the slow scattered update MOV(RAX, VAR(rs_c)) LEA(RAX, MEM(,RAX,8)) - MOV(RBX, VAR(cs_c)) LEA(RDI, MEM(RAX,RAX,2)) - CMP(RBX, IMM(1)) - JNE(SCATTEREDUPDATE) VMOVQ(RDX, XMM(1)) SAL(RDX) //shift out sign bit @@ -592,74 +593,6 @@ void bli_dgemm_knl_asm_24x8 UPDATE_C_BZ_FOUR_ROWS(24,25,26,27) UPDATE_C_BZ_FOUR_ROWS(28,29,30,31) - JMP(END) - - LABEL(SCATTEREDUPDATE) - - MOV(RDI, VAR(offsetPtr)) - VMOVAPS(ZMM(2), MEM(RDI)) - /* Note that this ignores the upper 32 bits in cs_c */ - VPBROADCASTD(ZMM(3), EBX) - VPMULLD(ZMM(2), ZMM(3), ZMM(2)) - - VMOVQ(RDX, XMM(1)) - SAL(RDX) //shift out sign bit - JZ(SCATTERBZ) - - UPDATE_C_ROW_SCATTERED( 8) - UPDATE_C_ROW_SCATTERED( 9) - UPDATE_C_ROW_SCATTERED(10) - UPDATE_C_ROW_SCATTERED(11) - UPDATE_C_ROW_SCATTERED(12) - UPDATE_C_ROW_SCATTERED(13) - UPDATE_C_ROW_SCATTERED(14) - UPDATE_C_ROW_SCATTERED(15) - UPDATE_C_ROW_SCATTERED(16) - UPDATE_C_ROW_SCATTERED(17) - UPDATE_C_ROW_SCATTERED(18) - UPDATE_C_ROW_SCATTERED(19) - UPDATE_C_ROW_SCATTERED(20) - UPDATE_C_ROW_SCATTERED(21) - UPDATE_C_ROW_SCATTERED(22) - UPDATE_C_ROW_SCATTERED(23) - UPDATE_C_ROW_SCATTERED(24) - UPDATE_C_ROW_SCATTERED(25) - UPDATE_C_ROW_SCATTERED(26) - UPDATE_C_ROW_SCATTERED(27) - UPDATE_C_ROW_SCATTERED(28) - UPDATE_C_ROW_SCATTERED(29) - UPDATE_C_ROW_SCATTERED(30) - UPDATE_C_ROW_SCATTERED(31) - - JMP(END) - - LABEL(SCATTERBZ) - - UPDATE_C_BZ_ROW_SCATTERED( 8) - UPDATE_C_BZ_ROW_SCATTERED( 9) - UPDATE_C_BZ_ROW_SCATTERED(10) - UPDATE_C_BZ_ROW_SCATTERED(11) - UPDATE_C_BZ_ROW_SCATTERED(12) - UPDATE_C_BZ_ROW_SCATTERED(13) - UPDATE_C_BZ_ROW_SCATTERED(14) - UPDATE_C_BZ_ROW_SCATTERED(15) - UPDATE_C_BZ_ROW_SCATTERED(16) - UPDATE_C_BZ_ROW_SCATTERED(17) - UPDATE_C_BZ_ROW_SCATTERED(18) - UPDATE_C_BZ_ROW_SCATTERED(19) - UPDATE_C_BZ_ROW_SCATTERED(20) - UPDATE_C_BZ_ROW_SCATTERED(21) - UPDATE_C_BZ_ROW_SCATTERED(22) - UPDATE_C_BZ_ROW_SCATTERED(23) - UPDATE_C_BZ_ROW_SCATTERED(24) - UPDATE_C_BZ_ROW_SCATTERED(25) - UPDATE_C_BZ_ROW_SCATTERED(26) - UPDATE_C_BZ_ROW_SCATTERED(27) - UPDATE_C_BZ_ROW_SCATTERED(28) - UPDATE_C_BZ_ROW_SCATTERED(29) - UPDATE_C_BZ_ROW_SCATTERED(30) - UPDATE_C_BZ_ROW_SCATTERED(31) - LABEL(END) #ifdef MONITORS @@ -701,6 +634,8 @@ void bli_dgemm_knl_asm_24x8 "zmm30", "zmm31", "memory" ) + GEMM_UKR_FLUSH_CT( d ); + #ifdef LOOPMON printf("looptime = \t%d\n", bloopl - tloopl); #endif diff --git a/kernels/knl/3/bli_sgemm_knl_asm_24x16.c b/kernels/knl/3/bli_sgemm_knl_asm_24x16.c index 6d485b5308..64feba09f1 100644 --- a/kernels/knl/3/bli_sgemm_knl_asm_24x16.c +++ b/kernels/knl/3/bli_sgemm_knl_asm_24x16.c @@ -182,6 +182,8 @@ static int32_t offsets[32] __attribute__((aligned(64))) = //#define LOOPMON void bli_sgemm_knl_asm_24x16 ( + dim_t m, + dim_t n, dim_t k_, float* restrict alpha, float* restrict a, @@ -198,10 +200,12 @@ void bli_sgemm_knl_asm_24x16 const double * a_next = bli_auxinfo_next_a( data ); const double * b_next = bli_auxinfo_next_b( data ); - const int32_t * offsetPtr = &offsets[0]; - const int64_t k = k_; - const int64_t rs_c = rs_c_; - const int64_t cs_c = cs_c_; + int32_t * offsetPtr = &offsets[0]; + int64_t k = k_; + int64_t rs_c = rs_c_; + int64_t cs_c = cs_c_; + + GEMM_UKR_SETUP_CT( s, 24, 16, true ); #ifdef MONITORS int toph, topl, both, botl, midl, midh, mid2l, mid2h; @@ -562,10 +566,7 @@ void bli_sgemm_knl_asm_24x16 // Check if C is row stride. If not, jump to the slow scattered update MOV(RAX, VAR(rs_c)) LEA(RAX, MEM(,RAX,4)) - MOV(RBX, VAR(cs_c)) LEA(RDI, MEM(RAX,RAX,2)) - CMP(RBX, IMM(1)) - JNE(SCATTEREDUPDATE) VMOVD(EDX, XMM(1)) SAL(EDX) //shift out sign bit @@ -589,74 +590,6 @@ void bli_sgemm_knl_asm_24x16 UPDATE_C_BZ_FOUR_ROWS(24,25,26,27) UPDATE_C_BZ_FOUR_ROWS(28,29,30,31) - JMP(END) - - LABEL(SCATTEREDUPDATE) - - MOV(RDI, VAR(offsetPtr)) - VMOVAPS(ZMM(2), MEM(RDI)) - /* Note that this ignores the upper 32 bits in cs_c */ - VPBROADCASTD(ZMM(3), EBX) - VPMULLD(ZMM(2), ZMM(3), ZMM(2)) - - VMOVD(EDX, XMM(1)) - SAL(EDX) //shift out sign bit - JZ(SCATTERBZ) - - UPDATE_C_ROW_SCATTERED( 8) - UPDATE_C_ROW_SCATTERED( 9) - UPDATE_C_ROW_SCATTERED(10) - UPDATE_C_ROW_SCATTERED(11) - UPDATE_C_ROW_SCATTERED(12) - UPDATE_C_ROW_SCATTERED(13) - UPDATE_C_ROW_SCATTERED(14) - UPDATE_C_ROW_SCATTERED(15) - UPDATE_C_ROW_SCATTERED(16) - UPDATE_C_ROW_SCATTERED(17) - UPDATE_C_ROW_SCATTERED(18) - UPDATE_C_ROW_SCATTERED(19) - UPDATE_C_ROW_SCATTERED(20) - UPDATE_C_ROW_SCATTERED(21) - UPDATE_C_ROW_SCATTERED(22) - UPDATE_C_ROW_SCATTERED(23) - UPDATE_C_ROW_SCATTERED(24) - UPDATE_C_ROW_SCATTERED(25) - UPDATE_C_ROW_SCATTERED(26) - UPDATE_C_ROW_SCATTERED(27) - UPDATE_C_ROW_SCATTERED(28) - UPDATE_C_ROW_SCATTERED(29) - UPDATE_C_ROW_SCATTERED(30) - UPDATE_C_ROW_SCATTERED(31) - - JMP(END) - - LABEL(SCATTERBZ) - - UPDATE_C_BZ_ROW_SCATTERED( 8) - UPDATE_C_BZ_ROW_SCATTERED( 9) - UPDATE_C_BZ_ROW_SCATTERED(10) - UPDATE_C_BZ_ROW_SCATTERED(11) - UPDATE_C_BZ_ROW_SCATTERED(12) - UPDATE_C_BZ_ROW_SCATTERED(13) - UPDATE_C_BZ_ROW_SCATTERED(14) - UPDATE_C_BZ_ROW_SCATTERED(15) - UPDATE_C_BZ_ROW_SCATTERED(16) - UPDATE_C_BZ_ROW_SCATTERED(17) - UPDATE_C_BZ_ROW_SCATTERED(18) - UPDATE_C_BZ_ROW_SCATTERED(19) - UPDATE_C_BZ_ROW_SCATTERED(20) - UPDATE_C_BZ_ROW_SCATTERED(21) - UPDATE_C_BZ_ROW_SCATTERED(22) - UPDATE_C_BZ_ROW_SCATTERED(23) - UPDATE_C_BZ_ROW_SCATTERED(24) - UPDATE_C_BZ_ROW_SCATTERED(25) - UPDATE_C_BZ_ROW_SCATTERED(26) - UPDATE_C_BZ_ROW_SCATTERED(27) - UPDATE_C_BZ_ROW_SCATTERED(28) - UPDATE_C_BZ_ROW_SCATTERED(29) - UPDATE_C_BZ_ROW_SCATTERED(30) - UPDATE_C_BZ_ROW_SCATTERED(31) - LABEL(END) #ifdef MONITORS @@ -698,6 +631,8 @@ void bli_sgemm_knl_asm_24x16 "zmm30", "zmm31", "memory" ) + GEMM_UKR_FLUSH_CT( s ); + #ifdef LOOPMON printf("looptime = \t%d\n", bloopl - tloopl); #endif diff --git a/kernels/sandybridge/3/bli_gemm_sandybridge_int_d8x4.c b/kernels/sandybridge/3/bli_gemm_sandybridge_int_d8x4.c index 7d093afade..8aeb734ddc 100644 --- a/kernels/sandybridge/3/bli_gemm_sandybridge_int_d8x4.c +++ b/kernels/sandybridge/3/bli_gemm_sandybridge_int_d8x4.c @@ -34,7 +34,6 @@ #include #include -#include "bli_misc_macro_defs.h" #include "blis.h" diff --git a/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c b/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c index 3a20cd8618..96b53194af 100644 --- a/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c +++ b/kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c @@ -288,6 +288,8 @@ static int64_t offsets[16] __attribute__((aligned(64))) = void bli_dgemm_skx_asm_16x12_l2( + dim_t m, + dim_t n, dim_t k_, double* restrict alpha, double* restrict a, @@ -301,10 +303,11 @@ void bli_dgemm_skx_asm_16x12_l2( (void)data; (void)cntx; - const int64_t* offsetPtr = &offsets[0]; - const int64_t k = k_; - const int64_t rs_c = rs_c_; - const int64_t cs_c = cs_c_; + int64_t k = k_; + int64_t rs_c = rs_c_; + int64_t cs_c = cs_c_; + + GEMM_UKR_SETUP_CT( d, 16, 12, false ); BEGIN_ASM() @@ -464,62 +467,26 @@ void bli_dgemm_skx_asm_16x12_l2( MOV(RAX, VAR(cs_c)) LEA(RAX, MEM(,RAX,8)) - MOV(RBX, VAR(rs_c)) - LEA(RBX, MEM(,RBX,8)) - - // Check if C is column stride. If not, jump to the slow scattered update - CMP(RBX, IMM(1)) - JNE(SCATTEREDUPDATE) - - VCOMISD(XMM(1), XMM(7)) - JE(COLSTORBZ) - UPDATE_C( 8, 9,10,11) - UPDATE_C(12,13,14,15) - UPDATE_C(16,17,18,19) - UPDATE_C(20,21,22,23) - UPDATE_C(24,25,26,27) - UPDATE_C(28,29,30,31) + VCOMISD(XMM(1), XMM(7)) + JE(COLSTORBZ) - JMP(END) - LABEL(COLSTORBZ) - - UPDATE_C_BZ( 8, 9,10,11) - UPDATE_C_BZ(12,13,14,15) - UPDATE_C_BZ(16,17,18,19) - UPDATE_C_BZ(20,21,22,23) - UPDATE_C_BZ(24,25,26,27) - UPDATE_C_BZ(28,29,30,31) + UPDATE_C( 8, 9,10,11) + UPDATE_C(12,13,14,15) + UPDATE_C(16,17,18,19) + UPDATE_C(20,21,22,23) + UPDATE_C(24,25,26,27) + UPDATE_C(28,29,30,31) JMP(END) - LABEL(SCATTEREDUPDATE) - - MOV(RDI, VAR(offsetPtr)) - VMOVDQA64(ZMM(2), MEM(RDI,0*64)) - VMOVDQA64(ZMM(3), MEM(RDI,1*64)) - VPBROADCASTQ(ZMM(6), RBX) - VPMULLQ(ZMM(2), ZMM(6), ZMM(2)) - VPMULLQ(ZMM(3), ZMM(6), ZMM(3)) - - VCOMISD(XMM(1), XMM(7)) - JE(SCATTERBZ) - - UPDATE_C_ROW_SCATTERED( 8, 9,10,11) - UPDATE_C_ROW_SCATTERED(12,13,14,15) - UPDATE_C_ROW_SCATTERED(16,17,18,19) - UPDATE_C_ROW_SCATTERED(20,21,22,23) - UPDATE_C_ROW_SCATTERED(24,25,26,27) - UPDATE_C_ROW_SCATTERED(28,29,30,31) - - JMP(END) - LABEL(SCATTERBZ) - - UPDATE_C_BZ_ROW_SCATTERED( 8, 9,10,11) - UPDATE_C_BZ_ROW_SCATTERED(12,13,14,15) - UPDATE_C_BZ_ROW_SCATTERED(16,17,18,19) - UPDATE_C_BZ_ROW_SCATTERED(20,21,22,23) - UPDATE_C_BZ_ROW_SCATTERED(24,25,26,27) - UPDATE_C_BZ_ROW_SCATTERED(28,29,30,31) + LABEL(COLSTORBZ) + + UPDATE_C_BZ( 8, 9,10,11) + UPDATE_C_BZ(12,13,14,15) + UPDATE_C_BZ(16,17,18,19) + UPDATE_C_BZ(20,21,22,23) + UPDATE_C_BZ(24,25,26,27) + UPDATE_C_BZ(28,29,30,31) LABEL(END) @@ -535,8 +502,7 @@ void bli_dgemm_skx_asm_16x12_l2( [beta] "m" (beta), [c] "m" (c), [rs_c] "m" (rs_c), - [cs_c] "m" (cs_c), - [offsetPtr] "m" (offsetPtr) + [cs_c] "m" (cs_c) : // register clobber list "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", @@ -545,4 +511,6 @@ void bli_dgemm_skx_asm_16x12_l2( "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/skx/3/bli_dgemm_skx_asm_16x14.c b/kernels/skx/3/bli_dgemm_skx_asm_16x14.c index 136f315323..cb1124bad5 100644 --- a/kernels/skx/3/bli_dgemm_skx_asm_16x14.c +++ b/kernels/skx/3/bli_dgemm_skx_asm_16x14.c @@ -154,6 +154,8 @@ static int64_t offsets[16] __attribute__((aligned(64))) = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15}; void bli_dgemm_skx_asm_16x14( + dim_t m, + dim_t n, dim_t k_, double* restrict alpha, double* restrict a, @@ -167,10 +169,11 @@ void bli_dgemm_skx_asm_16x14( (void)data; (void)cntx; - const int64_t* offsetPtr = &offsets[0]; - const int64_t k = k_; - const int64_t rs_c = rs_c_*8; - const int64_t cs_c = cs_c_*8; + int64_t k = k_; + int64_t rs_c = rs_c_; + int64_t cs_c = cs_c_; + + GEMM_UKR_SETUP_CT( d, 16, 14, false ); BEGIN_ASM() @@ -220,6 +223,8 @@ void bli_dgemm_skx_asm_16x14( MOV(R12, VAR(rs_c)) MOV(R10, VAR(cs_c)) + LEA(R12, MEM(,R12,8)) + LEA(R10, MEM(,R10,8)) MOV(RDI, RSI) AND(RSI, IMM(3)) @@ -320,119 +325,41 @@ void bli_dgemm_skx_asm_16x14( MOV(RAX, R12) MOV(RBX, R10) - // Check if C is column stride. - CMP(RAX, IMM(8)) - JNE(SCATTEREDUPDATE) - - VCOMISD(XMM(1), XMM(2)) - JE(COLSTORBZ) - - UPDATE_C( 4, 5) - UPDATE_C( 6, 7) - UPDATE_C( 8, 9) - UPDATE_C(10,11) - UPDATE_C(12,13) - UPDATE_C(14,15) - UPDATE_C(16,17) - UPDATE_C(18,19) - UPDATE_C(20,21) - UPDATE_C(22,23) - UPDATE_C(24,25) - UPDATE_C(26,27) - UPDATE_C(28,29) - UPDATE_C(30,31) - - JMP(END) - LABEL(COLSTORBZ) - - UPDATE_C_BZ( 4, 5) - UPDATE_C_BZ( 6, 7) - UPDATE_C_BZ( 8, 9) - UPDATE_C_BZ(10,11) - UPDATE_C_BZ(12,13) - UPDATE_C_BZ(14,15) - UPDATE_C_BZ(16,17) - UPDATE_C_BZ(18,19) - UPDATE_C_BZ(20,21) - UPDATE_C_BZ(22,23) - UPDATE_C_BZ(24,25) - UPDATE_C_BZ(26,27) - UPDATE_C_BZ(28,29) - UPDATE_C_BZ(30,31) + VCOMISD(XMM(1), XMM(2)) + JE(COLSTORBZ) + + UPDATE_C( 4, 5) + UPDATE_C( 6, 7) + UPDATE_C( 8, 9) + UPDATE_C(10,11) + UPDATE_C(12,13) + UPDATE_C(14,15) + UPDATE_C(16,17) + UPDATE_C(18,19) + UPDATE_C(20,21) + UPDATE_C(22,23) + UPDATE_C(24,25) + UPDATE_C(26,27) + UPDATE_C(28,29) + UPDATE_C(30,31) JMP(END) - LABEL(SCATTEREDUPDATE) - - VMULPD(ZMM( 4), ZMM( 4), ZMM(0)) - VMULPD(ZMM( 5), ZMM( 5), ZMM(0)) - VMULPD(ZMM( 6), ZMM( 6), ZMM(0)) - VMULPD(ZMM( 7), ZMM( 7), ZMM(0)) - VMULPD(ZMM( 8), ZMM( 8), ZMM(0)) - VMULPD(ZMM( 9), ZMM( 9), ZMM(0)) - VMULPD(ZMM(10), ZMM(10), ZMM(0)) - VMULPD(ZMM(11), ZMM(11), ZMM(0)) - VMULPD(ZMM(12), ZMM(12), ZMM(0)) - VMULPD(ZMM(13), ZMM(13), ZMM(0)) - VMULPD(ZMM(14), ZMM(14), ZMM(0)) - VMULPD(ZMM(15), ZMM(15), ZMM(0)) - VMULPD(ZMM(16), ZMM(16), ZMM(0)) - VMULPD(ZMM(17), ZMM(17), ZMM(0)) - VMULPD(ZMM(18), ZMM(18), ZMM(0)) - VMULPD(ZMM(19), ZMM(19), ZMM(0)) - VMULPD(ZMM(20), ZMM(20), ZMM(0)) - VMULPD(ZMM(21), ZMM(21), ZMM(0)) - VMULPD(ZMM(22), ZMM(22), ZMM(0)) - VMULPD(ZMM(23), ZMM(23), ZMM(0)) - VMULPD(ZMM(24), ZMM(24), ZMM(0)) - VMULPD(ZMM(25), ZMM(25), ZMM(0)) - VMULPD(ZMM(26), ZMM(26), ZMM(0)) - VMULPD(ZMM(27), ZMM(27), ZMM(0)) - VMULPD(ZMM(28), ZMM(28), ZMM(0)) - VMULPD(ZMM(29), ZMM(29), ZMM(0)) - VMULPD(ZMM(30), ZMM(30), ZMM(0)) - VMULPD(ZMM(31), ZMM(31), ZMM(0)) - - VCOMISD(XMM(1), XMM(2)) - - MOV(RDI, VAR(offsetPtr)) - VPBROADCASTQ(ZMM(0), RAX) - VPMULLQ(ZMM(2), ZMM(0), MEM(RDI)) - VPMULLQ(ZMM(3), ZMM(0), MEM(RDI,64)) - - JE(SCATTERBZ) - - UPDATE_C_COL_SCATTERED( 4, 5) - UPDATE_C_COL_SCATTERED( 6, 7) - UPDATE_C_COL_SCATTERED( 8, 9) - UPDATE_C_COL_SCATTERED(10,11) - UPDATE_C_COL_SCATTERED(12,13) - UPDATE_C_COL_SCATTERED(14,15) - UPDATE_C_COL_SCATTERED(16,17) - UPDATE_C_COL_SCATTERED(18,19) - UPDATE_C_COL_SCATTERED(20,21) - UPDATE_C_COL_SCATTERED(22,23) - UPDATE_C_COL_SCATTERED(24,25) - UPDATE_C_COL_SCATTERED(26,27) - UPDATE_C_COL_SCATTERED(28,29) - UPDATE_C_COL_SCATTERED(30,31) - - JMP(END) - LABEL(SCATTERBZ) - - UPDATE_C_BZ_COL_SCATTERED( 4, 5) - UPDATE_C_BZ_COL_SCATTERED( 6, 7) - UPDATE_C_BZ_COL_SCATTERED( 8, 9) - UPDATE_C_BZ_COL_SCATTERED(10,11) - UPDATE_C_BZ_COL_SCATTERED(12,13) - UPDATE_C_BZ_COL_SCATTERED(14,15) - UPDATE_C_BZ_COL_SCATTERED(16,17) - UPDATE_C_BZ_COL_SCATTERED(18,19) - UPDATE_C_BZ_COL_SCATTERED(20,21) - UPDATE_C_BZ_COL_SCATTERED(22,23) - UPDATE_C_BZ_COL_SCATTERED(24,25) - UPDATE_C_BZ_COL_SCATTERED(26,27) - UPDATE_C_BZ_COL_SCATTERED(28,29) - UPDATE_C_BZ_COL_SCATTERED(30,31) + LABEL(COLSTORBZ) + + UPDATE_C_BZ( 4, 5) + UPDATE_C_BZ( 6, 7) + UPDATE_C_BZ( 8, 9) + UPDATE_C_BZ(10,11) + UPDATE_C_BZ(12,13) + UPDATE_C_BZ(14,15) + UPDATE_C_BZ(16,17) + UPDATE_C_BZ(18,19) + UPDATE_C_BZ(20,21) + UPDATE_C_BZ(22,23) + UPDATE_C_BZ(24,25) + UPDATE_C_BZ(26,27) + UPDATE_C_BZ(28,29) + UPDATE_C_BZ(30,31) LABEL(END) @@ -449,8 +376,7 @@ void bli_dgemm_skx_asm_16x14( [beta] "m" (beta), [c] "m" (c), [rs_c] "m" (rs_c), - [cs_c] "m" (cs_c), - [offsetPtr] "m" (offsetPtr) + [cs_c] "m" (cs_c) : // register clobber list "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", @@ -459,4 +385,6 @@ void bli_dgemm_skx_asm_16x14( "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "memory" ) + + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c b/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c index 40af496140..63bb802153 100644 --- a/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c +++ b/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c @@ -318,6 +318,8 @@ static int64_t offsets[16] __attribute__((aligned(64))) = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15}; void bli_sgemm_skx_asm_32x12_l2( + dim_t m, + dim_t n, dim_t k_, float* restrict alpha, float* restrict a, @@ -331,10 +333,11 @@ void bli_sgemm_skx_asm_32x12_l2( (void)data; (void)cntx; - const int64_t* offsetPtr = &offsets[0]; - const int64_t k = k_; - const int64_t rs_c = rs_c_; - const int64_t cs_c = cs_c_; + int64_t k = k_; + int64_t rs_c = rs_c_; + int64_t cs_c = cs_c_; + + GEMM_UKR_SETUP_CT( s, 32, 12, false ); BEGIN_ASM() @@ -485,66 +488,26 @@ void bli_sgemm_skx_asm_32x12_l2( MOV(RAX, VAR(cs_c)) LEA(RAX, MEM(,RAX,4)) - MOV(RBX, VAR(rs_c)) - LEA(RBX, MEM(,RBX,4)) - - - // Check if C is column major (rs_c = 1). If not, jump to the slow scattered update - CMP(RBX, IMM(4)) - JNE(SCATTEREDUPDATE) - - VCOMISS(XMM(1), XMM(7)) - JE(COLSTORBZ) - UPDATE_C( 8, 9,10,11) - UPDATE_C(12,13,14,15) - UPDATE_C(16,17,18,19) - UPDATE_C(20,21,22,23) - UPDATE_C(24,25,26,27) - UPDATE_C(28,29,30,31) + VCOMISS(XMM(1), XMM(7)) + JE(COLSTORBZ) - JMP(END) - LABEL(COLSTORBZ) - - UPDATE_C_BZ( 8, 9,10,11) - UPDATE_C_BZ(12,13,14,15) - UPDATE_C_BZ(16,17,18,19) - UPDATE_C_BZ(20,21,22,23) - UPDATE_C_BZ(24,25,26,27) - UPDATE_C_BZ(28,29,30,31) + UPDATE_C( 8, 9,10,11) + UPDATE_C(12,13,14,15) + UPDATE_C(16,17,18,19) + UPDATE_C(20,21,22,23) + UPDATE_C(24,25,26,27) + UPDATE_C(28,29,30,31) JMP(END) - LABEL(SCATTEREDUPDATE) - - LEA(RDX, MEM(RCX,RBX,8)) - LEA(RDX, MEM(RDX,RBX,8)) - - MOV(RDI, VAR(offsetPtr)) - VMOVDQA64(ZMM(2), MEM(RDI,0*64)) - VMOVDQA64(ZMM(3), MEM(RDI,1*64)) - VPBROADCASTQ(ZMM(6), RBX) - VPMULLQ(ZMM(2), ZMM(6), ZMM(2)) - VPMULLQ(ZMM(3), ZMM(6), ZMM(3)) - - VCOMISS(XMM(1), XMM(7)) - JE(SCATTERBZ) - - UPDATE_C_ROW_SCATTERED( 8, 9,10,11) - UPDATE_C_ROW_SCATTERED(12,13,14,15) - UPDATE_C_ROW_SCATTERED(16,17,18,19) - UPDATE_C_ROW_SCATTERED(20,21,22,23) - UPDATE_C_ROW_SCATTERED(24,25,26,27) - UPDATE_C_ROW_SCATTERED(28,29,30,31) - - JMP(END) - LABEL(SCATTERBZ) - - UPDATE_C_BZ_ROW_SCATTERED( 8, 9,10,11) - UPDATE_C_BZ_ROW_SCATTERED(12,13,14,15) - UPDATE_C_BZ_ROW_SCATTERED(16,17,18,19) - UPDATE_C_BZ_ROW_SCATTERED(20,21,22,23) - UPDATE_C_BZ_ROW_SCATTERED(24,25,26,27) - UPDATE_C_BZ_ROW_SCATTERED(28,29,30,31) + LABEL(COLSTORBZ) + + UPDATE_C_BZ( 8, 9,10,11) + UPDATE_C_BZ(12,13,14,15) + UPDATE_C_BZ(16,17,18,19) + UPDATE_C_BZ(20,21,22,23) + UPDATE_C_BZ(24,25,26,27) + UPDATE_C_BZ(28,29,30,31) LABEL(END) @@ -560,8 +523,7 @@ void bli_sgemm_skx_asm_32x12_l2( [beta] "m" (beta), [c] "m" (c), [rs_c] "m" (rs_c), - [cs_c] "m" (cs_c), - [offsetPtr] "m" (offsetPtr) + [cs_c] "m" (cs_c) : // register clobber list "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", @@ -570,4 +532,6 @@ void bli_sgemm_skx_asm_32x12_l2( "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31", "memory" ) + + GEMM_UKR_FLUSH_CT( s ); } From 0b21b021c3ec6b0071dbe2a4628c55bc13a66789 Mon Sep 17 00:00:00 2001 From: Devin Matthews Date: Wed, 8 Dec 2021 19:37:35 +0000 Subject: [PATCH 08/12] Fix missed issues in ARM32 and ARMSVE kernels. --- .../3/bli_gemm_armsve_asm_c2vx10_unindexed.c | 2 + .../3/bli_gemm_armsve_asm_s2vx10_unindexed.c | 2 + .../3/bli_gemm_armsve_asm_z2vx10_unindexed.c | 2 + .../3/bli_gemm_armsve_asm_z2vx7_unindexed.c | 2 + .../3/bli_gemm_armsve_asm_z2vx8_unindexed.c | 2 + kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c | 24 ++-- kernels/armv7a/3/bli_gemm_armv7a_asm_wrap.c | 113 ------------------ kernels/armv7a/bli_kernels_armv7a.h | 29 +---- 8 files changed, 26 insertions(+), 150 deletions(-) delete mode 100644 kernels/armv7a/3/bli_gemm_armv7a_asm_wrap.c diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c index b86c901847..6941558967 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c @@ -46,6 +46,8 @@ void bli_cgemm_armsve_asm_2vx10_unindexed ( + dim_t m, + dim_t n, dim_t k0, scomplex* restrict alpha, scomplex* restrict a, diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c index e89daa0e32..74c4779d73 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_s2vx10_unindexed.c @@ -42,6 +42,8 @@ // 2vx10 microkernels. #include "armsve_asm_2vx10.h" +#include "arm_sve.h" + void bli_sgemm_armsve_asm_2vx10_unindexed ( dim_t m, diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c index 9ae666358b..d12e7c8353 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c @@ -46,6 +46,8 @@ void bli_zgemm_armsve_asm_2vx10_unindexed ( + dim_t m, + dim_t n, dim_t k0, dcomplex* restrict alpha, dcomplex* restrict a, diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c index fb0b596a31..b34bc5b1bd 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c @@ -46,6 +46,8 @@ void bli_zgemm_armsve_asm_2vx7_unindexed ( + dim_t m, + dim_t n, dim_t k0, dcomplex* restrict alpha, dcomplex* restrict a, diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c index 13fc26c8c2..a6f2c0481f 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c @@ -46,6 +46,8 @@ void bli_zgemm_armsve_asm_2vx8_unindexed ( + dim_t m, + dim_t n, dim_t k0, dcomplex* restrict alpha, dcomplex* restrict a, diff --git a/kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c b/kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c index 9c9b691fc4..c248285c38 100644 --- a/kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c +++ b/kernels/armv7a/3/bli_gemm_armv7a_asm_d4x4.c @@ -37,8 +37,6 @@ extern void bli_sgemm_armv7a_ker_4x4 ( - uint32_t m, - uint32_t n, uint32_t k, float* restrict alpha, float* restrict a, @@ -64,7 +62,9 @@ void bli_sgemm_armv7a_asm_4x4 { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - bli_sgemm_armv7a_ker_4x4( m, n, k, alpha, a, b, beta, c, rs_c, cs_c, data ); + GEMM_UKR_SETUP_CT_ANY( s, 4, 4, false ); + bli_sgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data ); + GEMM_UKR_FLUSH_CT( s ); } @@ -72,8 +72,6 @@ void bli_sgemm_armv7a_asm_4x4 extern void bli_dgemm_armv7a_ker_4x4 ( - uint32_t m, - uint32_t n, uint32_t k, double* restrict alpha, double* restrict a, @@ -99,7 +97,9 @@ void bli_dgemm_armv7a_asm_4x4 { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - bli_dgemm_armv7a_ker_4x4( m, n, k, alpha, a, b, beta, c, rs_c, cs_c, data ); + GEMM_UKR_SETUP_CT_ANY( d, 4, 4, false ); + bli_dgemm_armv7a_ker_4x4( k, alpha, a, b, beta, c, rs_c, cs_c, data ); + GEMM_UKR_FLUSH_CT( d ); } @@ -107,8 +107,6 @@ void bli_dgemm_armv7a_asm_4x4 extern void bli_cgemm_armv7a_ker_2x2 ( - uint32_t m, - uint32_t n, uint32_t k, scomplex* restrict alpha, scomplex* restrict a, @@ -134,7 +132,9 @@ void bli_cgemm_armv7a_asm_2x2 { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - bli_cgemm_armv7a_ker_2x2( m, n, k, alpha, a, b, beta, c, rs_c, cs_c, data ); + GEMM_UKR_SETUP_CT_ANY( c, 2, 2, false ); + bli_cgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data ); + GEMM_UKR_FLUSH_CT( c ); } @@ -142,8 +142,6 @@ void bli_cgemm_armv7a_asm_2x2 extern void bli_zgemm_armv7a_ker_2x2 ( - uint32_t m, - uint32_t n, uint32_t k, dcomplex* restrict alpha, dcomplex* restrict a, @@ -169,6 +167,8 @@ void bli_zgemm_armv7a_asm_2x2 { // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - bli_zgemm_armv7a_ker_2x2( m, n, k, alpha, a, b, beta, c, rs_c, cs_c, data ); + GEMM_UKR_SETUP_CT_ANY( z, 2, 2, false ); + bli_zgemm_armv7a_ker_2x2( k, alpha, a, b, beta, c, rs_c, cs_c, data ); + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/armv7a/3/bli_gemm_armv7a_asm_wrap.c b/kernels/armv7a/3/bli_gemm_armv7a_asm_wrap.c deleted file mode 100644 index ab2f91ced7..0000000000 --- a/kernels/armv7a/3/bli_gemm_armv7a_asm_wrap.c +++ /dev/null @@ -1,113 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" -#include "../bli_kernels_armv7a.h" - -void bli_sgemm_armv7a_asm_wrap_4x4 - ( - dim_t m, - dim_t n, - dim_t k, - float* restrict alpha, - float* restrict a, - float* restrict b, - float* restrict beta, - float* restrict c, inc_t rs_c, inc_t cs_c, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ) -{ - GEMM_UKR_SETUP_CT( s, 4, 4, false ); - bli_sgemm_armv7a_asm_4x4(k, alpha, a, b, beta, c, rs_c, cs_c, data, cntx); - GEMM_UKR_FLUSH_CT( s ); -} - -void bli_dgemm_armv7a_asm_wrap_4x4 - ( - dim_t m, - dim_t n, - dim_t k, - double* restrict alpha, - double* restrict a, - double* restrict b, - double* restrict beta, - double* restrict c, inc_t rs_c, inc_t cs_c, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ) -{ - GEMM_UKR_SETUP_CT( d, 4, 4, false ); - bli_dgemm_armv7a_asm_4x4(k, alpha, a, b, beta, c, rs_c, cs_c, data, cntx); - GEMM_UKR_FLUSH_CT( d ); -} - -void bli_cgemm_armv7a_asm_wrap_2x2 - ( - dim_t m, - dim_t n, - dim_t k, - scomplex* restrict alpha, - scomplex* restrict a, - scomplex* restrict b, - scomplex* restrict beta, - scomplex* restrict c, inc_t rs_c, inc_t cs_c, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ) -{ - GEMM_UKR_SETUP_CT( c, 2, 2, false ); - bli_cgemm_armv7a_asm_2x2(k, alpha, a, b, beta, c, rs_c, cs_c, data, cntx); - GEMM_UKR_FLUSH_CT( c ); -} - -void bli_zgemm_armv7a_asm_wrap_2x2 - ( - dim_t m, - dim_t n, - dim_t k, - dcomplex* restrict alpha, - dcomplex* restrict a, - dcomplex* restrict b, - dcomplex* restrict beta, - dcomplex* restrict c, inc_t rs_c, inc_t cs_c, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ) -{ - GEMM_UKR_SETUP_CT( z, 2, 2, false ); - bli_zgemm_armv7a_asm_2x2(k, alpha, a, b, beta, c, rs_c, cs_c, data, cntx); - GEMM_UKR_FLUSH_CT( z ); -} - diff --git a/kernels/armv7a/bli_kernels_armv7a.h b/kernels/armv7a/bli_kernels_armv7a.h index 9fe3b4cf42..7eaf16e655 100644 --- a/kernels/armv7a/bli_kernels_armv7a.h +++ b/kernels/armv7a/bli_kernels_armv7a.h @@ -32,31 +32,10 @@ */ -void bli_sgemm_armv7a_asm_4x4(dim_t k, float *restrict alpha, - float *restrict a, float *restrict b, - float *restrict beta, float *restrict c, - inc_t rs_c, inc_t cs_c, auxinfo_t *restrict data, - cntx_t *restrict cntx); -void bli_dgemm_armv7a_asm_4x4(dim_t k, double *restrict alpha, - double *restrict a, double *restrict b, - double *restrict beta, double *restrict c, - inc_t rs_c, inc_t cs_c, auxinfo_t *restrict data, - cntx_t *restrict cntx); -void bli_cgemm_armv7a_asm_2x2(dim_t k, scomplex *restrict alpha, - scomplex *restrict a, scomplex *restrict b, - scomplex *restrict beta, scomplex *restrict c, - inc_t rs_c, inc_t cs_c, auxinfo_t *restrict data, - cntx_t *restrict cntx); -void bli_zgemm_armv7a_asm_2x2(dim_t k, dcomplex *restrict alpha, - dcomplex *restrict a, dcomplex *restrict b, - dcomplex *restrict beta, dcomplex *restrict c, - inc_t rs_c, inc_t cs_c, auxinfo_t *restrict data, - cntx_t *restrict cntx); - -GEMM_UKR_PROT( float, s, gemm_armv7a_asm_wrap_4x4 ) -GEMM_UKR_PROT( double, d, gemm_armv7a_asm_wrap_4x4 ) -GEMM_UKR_PROT( scomplex, c, gemm_armv7a_asm_wrap_2x2 ) -GEMM_UKR_PROT( dcomplex, z, gemm_armv7a_asm_wrap_2x2 ) +GEMM_UKR_PROT( float, s, gemm_armv7a_asm_4x4 ) +GEMM_UKR_PROT( double, d, gemm_armv7a_asm_4x4 ) +GEMM_UKR_PROT( scomplex, c, gemm_armv7a_asm_2x2 ) +GEMM_UKR_PROT( dcomplex, z, gemm_armv7a_asm_2x2 ) GEMM_UKR_PROT( float, s, gemm_armv7a_int_4x4 ) GEMM_UKR_PROT( double, d, gemm_armv7a_int_4x4 ) From 550bff706834bd4acde051e99f4a88ab7e326107 Mon Sep 17 00:00:00 2001 From: Devin Matthews Date: Wed, 8 Dec 2021 13:53:56 -0600 Subject: [PATCH 09/12] Fix munged Sandybridge sgemm ukr. --- kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c b/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c index 0cf1dd20c3..ade7656252 100644 --- a/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c +++ b/kernels/sandybridge/3/bli_gemm_sandybridge_asm_d8x4.c @@ -79,7 +79,7 @@ void bli_sgemm_sandybridge_asm_8x8 mov(var(c), rcx) // load address of c mov(var(cs_c), rdi) // load cs_c - lea(mem( rdi, 4), rdi) // cs_c *= sizeof(float) + lea(mem(, rdi, 4), rdi) // cs_c *= sizeof(float) lea(mem(rcx, rdi, 4), r10) // load address of c + 4*cs_c; lea(mem(rdi, rdi, 2), r14) // r14 = 3*cs_c; @@ -412,11 +412,11 @@ void bli_sgemm_sandybridge_asm_8x8 mov(var(rs_c), rsi) // load rs_c - lea(mem( rsi, 4), rsi) // rsi = rs_c * sizeof(float) + lea(mem(, rsi, 4), rsi) // rsi = rs_c * sizeof(float) lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*rs_c; - lea(mem( rsi, 2), r12) // r12 = 2*rs_c; + lea(mem(, rsi, 2), r12) // r12 = 2*rs_c; lea(mem(r12, rsi, 1), r13) // r13 = 3*rs_c; @@ -529,9 +529,7 @@ void bli_sgemm_sandybridge_asm_8x8 "memory" ) - \ -if ( _use_ct ) bli_sxpbys_mxn ( m , n , _ct , _rs_ct , _cs_ct , _beta , _c , _rs_c , _cs_c ) ; - ; + GEMM_UKR_FLUSH_CT( s ); } void bli_dgemm_sandybridge_asm_8x4 From dda657f743b057cb005cf07762b08d962978bef3 Mon Sep 17 00:00:00 2001 From: "Field G. Van Zee" Date: Fri, 17 Dec 2021 16:53:55 -0600 Subject: [PATCH 10/12] Trivial whitespace, comment, and code changes. --- .../3/bli_gemmtrsm_l_template_noopt_mxn.c | 4 +- .../3/bli_gemmtrsm_u_template_noopt_mxn.c | 4 +- frame/1m/packm/bli_packm_alloc.c | 67 +- frame/1m/packm/bli_packm_alloc.h | 28 +- frame/3/bli_l3_cntl.c | 18 +- frame/3/bli_l3_ukr_oapi.c | 4 +- frame/3/bli_l3_ukr_tapi.c | 63 +- frame/3/gemm/bli_gemm_cntl.c | 8 +- frame/3/gemm/bli_gemm_ker_var2.c | 231 +- frame/3/gemm/bli_gemm_md_c2r_ref.c | 42 +- frame/3/gemm/bli_gemm_var.h | 4 +- frame/3/gemm/ind/bli_gemm_ind_opt.h | 2 +- frame/3/gemmt/bli_gemmt_l_ker_var2.c | 12 +- frame/3/gemmt/bli_gemmt_u_ker_var2.c | 12 +- frame/3/gemmt/other/bli_gemmt_l_ker_var2.c | 409 ++ frame/3/gemmt/other/bli_gemmt_u_ker_var2.c | 409 ++ frame/3/trmm/bli_trmm_ll_ker_var2.c | 12 +- frame/3/trmm/bli_trmm_lu_ker_var2.c | 8 +- frame/3/trmm/bli_trmm_rl_ker_var2.c | 14 +- frame/3/trmm/bli_trmm_ru_ker_var2.c | 16 +- frame/3/trsm/bli_trsm_cntl.c | 12 +- frame/3/trsm/bli_trsm_ll_ker_var2.c | 4 +- frame/3/trsm/bli_trsm_lu_ker_var2.c | 4 +- frame/3/trsm/bli_trsm_rl_ker_var2.c | 6 +- frame/3/trsm/bli_trsm_ru_ker_var2.c | 6 +- frame/include/bli_misc_macro_defs.h | 90 +- frame/include/bli_type_defs.h | 6 +- .../3/bli_gemm_armsve_asm_c2vx10_unindexed.c | 6 +- .../3/bli_gemm_armsve_asm_z2vx10_unindexed.c | 6 +- .../3/bli_gemm_armsve_asm_z2vx7_unindexed.c | 6 +- .../3/bli_gemm_armsve_asm_z2vx8_unindexed.c | 6 +- kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c | 3914 ++++++++--------- .../3/bli_gemm_bulldozer_asm_d4x6_fma4.c | 584 +-- kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c | 1250 +++--- kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c | 382 +- kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c | 364 +- .../3/bli_gemm_piledriver_asm_d8x3.c | 366 +- kernels/power10/3/bli_i4gemm_power10_mma.c | 28 +- kernels/power10/3/bli_i8gemm_power10_mma.c | 28 +- kernels/power9/3/bli_gemm_power9_asm_d12x6.c | 202 +- .../3/bli_gemm_sandybridge_asm_d8x4.c | 794 ++-- .../3/bli_gemm_sandybridge_int_d8x4.c | 52 +- kernels/skx/3/bli_dgemm_skx_asm_16x12_l2.c | 25 +- kernels/skx/3/bli_dgemm_skx_asm_16x14.c | 25 +- kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c | 27 +- ref_kernels/3/bb/bli_gemmtrsmbb_ref.c | 4 +- ref_kernels/3/bli_gemmtrsm_ref.c | 4 +- ref_kernels/ind/bli_gemm1m_ref.c | 14 +- ref_kernels/ind/bli_gemmtrsm1m_ref.c | 4 +- test/syrk_diagonal/syrk_diagonal_example.c | 214 +- test/syrk_diagonal/syrk_diagonal_example2.c | 233 +- 51 files changed, 5447 insertions(+), 4586 deletions(-) create mode 100644 frame/3/gemmt/other/bli_gemmt_l_ker_var2.c create mode 100644 frame/3/gemmt/other/bli_gemmt_u_ker_var2.c diff --git a/config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c b/config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c index 1582566ae6..87c21f7edf 100644 --- a/config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c +++ b/config/template/kernels/3/bli_gemmtrsm_l_template_noopt_mxn.c @@ -86,8 +86,8 @@ void bli_zgemmtrsm_l_template_noopt /* b11 = alpha * b11 - a10 * b01; */ bli_zgemm_template_noopt ( - mr, - nr, + mr, + nr, k, minus_one, a10, diff --git a/config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c b/config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c index 18f288f719..0b4544ae1d 100644 --- a/config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c +++ b/config/template/kernels/3/bli_gemmtrsm_u_template_noopt_mxn.c @@ -86,8 +86,8 @@ void bli_zgemmtrsm_u_template_noopt /* b11 = alpha * b11 - a12 * b21; */ bli_zgemm_template_noopt ( - mr, - nr, + mr, + nr, k, minus_one, a10, diff --git a/frame/1m/packm/bli_packm_alloc.c b/frame/1m/packm/bli_packm_alloc.c index 5015984848..c316932145 100644 --- a/frame/1m/packm/bli_packm_alloc.c +++ b/frame/1m/packm/bli_packm_alloc.c @@ -36,34 +36,34 @@ #include "blis.h" void* bli_packm_alloc - ( - siz_t size_needed, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) + ( + siz_t size_needed, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) { // Query the pack buffer type from the control tree node. packbuf_t pack_buf_type = bli_cntl_packm_params_pack_buf_type( cntl ); - return bli_packm_alloc_ex - ( - size_needed, - pack_buf_type, - rntm, - cntl, - thread - ); + return bli_packm_alloc_ex + ( + size_needed, + pack_buf_type, + rntm, + cntl, + thread + ); } void* bli_packm_alloc_ex - ( - siz_t size_needed, - packbuf_t pack_buf_type, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ) + ( + siz_t size_needed, + packbuf_t pack_buf_type, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) { // Query the address of the mem_t entry within the control tree node. mem_t* cntl_mem_p = bli_cntl_pack_mem( cntl ); @@ -74,7 +74,7 @@ void* bli_packm_alloc_ex siz_t cntl_mem_size = 0; if ( bli_mem_is_alloc( cntl_mem_p ) ) - cntl_mem_size = bli_mem_size( cntl_mem_p ); + cntl_mem_size = bli_mem_size( cntl_mem_p ); if ( cntl_mem_size < size_needed ) { @@ -83,14 +83,15 @@ void* bli_packm_alloc_ex // The chief thread releases the existing block associated with // the mem_t entry in the control tree, and then re-acquires a // new block, saving the associated mem_t entry to local_mem_s. - if ( bli_mem_is_alloc( cntl_mem_p ) ) - { - bli_pba_release - ( - rntm, - cntl_mem_p - ); - } + if ( bli_mem_is_alloc( cntl_mem_p ) ) + { + bli_pba_release + ( + rntm, + cntl_mem_p + ); + } + bli_pba_acquire_m ( rntm, @@ -108,9 +109,9 @@ void* bli_packm_alloc_ex // this thread's control tree node. *cntl_mem_p = *local_mem_p; - // Barrier so that the master thread doesn't return from the function - // before we are done reading. - bli_thread_barrier( thread ); + // Barrier so that the master thread doesn't return from the function + // before we are done reading. + bli_thread_barrier( thread ); } return bli_mem_buffer( cntl_mem_p ); diff --git a/frame/1m/packm/bli_packm_alloc.h b/frame/1m/packm/bli_packm_alloc.h index f9870562d4..5a5cf126b1 100644 --- a/frame/1m/packm/bli_packm_alloc.h +++ b/frame/1m/packm/bli_packm_alloc.h @@ -32,20 +32,20 @@ */ -BLIS_EXPORT_BLIS void* bli_packm_alloc - ( - siz_t size_needed, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ); +BLIS_EXPORT_BLIS void* bli_packm_alloc + ( + siz_t size_needed, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ); BLIS_EXPORT_BLIS void* bli_packm_alloc_ex - ( - siz_t size_needed, - packbuf_t pack_buf_type, - rntm_t* rntm, - cntl_t* cntl, - thrinfo_t* thread - ); + ( + siz_t size_needed, + packbuf_t pack_buf_type, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ); diff --git a/frame/3/bli_l3_cntl.c b/frame/3/bli_l3_cntl.c index c9207c226a..83ff8e5af5 100644 --- a/frame/3/bli_l3_cntl.c +++ b/frame/3/bli_l3_cntl.c @@ -57,7 +57,14 @@ void bli_l3_cntl_create_if family == BLIS_GEMMT || family == BLIS_TRMM ) { - *cntl_use = bli_gemm_cntl_create( rntm, family, schema_a, schema_b, bli_obj_ker_fn( c ) ); + *cntl_use = bli_gemm_cntl_create + ( + rntm, + family, + schema_a, + schema_b, + bli_obj_ker_fn( c ) + ); } else // if ( family == BLIS_TRSM ) { @@ -66,7 +73,14 @@ void bli_l3_cntl_create_if if ( bli_obj_is_triangular( a ) ) side = BLIS_LEFT; else side = BLIS_RIGHT; - *cntl_use = bli_trsm_cntl_create( rntm, side, schema_a, schema_b, bli_obj_ker_fn( c ) ); + *cntl_use = bli_trsm_cntl_create + ( + rntm, + side, + schema_a, + schema_b, + bli_obj_ker_fn( c ) + ); } } else diff --git a/frame/3/bli_l3_ukr_oapi.c b/frame/3/bli_l3_ukr_oapi.c index 805a3f8f0c..b8f2e00e6a 100644 --- a/frame/3/bli_l3_ukr_oapi.c +++ b/frame/3/bli_l3_ukr_oapi.c @@ -77,8 +77,8 @@ void PASTEMAC0(opname) \ \ f \ ( \ - m, \ - n, \ + m, \ + n, \ k, \ buf_alpha, \ buf_a, \ diff --git a/frame/3/bli_l3_ukr_tapi.c b/frame/3/bli_l3_ukr_tapi.c index 178981976e..ab745d12b3 100644 --- a/frame/3/bli_l3_ukr_tapi.c +++ b/frame/3/bli_l3_ukr_tapi.c @@ -60,18 +60,19 @@ void PASTEMAC(ch,opname) \ PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \ \ /* Invoke the typed function for the given datatype. */ \ - f( \ - m, \ - n, \ - k, \ - alpha, \ - a, \ - b, \ - beta, \ - c, rs_c, cs_c, \ - data, \ - cntx \ - ); \ + f \ + ( \ + m, \ + n, \ + k, \ + alpha, \ + a, \ + b, \ + beta, \ + c, rs_c, cs_c, \ + data, \ + cntx \ + ); \ } \ INSERT_GENTFUNC_BASIC2( gemm_ukernel, gemm, BLIS_GEMM_UKR ) @@ -102,17 +103,18 @@ void PASTEMAC(ch,opname) \ PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \ \ /* Invoke the typed function for the given datatype. */ \ - f( \ - k, \ - alpha, \ - a1x, \ - a11, \ - bx1, \ - b11, \ - c11, rs_c, cs_c, \ - data, \ - cntx \ - ); \ + f \ + ( \ + k, \ + alpha, \ + a1x, \ + a11, \ + bx1, \ + b11, \ + c11, rs_c, cs_c, \ + data, \ + cntx \ + ); \ } \ INSERT_GENTFUNC_BASIC2( gemmtrsm_l_ukernel, gemmtrsm, BLIS_GEMMTRSM_L_UKR ) @@ -140,13 +142,14 @@ void PASTEMAC(ch,opname) \ PASTECH2(ch,tname,_ukr_ft) f = bli_cntx_get_l3_vir_ukr_dt( dt, kerid, cntx ); \ \ /* Invoke the typed function for the given datatype. */ \ - f( \ - a, \ - b, \ - c, rs_c, cs_c, \ - data, \ - cntx \ - ); \ + f \ + ( \ + a, \ + b, \ + c, rs_c, cs_c, \ + data, \ + cntx \ + ); \ } \ INSERT_GENTFUNC_BASIC2( trsm_l_ukernel, trsm, BLIS_TRSM_L_UKR ) diff --git a/frame/3/gemm/bli_gemm_cntl.c b/frame/3/gemm/bli_gemm_cntl.c index 71dad78f30..052c812a33 100644 --- a/frame/3/gemm/bli_gemm_cntl.c +++ b/frame/3/gemm/bli_gemm_cntl.c @@ -60,13 +60,15 @@ cntl_t* bli_gemmbp_cntl_create { void_fp macro_kernel_fp; - // Use the function pointers to the macrokernels that use slab - // assignment of micropanels to threads in the jr and ir loops. + // Choose the default macrokernel based on the operation family... if ( family == BLIS_GEMM ) macro_kernel_fp = bli_gemm_ker_var2; else if ( family == BLIS_GEMMT ) macro_kernel_fp = bli_gemmt_x_ker_var2; else if ( family == BLIS_TRMM ) macro_kernel_fp = bli_trmm_xx_ker_var2; else /* should never execute */ macro_kernel_fp = NULL; - if ( ker ) macro_kernel_fp = ker; + + // ...unless a non-NULL kernel function pointer is passed in, in which + // case we use that instead. + if ( ker ) macro_kernel_fp = ker; // Create two nodes for the macro-kernel. cntl_t* gemm_cntl_bu_ke = bli_gemm_cntl_create_node diff --git a/frame/3/gemm/bli_gemm_ker_var2.c b/frame/3/gemm/bli_gemm_ker_var2.c index d49c9823e2..6de361194d 100644 --- a/frame/3/gemm/bli_gemm_ker_var2.c +++ b/frame/3/gemm/bli_gemm_ker_var2.c @@ -46,6 +46,7 @@ typedef void (*xpbys_mxn_vft) #undef GENTFUNC2 #define GENTFUNC2(ctypex,ctypey,chx,chy,op) \ +\ void PASTEMAC2(chx,chy,op) \ ( \ dim_t m, \ @@ -55,9 +56,9 @@ void PASTEMAC2(chx,chy,op) \ void* y, inc_t rs_y, inc_t cs_y \ ) \ { \ - ctypex* restrict x_cast = x; \ - ctypey* restrict b_cast = b; \ - ctypey* restrict y_cast = y; \ + ctypex* restrict x_cast = x; \ + ctypey* restrict b_cast = b; \ + ctypey* restrict y_cast = y; \ \ PASTEMAC3(chx,chy,chy,xpbys_mxn) \ ( \ @@ -109,7 +110,7 @@ void bli_gemm_ker_var2 inc_t rs_c = bli_obj_row_stride( c ); inc_t cs_c = bli_obj_col_stride( c ); - /* If any dimension is zero, return immediately. */ + // If any dimension is zero, return immediately. if ( bli_zero_dim3( m, n, k ) ) return; // Detach and multiply the scalars attached to A and B. @@ -158,65 +159,69 @@ void bli_gemm_ker_var2 // Tweak parameters in select mixed domain cases (rcc, crc, ccr). if ( bli_cntx_method( cntx ) == BLIS_NAT ) { - bli_gemm_md_ker_var2_recast - ( - &dt_exec, - bli_obj_dt( a ), - bli_obj_dt( b ), - &dt_c, - &m, &n, &k, - &pd_a, &ps_a, - &pd_b, &ps_b, - c, - &rs_c, &cs_c - ); - } + bli_gemm_md_ker_var2_recast + ( + &dt_exec, + bli_obj_dt( a ), + bli_obj_dt( b ), + &dt_c, + &m, &n, &k, + &pd_a, &ps_a, + &pd_b, &ps_b, + c, + &rs_c, &cs_c + ); + } #endif - siz_t dt_size = bli_dt_size( dt_exec ); - siz_t dt_c_size = bli_dt_size( dt_c ); + siz_t dt_size = bli_dt_size( dt_exec ); + siz_t dt_c_size = bli_dt_size( dt_c ); - /* Alias some constants to simpler names. */ + // Alias some constants to simpler names. const dim_t MR = pd_a; const dim_t NR = pd_b; - /*const dim_t PACKMR = cs_a;*/ - /*const dim_t PACKNR = rs_b;*/ + //const dim_t PACKMR = cs_a; + //const dim_t PACKNR = rs_b; - /* Query the context for the micro-kernel address and cast it to its - function pointer type. */ + // Query the context for the micro-kernel address and cast it to its + // function pointer type. gemm_ukr_vft gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt_exec, BLIS_GEMM_UKR, cntx ); - gemm_ker_params_t* params = bli_obj_ker_params( c ); - gemm_ukr_vft user_ukr = params ? params->ukr : NULL; - if ( user_ukr ) gemm_ukr = user_ukr; - - /* Temporary C buffer for edge cases. Note that the strides of this - temporary buffer are set so that they match the storage of the - original C matrix. For example, if C is column-stored, ct will be - column-stored as well. */ + // Query the params field from the obj_t. If it is non-NULL, grab the ukr + // field of the params struct. If that function pointer is non-NULL, use it + // as our microkernel instead of the default microkernel queried from the + // cntx above. + gemm_ker_params_t* params = bli_obj_ker_params( c ); + gemm_ukr_vft user_ukr = params ? params->ukr : NULL; + if ( user_ukr ) gemm_ukr = user_ukr; + + // Temporary C buffer for edge cases. Note that the strides of this + // temporary buffer are set so that they match the storage of the + // original C matrix. For example, if C is column-stored, ct will be + // column-stored as well. char ct[ BLIS_STACK_BUF_MAX_SIZE ] - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt_exec, BLIS_GEMM_UKR, cntx ); const inc_t rs_ct = ( col_pref ? 1 : NR ); const inc_t cs_ct = ( col_pref ? MR : 1 ); char* zero = bli_obj_buffer_for_const( dt_exec, &BLIS_ZERO ); - /* - Assumptions/assertions: - rs_a == 1 - cs_a == PACKMR - pd_a == MR - ps_a == stride to next micro-panel of A - rs_b == PACKNR - cs_b == 1 - pd_b == NR - ps_b == stride to next micro-panel of B - rs_c == (no assumptions) - cs_c == (no assumptions) - */ - - /* Compute number of primary and leftover components of the m and n - dimensions. */ + // + // Assumptions/assertions: + // rs_a == 1 + // cs_a == PACKMR + // pd_a == MR + // ps_a == stride to next micro-panel of A + // rs_b == PACKNR + // cs_b == 1 + // pd_b == NR + // ps_b == stride to next micro-panel of B + // rs_c == (no assumptions) + // cs_c == (no assumptions) + // + + // Compute number of primary and leftover components of the m and n + // dimensions. dim_t n_iter = n / NR; dim_t n_left = n % NR; @@ -226,7 +231,7 @@ void bli_gemm_ker_var2 if ( n_left ) ++n_iter; if ( m_left ) ++m_iter; - /* Determine some increments used to step through A, B, and C. */ + // Determine some increments used to step through A, B, and C. inc_t rstep_a = ps_a * dt_size; inc_t cstep_b = ps_b * dt_size; @@ -236,24 +241,24 @@ void bli_gemm_ker_var2 auxinfo_t aux; - /* Save the pack schemas of A and B to the auxinfo_t object. */ + // Save the pack schemas of A and B to the auxinfo_t object. bli_auxinfo_set_schema_a( schema_a, &aux ); bli_auxinfo_set_schema_b( schema_b, &aux ); - /* Save the imaginary stride of A and B to the auxinfo_t object. */ + // Save the imaginary stride of A and B to the auxinfo_t object. bli_auxinfo_set_is_a( is_a, &aux ); bli_auxinfo_set_is_b( is_b, &aux ); - /* Save the virtual microkernel address and the params. */ - bli_auxinfo_set_ukr( gemm_ukr, &aux ); - bli_auxinfo_set_params( params, &aux ); + // Save the virtual microkernel address and the params. + bli_auxinfo_set_ukr( gemm_ukr, &aux ); + bli_auxinfo_set_params( params, &aux ); - /* The 'thread' argument points to the thrinfo_t node for the 2nd (jr) - loop around the microkernel. Here we query the thrinfo_t node for the - 1st (ir) loop around the microkernel. */ + // The 'thread' argument points to the thrinfo_t node for the 2nd (jr) + // loop around the microkernel. Here we query the thrinfo_t node for the + // 1st (ir) loop around the microkernel. thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); - /* Query the number of threads and thread ids for each loop. */ + // Query the number of threads and thread ids for each loop. dim_t jr_nt = bli_thread_n_way( thread ); dim_t jr_tid = bli_thread_work_id( thread ); dim_t ir_nt = bli_thread_n_way( caucus ); @@ -263,13 +268,13 @@ void bli_gemm_ker_var2 dim_t ir_start, ir_end; dim_t jr_inc, ir_inc; - /* Determine the thread range and increment for the 2nd and 1st loops. - NOTE: The definition of bli_thread_range_jrir() will depend on whether - slab or round-robin partitioning was requested at configure-time. */ + // Determine the thread range and increment for the 2nd and 1st loops. + // NOTE: The definition of bli_thread_range_jrir() will depend on whether + // slab or round-robin partitioning was requested at configure-time. bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); - /* Loop over the n dimension (NR columns at a time). */ + // Loop over the n dimension (NR columns at a time). for ( dim_t j = jr_start; j < jr_end; j += jr_inc ) { char* b1 = b_cast + j * cstep_b; @@ -277,10 +282,10 @@ void bli_gemm_ker_var2 dim_t n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); - /* Initialize our next panel of B to be the current panel of B. */ + // Initialize our next panel of B to be the current panel of B. char* b2 = b1; - /* Loop over the m dimension (MR rows at a time). */ + // Loop over the m dimension (MR rows at a time). for ( dim_t i = ir_start; i < ir_end; i += ir_inc ) { char* a1 = a_cast + i * rstep_a; @@ -288,7 +293,7 @@ void bli_gemm_ker_var2 dim_t m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); - /* Compute the addresses of the next panels of A and B. */ + // Compute the addresses of the next panels of A and B. char* a2 = bli_gemm_get_next_a_upanel( a1, rstep_a, ir_inc ); if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) { @@ -298,54 +303,58 @@ void bli_gemm_ker_var2 b2 = b_cast; } - /* Save addresses of next panels of A and B to the auxinfo_t - object. */ + // Save addresses of next panels of A and B to the auxinfo_t + // object. bli_auxinfo_set_next_a( a2, &aux ); bli_auxinfo_set_next_b( b2, &aux ); - if ( dt_exec == dt_c ) - { - /* Invoke the gemm micro-kernel. */ - gemm_ukr - ( - m_cur, - n_cur, - k, - alpha_cast, - a1, - b1, - beta_cast, - c11, rs_c, cs_c, - &aux, - cntx - ); - } - else - { - /* Invoke the gemm micro-kernel. */ - gemm_ukr - ( - MR, - NR, - k, - alpha_cast, - a1, - b1, - zero, - &ct, rs_ct, cs_ct, - &aux, - cntx - ); - - /* Accumulate to C with type-casting. */ - xbpys_mxn[ dt_exec ][ dt_c ] - ( - m_cur, n_cur, - &ct, rs_ct, cs_ct, - beta_cast, - c11, rs_c, cs_c - ); - } + // Edge case handling now occurs within the microkernel itself, but + // we must still explicitly accumulate to a temporary microtile in + // situations where a virtual microkernel is being used, such as + // during the 1m method or some cases of mixed datatypes. + if ( dt_exec == dt_c ) + { + // Invoke the gemm micro-kernel. + gemm_ukr + ( + m_cur, + n_cur, + k, + alpha_cast, + a1, + b1, + beta_cast, + c11, rs_c, cs_c, + &aux, + cntx + ); + } + else + { + // Invoke the gemm micro-kernel. + gemm_ukr + ( + MR, + NR, + k, + alpha_cast, + a1, + b1, + zero, + &ct, rs_ct, cs_ct, + &aux, + cntx + ); + + // Accumulate to C with type-casting. + xbpys_mxn[ dt_exec ][ dt_c ] + ( + m_cur, n_cur, + &ct, rs_ct, cs_ct, + beta_cast, + c11, rs_c, cs_c + ); + } } } diff --git a/frame/3/gemm/bli_gemm_md_c2r_ref.c b/frame/3/gemm/bli_gemm_md_c2r_ref.c index 13d66ae9a2..bbd9190a9a 100644 --- a/frame/3/gemm/bli_gemm_md_c2r_ref.c +++ b/frame/3/gemm/bli_gemm_md_c2r_ref.c @@ -64,8 +64,8 @@ void PASTEMAC2(ch,opname,suf) \ const dim_t mr = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ const dim_t nr = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ \ - dim_t mr_r = mr; \ - dim_t nr_r = nr; \ + dim_t mr_r = mr; \ + dim_t nr_r = nr; \ \ ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ / sizeof( ctype_r ) ] \ @@ -154,18 +154,16 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ rs_c_use = rs_ct; \ cs_c_use = cs_ct; \ \ - /* Convert the strides from being in units of complex elements to - be in units of real elements. Note that we don't need to check for - general storage here because that case corresponds to the scenario - where we are using the ct buffer and its rs_ct/cs_ct strides. */ \ + /* Convert the strides and corresponding microtile dimension from being + in units of complex elements to be in units of real elements. */ \ if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) { cs_c_use *= 2; mr_r *= 2; } \ else { rs_c_use *= 2; nr_r *= 2; }\ \ /* c = beta * c + alpha_r * a * b; */ \ rgemm_ukr \ ( \ - mr_r, \ - nr_r, \ + mr_r, \ + nr_r, \ k, \ alpha_r, \ a_r, \ @@ -175,14 +173,12 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ data, \ cntx \ ); \ -\ - dim_t i, j; \ \ /* Accumulate the final result in ct back to c. */ \ if ( PASTEMAC(ch,eq1)( *beta ) ) \ { \ - for ( j = 0; j < n; ++j ) \ - for ( i = 0; i < m; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,adds)( *(ct + i*rs_ct + j*cs_ct), \ *(c + i*rs_c + j*cs_c ) ); \ @@ -190,8 +186,8 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ } \ else if ( PASTEMAC(ch,eq0)( *beta ) ) \ { \ - for ( j = 0; j < n; ++j ) \ - for ( i = 0; i < m; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,copys)( *(ct + i*rs_ct + j*cs_ct), \ *(c + i*rs_c + j*cs_c ) ); \ @@ -199,8 +195,8 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ } \ else \ { \ - for ( j = 0; j < n; ++j ) \ - for ( i = 0; i < m; ++i ) \ + for ( dim_t j = 0; j < n; ++j ) \ + for ( dim_t i = 0; i < m; ++i ) \ { \ PASTEMAC(ch,xpbys)( *(ct + i*rs_ct + j*cs_ct), \ *beta, \ @@ -216,21 +212,19 @@ PASTEMAC(chr,fprintm)( stdout, "gemm_ukr: c before", mr, nr, \ c_use = ( ctype_r* )c; \ rs_c_use = rs_c; \ cs_c_use = cs_c; \ - m_use = m; \ - n_use = n; \ + m_use = m; \ + n_use = n; \ \ - /* Convert the strides from being in units of complex elements to - be in units of real elements. Note that we don't need to check for - general storage here because that case corresponds to the scenario - where we are using the ct buffer and its rs_ct/cs_ct strides. */ \ + /* Convert the strides and corresponding microtile dimension from being + in units of complex elements to be in units of real elements. */ \ if ( bli_is_col_stored( rs_c_use, cs_c_use ) ) { cs_c_use *= 2; m_use *= 2; } \ else { rs_c_use *= 2; n_use *= 2; } \ \ /* c = beta * c + alpha_r * a * b; */ \ rgemm_ukr \ ( \ - m_use, \ - n_use, \ + m_use, \ + n_use, \ k, \ alpha_r, \ a_r, \ diff --git a/frame/3/gemm/bli_gemm_var.h b/frame/3/gemm/bli_gemm_var.h index 9fb8510101..888181bad6 100644 --- a/frame/3/gemm/bli_gemm_var.h +++ b/frame/3/gemm/bli_gemm_var.h @@ -35,12 +35,12 @@ // -// GEMM kernel parameter struct. +// gemm kernel parameter struct. // typedef struct { - gemm_ukr_vft ukr; + gemm_ukr_vft ukr; } gemm_ker_params_t; diff --git a/frame/3/gemm/ind/bli_gemm_ind_opt.h b/frame/3/gemm/ind/bli_gemm_ind_opt.h index 77213d0fa3..52ea81a5e8 100644 --- a/frame/3/gemm/ind/bli_gemm_ind_opt.h +++ b/frame/3/gemm/ind/bli_gemm_ind_opt.h @@ -58,7 +58,7 @@ BLIS_INLINE void bli_gemm_ind_recast_1m_params !bli_is_gen_stored( *rs_c, *cs_c ) ) { *dt_exec = bli_dt_proj_to_real( *dt_exec ); - *dt_c = bli_dt_proj_to_real( *dt_c ); + *dt_c = bli_dt_proj_to_real( *dt_c ); if ( bli_is_1e_packed( schema_a ) ) { diff --git a/frame/3/gemmt/bli_gemmt_l_ker_var2.c b/frame/3/gemmt/bli_gemmt_l_ker_var2.c index 6db72fd55e..fea4efec0a 100644 --- a/frame/3/gemmt/bli_gemmt_l_ker_var2.c +++ b/frame/3/gemmt/bli_gemmt_l_ker_var2.c @@ -387,8 +387,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - m_cur, \ - n_cur, \ + m_cur, \ + n_cur, \ k, \ alpha_cast, \ a1, \ @@ -470,8 +470,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - MR, \ - NR, \ + MR, \ + NR, \ k, \ alpha_cast, \ a1, \ @@ -494,8 +494,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - m_cur, \ - n_cur, \ + m_cur, \ + n_cur, \ k, \ alpha_cast, \ a1, \ diff --git a/frame/3/gemmt/bli_gemmt_u_ker_var2.c b/frame/3/gemmt/bli_gemmt_u_ker_var2.c index 0518cc5416..4b849bbc6d 100644 --- a/frame/3/gemmt/bli_gemmt_u_ker_var2.c +++ b/frame/3/gemmt/bli_gemmt_u_ker_var2.c @@ -388,8 +388,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - MR, \ - NR, \ + MR, \ + NR, \ k, \ alpha_cast, \ a1, \ @@ -412,8 +412,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - m_cur, \ - n_cur, \ + m_cur, \ + n_cur, \ k, \ alpha_cast, \ a1, \ @@ -497,8 +497,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - m_cur, \ - n_cur, \ + m_cur, \ + n_cur, \ k, \ alpha_cast, \ a1, \ diff --git a/frame/3/gemmt/other/bli_gemmt_l_ker_var2.c b/frame/3/gemmt/other/bli_gemmt_l_ker_var2.c new file mode 100644 index 0000000000..0bf4b1a0fb --- /dev/null +++ b/frame/3/gemmt/other/bli_gemmt_l_ker_var2.c @@ -0,0 +1,409 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmt_fp + +typedef void (*FUNCPTR_T) + ( + doff_t diagoffc, + pack_t schema_a, + pack_t schema_b, + dim_t m, + dim_t n, + dim_t k, + void* alpha, + void* a, inc_t cs_a, inc_t is_a, + dim_t pd_a, inc_t ps_a, + void* b, inc_t rs_b, inc_t is_b, + dim_t pd_b, inc_t ps_b, + void* beta, + void* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +static FUNCPTR_T GENARRAY(ftypes,gemmt_l_ker_var2); + + +void bli_gemmt_l_ker_var2 + ( + obj_t* a, + obj_t* b, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + num_t dt_exec = bli_obj_exec_dt( c ); + + doff_t diagoffc = bli_obj_diag_offset( c ); + + pack_t schema_a = bli_obj_pack_schema( a ); + pack_t schema_b = bli_obj_pack_schema( b ); + + dim_t m = bli_obj_length( c ); + dim_t n = bli_obj_width( c ); + dim_t k = bli_obj_width( a ); + + void* buf_a = bli_obj_buffer_at_off( a ); + inc_t cs_a = bli_obj_col_stride( a ); + inc_t is_a = bli_obj_imag_stride( a ); + dim_t pd_a = bli_obj_panel_dim( a ); + inc_t ps_a = bli_obj_panel_stride( a ); + + void* buf_b = bli_obj_buffer_at_off( b ); + inc_t rs_b = bli_obj_row_stride( b ); + inc_t is_b = bli_obj_imag_stride( b ); + dim_t pd_b = bli_obj_panel_dim( b ); + inc_t ps_b = bli_obj_panel_stride( b ); + + void* buf_c = bli_obj_buffer_at_off( c ); + inc_t rs_c = bli_obj_row_stride( c ); + inc_t cs_c = bli_obj_col_stride( c ); + + obj_t scalar_a; + obj_t scalar_b; + + void* buf_alpha; + void* buf_beta; + + FUNCPTR_T f; + + // Detach and multiply the scalars attached to A and B. + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); + buf_beta = bli_obj_internal_scalar_buffer( c ); + + // Index into the type combination array to extract the correct + // function pointer. + f = ftypes[dt_exec]; + + // Invoke the function. + f( diagoffc, + schema_a, + schema_b, + m, + n, + k, + buf_alpha, + buf_a, cs_a, is_a, + pd_a, ps_a, + buf_b, rs_b, is_b, + pd_b, ps_b, + buf_beta, + buf_c, rs_c, cs_c, + cntx, + rntm, + thread ); +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + doff_t diagoffc, \ + pack_t schema_a, \ + pack_t schema_b, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* alpha, \ + void* a, inc_t cs_a, inc_t is_a, \ + dim_t pd_a, inc_t ps_a, \ + void* b, inc_t rs_b, inc_t is_b, \ + dim_t pd_b, inc_t ps_b, \ + void* beta, \ + void* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Alias some constants to simpler names. */ \ + const dim_t MR = pd_a; \ + const dim_t NR = pd_b; \ + /*const dim_t PACKMR = cs_a;*/ \ + /*const dim_t PACKNR = rs_b;*/ \ +\ + /* Query the context for the micro-kernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + ctype* restrict zero = PASTEMAC(ch,0); \ + ctype* restrict a_cast = a; \ + ctype* restrict b_cast = b; \ + ctype* restrict c_cast = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ + ctype* restrict b1; \ + ctype* restrict c1; \ +\ + doff_t diagoffc_ij; \ + dim_t m_iter, m_left; \ + dim_t n_iter, n_left; \ + dim_t m_cur; \ + dim_t n_cur; \ + dim_t i, j, ip; \ + inc_t rstep_a; \ + inc_t cstep_b; \ + inc_t rstep_c, cstep_c; \ + auxinfo_t aux; \ +\ + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ \ +\ + /* If any dimension is zero, return immediately. */ \ + if ( bli_zero_dim3( m, n, k ) ) return; \ +\ + /* Safeguard: If the current panel of C is entirely above the diagonal, + it is not stored. So we do nothing. */ \ + if ( bli_is_strictly_above_diag_n( diagoffc, m, n ) ) return; \ +\ + /* If there is a zero region above where the diagonal of C intersects + the left edge of the panel, adjust the pointer to C and A and treat + this case as if the diagonal offset were zero. */ \ + if ( diagoffc < 0 ) \ + { \ + ip = -diagoffc / MR; \ + i = ip * MR; \ + m = m - i; \ + diagoffc = -diagoffc % MR; \ + c_cast = c_cast + (i )*rs_c; \ + a_cast = a_cast + (ip )*ps_a; \ + } \ +\ + /* If there is a zero region to the right of where the diagonal + of C intersects the bottom of the panel, shrink it to prevent + "no-op" iterations from executing. */ \ + if ( diagoffc + m < n ) \ + { \ + n = diagoffc + m; \ + } \ +\ + /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ + PASTEMAC(ch,set0s_mxn)( MR, NR, \ + ct, rs_ct, cs_ct ); \ +\ + /* Compute number of primary and leftover components of the m and n + dimensions. */ \ + n_iter = n / NR; \ + n_left = n % NR; \ +\ + m_iter = m / MR; \ + m_left = m % MR; \ +\ + if ( n_left ) ++n_iter; \ + if ( m_left ) ++m_iter; \ +\ + /* Determine some increments used to step through A, B, and C. */ \ + rstep_a = ps_a; \ +\ + cstep_b = ps_b; \ +\ + rstep_c = rs_c * MR; \ + cstep_c = cs_c * NR; \ +\ + /* Save the pack schemas of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_schema_a( schema_a, &aux ); \ + bli_auxinfo_set_schema_b( schema_b, &aux ); \ +\ + /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_is_a( is_a, &aux ); \ + bli_auxinfo_set_is_b( is_b, &aux ); \ +\ + b1 = b_cast; \ + c1 = c_cast; \ +\ + thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ + dim_t jr_num_threads = bli_thread_n_way( thread ); \ + dim_t jr_thread_id = bli_thread_work_id( thread ); \ + dim_t ir_num_threads = bli_thread_n_way( caucus ); \ + dim_t ir_thread_id = bli_thread_work_id( caucus ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( j = jr_thread_id; j < n_iter; j += jr_num_threads ) \ + { \ + ctype* restrict a1; \ + ctype* restrict c11; \ + ctype* restrict b2; \ +\ + b1 = b_cast + j * cstep_b; \ + c1 = c_cast + j * cstep_c; \ +\ + n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ +\ + /* Initialize our next panel of B to be the current panel of B. */ \ + b2 = b1; \ +\ + /* Interior loop over the m dimension (MR rows at a time). */ \ + for ( i = ir_thread_id; i < m_iter; i += ir_num_threads ) \ + { \ + ctype* restrict a2; \ +\ + a1 = a_cast + i * rstep_a; \ + c11 = c1 + i * rstep_c; \ +\ + /* Compute the diagonal offset for the submatrix at (i,j). */ \ + diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR; \ +\ + m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ +\ + /* Compute the addresses of the next panels of A and B. */ \ + a2 = bli_gemmt_get_next_a_upanel( caucus, a1, rstep_a ); \ + if ( bli_is_last_iter( i, m_iter, ir_thread_id, ir_num_threads ) ) \ + { \ + a2 = a_cast; \ + b2 = bli_gemmt_get_next_b_upanel( thread, b1, cstep_b ); \ + if ( bli_is_last_iter( j, n_iter, jr_thread_id, jr_num_threads ) ) \ + b2 = b_cast; \ + } \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* If the diagonal intersects the current MR x NR submatrix, we + compute it the temporary buffer and then add in the elements + on or below the diagonal. + Otherwise, if the submatrix is strictly below the diagonal, + we compute and store as we normally would. + And if we're strictly above the diagonal, we do nothing and + continue. */ \ + if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ + { \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale C and add the result to only the stored part. */ \ + PASTEMAC(ch,xpbys_mxn_l)( diagoffc_ij, \ + m_cur, n_cur, \ + ct, rs_ct, cs_ct, \ + beta_cast, \ + c11, rs_c, cs_c ); \ + } \ + else if ( bli_is_strictly_below_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ + { \ + /* Handle interior and edge cases separately. */ \ + if ( m_cur == MR && n_cur == NR ) \ + { \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + else \ + { \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale the edge of C and add the result. */ \ + PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ + ct, rs_ct, cs_ct, \ + beta_cast, \ + c11, rs_c, cs_c ); \ + } \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC0( gemmt_l_ker_var2 ) + diff --git a/frame/3/gemmt/other/bli_gemmt_u_ker_var2.c b/frame/3/gemmt/other/bli_gemmt_u_ker_var2.c new file mode 100644 index 0000000000..1655bea555 --- /dev/null +++ b/frame/3/gemmt/other/bli_gemmt_u_ker_var2.c @@ -0,0 +1,409 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmt_fp + +typedef void (*FUNCPTR_T) + ( + doff_t diagoffc, + pack_t schema_a, + pack_t schema_b, + dim_t m, + dim_t n, + dim_t k, + void* alpha, + void* a, inc_t cs_a, inc_t is_a, + dim_t pd_a, inc_t ps_a, + void* b, inc_t rs_b, inc_t is_b, + dim_t pd_b, inc_t ps_b, + void* beta, + void* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +static FUNCPTR_T GENARRAY(ftypes,gemmt_u_ker_var2); + + +void bli_gemmt_u_ker_var2 + ( + obj_t* a, + obj_t* b, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl, + thrinfo_t* thread + ) +{ + num_t dt_exec = bli_obj_exec_dt( c ); + + doff_t diagoffc = bli_obj_diag_offset( c ); + + pack_t schema_a = bli_obj_pack_schema( a ); + pack_t schema_b = bli_obj_pack_schema( b ); + + dim_t m = bli_obj_length( c ); + dim_t n = bli_obj_width( c ); + dim_t k = bli_obj_width( a ); + + void* buf_a = bli_obj_buffer_at_off( a ); + inc_t cs_a = bli_obj_col_stride( a ); + inc_t is_a = bli_obj_imag_stride( a ); + dim_t pd_a = bli_obj_panel_dim( a ); + inc_t ps_a = bli_obj_panel_stride( a ); + + void* buf_b = bli_obj_buffer_at_off( b ); + inc_t rs_b = bli_obj_row_stride( b ); + inc_t is_b = bli_obj_imag_stride( b ); + dim_t pd_b = bli_obj_panel_dim( b ); + inc_t ps_b = bli_obj_panel_stride( b ); + + void* buf_c = bli_obj_buffer_at_off( c ); + inc_t rs_c = bli_obj_row_stride( c ); + inc_t cs_c = bli_obj_col_stride( c ); + + obj_t scalar_a; + obj_t scalar_b; + + void* buf_alpha; + void* buf_beta; + + FUNCPTR_T f; + + // Detach and multiply the scalars attached to A and B. + bli_obj_scalar_detach( a, &scalar_a ); + bli_obj_scalar_detach( b, &scalar_b ); + bli_mulsc( &scalar_a, &scalar_b ); + + // Grab the addresses of the internal scalar buffers for the scalar + // merged above and the scalar attached to C. + buf_alpha = bli_obj_internal_scalar_buffer( &scalar_b ); + buf_beta = bli_obj_internal_scalar_buffer( c ); + + // Index into the type combination array to extract the correct + // function pointer. + f = ftypes[dt_exec]; + + // Invoke the function. + f( diagoffc, + schema_a, + schema_b, + m, + n, + k, + buf_alpha, + buf_a, cs_a, is_a, + pd_a, ps_a, + buf_b, rs_b, is_b, + pd_b, ps_b, + buf_beta, + buf_c, rs_c, cs_c, + cntx, + rntm, + thread ); +} + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + doff_t diagoffc, \ + pack_t schema_a, \ + pack_t schema_b, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* alpha, \ + void* a, inc_t cs_a, inc_t is_a, \ + dim_t pd_a, inc_t ps_a, \ + void* b, inc_t rs_b, inc_t is_b, \ + dim_t pd_b, inc_t ps_b, \ + void* beta, \ + void* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Alias some constants to simpler names. */ \ + const dim_t MR = pd_a; \ + const dim_t NR = pd_b; \ + /*const dim_t PACKMR = cs_a;*/ \ + /*const dim_t PACKNR = rs_b;*/ \ +\ + /* Query the context for the micro-kernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + ctype* restrict zero = PASTEMAC(ch,0); \ + ctype* restrict a_cast = a; \ + ctype* restrict b_cast = b; \ + ctype* restrict c_cast = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ + ctype* restrict b1; \ + ctype* restrict c1; \ +\ + doff_t diagoffc_ij; \ + dim_t m_iter, m_left; \ + dim_t n_iter, n_left; \ + dim_t m_cur; \ + dim_t n_cur; \ + dim_t i, j, jp; \ + inc_t rstep_a; \ + inc_t cstep_b; \ + inc_t rstep_c, cstep_c; \ + auxinfo_t aux; \ +\ + /* + Assumptions/assertions: + rs_a == 1 + cs_a == PACKMR + pd_a == MR + ps_a == stride to next micro-panel of A + rs_b == PACKNR + cs_b == 1 + pd_b == NR + ps_b == stride to next micro-panel of B + rs_c == (no assumptions) + cs_c == (no assumptions) + */ \ +\ + /* If any dimension is zero, return immediately. */ \ + if ( bli_zero_dim3( m, n, k ) ) return; \ +\ + /* Safeguard: If the current panel of C is entirely below the diagonal, + it is not stored. So we do nothing. */ \ + if ( bli_is_strictly_below_diag_n( diagoffc, m, n ) ) return; \ +\ + /* If there is a zero region to the left of where the diagonal of C + intersects the top edge of the panel, adjust the pointer to C and B + and treat this case as if the diagonal offset were zero. */ \ + if ( diagoffc > 0 ) \ + { \ + jp = diagoffc / NR; \ + j = jp * NR; \ + n = n - j; \ + diagoffc = diagoffc % NR; \ + c_cast = c_cast + (j )*cs_c; \ + b_cast = b_cast + (jp )*ps_b; \ + } \ +\ + /* If there is a zero region below where the diagonal of C intersects + the right edge of the panel, shrink it to prevent "no-op" iterations + from executing. */ \ + if ( -diagoffc + n < m ) \ + { \ + m = -diagoffc + n; \ + } \ +\ + /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ + PASTEMAC(ch,set0s_mxn)( MR, NR, \ + ct, rs_ct, cs_ct ); \ +\ + /* Compute number of primary and leftover components of the m and n + dimensions. */ \ + n_iter = n / NR; \ + n_left = n % NR; \ +\ + m_iter = m / MR; \ + m_left = m % MR; \ +\ + if ( n_left ) ++n_iter; \ + if ( m_left ) ++m_iter; \ +\ + /* Determine some increments used to step through A, B, and C. */ \ + rstep_a = ps_a; \ +\ + cstep_b = ps_b; \ +\ + rstep_c = rs_c * MR; \ + cstep_c = cs_c * NR; \ +\ + /* Save the pack schemas of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_schema_a( schema_a, &aux ); \ + bli_auxinfo_set_schema_b( schema_b, &aux ); \ +\ + /* Save the imaginary stride of A and B to the auxinfo_t object. */ \ + bli_auxinfo_set_is_a( is_a, &aux ); \ + bli_auxinfo_set_is_b( is_b, &aux ); \ +\ + b1 = b_cast; \ + c1 = c_cast; \ +\ + thrinfo_t* caucus = bli_thrinfo_sub_node( thread ); \ + dim_t jr_num_threads = bli_thread_n_way( thread ); \ + dim_t jr_thread_id = bli_thread_work_id( thread ); \ + dim_t ir_num_threads = bli_thread_n_way( caucus ); \ + dim_t ir_thread_id = bli_thread_work_id( caucus ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( j = jr_thread_id; j < n_iter; j += jr_num_threads ) \ + { \ + ctype* restrict a1; \ + ctype* restrict c11; \ + ctype* restrict b2; \ +\ + b1 = b_cast + j * cstep_b; \ + c1 = c_cast + j * cstep_c; \ +\ + n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ +\ + /* Initialize our next panel of B to be the current panel of B. */ \ + b2 = b1; \ +\ + /* Interior loop over the m dimension (MR rows at a time). */ \ + for ( i = ir_thread_id; i < m_iter; i += ir_num_threads ) \ + { \ + ctype* restrict a2; \ +\ + a1 = a_cast + i * rstep_a; \ + c11 = c1 + i * rstep_c; \ +\ + /* Compute the diagonal offset for the submatrix at (i,j). */ \ + diagoffc_ij = diagoffc - (doff_t)j*NR + (doff_t)i*MR; \ +\ + m_cur = ( bli_is_not_edge_f( i, m_iter, m_left ) ? MR : m_left ); \ +\ + /* Compute the addresses of the next panels of A and B. */ \ + a2 = bli_gemmt_get_next_a_upanel( caucus, a1, rstep_a ); \ + if ( bli_is_last_iter( i, m_iter, ir_thread_id, ir_num_threads ) ) \ + { \ + a2 = a_cast; \ + b2 = bli_gemmt_get_next_b_upanel( thread, b1, cstep_b ); \ + if ( bli_is_last_iter( j, n_iter, jr_thread_id, jr_num_threads ) ) \ + b2 = b_cast; \ + } \ +\ + /* Save addresses of next panels of A and B to the auxinfo_t + object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* If the diagonal intersects the current MR x NR submatrix, we + compute it the temporary buffer and then add in the elements + on or below the diagonal. + Otherwise, if the submatrix is strictly above the diagonal, + we compute and store as we normally would. + And if we're strictly below the diagonal, we do nothing and + continue. */ \ + if ( bli_intersects_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ + { \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale C and add the result to only the stored part. */ \ + PASTEMAC(ch,xpbys_mxn_u)( diagoffc_ij, \ + m_cur, n_cur, \ + ct, rs_ct, cs_ct, \ + beta_cast, \ + c11, rs_c, cs_c ); \ + } \ + else if ( bli_is_strictly_above_diag_n( diagoffc_ij, m_cur, n_cur ) ) \ + { \ + /* Handle interior and edge cases separately. */ \ + if ( m_cur == MR && n_cur == NR ) \ + { \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + beta_cast, \ + c11, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + else \ + { \ + /* Invoke the gemm micro-kernel. */ \ + gemm_ukr \ + ( \ + k, \ + alpha_cast, \ + a1, \ + b1, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale the edge of C and add the result. */ \ + PASTEMAC(ch,xpbys_mxn)( m_cur, n_cur, \ + ct, rs_ct, cs_ct, \ + beta_cast, \ + c11, rs_c, cs_c ); \ + } \ + } \ + } \ + } \ +} + +INSERT_GENTFUNC_BASIC0( gemmt_u_ker_var2 ) + diff --git a/frame/3/trmm/bli_trmm_ll_ker_var2.c b/frame/3/trmm/bli_trmm_ll_ker_var2.c index faf6ca100b..646287f931 100644 --- a/frame/3/trmm/bli_trmm_ll_ker_var2.c +++ b/frame/3/trmm/bli_trmm_ll_ker_var2.c @@ -291,8 +291,8 @@ void PASTEMAC(ch,varname) \ dim_t jr_inc; \ \ /* Determine the thread range and increment for the 2nd loop. - NOTE: The definition of bli_thread_range_jrir() will depend on whether - slab or round-robin partitioning was requested at configure-time. \ + NOTE: The definition of bli_thread_range_jrir() will depend on whether + slab or round-robin partitioning was requested at configure-time. \ NOTE: Parallelism in the 1st loop is disabled for now. */ \ bli_thread_range_jrir( thread, n_iter, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ /*bli_thread_range_jrir_rr( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc );*/ \ @@ -366,8 +366,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - m_cur, \ - n_cur, \ + m_cur, \ + n_cur, \ k_a1011, \ alpha_cast, \ a1, \ @@ -406,8 +406,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - m_cur, \ - n_cur, \ + m_cur, \ + n_cur, \ k, \ alpha_cast, \ a1, \ diff --git a/frame/3/trmm/bli_trmm_lu_ker_var2.c b/frame/3/trmm/bli_trmm_lu_ker_var2.c index f7a3a717ce..9ef2a475de 100644 --- a/frame/3/trmm/bli_trmm_lu_ker_var2.c +++ b/frame/3/trmm/bli_trmm_lu_ker_var2.c @@ -373,8 +373,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - m_cur, \ - n_cur, \ + m_cur, \ + n_cur, \ k_a1112, \ alpha_cast, \ a1, \ @@ -413,8 +413,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - m_cur, \ - n_cur, \ + m_cur, \ + n_cur, \ k, \ alpha_cast, \ a1, \ diff --git a/frame/3/trmm/bli_trmm_rl_ker_var2.c b/frame/3/trmm/bli_trmm_rl_ker_var2.c index 195f8577db..f6b20af2e5 100644 --- a/frame/3/trmm/bli_trmm_rl_ker_var2.c +++ b/frame/3/trmm/bli_trmm_rl_ker_var2.c @@ -319,9 +319,9 @@ void PASTEMAC(ch,varname) \ \ /* Determine the thread range and increment for the 2nd and 1st loops for the initial rectangular region of B (if it exists). - NOTE: The definition of bli_thread_range_jrir() will depend on whether - slab or round-robin partitioning was requested at configure-time. \ - NOTE: Parallelism in the 1st loop is disabled for now. */ \ + NOTE: The definition of bli_thread_range_jrir() will depend on whether + slab or round-robin partitioning was requested at configure-time. \ + NOTE: Parallelism in the 1st loop is disabled for now. */ \ bli_thread_range_jrir( thread, n_iter_rct, 1, FALSE, &jr_start, &jr_end, &jr_inc ); \ bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ \ @@ -369,8 +369,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - m_cur, \ - n_cur, \ + m_cur, \ + n_cur, \ k, \ alpha_cast, \ a1, \ @@ -466,8 +466,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - m_cur, \ - n_cur, \ + m_cur, \ + n_cur, \ k_b1121, \ alpha_cast, \ a1_i, \ diff --git a/frame/3/trmm/bli_trmm_ru_ker_var2.c b/frame/3/trmm/bli_trmm_ru_ker_var2.c index 47df210f71..f71fb3c4d8 100644 --- a/frame/3/trmm/bli_trmm_ru_ker_var2.c +++ b/frame/3/trmm/bli_trmm_ru_ker_var2.c @@ -397,8 +397,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - m_cur, \ - n_cur, \ + m_cur, \ + n_cur, \ k_b0111, \ alpha_cast, \ a1_i, \ @@ -433,9 +433,9 @@ void PASTEMAC(ch,varname) \ bli_thread_range_jrir( caucus, m_iter, 1, FALSE, &ir_start, &ir_end, &ir_inc ); \ \ /* Advance the start and end iteration offsets for the rectangular region - by the number of iterations used for the triangular region. */ \ - jr_start += n_iter_tri; \ - jr_end += n_iter_tri; \ + by the number of iterations used for the triangular region. */ \ + jr_start += n_iter_tri; \ + jr_end += n_iter_tri; \ jb0 = n_iter_tri; \ \ /* Save the resulting value of b1 from the previous loop since it represents @@ -453,7 +453,7 @@ void PASTEMAC(ch,varname) \ the starting address of the rectangular region (which is already n_iter_tri logical iterations through B). */ \ b1 = b_cast + (j-jb0) * cstep_b; \ - c1 = c_cast + j * cstep_c; \ + c1 = c_cast + j * cstep_c; \ \ n_cur = ( bli_is_not_edge_f( j, n_iter, n_left ) ? NR : n_left ); \ \ @@ -493,8 +493,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - m_cur, \ - n_cur, \ + m_cur, \ + n_cur, \ k, \ alpha_cast, \ a1, \ diff --git a/frame/3/trsm/bli_trsm_cntl.c b/frame/3/trsm/bli_trsm_cntl.c index e37e1117c8..0a3be87f74 100644 --- a/frame/3/trsm/bli_trsm_cntl.c +++ b/frame/3/trsm/bli_trsm_cntl.c @@ -60,10 +60,10 @@ cntl_t* bli_trsm_l_cntl_create { void_fp macro_kernel_p; - // Use the function pointer to the macrokernels that use slab - // assignment of micropanels to threads in the jr and ir loops. + // Set the default macrokernel. If a non-NULL kernel function pointer is + // passed in, we use that instead. macro_kernel_p = bli_trsm_xx_ker_var2; - if ( ker ) macro_kernel_p = ker; + if ( ker ) macro_kernel_p = ker; const opid_t family = BLIS_TRSM; @@ -203,15 +203,17 @@ cntl_t* bli_trsm_l_cntl_create cntl_t* bli_trsm_r_cntl_create ( - rntm_t* rntm, + rntm_t* rntm, pack_t schema_a, pack_t schema_b, void_fp ker ) { // NOTE: trsm macrokernels are presently disabled for right-side execution. + // Set the default macrokernel. If a non-NULL kernel function pointer is + // passed in, we use that instead. void_fp macro_kernel_p = bli_trsm_xx_ker_var2; - if ( ker ) macro_kernel_p = ker; + if ( ker ) macro_kernel_p = ker; const opid_t family = BLIS_TRSM; diff --git a/frame/3/trsm/bli_trsm_ll_ker_var2.c b/frame/3/trsm/bli_trsm_ll_ker_var2.c index cf6a0cee48..b503efa5bf 100644 --- a/frame/3/trsm/bli_trsm_ll_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ll_ker_var2.c @@ -472,8 +472,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - m_cur, \ - n_cur, \ + m_cur, \ + n_cur, \ k, \ minus_one, \ a1, \ diff --git a/frame/3/trsm/bli_trsm_lu_ker_var2.c b/frame/3/trsm/bli_trsm_lu_ker_var2.c index 38d5ab0df1..55ceafb91d 100644 --- a/frame/3/trsm/bli_trsm_lu_ker_var2.c +++ b/frame/3/trsm/bli_trsm_lu_ker_var2.c @@ -482,8 +482,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - m_cur, \ - n_cur, \ + m_cur, \ + n_cur, \ k, \ minus_one, \ a1, \ diff --git a/frame/3/trsm/bli_trsm_rl_ker_var2.c b/frame/3/trsm/bli_trsm_rl_ker_var2.c index 3fedc9295f..23d4dd7289 100644 --- a/frame/3/trsm/bli_trsm_rl_ker_var2.c +++ b/frame/3/trsm/bli_trsm_rl_ker_var2.c @@ -383,7 +383,7 @@ void PASTEMAC(ch,varname) \ \ /* Compute the addresses of the triangular block B11 and the panel B21. */ \ - b11 = b1; \ + b11 = b1; \ b21 = b1 + k_b11 * PACKNR; \ /*b21 = bli_ptr_inc_by_frac( b1, sizeof( ctype ), k_b11 * PACKNR, 1 );*/ \ \ @@ -501,8 +501,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - m_cur, \ - n_cur, \ + m_cur, \ + n_cur, \ k, \ minus_one, \ b1, \ diff --git a/frame/3/trsm/bli_trsm_ru_ker_var2.c b/frame/3/trsm/bli_trsm_ru_ker_var2.c index 72b917b2c4..71381707c4 100644 --- a/frame/3/trsm/bli_trsm_ru_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ru_ker_var2.c @@ -376,7 +376,7 @@ void PASTEMAC(ch,varname) \ \ /* Compute the addresses of the panel B10 and the triangular block B11. */ \ - b01 = b1; \ + b01 = b1; \ b11 = b1 + k_b01 * PACKNR; \ /*b11 = bli_ptr_inc_by_frac( b1, sizeof( ctype ), k_b01 * PACKNR, 1 );*/ \ \ @@ -494,8 +494,8 @@ void PASTEMAC(ch,varname) \ /* Invoke the gemm micro-kernel. */ \ gemm_ukr \ ( \ - m_cur, \ - n_cur, \ + m_cur, \ + n_cur, \ k, \ minus_one, \ b1, \ diff --git a/frame/include/bli_misc_macro_defs.h b/frame/include/bli_misc_macro_defs.h index b166b7a171..7e1b93b944 100644 --- a/frame/include/bli_misc_macro_defs.h +++ b/frame/include/bli_misc_macro_defs.h @@ -167,58 +167,70 @@ BLIS_INLINE void bli_toggle_bool( bool* b ) // helper macros for gemm microkernels #define GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major) \ -PASTEMAC(ch,ctype)* restrict _beta = beta; \ -PASTEMAC(ch,ctype)* restrict _c = c; \ -const inc_t _rs_c = rs_c; \ -const inc_t _cs_c = cs_c; \ -PASTEMAC(ch,ctype) _ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( PASTEMAC(ch,type) ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ -const inc_t _rs_ct = row_major ? nr : 1; \ -const inc_t _cs_ct = row_major ? 1 : mr; +\ + PASTEMAC(ch,ctype)* restrict _beta = beta; \ + PASTEMAC(ch,ctype)* restrict _c = c; \ + const inc_t _rs_c = rs_c; \ + const inc_t _cs_c = cs_c; \ + PASTEMAC(ch,ctype) _ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( PASTEMAC(ch,type) ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const inc_t _rs_ct = row_major ? nr : 1; \ + const inc_t _cs_ct = row_major ? 1 : mr; #define GEMM_UKR_SETUP_CT_POST(ch) \ -PASTEMAC(ch,ctype) _zero; \ -PASTEMAC(ch,set0s)( _zero ); \ \ -if ( _use_ct ) \ -{ \ - c = _ct; \ - rs_c = _rs_ct; \ - cs_c = _cs_ct; \ - beta = &_zero; \ -} + PASTEMAC(ch,ctype) _zero; \ + PASTEMAC(ch,set0s)( _zero ); \ + \ + if ( _use_ct ) \ + { \ + c = _ct; \ + rs_c = _rs_ct; \ + cs_c = _cs_ct; \ + beta = &_zero; \ + } #define GEMM_UKR_SETUP_CT(ch,mr,nr,row_major) \ -GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ -const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \ - m != mr || n != nr; \ -GEMM_UKR_SETUP_CT_POST(ch); +\ + GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ + const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \ + m != mr || n != nr; \ + GEMM_UKR_SETUP_CT_POST(ch); #define GEMM_UKR_SETUP_CT_AMBI(ch,mr,nr,row_major) \ -GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ -const bool _use_ct = ( cs_c != 1 && rs_c != 1 ) || \ - m != mr || n != nr; \ -GEMM_UKR_SETUP_CT_POST(ch); +\ + GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ + const bool _use_ct = ( cs_c != 1 && rs_c != 1 ) || \ + m != mr || n != nr; \ + GEMM_UKR_SETUP_CT_POST(ch); #define GEMM_UKR_SETUP_CT_ANY(ch,mr,nr,row_major) \ -GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ -const bool _use_ct = m != mr || n != nr; \ -GEMM_UKR_SETUP_CT_POST(ch); +\ + GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ + const bool _use_ct = m != mr || n != nr; \ + GEMM_UKR_SETUP_CT_POST(ch); #define GEMM_UKR_SETUP_CT_ALIGNED(ch,mr,nr,row_major,alignment) \ -GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ -const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \ - m != mr || n != nr || \ - ( (uintptr_t)_c % alignment ) || \ - ( ( ( row_major ? _rs_c : _cs_c )*sizeof( PASTEMAC(ch,ctype) ) ) % alignment ); \ -GEMM_UKR_SETUP_CT_POST(ch); +\ + GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ + const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \ + m != mr || n != nr || \ + ( (uintptr_t)_c % alignment ) || \ + ( ( ( row_major ? _rs_c : _cs_c )*sizeof( PASTEMAC(ch,ctype) ) ) % alignment ); \ + GEMM_UKR_SETUP_CT_POST(ch); #define GEMM_UKR_FLUSH_CT(ch) \ -if ( _use_ct ) \ - PASTEMAC(ch,xpbys_mxn)( m, n, \ - _ct, _rs_ct, _cs_ct, \ - _beta, \ - _c, _rs_c, _cs_c); +\ + if ( _use_ct ) \ + { \ + PASTEMAC(ch,xpbys_mxn) \ + ( \ + m, n, \ + _ct, _rs_ct, _cs_ct, \ + _beta, \ + _c, _rs_c, _cs_c \ + ); \ + } \ #endif diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index ba6db2af89..c66505bde8 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -1147,9 +1147,9 @@ typedef struct // The type to convert to on output. //num_t dt_on_output; - // (Virtual) microkernel address and additional parameters - void_fp ukr; - void* params; + // (Virtual) microkernel address and additional parameters. + void_fp ukr; + void* params; } auxinfo_t; diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c index 6941558967..913abd1f6c 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_c2vx10_unindexed.c @@ -48,7 +48,7 @@ void bli_cgemm_armsve_asm_2vx10_unindexed ( dim_t m, dim_t n, - dim_t k0, + dim_t k, scomplex* restrict alpha, scomplex* restrict a, scomplex* restrict b, @@ -63,8 +63,8 @@ void bli_cgemm_armsve_asm_2vx10_unindexed // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_mker = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_mker = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; uint64_t info = 0; diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c index d12e7c8353..ee041b3c40 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx10_unindexed.c @@ -48,7 +48,7 @@ void bli_zgemm_armsve_asm_2vx10_unindexed ( dim_t m, dim_t n, - dim_t k0, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -63,8 +63,8 @@ void bli_zgemm_armsve_asm_2vx10_unindexed // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_mker = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_mker = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; uint64_t info = 0; diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c index b34bc5b1bd..641944ecd4 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx7_unindexed.c @@ -48,7 +48,7 @@ void bli_zgemm_armsve_asm_2vx7_unindexed ( dim_t m, dim_t n, - dim_t k0, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -63,8 +63,8 @@ void bli_zgemm_armsve_asm_2vx7_unindexed // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_mker = k0 / 4; - uint64_t k_left = k0 % 4; + uint64_t k_mker = k / 4; + uint64_t k_left = k % 4; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; uint64_t info = 0; diff --git a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c index a6f2c0481f..4272f72c02 100644 --- a/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c +++ b/kernels/armsve/3/bli_gemm_armsve_asm_z2vx8_unindexed.c @@ -48,7 +48,7 @@ void bli_zgemm_armsve_asm_2vx8_unindexed ( dim_t m, dim_t n, - dim_t k0, + dim_t k, dcomplex* restrict alpha, dcomplex* restrict a, dcomplex* restrict b, @@ -63,8 +63,8 @@ void bli_zgemm_armsve_asm_2vx8_unindexed // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. - uint64_t k_mker = k0 / 6; - uint64_t k_left = k0 % 6; + uint64_t k_mker = k / 6; + uint64_t k_left = k % 6; uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; uint64_t info = 0; diff --git a/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c b/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c index 2b188da842..7b420f202f 100644 --- a/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c +++ b/kernels/armv8a/3/bli_gemm_armv8a_asm_d6x8.c @@ -75,1018 +75,1018 @@ void bli_sgemm_armv8a_asm_8x12 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT( s, 8, 12, false ); + GEMM_UKR_SETUP_CT( s, 8, 12, false ); - __asm__ volatile - ( - " \n\t" - " \n\t" - " ldr x0,%[aaddr] \n\t" // Load address of A. - " ldr x1,%[baddr] \n\t" // Load address of B. - " ldr x2,%[caddr] \n\t" // Load address of C. - " \n\t" - " ldr x5,%[k_iter] \n\t" // Number of unrolled iterations (k_iter). - " ldr x6,%[k_left] \n\t" // Number of remaining iterations (k_left). - " \n\t" - " ldr x10,%[cs_c] \n\t" // Load cs_c. - " lsl x10,x10,#2 \n\t" // cs_c * sizeof(float) -- AUX. - " \n\t" - " ldr x14,%[rs_c] \n\t" // Load rs_c. - " lsl x14,x14,#2 \n\t" // rs_c * sizeof(float). - " \n\t" - " add x16,x2,x10 \n\t" //Load address Column 1 of C - " add x17,x16,x10 \n\t" //Load address Column 2 of C - " add x19,x17,x10 \n\t" //Load address Column 3 of C - " add x20,x19,x10 \n\t" //Load address Column 4 of C - " add x21,x20,x10 \n\t" //Load address Column 5 of C - " add x22,x21,x10 \n\t" //Load address Column 6 of C - " add x23,x22,x10 \n\t" //Load address Column 7 of C - " add x24,x23,x10 \n\t" //Load address Column 8 of C - " add x25,x24,x10 \n\t" //Load address Column 9 of C - " add x26,x25,x10 \n\t" //Load address Column 10 of C - " add x27,x26,x10 \n\t" //Load address Column 11 of C - " \n\t" - " prfm pldl1keep,[x2] \n\t" // Prefetch c. - " prfm pldl1keep,[x16] \n\t" // Prefetch c. - " prfm pldl1keep,[x17] \n\t" // Prefetch c. - " prfm pldl1keep,[x19] \n\t" // Prefetch c. - " prfm pldl1keep,[x20] \n\t" // Prefetch c. - " prfm pldl1keep,[x21] \n\t" // Prefetch c. - " prfm pldl1keep,[x22] \n\t" // Prefetch c. - " prfm pldl1keep,[x23] \n\t" // Prefetch c. - " prfm pldl1keep,[x24] \n\t" // Prefetch c. - " prfm pldl1keep,[x25] \n\t" // Prefetch c. - " prfm pldl1keep,[x26] \n\t" // Prefetch c. - " prfm pldl1keep,[x27] \n\t" // Prefetch c. - " \n\t" - " dup v8.4s, wzr \n\t" // Vector for accummulating column 0 - " prfm PLDL1KEEP, [x1, #192] \n\t" - " dup v9.4s, wzr \n\t" // Vector for accummulating column 0 - " prfm PLDL1KEEP, [x1, #256] \n\t" - " dup v10.4s, wzr \n\t" // Vector for accummulating column 1 - " prfm PLDL1KEEP, [x1, #320] \n\t" - " dup v11.4s, wzr \n\t" // Vector for accummulating column 1 - " dup v12.4s, wzr \n\t" // Vector for accummulating column 2 - " dup v13.4s, wzr \n\t" // Vector for accummulating column 2 - " \n\t" - " dup v14.4s, wzr \n\t" // Vector for accummulating column 3 - " prfm PLDL1KEEP, [x0, #128] \n\t" - " dup v15.4s, wzr \n\t" // Vector for accummulating column 3 - " prfm PLDL1KEEP, [x0, #192] \n\t" - " dup v16.4s, wzr \n\t" // Vector for accummulating column 4 - " dup v17.4s, wzr \n\t" // Vector for accummulating column 4 - " dup v18.4s, wzr \n\t" // Vector for accummulating column 5 - " dup v19.4s, wzr \n\t" // Vector for accummulating column 5 - " \n\t" - " dup v20.4s, wzr \n\t" // Vector for accummulating column 6 - " dup v21.4s, wzr \n\t" // Vector for accummulating column 6 - " dup v22.4s, wzr \n\t" // Vector for accummulating column 7 - " dup v23.4s, wzr \n\t" // Vector for accummulating column 7 - " dup v24.4s, wzr \n\t" // Vector for accummulating column 8 - " dup v25.4s, wzr \n\t" // Vector for accummulating column 8 - " \n\t" - " dup v26.4s, wzr \n\t" // Vector for accummulating column 9 - " dup v27.4s, wzr \n\t" // Vector for accummulating column 9 - " dup v28.4s, wzr \n\t" // Vector for accummulating column 10 - " dup v29.4s, wzr \n\t" // Vector for accummulating column 10 - " dup v30.4s, wzr \n\t" // Vector for accummulating column 11 - " dup v31.4s, wzr \n\t" // Vector for accummulating column 11 - " \n\t" - " cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. - BEQ(SCONSIDERKLEFT) - " \n\t" - " ldr q0, [x0] \n\t" - " ldr q1, [x0, #16] \n\t" // Load a - " \n\t" - " ldr q2, [x1] \n\t" // Load b - " ldr q3, [x1, #16] \n\t" - " ldr q4, [x1, #32] \n\t" - " \n\t" - " add x0, x0, #32 \n\t" //update address of A - " add x1, x1, #48 \n\t" //update address of B - " \n\t" - " cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. - BEQ(SLASTITER) // (as loop is do-while-like). - " \n\t" - LABEL(SLOOPKITER) // Body of the k_iter loop. - " \n\t" - " ldr q5, [x0] \n\t" - " fmla v8.4s, v0.4s,v2.s[0] \n\t" // Accummulate. - " fmla v9.4s, v1.4s,v2.s[0] \n\t" // Accummulate. - " ldr q6, [x0, #16] \n\t" - " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. - " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. - " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. - " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. - " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. - " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. - " ldr q2, [x1] \n\t" - " \n\t" - " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. - " prfm PLDL1KEEP, [x1, #336] \n\t" - " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. - " prfm PLDL1KEEP, [x1, #400] \n\t" - " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. - " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. - " prfm PLDL1KEEP, [x1, #464] \n\t" - " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. - " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. - " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. - " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. - " \n\t" - " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. - " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. - " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. - " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. - " ldr q3, [x1, #16] \n\t" - " \n\t" - " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. - " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. - " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. - " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. - " ldr q4, [x1, #32] \n\t" - " \n\t" //End It 1 - " \n\t" - " ldr q0, [x0, #32] \n\t" - " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. - " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. - " ldr q1, [x0, #48] \n\t" - " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. - " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. - " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. - " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. - " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. - " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. - " ldr q2, [x1, #48] \n\t" - " \n\t" - " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. - " prfm PLDL1KEEP, [x0, #224] \n\t" - " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. - " prfm PLDL1KEEP, [x0, #288] \n\t" - " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. - " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. - " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. - " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. - " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. - " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. - " \n\t" - " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. - " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. - " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. - " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. - " ldr q3, [x1, #64] \n\t" - " \n\t" - " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. - " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. - " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. - " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. - " ldr q4, [x1, #80] \n\t" - " \n\t" //End It 2 - " \n\t" - " ldr q5, [x0, #64] \n\t" - " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. - " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. - " ldr q6, [x0, #80] \n\t" - " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. - " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. - " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. - " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. - " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. - " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. - " ldr q2, [x1, #96] \n\t" - " \n\t" - " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. - " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. - " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. - " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. - " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. - " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. - " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. - " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. - " \n\t" - " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. - " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. - " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. - " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. - " ldr q3, [x1, #112] \n\t" - " \n\t" - " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. - " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. - " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. - " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. - " ldr q4, [x1, #128] \n\t" - " \n\t" //End It 3 - " \n\t" - " ldr q0, [x0, #96] \n\t" - " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. - " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. - " ldr q1, [x0, #112] \n\t" - " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. - " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. - " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. - " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. - " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. - " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. - " ldr q2, [x1, #144] \n\t" - " \n\t" - " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. - " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. - " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. - " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. - " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. - " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. - " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. - " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. - " \n\t" - " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. - " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. - " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. - " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. - " ldr q3, [x1, #160] \n\t" - " \n\t" - " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. - " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. - " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. - " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. - " ldr q4, [x1, #176] \n\t" - " add x1, x1, #192 \n\t" - " add x0, x0, #128 \n\t" - " \n\t" //End It 4 - " sub x5,x5,1 \n\t" // i-=1. - " cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. - BNE(SLOOPKITER) - " \n\t" - LABEL(SLASTITER) // Last iteration of k_iter loop. - " \n\t" - " \n\t" - " ldr q5, [x0] \n\t" - " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. - " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. - " ldr q6, [x0, #16] \n\t" - " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. - " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. - " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. - " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. - " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. - " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. - " ldr q2, [x1] \n\t" - " \n\t" - " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. - " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. - " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. - " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. - " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. - " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. - " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. - " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. - " \n\t" - " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. - " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. - " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. - " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. - " ldr q3, [x1, #16] \n\t" - " \n\t" - " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. - " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. - " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. - " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. - " ldr q4, [x1, #32] \n\t" - " \n\t" //End It 1 - " \n\t" - " ldr q0, [x0, #32] \n\t" - " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. - " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. - " ldr q1, [x0, #48] \n\t" - " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. - " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. - " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. - " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. - " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. - " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. - " ldr q2, [x1, #48] \n\t" - " \n\t" - " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. - " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. - " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. - " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. - " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. - " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. - " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. - " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. - " \n\t" - " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. - " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. - " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. - " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. - " ldr q3, [x1, #64] \n\t" - " \n\t" - " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. - " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. - " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. - " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. - " ldr q4, [x1, #80] \n\t" - " \n\t" //End It 2 - " \n\t" - " ldr q5, [x0, #64] \n\t" - " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. - " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. - " ldr q6, [x0, #80] \n\t" - " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. - " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. - " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. - " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. - " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. - " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. - " ldr q2, [x1, #96] \n\t" - " \n\t" - " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. - " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. - " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. - " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. - " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. - " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. - " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. - " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. - " \n\t" - " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. - " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. - " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. - " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. - " ldr q3, [x1, #112] \n\t" - " \n\t" - " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. - " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. - " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. - " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. - " ldr q4, [x1, #128] \n\t" - " \n\t" //End It 3 - " \n\t" - " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. - " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. - " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. - " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. - " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. - " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. - " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. - " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. - " \n\t" - " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. - " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. - " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. - " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. - " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. - " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. - " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. - " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. - " \n\t" - " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. - " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. - " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. - " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. - " \n\t" - " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. - " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. - " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. - " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. - " add x1, x1, #144 \n\t" - " add x0, x0, #96 \n\t" - " \n\t" //End It 4 - " \n\t" - LABEL(SCONSIDERKLEFT) - " cmp x6,0 \n\t" // If k_left == 0, we are done. - BEQ(SPOSTACCUM) // else, we enter the k_left loop. - " \n\t" - LABEL(SLOOPKLEFT) // Body of the left iterations - " \n\t" - " ldr q0, [x0],#16 \n\t" - " ldr q1, [x0],#16 \n\t" // Load a - " \n\t" - " ldr q2, [x1],#16 \n\t" // Load b - " ldr q3, [x1],#16 \n\t" - " ldr q4, [x1],#16 \n\t" - " \n\t" - " sub x6,x6,1 \n\t" // i = i-1. - " \n\t" - " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. - " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. - " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. - " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. - " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. - " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. - " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. - " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. - " \n\t" - " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. - " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. - " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. - " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. - " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. - " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. - " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. - " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. - " \n\t" - " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. - " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. - " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. - " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. - " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. - " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. - " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. - " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. - " \n\t" - " cmp x6,0 \n\t" // Iterate again. - BNE(SLOOPKLEFT) // if i!=0. - " \n\t" - LABEL(SPOSTACCUM) - " \n\t" - " ldr x0,%[alpha] \n\t" // Alpha address. - " ldr x1,%[beta] \n\t" // Beta address. - " \n\t" - " ld1r {v6.4s},[x0] \n\t" // Load alpha. - " ld1r {v7.4s},[x1] \n\t" // Load beta - " \n\t" - " ldr x0,%[a_next] \n\t" // Pointer to next block of A. - " ldr x1,%[b_next] \n\t" // Pointer to next pointer of B. - " \n\t" -" cmp x14,#4 \n\t" // If rs_c != 1 (column-major) -BNE(SGENSTORED) -" \n\t" - LABEL(SCOLSTORED) // C is column-major. - " \n\t" - " dup v0.4s, wzr \n\t" - " dup v1.4s, wzr \n\t" - " dup v2.4s, wzr \n\t" - " dup v3.4s, wzr \n\t" - " dup v4.4s, wzr \n\t" - " dup v5.4s, wzr \n\t" - " \n\t" - " fcmp s7,#0.0 \n\t" - BEQ(SBETAZEROCOLSTOREDS1) // Taking care of the beta==0 case. - " \n\t" - " ldr q0, [x2] \n\t" //Load column 0 of C - " ldr q1, [x2, #16] \n\t" - " ldr q2, [x16] \n\t" //Load column 1 of C - " ldr q3, [x16, #16] \n\t" - " ldr q4, [x17] \n\t" //Load column 2 of C - " ldr q5, [x17, #16] \n\t" - " \n\t" - " fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta - " fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta - " fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta - " fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta - " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta - " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta - " \n\t" - LABEL(SBETAZEROCOLSTOREDS1) - " \n\t" - " fmla v0.4s,v8.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v1.4s,v9.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v2.4s,v10.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v3.4s,v11.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v4.4s,v12.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v5.4s,v13.4s,v6.s[0] \n\t" // Scale by alpha - " \n\t" - " str q0, [x2] \n\t" //Store column 0 of C - " str q1, [x2, #16] \n\t" - " str q2, [x16] \n\t" //Store column 1 of C - " str q3, [x16, #16] \n\t" - " str q4, [x17] \n\t" //Store column 2 of C - " str q5, [x17, #16] \n\t" - " \n\t" - " dup v8.4s, wzr \n\t" - " dup v9.4s, wzr \n\t" - " dup v10.4s, wzr \n\t" - " dup v11.4s, wzr \n\t" - " dup v12.4s, wzr \n\t" - " dup v13.4s, wzr \n\t" - " \n\t" - " fcmp s7,#0.0 \n\t" - BEQ(SBETAZEROCOLSTOREDS2) // Taking care of the beta==0 case. - " \n\t" - " ldr q8, [x19] \n\t" //Load column 3 of C - " ldr q9, [x19, #16] \n\t" - " ldr q10, [x20] \n\t" //Load column 4 of C - " ldr q11, [x20, #16] \n\t" - " ldr q12, [x21] \n\t" //Load column 5 of C - " ldr q13, [x21, #16] \n\t" - " \n\t" - " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta - " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta - " fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta - " fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta - " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta - " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta - " \n\t" - LABEL(SBETAZEROCOLSTOREDS2) - " \n\t" - " fmla v8.4s, v14.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v9.4s, v15.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v10.4s,v16.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v11.4s,v17.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v12.4s,v18.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v13.4s,v19.4s,v6.s[0] \n\t" // Scale by alpha - " \n\t" - " str q8, [x19] \n\t" //Store column 3 of C - " str q9, [x19, #16] \n\t" - " str q10, [x20] \n\t" //Store column 4 of C - " str q11, [x20, #16] \n\t" - " str q12, [x21] \n\t" //Store column 5 of C - " str q13, [x21, #16] \n\t" - " \n\t" - " dup v0.4s, wzr \n\t" - " dup v1.4s, wzr \n\t" - " dup v2.4s, wzr \n\t" - " dup v3.4s, wzr \n\t" - " dup v4.4s, wzr \n\t" - " dup v5.4s, wzr \n\t" - " \n\t" - " fcmp s7,#0.0 \n\t" - BEQ(SBETAZEROCOLSTOREDS3) // Taking care of the beta==0 case. - " \n\t" - " ldr q0, [x22] \n\t" //Load column 6 of C - " ldr q1, [x22, #16] \n\t" - " ldr q2, [x23] \n\t" //Load column 7 of C - " ldr q3, [x23, #16] \n\t" - " ldr q4, [x24] \n\t" //Load column 8 of C - " ldr q5, [x24, #16] \n\t" - " \n\t" - " fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta - " fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta - " fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta - " fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta - " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta - " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta - " \n\t" - LABEL(SBETAZEROCOLSTOREDS3) - " \n\t" - " fmla v0.4s,v20.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v1.4s,v21.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v2.4s,v22.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v3.4s,v23.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v4.4s,v24.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v5.4s,v25.4s,v6.s[0] \n\t" // Scale by alpha - " \n\t" - " str q0, [x22] \n\t" //Store column 6 of C - " str q1, [x22, #16] \n\t" - " str q2, [x23] \n\t" //Store column 7 of C - " str q3, [x23, #16] \n\t" - " str q4, [x24] \n\t" //Store column 8 of C - " str q5, [x24, #16] \n\t" - " \n\t" - " dup v8.4s, wzr \n\t" - " dup v9.4s, wzr \n\t" - " dup v10.4s, wzr \n\t" - " dup v11.4s, wzr \n\t" - " dup v12.4s, wzr \n\t" - " dup v13.4s, wzr \n\t" - " \n\t" - " fcmp s7,#0.0 \n\t" - BEQ(SBETAZEROCOLSTOREDS4) // Taking care of the beta==0 case. - " \n\t" - " ldr q8, [x25] \n\t" //Load column 9 of C - " ldr q9, [x25, #16] \n\t" - " ldr q10, [x26] \n\t" //Load column 10 of C - " ldr q11, [x26, #16] \n\t" - " ldr q12, [x27] \n\t" //Load column 11 of C - " ldr q13, [x27, #16] \n\t" - " \n\t" - " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta - " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta - " fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta - " fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta - " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta - " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta - " \n\t" - LABEL(SBETAZEROCOLSTOREDS4) - " \n\t" - " prfm pldl2keep,[x0] \n\t" - " prfm pldl2keep,[x1] \n\t" - " \n\t" - " fmla v8.4s, v26.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v9.4s, v27.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v10.4s,v28.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v11.4s,v29.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v12.4s,v30.4s,v6.s[0] \n\t" // Scale by alpha - " fmla v13.4s,v31.4s,v6.s[0] \n\t" // Scale by alpha - " \n\t" - " str q8, [x25] \n\t" //Store column 9 of C - " str q9, [x25, #16] \n\t" - " str q10, [x26] \n\t" //Store column 10 of C - " str q11, [x26, #16] \n\t" - " str q12, [x27] \n\t" //Store column 11 of C - " str q13, [x27, #16] \n\t" - " \n\t" -" \n\t" -BRANCH(SEND) // Done. -" \n\t" -" \n\t" -LABEL(SGENSTORED) // C is general-stride stored. -" \n\t" -" \n\t" -" dup v0.4s, wzr \n\t" -" dup v1.4s, wzr \n\t" -" dup v2.4s, wzr \n\t" -" dup v3.4s, wzr \n\t" -" dup v4.4s, wzr \n\t" -" dup v5.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -BEQ(SBETAZEROGENSTOREDS1) // Taking care of the beta==0 case. -" \n\t" -" mov x5, x2 \n\t" -" \n\t" -" ld1 {v0.s}[0],[x5],x14 \n\t" // Load c00 into quad and increment by rs_c. -" ld1 {v0.s}[1],[x5],x14 \n\t" // Load c01 into quad and increment by rs_c. -" ld1 {v0.s}[2],[x5],x14 \n\t" // Load c02 into quad and increment by rs_c. -" ld1 {v0.s}[3],[x5],x14 \n\t" // Load c03 into quad and increment by rs_c. -" ld1 {v1.s}[0],[x5],x14 \n\t" // Load c04 into quad and increment by rs_c. -" ld1 {v1.s}[1],[x5],x14 \n\t" // Load c05 into quad and increment by rs_c. -" ld1 {v1.s}[2],[x5],x14 \n\t" // Load c06 into quad and increment by rs_c. -" ld1 {v1.s}[3],[x5],x14 \n\t" // Load c07 into quad and increment by rs_c. -" \n\t" -" mov x5, x16 \n\t" -" \n\t" -" ld1 {v2.s}[0],[x5],x14 \n\t" // Load c10 into quad and increment by rs_c. -" ld1 {v2.s}[1],[x5],x14 \n\t" // Load c11 into quad and increment by rs_c. -" ld1 {v2.s}[2],[x5],x14 \n\t" // Load c12 into quad and increment by rs_c. -" ld1 {v2.s}[3],[x5],x14 \n\t" // Load c13 into quad and increment by rs_c. -" ld1 {v3.s}[0],[x5],x14 \n\t" // Load c14 into quad and increment by rs_c. -" ld1 {v3.s}[1],[x5],x14 \n\t" // Load c15 into quad and increment by rs_c. -" ld1 {v3.s}[2],[x5],x14 \n\t" // Load c16 into quad and increment by rs_c. -" ld1 {v3.s}[3],[x5],x14 \n\t" // Load c17 into quad and increment by rs_c. -" \n\t" -" mov x5, x17 \n\t" -" \n\t" -" ld1 {v4.s}[0],[x5],x14 \n\t" // Load c20 into quad and increment by rs_c. -" ld1 {v4.s}[1],[x5],x14 \n\t" // Load c21 into quad and increment by rs_c. -" ld1 {v4.s}[2],[x5],x14 \n\t" // Load c22 into quad and increment by rs_c. -" ld1 {v4.s}[3],[x5],x14 \n\t" // Load c23 into quad and increment by rs_c. -" ld1 {v5.s}[0],[x5],x14 \n\t" // Load c24 into quad and increment by rs_c. -" ld1 {v5.s}[1],[x5],x14 \n\t" // Load c25 into quad and increment by rs_c. -" ld1 {v5.s}[2],[x5],x14 \n\t" // Load c26 into quad and increment by rs_c. -" ld1 {v5.s}[3],[x5],x14 \n\t" // Load c27 into quad and increment by rs_c. -" \n\t" -" fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta -" fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta -" fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta -" fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta -" fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta -" fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -LABEL(SBETAZEROGENSTOREDS1) -" \n\t" -" fmla v0.4s, v8.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v1.4s, v9.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v2.4s,v10.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v3.4s,v11.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v4.4s,v12.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v5.4s,v13.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" mov x5, x2 \n\t" -" \n\t" -" st1 {v0.s}[0],[x5],x14 \n\t" // Store c00 into quad and increment by rs_c. -" st1 {v0.s}[1],[x5],x14 \n\t" // Store c01 into quad and increment by rs_c. -" st1 {v0.s}[2],[x5],x14 \n\t" // Store c02 into quad and increment by rs_c. -" st1 {v0.s}[3],[x5],x14 \n\t" // Store c03 into quad and increment by rs_c. -" st1 {v1.s}[0],[x5],x14 \n\t" // Store c04 into quad and increment by rs_c. -" st1 {v1.s}[1],[x5],x14 \n\t" // Store c05 into quad and increment by rs_c. -" st1 {v1.s}[2],[x5],x14 \n\t" // Store c06 into quad and increment by rs_c. -" st1 {v1.s}[3],[x5],x14 \n\t" // Store c07 into quad and increment by rs_c. -" \n\t" -" mov x5, x16 \n\t" -" \n\t" -" st1 {v2.s}[0],[x5],x14 \n\t" // Store c10 into quad and increment by rs_c. -" st1 {v2.s}[1],[x5],x14 \n\t" // Store c11 into quad and increment by rs_c. -" st1 {v2.s}[2],[x5],x14 \n\t" // Store c12 into quad and increment by rs_c. -" st1 {v2.s}[3],[x5],x14 \n\t" // Store c13 into quad and increment by rs_c. -" st1 {v3.s}[0],[x5],x14 \n\t" // Store c14 into quad and increment by rs_c. -" st1 {v3.s}[1],[x5],x14 \n\t" // Store c15 into quad and increment by rs_c. -" st1 {v3.s}[2],[x5],x14 \n\t" // Store c16 into quad and increment by rs_c. -" st1 {v3.s}[3],[x5],x14 \n\t" // Store c17 into quad and increment by rs_c. -" \n\t" -" mov x5, x17 \n\t" -" \n\t" -" st1 {v4.s}[0],[x5],x14 \n\t" // Store c20 into quad and increment by rs_c. -" st1 {v4.s}[1],[x5],x14 \n\t" // Store c21 into quad and increment by rs_c. -" st1 {v4.s}[2],[x5],x14 \n\t" // Store c22 into quad and increment by rs_c. -" st1 {v4.s}[3],[x5],x14 \n\t" // Store c23 into quad and increment by rs_c. -" st1 {v5.s}[0],[x5],x14 \n\t" // Store c24 into quad and increment by rs_c. -" st1 {v5.s}[1],[x5],x14 \n\t" // Store c25 into quad and increment by rs_c. -" st1 {v5.s}[2],[x5],x14 \n\t" // Store c26 into quad and increment by rs_c. -" st1 {v5.s}[3],[x5],x14 \n\t" // Store c27 into quad and increment by rs_c. -" \n\t" -" dup v8.4s, wzr \n\t" -" dup v9.4s, wzr \n\t" -" dup v10.4s, wzr \n\t" -" dup v11.4s, wzr \n\t" -" dup v12.4s, wzr \n\t" -" dup v13.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -BEQ(SBETAZEROGENSTOREDS2) // Taking care of the beta==0 case. -" \n\t" -" mov x5, x19 \n\t" -" \n\t" -" ld1 {v8.s}[0],[x5],x14 \n\t" // Load c30 into quad and increment by rs_c. -" ld1 {v8.s}[1],[x5],x14 \n\t" // Load c31 into quad and increment by rs_c. -" ld1 {v8.s}[2],[x5],x14 \n\t" // Load c32 into quad and increment by rs_c. -" ld1 {v8.s}[3],[x5],x14 \n\t" // Load c33 into quad and increment by rs_c. -" ld1 {v9.s}[0],[x5],x14 \n\t" // Load c34 into quad and increment by rs_c. -" ld1 {v9.s}[1],[x5],x14 \n\t" // Load c35 into quad and increment by rs_c. -" ld1 {v9.s}[2],[x5],x14 \n\t" // Load c36 into quad and increment by rs_c. -" ld1 {v9.s}[3],[x5],x14 \n\t" // Load c37 into quad and increment by rs_c. -" \n\t" -" mov x5, x20 \n\t" -" \n\t" -" ld1 {v10.s}[0],[x5],x14 \n\t" // Load c40 into quad and increment by rs_c. -" ld1 {v10.s}[1],[x5],x14 \n\t" // Load c41 into quad and increment by rs_c. -" ld1 {v10.s}[2],[x5],x14 \n\t" // Load c42 into quad and increment by rs_c. -" ld1 {v10.s}[3],[x5],x14 \n\t" // Load c43 into quad and increment by rs_c. -" ld1 {v11.s}[0],[x5],x14 \n\t" // Load c44 into quad and increment by rs_c. -" ld1 {v11.s}[1],[x5],x14 \n\t" // Load c45 into quad and increment by rs_c. -" ld1 {v11.s}[2],[x5],x14 \n\t" // Load c46 into quad and increment by rs_c. -" ld1 {v11.s}[3],[x5],x14 \n\t" // Load c47 into quad and increment by rs_c. -" \n\t" -" mov x5, x21 \n\t" -" \n\t" -" ld1 {v12.s}[0],[x5],x14 \n\t" // Load c50 into quad and increment by rs_c. -" ld1 {v12.s}[1],[x5],x14 \n\t" // Load c51 into quad and increment by rs_c. -" ld1 {v12.s}[2],[x5],x14 \n\t" // Load c52 into quad and increment by rs_c. -" ld1 {v12.s}[3],[x5],x14 \n\t" // Load c53 into quad and increment by rs_c. -" ld1 {v13.s}[0],[x5],x14 \n\t" // Load c54 into quad and increment by rs_c. -" ld1 {v13.s}[1],[x5],x14 \n\t" // Load c55 into quad and increment by rs_c. -" ld1 {v13.s}[2],[x5],x14 \n\t" // Load c56 into quad and increment by rs_c. -" ld1 {v13.s}[3],[x5],x14 \n\t" // Load c57 into quad and increment by rs_c. -" \n\t" -" fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta -" fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta -" fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta -" fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta -" fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta -" fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -LABEL(SBETAZEROGENSTOREDS2) -" \n\t" -" fmla v8.4s, v14.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v9.4s, v15.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v10.4s,v16.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v11.4s,v17.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v12.4s,v18.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v13.4s,v19.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" mov x5, x19 \n\t" -" \n\t" -" st1 {v8.s}[0],[x5],x14 \n\t" // Store c30 into quad and increment by rs_c. -" st1 {v8.s}[1],[x5],x14 \n\t" // Store c31 into quad and increment by rs_c. -" st1 {v8.s}[2],[x5],x14 \n\t" // Store c32 into quad and increment by rs_c. -" st1 {v8.s}[3],[x5],x14 \n\t" // Store c33 into quad and increment by rs_c. -" st1 {v9.s}[0],[x5],x14 \n\t" // Store c34 into quad and increment by rs_c. -" st1 {v9.s}[1],[x5],x14 \n\t" // Store c35 into quad and increment by rs_c. -" st1 {v9.s}[2],[x5],x14 \n\t" // Store c36 into quad and increment by rs_c. -" st1 {v9.s}[3],[x5],x14 \n\t" // Store c37 into quad and increment by rs_c. -" \n\t" -" mov x5, x20 \n\t" -" \n\t" -" st1 {v10.s}[0],[x5],x14 \n\t" // Store c40 into quad and increment by rs_c. -" st1 {v10.s}[1],[x5],x14 \n\t" // Store c41 into quad and increment by rs_c. -" st1 {v10.s}[2],[x5],x14 \n\t" // Store c42 into quad and increment by rs_c. -" st1 {v10.s}[3],[x5],x14 \n\t" // Store c43 into quad and increment by rs_c. -" st1 {v11.s}[0],[x5],x14 \n\t" // Store c44 into quad and increment by rs_c. -" st1 {v11.s}[1],[x5],x14 \n\t" // Store c45 into quad and increment by rs_c. -" st1 {v11.s}[2],[x5],x14 \n\t" // Store c46 into quad and increment by rs_c. -" st1 {v11.s}[3],[x5],x14 \n\t" // Store c47 into quad and increment by rs_c. -" \n\t" -" mov x5, x21 \n\t" -" \n\t" -" st1 {v12.s}[0],[x5],x14 \n\t" // Store c50 into quad and increment by rs_c. -" st1 {v12.s}[1],[x5],x14 \n\t" // Store c51 into quad and increment by rs_c. -" st1 {v12.s}[2],[x5],x14 \n\t" // Store c52 into quad and increment by rs_c. -" st1 {v12.s}[3],[x5],x14 \n\t" // Store c53 into quad and increment by rs_c. -" st1 {v13.s}[0],[x5],x14 \n\t" // Store c54 into quad and increment by rs_c. -" st1 {v13.s}[1],[x5],x14 \n\t" // Store c55 into quad and increment by rs_c. -" st1 {v13.s}[2],[x5],x14 \n\t" // Store c56 into quad and increment by rs_c. -" st1 {v13.s}[3],[x5],x14 \n\t" // Store c57 into quad and increment by rs_c. -" \n\t" -" dup v0.4s, wzr \n\t" -" dup v1.4s, wzr \n\t" -" dup v2.4s, wzr \n\t" -" dup v3.4s, wzr \n\t" -" dup v4.4s, wzr \n\t" -" dup v5.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -BEQ(SBETAZEROGENSTOREDS3) // Taking care of the beta==0 case. -" \n\t" -" mov x5, x22 \n\t" -" \n\t" -" ld1 {v0.s}[0],[x5],x14 \n\t" // Load c60 into quad and increment by rs_c. -" ld1 {v0.s}[1],[x5],x14 \n\t" // Load c61 into quad and increment by rs_c. -" ld1 {v0.s}[2],[x5],x14 \n\t" // Load c62 into quad and increment by rs_c. -" ld1 {v0.s}[3],[x5],x14 \n\t" // Load c63 into quad and increment by rs_c. -" ld1 {v1.s}[0],[x5],x14 \n\t" // Load c64 into quad and increment by rs_c. -" ld1 {v1.s}[1],[x5],x14 \n\t" // Load c65 into quad and increment by rs_c. -" ld1 {v1.s}[2],[x5],x14 \n\t" // Load c66 into quad and increment by rs_c. -" ld1 {v1.s}[3],[x5],x14 \n\t" // Load c67 into quad and increment by rs_c. -" \n\t" -" mov x5, x23 \n\t" -" \n\t" -" ld1 {v2.s}[0],[x5],x14 \n\t" // Load c70 into quad and increment by rs_c. -" ld1 {v2.s}[1],[x5],x14 \n\t" // Load c71 into quad and increment by rs_c. -" ld1 {v2.s}[2],[x5],x14 \n\t" // Load c72 into quad and increment by rs_c. -" ld1 {v2.s}[3],[x5],x14 \n\t" // Load c73 into quad and increment by rs_c. -" ld1 {v3.s}[0],[x5],x14 \n\t" // Load c74 into quad and increment by rs_c. -" ld1 {v3.s}[1],[x5],x14 \n\t" // Load c75 into quad and increment by rs_c. -" ld1 {v3.s}[2],[x5],x14 \n\t" // Load c76 into quad and increment by rs_c. -" ld1 {v3.s}[3],[x5],x14 \n\t" // Load c77 into quad and increment by rs_c. -" \n\t" -" mov x5, x24 \n\t" -" \n\t" -" ld1 {v4.s}[0],[x5],x14 \n\t" // Load c80 into quad and increment by rs_c. -" ld1 {v4.s}[1],[x5],x14 \n\t" // Load c81 into quad and increment by rs_c. -" ld1 {v4.s}[2],[x5],x14 \n\t" // Load c82 into quad and increment by rs_c. -" ld1 {v4.s}[3],[x5],x14 \n\t" // Load c83 into quad and increment by rs_c. -" ld1 {v5.s}[0],[x5],x14 \n\t" // Load c84 into quad and increment by rs_c. -" ld1 {v5.s}[1],[x5],x14 \n\t" // Load c85 into quad and increment by rs_c. -" ld1 {v5.s}[2],[x5],x14 \n\t" // Load c86 into quad and increment by rs_c. -" ld1 {v5.s}[3],[x5],x14 \n\t" // Load c87 into quad and increment by rs_c. -" \n\t" -" fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta -" fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta -" fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta -" fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta -" fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta -" fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -LABEL(SBETAZEROGENSTOREDS3) -" \n\t" -" fmla v0.4s,v20.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v1.4s,v21.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v2.4s,v22.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v3.4s,v23.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v4.4s,v24.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v5.4s,v25.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" mov x5, x22 \n\t" -" \n\t" -" st1 {v0.s}[0],[x5],x14 \n\t" // Store c60 into quad and increment by rs_c. -" st1 {v0.s}[1],[x5],x14 \n\t" // Store c61 into quad and increment by rs_c. -" st1 {v0.s}[2],[x5],x14 \n\t" // Store c62 into quad and increment by rs_c. -" st1 {v0.s}[3],[x5],x14 \n\t" // Store c63 into quad and increment by rs_c. -" st1 {v1.s}[0],[x5],x14 \n\t" // Store c64 into quad and increment by rs_c. -" st1 {v1.s}[1],[x5],x14 \n\t" // Store c65 into quad and increment by rs_c. -" st1 {v1.s}[2],[x5],x14 \n\t" // Store c66 into quad and increment by rs_c. -" st1 {v1.s}[3],[x5],x14 \n\t" // Store c67 into quad and increment by rs_c. -" \n\t" -" mov x5, x23 \n\t" -" \n\t" -" st1 {v2.s}[0],[x5],x14 \n\t" // Store c70 into quad and increment by rs_c. -" st1 {v2.s}[1],[x5],x14 \n\t" // Store c71 into quad and increment by rs_c. -" st1 {v2.s}[2],[x5],x14 \n\t" // Store c72 into quad and increment by rs_c. -" st1 {v2.s}[3],[x5],x14 \n\t" // Store c73 into quad and increment by rs_c. -" st1 {v3.s}[0],[x5],x14 \n\t" // Store c74 into quad and increment by rs_c. -" st1 {v3.s}[1],[x5],x14 \n\t" // Store c75 into quad and increment by rs_c. -" st1 {v3.s}[2],[x5],x14 \n\t" // Store c76 into quad and increment by rs_c. -" st1 {v3.s}[3],[x5],x14 \n\t" // Store c77 into quad and increment by rs_c. -" \n\t" -" mov x5, x24 \n\t" -" \n\t" -" st1 {v4.s}[0],[x5],x14 \n\t" // Store c80 into quad and increment by rs_c. -" st1 {v4.s}[1],[x5],x14 \n\t" // Store c81 into quad and increment by rs_c. -" st1 {v4.s}[2],[x5],x14 \n\t" // Store c82 into quad and increment by rs_c. -" st1 {v4.s}[3],[x5],x14 \n\t" // Store c83 into quad and increment by rs_c. -" st1 {v5.s}[0],[x5],x14 \n\t" // Store c84 into quad and increment by rs_c. -" st1 {v5.s}[1],[x5],x14 \n\t" // Store c85 into quad and increment by rs_c. -" st1 {v5.s}[2],[x5],x14 \n\t" // Store c86 into quad and increment by rs_c. -" st1 {v5.s}[3],[x5],x14 \n\t" // Store c87 into quad and increment by rs_c. -" \n\t" -" dup v8.4s, wzr \n\t" -" dup v9.4s, wzr \n\t" -" dup v10.4s, wzr \n\t" -" dup v11.4s, wzr \n\t" -" dup v12.4s, wzr \n\t" -" dup v13.4s, wzr \n\t" -" \n\t" -" fcmp s7,#0.0 \n\t" -BEQ(SBETAZEROGENSTOREDS4) // Taking care of the beta==0 case. -" \n\t" -" mov x5, x25 \n\t" -" \n\t" -" ld1 {v8.s}[0],[x5],x14 \n\t" // Load c90 into quad and increment by rs_c. -" ld1 {v8.s}[1],[x5],x14 \n\t" // Load c91 into quad and increment by rs_c. -" ld1 {v8.s}[2],[x5],x14 \n\t" // Load c92 into quad and increment by rs_c. -" ld1 {v8.s}[3],[x5],x14 \n\t" // Load c93 into quad and increment by rs_c. -" ld1 {v9.s}[0],[x5],x14 \n\t" // Load c94 into quad and increment by rs_c. -" ld1 {v9.s}[1],[x5],x14 \n\t" // Load c95 into quad and increment by rs_c. -" ld1 {v9.s}[2],[x5],x14 \n\t" // Load c96 into quad and increment by rs_c. -" ld1 {v9.s}[3],[x5],x14 \n\t" // Load c97 into quad and increment by rs_c. -" \n\t" -" mov x5, x26 \n\t" -" \n\t" -" ld1 {v10.s}[0],[x5],x14 \n\t" // Load c100 into quad and increment by rs_c. -" ld1 {v10.s}[1],[x5],x14 \n\t" // Load c101 into quad and increment by rs_c. -" ld1 {v10.s}[2],[x5],x14 \n\t" // Load c102 into quad and increment by rs_c. -" ld1 {v10.s}[3],[x5],x14 \n\t" // Load c103 into quad and increment by rs_c. -" ld1 {v11.s}[0],[x5],x14 \n\t" // Load c104 into quad and increment by rs_c. -" ld1 {v11.s}[1],[x5],x14 \n\t" // Load c105 into quad and increment by rs_c. -" ld1 {v11.s}[2],[x5],x14 \n\t" // Load c106 into quad and increment by rs_c. -" ld1 {v11.s}[3],[x5],x14 \n\t" // Load c107 into quad and increment by rs_c. -" \n\t" -" mov x5, x27 \n\t" -" \n\t" -" ld1 {v12.s}[0],[x5],x14 \n\t" // Load c110 into quad and increment by rs_c. -" ld1 {v12.s}[1],[x5],x14 \n\t" // Load c111 into quad and increment by rs_c. -" ld1 {v12.s}[2],[x5],x14 \n\t" // Load c112 into quad and increment by rs_c. -" ld1 {v12.s}[3],[x5],x14 \n\t" // Load c113 into quad and increment by rs_c. -" ld1 {v13.s}[0],[x5],x14 \n\t" // Load c114 into quad and increment by rs_c. -" ld1 {v13.s}[1],[x5],x14 \n\t" // Load c115 into quad and increment by rs_c. -" ld1 {v13.s}[2],[x5],x14 \n\t" // Load c116 into quad and increment by rs_c. -" ld1 {v13.s}[3],[x5],x14 \n\t" // Load c117 into quad and increment by rs_c. -" \n\t" -" fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta -" fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta -" fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta -" fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta -" fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta -" fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta -" \n\t" -LABEL(SBETAZEROGENSTOREDS4) -" \n\t" -" prfm pldl2keep,[x0] \n\t" -" prfm pldl2keep,[x1] \n\t" -" \n\t" -" fmla v8.4s, v26.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v9.4s, v27.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v10.4s,v28.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v11.4s,v29.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v12.4s,v30.4s,v6.s[0] \n\t" // Scale by alpha -" fmla v13.4s,v31.4s,v6.s[0] \n\t" // Scale by alpha -" \n\t" -" mov x5, x25 \n\t" -" \n\t" -" st1 {v8.s}[0],[x5],x14 \n\t" // Store c90 into quad and increment by rs_c. -" st1 {v8.s}[1],[x5],x14 \n\t" // Store c91 into quad and increment by rs_c. -" st1 {v8.s}[2],[x5],x14 \n\t" // Store c92 into quad and increment by rs_c. -" st1 {v8.s}[3],[x5],x14 \n\t" // Store c93 into quad and increment by rs_c. -" st1 {v9.s}[0],[x5],x14 \n\t" // Store c94 into quad and increment by rs_c. -" st1 {v9.s}[1],[x5],x14 \n\t" // Store c95 into quad and increment by rs_c. -" st1 {v9.s}[2],[x5],x14 \n\t" // Store c96 into quad and increment by rs_c. -" st1 {v9.s}[3],[x5],x14 \n\t" // Store c97 into quad and increment by rs_c. -" \n\t" -" mov x5, x26 \n\t" -" \n\t" -" st1 {v10.s}[0],[x5],x14 \n\t" // Store c100 into quad and increment by rs_c. -" st1 {v10.s}[1],[x5],x14 \n\t" // Store c101 into quad and increment by rs_c. -" st1 {v10.s}[2],[x5],x14 \n\t" // Store c102 into quad and increment by rs_c. -" st1 {v10.s}[3],[x5],x14 \n\t" // Store c103 into quad and increment by rs_c. -" st1 {v11.s}[0],[x5],x14 \n\t" // Store c104 into quad and increment by rs_c. -" st1 {v11.s}[1],[x5],x14 \n\t" // Store c105 into quad and increment by rs_c. -" st1 {v11.s}[2],[x5],x14 \n\t" // Store c106 into quad and increment by rs_c. -" st1 {v11.s}[3],[x5],x14 \n\t" // Store c107 into quad and increment by rs_c. -" \n\t" -" mov x5, x27 \n\t" -" \n\t" -" st1 {v12.s}[0],[x5],x14 \n\t" // Store c110 into quad and increment by rs_c. -" st1 {v12.s}[1],[x5],x14 \n\t" // Store c111 into quad and increment by rs_c. -" st1 {v12.s}[2],[x5],x14 \n\t" // Store c112 into quad and increment by rs_c. -" st1 {v12.s}[3],[x5],x14 \n\t" // Store c113 into quad and increment by rs_c. -" st1 {v13.s}[0],[x5],x14 \n\t" // Store c114 into quad and increment by rs_c. -" st1 {v13.s}[1],[x5],x14 \n\t" // Store c115 into quad and increment by rs_c. -" st1 {v13.s}[2],[x5],x14 \n\t" // Store c116 into quad and increment by rs_c. -" st1 {v13.s}[3],[x5],x14 \n\t" // Store c147 into quad and increment by rs_c. -" \n\t" - LABEL(SEND) // Done! - " \n\t" - :// output operands (none) - :// input operands - [aaddr] "m" (a), // 0 - [baddr] "m" (b), // 1 - [caddr] "m" (c), // 2 - [k_iter] "m" (k_iter), // 3 - [k_left] "m" (k_left), // 4 - [alpha] "m" (alpha), // 5 - [beta] "m" (beta), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [a_next] "m" (a_next), // 9 - [b_next] "m" (b_next) // 10 - :// Register clobber list - "x0", "x1", "x2", - "x5", "x6", "x10","x14", - "x16","x17","x19","x20", - "x21","x22","x23","x24", - "x25","x26","x27", - "v0", "v1", "v2", "v3", - "v4", "v5", "v6", "v7", - "v8", "v9", "v10","v11", - "v12","v13","v14","v15", - "v16","v17","v18","v19", - "v20","v21","v22","v23", - "v24","v25","v26","v27", - "v28","v29","v30","v31" - ); + __asm__ volatile + ( + " \n\t" + " \n\t" + " ldr x0,%[aaddr] \n\t" // Load address of A. + " ldr x1,%[baddr] \n\t" // Load address of B. + " ldr x2,%[caddr] \n\t" // Load address of C. + " \n\t" + " ldr x5,%[k_iter] \n\t" // Number of unrolled iterations (k_iter). + " ldr x6,%[k_left] \n\t" // Number of remaining iterations (k_left). + " \n\t" + " ldr x10,%[cs_c] \n\t" // Load cs_c. + " lsl x10,x10,#2 \n\t" // cs_c * sizeof(float) -- AUX. + " \n\t" + " ldr x14,%[rs_c] \n\t" // Load rs_c. + " lsl x14,x14,#2 \n\t" // rs_c * sizeof(float). + " \n\t" + " add x16,x2,x10 \n\t" //Load address Column 1 of C + " add x17,x16,x10 \n\t" //Load address Column 2 of C + " add x19,x17,x10 \n\t" //Load address Column 3 of C + " add x20,x19,x10 \n\t" //Load address Column 4 of C + " add x21,x20,x10 \n\t" //Load address Column 5 of C + " add x22,x21,x10 \n\t" //Load address Column 6 of C + " add x23,x22,x10 \n\t" //Load address Column 7 of C + " add x24,x23,x10 \n\t" //Load address Column 8 of C + " add x25,x24,x10 \n\t" //Load address Column 9 of C + " add x26,x25,x10 \n\t" //Load address Column 10 of C + " add x27,x26,x10 \n\t" //Load address Column 11 of C + " \n\t" + " prfm pldl1keep,[x2] \n\t" // Prefetch c. + " prfm pldl1keep,[x16] \n\t" // Prefetch c. + " prfm pldl1keep,[x17] \n\t" // Prefetch c. + " prfm pldl1keep,[x19] \n\t" // Prefetch c. + " prfm pldl1keep,[x20] \n\t" // Prefetch c. + " prfm pldl1keep,[x21] \n\t" // Prefetch c. + " prfm pldl1keep,[x22] \n\t" // Prefetch c. + " prfm pldl1keep,[x23] \n\t" // Prefetch c. + " prfm pldl1keep,[x24] \n\t" // Prefetch c. + " prfm pldl1keep,[x25] \n\t" // Prefetch c. + " prfm pldl1keep,[x26] \n\t" // Prefetch c. + " prfm pldl1keep,[x27] \n\t" // Prefetch c. + " \n\t" + " dup v8.4s, wzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #192] \n\t" + " dup v9.4s, wzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #256] \n\t" + " dup v10.4s, wzr \n\t" // Vector for accummulating column 1 + " prfm PLDL1KEEP, [x1, #320] \n\t" + " dup v11.4s, wzr \n\t" // Vector for accummulating column 1 + " dup v12.4s, wzr \n\t" // Vector for accummulating column 2 + " dup v13.4s, wzr \n\t" // Vector for accummulating column 2 + " \n\t" + " dup v14.4s, wzr \n\t" // Vector for accummulating column 3 + " prfm PLDL1KEEP, [x0, #128] \n\t" + " dup v15.4s, wzr \n\t" // Vector for accummulating column 3 + " prfm PLDL1KEEP, [x0, #192] \n\t" + " dup v16.4s, wzr \n\t" // Vector for accummulating column 4 + " dup v17.4s, wzr \n\t" // Vector for accummulating column 4 + " dup v18.4s, wzr \n\t" // Vector for accummulating column 5 + " dup v19.4s, wzr \n\t" // Vector for accummulating column 5 + " \n\t" + " dup v20.4s, wzr \n\t" // Vector for accummulating column 6 + " dup v21.4s, wzr \n\t" // Vector for accummulating column 6 + " dup v22.4s, wzr \n\t" // Vector for accummulating column 7 + " dup v23.4s, wzr \n\t" // Vector for accummulating column 7 + " dup v24.4s, wzr \n\t" // Vector for accummulating column 8 + " dup v25.4s, wzr \n\t" // Vector for accummulating column 8 + " \n\t" + " dup v26.4s, wzr \n\t" // Vector for accummulating column 9 + " dup v27.4s, wzr \n\t" // Vector for accummulating column 9 + " dup v28.4s, wzr \n\t" // Vector for accummulating column 10 + " dup v29.4s, wzr \n\t" // Vector for accummulating column 10 + " dup v30.4s, wzr \n\t" // Vector for accummulating column 11 + " dup v31.4s, wzr \n\t" // Vector for accummulating column 11 + " \n\t" + " cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. + BEQ(SCONSIDERKLEFT) + " \n\t" + " ldr q0, [x0] \n\t" + " ldr q1, [x0, #16] \n\t" // Load a + " \n\t" + " ldr q2, [x1] \n\t" // Load b + " ldr q3, [x1, #16] \n\t" + " ldr q4, [x1, #32] \n\t" + " \n\t" + " add x0, x0, #32 \n\t" //update address of A + " add x1, x1, #48 \n\t" //update address of B + " \n\t" + " cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. + BEQ(SLASTITER) // (as loop is do-while-like). + " \n\t" + LABEL(SLOOPKITER) // Body of the k_iter loop. + " \n\t" + " ldr q5, [x0] \n\t" + " fmla v8.4s, v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s, v1.4s,v2.s[0] \n\t" // Accummulate. + " ldr q6, [x0, #16] \n\t" + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1] \n\t" + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x1, #336] \n\t" + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x1, #400] \n\t" + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x1, #464] \n\t" + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #16] \n\t" + " \n\t" + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #32] \n\t" + " \n\t" //End It 1 + " \n\t" + " ldr q0, [x0, #32] \n\t" + " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. + " ldr q1, [x0, #48] \n\t" + " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #48] \n\t" + " \n\t" + " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x0, #224] \n\t" + " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. + " prfm PLDL1KEEP, [x0, #288] \n\t" + " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #64] \n\t" + " \n\t" + " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #80] \n\t" + " \n\t" //End It 2 + " \n\t" + " ldr q5, [x0, #64] \n\t" + " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. + " ldr q6, [x0, #80] \n\t" + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #96] \n\t" + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #112] \n\t" + " \n\t" + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #128] \n\t" + " \n\t" //End It 3 + " \n\t" + " ldr q0, [x0, #96] \n\t" + " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. + " ldr q1, [x0, #112] \n\t" + " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #144] \n\t" + " \n\t" + " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #160] \n\t" + " \n\t" + " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #176] \n\t" + " add x1, x1, #192 \n\t" + " add x0, x0, #128 \n\t" + " \n\t" //End It 4 + " sub x5,x5,1 \n\t" // i-=1. + " cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. + BNE(SLOOPKITER) + " \n\t" + LABEL(SLASTITER) // Last iteration of k_iter loop. + " \n\t" + " \n\t" + " ldr q5, [x0] \n\t" + " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. + " ldr q6, [x0, #16] \n\t" + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1] \n\t" + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #16] \n\t" + " \n\t" + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #32] \n\t" + " \n\t" //End It 1 + " \n\t" + " ldr q0, [x0, #32] \n\t" + " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. + " ldr q1, [x0, #48] \n\t" + " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #48] \n\t" + " \n\t" + " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #64] \n\t" + " \n\t" + " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #80] \n\t" + " \n\t" //End It 2 + " \n\t" + " ldr q5, [x0, #64] \n\t" + " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. + " ldr q6, [x0, #80] \n\t" + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " ldr q2, [x1, #96] \n\t" + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " ldr q3, [x1, #112] \n\t" + " \n\t" + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " ldr q4, [x1, #128] \n\t" + " \n\t" //End It 3 + " \n\t" + " fmla v8.4s,v5.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v6.4s,v2.s[0] \n\t" // Accummulate. + " fmla v10.4s,v5.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v6.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v5.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v6.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v5.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v6.4s,v2.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v16.4s,v5.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v6.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v5.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v6.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v5.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v6.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v5.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v6.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v5.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v5.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v5.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v5.4s,v4.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v25.4s,v6.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v6.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v6.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v6.4s,v4.s[3] \n\t" // Accummulate. + " add x1, x1, #144 \n\t" + " add x0, x0, #96 \n\t" + " \n\t" //End It 4 + " \n\t" + LABEL(SCONSIDERKLEFT) + " cmp x6,0 \n\t" // If k_left == 0, we are done. + BEQ(SPOSTACCUM) // else, we enter the k_left loop. + " \n\t" + LABEL(SLOOPKLEFT) // Body of the left iterations + " \n\t" + " ldr q0, [x0],#16 \n\t" + " ldr q1, [x0],#16 \n\t" // Load a + " \n\t" + " ldr q2, [x1],#16 \n\t" // Load b + " ldr q3, [x1],#16 \n\t" + " ldr q4, [x1],#16 \n\t" + " \n\t" + " sub x6,x6,1 \n\t" // i = i-1. + " \n\t" + " fmla v8.4s,v0.4s,v2.s[0] \n\t" // Accummulate. + " fmla v9.4s,v1.4s,v2.s[0] \n\t" // Accummulate. + " fmla v10.4s,v0.4s,v2.s[1] \n\t" // Accummulate. + " fmla v11.4s,v1.4s,v2.s[1] \n\t" // Accummulate. + " fmla v12.4s,v0.4s,v2.s[2] \n\t" // Accummulate. + " fmla v13.4s,v1.4s,v2.s[2] \n\t" // Accummulate. + " fmla v14.4s,v0.4s,v2.s[3] \n\t" // Accummulate. + " fmla v15.4s,v1.4s,v2.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v16.4s,v0.4s,v3.s[0] \n\t" // Accummulate. + " fmla v17.4s,v1.4s,v3.s[0] \n\t" // Accummulate. + " fmla v18.4s,v0.4s,v3.s[1] \n\t" // Accummulate. + " fmla v19.4s,v1.4s,v3.s[1] \n\t" // Accummulate. + " fmla v20.4s,v0.4s,v3.s[2] \n\t" // Accummulate. + " fmla v21.4s,v1.4s,v3.s[2] \n\t" // Accummulate. + " fmla v22.4s,v0.4s,v3.s[3] \n\t" // Accummulate. + " fmla v23.4s,v1.4s,v3.s[3] \n\t" // Accummulate. + " \n\t" + " fmla v24.4s,v0.4s,v4.s[0] \n\t" // Accummulate. + " fmla v26.4s,v0.4s,v4.s[1] \n\t" // Accummulate. + " fmla v28.4s,v0.4s,v4.s[2] \n\t" // Accummulate. + " fmla v30.4s,v0.4s,v4.s[3] \n\t" // Accummulate. + " fmla v25.4s,v1.4s,v4.s[0] \n\t" // Accummulate. + " fmla v27.4s,v1.4s,v4.s[1] \n\t" // Accummulate. + " fmla v29.4s,v1.4s,v4.s[2] \n\t" // Accummulate. + " fmla v31.4s,v1.4s,v4.s[3] \n\t" // Accummulate. + " \n\t" + " cmp x6,0 \n\t" // Iterate again. + BNE(SLOOPKLEFT) // if i!=0. + " \n\t" + LABEL(SPOSTACCUM) + " \n\t" + " ldr x0,%[alpha] \n\t" // Alpha address. + " ldr x1,%[beta] \n\t" // Beta address. + " \n\t" + " ld1r {v6.4s},[x0] \n\t" // Load alpha. + " ld1r {v7.4s},[x1] \n\t" // Load beta + " \n\t" + " ldr x0,%[a_next] \n\t" // Pointer to next block of A. + " ldr x1,%[b_next] \n\t" // Pointer to next pointer of B. + " \n\t" + " cmp x14,#4 \n\t" // If rs_c != 1 (column-major) + BNE(SGENSTORED) + " \n\t" + LABEL(SCOLSTORED) // C is column-major. + " \n\t" + " dup v0.4s, wzr \n\t" + " dup v1.4s, wzr \n\t" + " dup v2.4s, wzr \n\t" + " dup v3.4s, wzr \n\t" + " dup v4.4s, wzr \n\t" + " dup v5.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROCOLSTOREDS1) // Taking care of the beta==0 case. + " \n\t" + " ldr q0, [x2] \n\t" //Load column 0 of C + " ldr q1, [x2, #16] \n\t" + " ldr q2, [x16] \n\t" //Load column 1 of C + " ldr q3, [x16, #16] \n\t" + " ldr q4, [x17] \n\t" //Load column 2 of C + " ldr q5, [x17, #16] \n\t" + " \n\t" + " fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta + " fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta + " fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta + " fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta + " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta + " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROCOLSTOREDS1) + " \n\t" + " fmla v0.4s,v8.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v1.4s,v9.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v2.4s,v10.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v3.4s,v11.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v4.4s,v12.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v5.4s,v13.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " str q0, [x2] \n\t" //Store column 0 of C + " str q1, [x2, #16] \n\t" + " str q2, [x16] \n\t" //Store column 1 of C + " str q3, [x16, #16] \n\t" + " str q4, [x17] \n\t" //Store column 2 of C + " str q5, [x17, #16] \n\t" + " \n\t" + " dup v8.4s, wzr \n\t" + " dup v9.4s, wzr \n\t" + " dup v10.4s, wzr \n\t" + " dup v11.4s, wzr \n\t" + " dup v12.4s, wzr \n\t" + " dup v13.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROCOLSTOREDS2) // Taking care of the beta==0 case. + " \n\t" + " ldr q8, [x19] \n\t" //Load column 3 of C + " ldr q9, [x19, #16] \n\t" + " ldr q10, [x20] \n\t" //Load column 4 of C + " ldr q11, [x20, #16] \n\t" + " ldr q12, [x21] \n\t" //Load column 5 of C + " ldr q13, [x21, #16] \n\t" + " \n\t" + " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta + " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta + " fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta + " fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta + " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta + " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROCOLSTOREDS2) + " \n\t" + " fmla v8.4s, v14.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v9.4s, v15.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v10.4s,v16.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v11.4s,v17.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v12.4s,v18.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v13.4s,v19.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " str q8, [x19] \n\t" //Store column 3 of C + " str q9, [x19, #16] \n\t" + " str q10, [x20] \n\t" //Store column 4 of C + " str q11, [x20, #16] \n\t" + " str q12, [x21] \n\t" //Store column 5 of C + " str q13, [x21, #16] \n\t" + " \n\t" + " dup v0.4s, wzr \n\t" + " dup v1.4s, wzr \n\t" + " dup v2.4s, wzr \n\t" + " dup v3.4s, wzr \n\t" + " dup v4.4s, wzr \n\t" + " dup v5.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROCOLSTOREDS3) // Taking care of the beta==0 case. + " \n\t" + " ldr q0, [x22] \n\t" //Load column 6 of C + " ldr q1, [x22, #16] \n\t" + " ldr q2, [x23] \n\t" //Load column 7 of C + " ldr q3, [x23, #16] \n\t" + " ldr q4, [x24] \n\t" //Load column 8 of C + " ldr q5, [x24, #16] \n\t" + " \n\t" + " fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta + " fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta + " fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta + " fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta + " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta + " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROCOLSTOREDS3) + " \n\t" + " fmla v0.4s,v20.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v1.4s,v21.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v2.4s,v22.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v3.4s,v23.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v4.4s,v24.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v5.4s,v25.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " str q0, [x22] \n\t" //Store column 6 of C + " str q1, [x22, #16] \n\t" + " str q2, [x23] \n\t" //Store column 7 of C + " str q3, [x23, #16] \n\t" + " str q4, [x24] \n\t" //Store column 8 of C + " str q5, [x24, #16] \n\t" + " \n\t" + " dup v8.4s, wzr \n\t" + " dup v9.4s, wzr \n\t" + " dup v10.4s, wzr \n\t" + " dup v11.4s, wzr \n\t" + " dup v12.4s, wzr \n\t" + " dup v13.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROCOLSTOREDS4) // Taking care of the beta==0 case. + " \n\t" + " ldr q8, [x25] \n\t" //Load column 9 of C + " ldr q9, [x25, #16] \n\t" + " ldr q10, [x26] \n\t" //Load column 10 of C + " ldr q11, [x26, #16] \n\t" + " ldr q12, [x27] \n\t" //Load column 11 of C + " ldr q13, [x27, #16] \n\t" + " \n\t" + " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta + " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta + " fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta + " fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta + " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta + " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROCOLSTOREDS4) + " \n\t" + " prfm pldl2keep,[x0] \n\t" + " prfm pldl2keep,[x1] \n\t" + " \n\t" + " fmla v8.4s, v26.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v9.4s, v27.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v10.4s,v28.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v11.4s,v29.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v12.4s,v30.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v13.4s,v31.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " str q8, [x25] \n\t" //Store column 9 of C + " str q9, [x25, #16] \n\t" + " str q10, [x26] \n\t" //Store column 10 of C + " str q11, [x26, #16] \n\t" + " str q12, [x27] \n\t" //Store column 11 of C + " str q13, [x27, #16] \n\t" + " \n\t" + " \n\t" + BRANCH(SEND) // Done. + " \n\t" + " \n\t" + LABEL(SGENSTORED) // C is general-stride stored. + " \n\t" + " \n\t" + " dup v0.4s, wzr \n\t" + " dup v1.4s, wzr \n\t" + " dup v2.4s, wzr \n\t" + " dup v3.4s, wzr \n\t" + " dup v4.4s, wzr \n\t" + " dup v5.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROGENSTOREDS1) // Taking care of the beta==0 case. + " \n\t" + " mov x5, x2 \n\t" + " \n\t" + " ld1 {v0.s}[0],[x5],x14 \n\t" // Load c00 into quad and increment by rs_c. + " ld1 {v0.s}[1],[x5],x14 \n\t" // Load c01 into quad and increment by rs_c. + " ld1 {v0.s}[2],[x5],x14 \n\t" // Load c02 into quad and increment by rs_c. + " ld1 {v0.s}[3],[x5],x14 \n\t" // Load c03 into quad and increment by rs_c. + " ld1 {v1.s}[0],[x5],x14 \n\t" // Load c04 into quad and increment by rs_c. + " ld1 {v1.s}[1],[x5],x14 \n\t" // Load c05 into quad and increment by rs_c. + " ld1 {v1.s}[2],[x5],x14 \n\t" // Load c06 into quad and increment by rs_c. + " ld1 {v1.s}[3],[x5],x14 \n\t" // Load c07 into quad and increment by rs_c. + " \n\t" + " mov x5, x16 \n\t" + " \n\t" + " ld1 {v2.s}[0],[x5],x14 \n\t" // Load c10 into quad and increment by rs_c. + " ld1 {v2.s}[1],[x5],x14 \n\t" // Load c11 into quad and increment by rs_c. + " ld1 {v2.s}[2],[x5],x14 \n\t" // Load c12 into quad and increment by rs_c. + " ld1 {v2.s}[3],[x5],x14 \n\t" // Load c13 into quad and increment by rs_c. + " ld1 {v3.s}[0],[x5],x14 \n\t" // Load c14 into quad and increment by rs_c. + " ld1 {v3.s}[1],[x5],x14 \n\t" // Load c15 into quad and increment by rs_c. + " ld1 {v3.s}[2],[x5],x14 \n\t" // Load c16 into quad and increment by rs_c. + " ld1 {v3.s}[3],[x5],x14 \n\t" // Load c17 into quad and increment by rs_c. + " \n\t" + " mov x5, x17 \n\t" + " \n\t" + " ld1 {v4.s}[0],[x5],x14 \n\t" // Load c20 into quad and increment by rs_c. + " ld1 {v4.s}[1],[x5],x14 \n\t" // Load c21 into quad and increment by rs_c. + " ld1 {v4.s}[2],[x5],x14 \n\t" // Load c22 into quad and increment by rs_c. + " ld1 {v4.s}[3],[x5],x14 \n\t" // Load c23 into quad and increment by rs_c. + " ld1 {v5.s}[0],[x5],x14 \n\t" // Load c24 into quad and increment by rs_c. + " ld1 {v5.s}[1],[x5],x14 \n\t" // Load c25 into quad and increment by rs_c. + " ld1 {v5.s}[2],[x5],x14 \n\t" // Load c26 into quad and increment by rs_c. + " ld1 {v5.s}[3],[x5],x14 \n\t" // Load c27 into quad and increment by rs_c. + " \n\t" + " fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta + " fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta + " fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta + " fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta + " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta + " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROGENSTOREDS1) + " \n\t" + " fmla v0.4s, v8.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v1.4s, v9.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v2.4s,v10.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v3.4s,v11.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v4.4s,v12.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v5.4s,v13.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " mov x5, x2 \n\t" + " \n\t" + " st1 {v0.s}[0],[x5],x14 \n\t" // Store c00 into quad and increment by rs_c. + " st1 {v0.s}[1],[x5],x14 \n\t" // Store c01 into quad and increment by rs_c. + " st1 {v0.s}[2],[x5],x14 \n\t" // Store c02 into quad and increment by rs_c. + " st1 {v0.s}[3],[x5],x14 \n\t" // Store c03 into quad and increment by rs_c. + " st1 {v1.s}[0],[x5],x14 \n\t" // Store c04 into quad and increment by rs_c. + " st1 {v1.s}[1],[x5],x14 \n\t" // Store c05 into quad and increment by rs_c. + " st1 {v1.s}[2],[x5],x14 \n\t" // Store c06 into quad and increment by rs_c. + " st1 {v1.s}[3],[x5],x14 \n\t" // Store c07 into quad and increment by rs_c. + " \n\t" + " mov x5, x16 \n\t" + " \n\t" + " st1 {v2.s}[0],[x5],x14 \n\t" // Store c10 into quad and increment by rs_c. + " st1 {v2.s}[1],[x5],x14 \n\t" // Store c11 into quad and increment by rs_c. + " st1 {v2.s}[2],[x5],x14 \n\t" // Store c12 into quad and increment by rs_c. + " st1 {v2.s}[3],[x5],x14 \n\t" // Store c13 into quad and increment by rs_c. + " st1 {v3.s}[0],[x5],x14 \n\t" // Store c14 into quad and increment by rs_c. + " st1 {v3.s}[1],[x5],x14 \n\t" // Store c15 into quad and increment by rs_c. + " st1 {v3.s}[2],[x5],x14 \n\t" // Store c16 into quad and increment by rs_c. + " st1 {v3.s}[3],[x5],x14 \n\t" // Store c17 into quad and increment by rs_c. + " \n\t" + " mov x5, x17 \n\t" + " \n\t" + " st1 {v4.s}[0],[x5],x14 \n\t" // Store c20 into quad and increment by rs_c. + " st1 {v4.s}[1],[x5],x14 \n\t" // Store c21 into quad and increment by rs_c. + " st1 {v4.s}[2],[x5],x14 \n\t" // Store c22 into quad and increment by rs_c. + " st1 {v4.s}[3],[x5],x14 \n\t" // Store c23 into quad and increment by rs_c. + " st1 {v5.s}[0],[x5],x14 \n\t" // Store c24 into quad and increment by rs_c. + " st1 {v5.s}[1],[x5],x14 \n\t" // Store c25 into quad and increment by rs_c. + " st1 {v5.s}[2],[x5],x14 \n\t" // Store c26 into quad and increment by rs_c. + " st1 {v5.s}[3],[x5],x14 \n\t" // Store c27 into quad and increment by rs_c. + " \n\t" + " dup v8.4s, wzr \n\t" + " dup v9.4s, wzr \n\t" + " dup v10.4s, wzr \n\t" + " dup v11.4s, wzr \n\t" + " dup v12.4s, wzr \n\t" + " dup v13.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROGENSTOREDS2) // Taking care of the beta==0 case. + " \n\t" + " mov x5, x19 \n\t" + " \n\t" + " ld1 {v8.s}[0],[x5],x14 \n\t" // Load c30 into quad and increment by rs_c. + " ld1 {v8.s}[1],[x5],x14 \n\t" // Load c31 into quad and increment by rs_c. + " ld1 {v8.s}[2],[x5],x14 \n\t" // Load c32 into quad and increment by rs_c. + " ld1 {v8.s}[3],[x5],x14 \n\t" // Load c33 into quad and increment by rs_c. + " ld1 {v9.s}[0],[x5],x14 \n\t" // Load c34 into quad and increment by rs_c. + " ld1 {v9.s}[1],[x5],x14 \n\t" // Load c35 into quad and increment by rs_c. + " ld1 {v9.s}[2],[x5],x14 \n\t" // Load c36 into quad and increment by rs_c. + " ld1 {v9.s}[3],[x5],x14 \n\t" // Load c37 into quad and increment by rs_c. + " \n\t" + " mov x5, x20 \n\t" + " \n\t" + " ld1 {v10.s}[0],[x5],x14 \n\t" // Load c40 into quad and increment by rs_c. + " ld1 {v10.s}[1],[x5],x14 \n\t" // Load c41 into quad and increment by rs_c. + " ld1 {v10.s}[2],[x5],x14 \n\t" // Load c42 into quad and increment by rs_c. + " ld1 {v10.s}[3],[x5],x14 \n\t" // Load c43 into quad and increment by rs_c. + " ld1 {v11.s}[0],[x5],x14 \n\t" // Load c44 into quad and increment by rs_c. + " ld1 {v11.s}[1],[x5],x14 \n\t" // Load c45 into quad and increment by rs_c. + " ld1 {v11.s}[2],[x5],x14 \n\t" // Load c46 into quad and increment by rs_c. + " ld1 {v11.s}[3],[x5],x14 \n\t" // Load c47 into quad and increment by rs_c. + " \n\t" + " mov x5, x21 \n\t" + " \n\t" + " ld1 {v12.s}[0],[x5],x14 \n\t" // Load c50 into quad and increment by rs_c. + " ld1 {v12.s}[1],[x5],x14 \n\t" // Load c51 into quad and increment by rs_c. + " ld1 {v12.s}[2],[x5],x14 \n\t" // Load c52 into quad and increment by rs_c. + " ld1 {v12.s}[3],[x5],x14 \n\t" // Load c53 into quad and increment by rs_c. + " ld1 {v13.s}[0],[x5],x14 \n\t" // Load c54 into quad and increment by rs_c. + " ld1 {v13.s}[1],[x5],x14 \n\t" // Load c55 into quad and increment by rs_c. + " ld1 {v13.s}[2],[x5],x14 \n\t" // Load c56 into quad and increment by rs_c. + " ld1 {v13.s}[3],[x5],x14 \n\t" // Load c57 into quad and increment by rs_c. + " \n\t" + " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta + " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta + " fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta + " fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta + " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta + " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROGENSTOREDS2) + " \n\t" + " fmla v8.4s, v14.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v9.4s, v15.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v10.4s,v16.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v11.4s,v17.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v12.4s,v18.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v13.4s,v19.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " mov x5, x19 \n\t" + " \n\t" + " st1 {v8.s}[0],[x5],x14 \n\t" // Store c30 into quad and increment by rs_c. + " st1 {v8.s}[1],[x5],x14 \n\t" // Store c31 into quad and increment by rs_c. + " st1 {v8.s}[2],[x5],x14 \n\t" // Store c32 into quad and increment by rs_c. + " st1 {v8.s}[3],[x5],x14 \n\t" // Store c33 into quad and increment by rs_c. + " st1 {v9.s}[0],[x5],x14 \n\t" // Store c34 into quad and increment by rs_c. + " st1 {v9.s}[1],[x5],x14 \n\t" // Store c35 into quad and increment by rs_c. + " st1 {v9.s}[2],[x5],x14 \n\t" // Store c36 into quad and increment by rs_c. + " st1 {v9.s}[3],[x5],x14 \n\t" // Store c37 into quad and increment by rs_c. + " \n\t" + " mov x5, x20 \n\t" + " \n\t" + " st1 {v10.s}[0],[x5],x14 \n\t" // Store c40 into quad and increment by rs_c. + " st1 {v10.s}[1],[x5],x14 \n\t" // Store c41 into quad and increment by rs_c. + " st1 {v10.s}[2],[x5],x14 \n\t" // Store c42 into quad and increment by rs_c. + " st1 {v10.s}[3],[x5],x14 \n\t" // Store c43 into quad and increment by rs_c. + " st1 {v11.s}[0],[x5],x14 \n\t" // Store c44 into quad and increment by rs_c. + " st1 {v11.s}[1],[x5],x14 \n\t" // Store c45 into quad and increment by rs_c. + " st1 {v11.s}[2],[x5],x14 \n\t" // Store c46 into quad and increment by rs_c. + " st1 {v11.s}[3],[x5],x14 \n\t" // Store c47 into quad and increment by rs_c. + " \n\t" + " mov x5, x21 \n\t" + " \n\t" + " st1 {v12.s}[0],[x5],x14 \n\t" // Store c50 into quad and increment by rs_c. + " st1 {v12.s}[1],[x5],x14 \n\t" // Store c51 into quad and increment by rs_c. + " st1 {v12.s}[2],[x5],x14 \n\t" // Store c52 into quad and increment by rs_c. + " st1 {v12.s}[3],[x5],x14 \n\t" // Store c53 into quad and increment by rs_c. + " st1 {v13.s}[0],[x5],x14 \n\t" // Store c54 into quad and increment by rs_c. + " st1 {v13.s}[1],[x5],x14 \n\t" // Store c55 into quad and increment by rs_c. + " st1 {v13.s}[2],[x5],x14 \n\t" // Store c56 into quad and increment by rs_c. + " st1 {v13.s}[3],[x5],x14 \n\t" // Store c57 into quad and increment by rs_c. + " \n\t" + " dup v0.4s, wzr \n\t" + " dup v1.4s, wzr \n\t" + " dup v2.4s, wzr \n\t" + " dup v3.4s, wzr \n\t" + " dup v4.4s, wzr \n\t" + " dup v5.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROGENSTOREDS3) // Taking care of the beta==0 case. + " \n\t" + " mov x5, x22 \n\t" + " \n\t" + " ld1 {v0.s}[0],[x5],x14 \n\t" // Load c60 into quad and increment by rs_c. + " ld1 {v0.s}[1],[x5],x14 \n\t" // Load c61 into quad and increment by rs_c. + " ld1 {v0.s}[2],[x5],x14 \n\t" // Load c62 into quad and increment by rs_c. + " ld1 {v0.s}[3],[x5],x14 \n\t" // Load c63 into quad and increment by rs_c. + " ld1 {v1.s}[0],[x5],x14 \n\t" // Load c64 into quad and increment by rs_c. + " ld1 {v1.s}[1],[x5],x14 \n\t" // Load c65 into quad and increment by rs_c. + " ld1 {v1.s}[2],[x5],x14 \n\t" // Load c66 into quad and increment by rs_c. + " ld1 {v1.s}[3],[x5],x14 \n\t" // Load c67 into quad and increment by rs_c. + " \n\t" + " mov x5, x23 \n\t" + " \n\t" + " ld1 {v2.s}[0],[x5],x14 \n\t" // Load c70 into quad and increment by rs_c. + " ld1 {v2.s}[1],[x5],x14 \n\t" // Load c71 into quad and increment by rs_c. + " ld1 {v2.s}[2],[x5],x14 \n\t" // Load c72 into quad and increment by rs_c. + " ld1 {v2.s}[3],[x5],x14 \n\t" // Load c73 into quad and increment by rs_c. + " ld1 {v3.s}[0],[x5],x14 \n\t" // Load c74 into quad and increment by rs_c. + " ld1 {v3.s}[1],[x5],x14 \n\t" // Load c75 into quad and increment by rs_c. + " ld1 {v3.s}[2],[x5],x14 \n\t" // Load c76 into quad and increment by rs_c. + " ld1 {v3.s}[3],[x5],x14 \n\t" // Load c77 into quad and increment by rs_c. + " \n\t" + " mov x5, x24 \n\t" + " \n\t" + " ld1 {v4.s}[0],[x5],x14 \n\t" // Load c80 into quad and increment by rs_c. + " ld1 {v4.s}[1],[x5],x14 \n\t" // Load c81 into quad and increment by rs_c. + " ld1 {v4.s}[2],[x5],x14 \n\t" // Load c82 into quad and increment by rs_c. + " ld1 {v4.s}[3],[x5],x14 \n\t" // Load c83 into quad and increment by rs_c. + " ld1 {v5.s}[0],[x5],x14 \n\t" // Load c84 into quad and increment by rs_c. + " ld1 {v5.s}[1],[x5],x14 \n\t" // Load c85 into quad and increment by rs_c. + " ld1 {v5.s}[2],[x5],x14 \n\t" // Load c86 into quad and increment by rs_c. + " ld1 {v5.s}[3],[x5],x14 \n\t" // Load c87 into quad and increment by rs_c. + " \n\t" + " fmul v0.4s,v0.4s,v7.s[0] \n\t" // Scale by beta + " fmul v1.4s,v1.4s,v7.s[0] \n\t" // Scale by beta + " fmul v2.4s,v2.4s,v7.s[0] \n\t" // Scale by beta + " fmul v3.4s,v3.4s,v7.s[0] \n\t" // Scale by beta + " fmul v4.4s,v4.4s,v7.s[0] \n\t" // Scale by beta + " fmul v5.4s,v5.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROGENSTOREDS3) + " \n\t" + " fmla v0.4s,v20.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v1.4s,v21.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v2.4s,v22.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v3.4s,v23.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v4.4s,v24.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v5.4s,v25.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " mov x5, x22 \n\t" + " \n\t" + " st1 {v0.s}[0],[x5],x14 \n\t" // Store c60 into quad and increment by rs_c. + " st1 {v0.s}[1],[x5],x14 \n\t" // Store c61 into quad and increment by rs_c. + " st1 {v0.s}[2],[x5],x14 \n\t" // Store c62 into quad and increment by rs_c. + " st1 {v0.s}[3],[x5],x14 \n\t" // Store c63 into quad and increment by rs_c. + " st1 {v1.s}[0],[x5],x14 \n\t" // Store c64 into quad and increment by rs_c. + " st1 {v1.s}[1],[x5],x14 \n\t" // Store c65 into quad and increment by rs_c. + " st1 {v1.s}[2],[x5],x14 \n\t" // Store c66 into quad and increment by rs_c. + " st1 {v1.s}[3],[x5],x14 \n\t" // Store c67 into quad and increment by rs_c. + " \n\t" + " mov x5, x23 \n\t" + " \n\t" + " st1 {v2.s}[0],[x5],x14 \n\t" // Store c70 into quad and increment by rs_c. + " st1 {v2.s}[1],[x5],x14 \n\t" // Store c71 into quad and increment by rs_c. + " st1 {v2.s}[2],[x5],x14 \n\t" // Store c72 into quad and increment by rs_c. + " st1 {v2.s}[3],[x5],x14 \n\t" // Store c73 into quad and increment by rs_c. + " st1 {v3.s}[0],[x5],x14 \n\t" // Store c74 into quad and increment by rs_c. + " st1 {v3.s}[1],[x5],x14 \n\t" // Store c75 into quad and increment by rs_c. + " st1 {v3.s}[2],[x5],x14 \n\t" // Store c76 into quad and increment by rs_c. + " st1 {v3.s}[3],[x5],x14 \n\t" // Store c77 into quad and increment by rs_c. + " \n\t" + " mov x5, x24 \n\t" + " \n\t" + " st1 {v4.s}[0],[x5],x14 \n\t" // Store c80 into quad and increment by rs_c. + " st1 {v4.s}[1],[x5],x14 \n\t" // Store c81 into quad and increment by rs_c. + " st1 {v4.s}[2],[x5],x14 \n\t" // Store c82 into quad and increment by rs_c. + " st1 {v4.s}[3],[x5],x14 \n\t" // Store c83 into quad and increment by rs_c. + " st1 {v5.s}[0],[x5],x14 \n\t" // Store c84 into quad and increment by rs_c. + " st1 {v5.s}[1],[x5],x14 \n\t" // Store c85 into quad and increment by rs_c. + " st1 {v5.s}[2],[x5],x14 \n\t" // Store c86 into quad and increment by rs_c. + " st1 {v5.s}[3],[x5],x14 \n\t" // Store c87 into quad and increment by rs_c. + " \n\t" + " dup v8.4s, wzr \n\t" + " dup v9.4s, wzr \n\t" + " dup v10.4s, wzr \n\t" + " dup v11.4s, wzr \n\t" + " dup v12.4s, wzr \n\t" + " dup v13.4s, wzr \n\t" + " \n\t" + " fcmp s7,#0.0 \n\t" + BEQ(SBETAZEROGENSTOREDS4) // Taking care of the beta==0 case. + " \n\t" + " mov x5, x25 \n\t" + " \n\t" + " ld1 {v8.s}[0],[x5],x14 \n\t" // Load c90 into quad and increment by rs_c. + " ld1 {v8.s}[1],[x5],x14 \n\t" // Load c91 into quad and increment by rs_c. + " ld1 {v8.s}[2],[x5],x14 \n\t" // Load c92 into quad and increment by rs_c. + " ld1 {v8.s}[3],[x5],x14 \n\t" // Load c93 into quad and increment by rs_c. + " ld1 {v9.s}[0],[x5],x14 \n\t" // Load c94 into quad and increment by rs_c. + " ld1 {v9.s}[1],[x5],x14 \n\t" // Load c95 into quad and increment by rs_c. + " ld1 {v9.s}[2],[x5],x14 \n\t" // Load c96 into quad and increment by rs_c. + " ld1 {v9.s}[3],[x5],x14 \n\t" // Load c97 into quad and increment by rs_c. + " \n\t" + " mov x5, x26 \n\t" + " \n\t" + " ld1 {v10.s}[0],[x5],x14 \n\t" // Load c100 into quad and increment by rs_c. + " ld1 {v10.s}[1],[x5],x14 \n\t" // Load c101 into quad and increment by rs_c. + " ld1 {v10.s}[2],[x5],x14 \n\t" // Load c102 into quad and increment by rs_c. + " ld1 {v10.s}[3],[x5],x14 \n\t" // Load c103 into quad and increment by rs_c. + " ld1 {v11.s}[0],[x5],x14 \n\t" // Load c104 into quad and increment by rs_c. + " ld1 {v11.s}[1],[x5],x14 \n\t" // Load c105 into quad and increment by rs_c. + " ld1 {v11.s}[2],[x5],x14 \n\t" // Load c106 into quad and increment by rs_c. + " ld1 {v11.s}[3],[x5],x14 \n\t" // Load c107 into quad and increment by rs_c. + " \n\t" + " mov x5, x27 \n\t" + " \n\t" + " ld1 {v12.s}[0],[x5],x14 \n\t" // Load c110 into quad and increment by rs_c. + " ld1 {v12.s}[1],[x5],x14 \n\t" // Load c111 into quad and increment by rs_c. + " ld1 {v12.s}[2],[x5],x14 \n\t" // Load c112 into quad and increment by rs_c. + " ld1 {v12.s}[3],[x5],x14 \n\t" // Load c113 into quad and increment by rs_c. + " ld1 {v13.s}[0],[x5],x14 \n\t" // Load c114 into quad and increment by rs_c. + " ld1 {v13.s}[1],[x5],x14 \n\t" // Load c115 into quad and increment by rs_c. + " ld1 {v13.s}[2],[x5],x14 \n\t" // Load c116 into quad and increment by rs_c. + " ld1 {v13.s}[3],[x5],x14 \n\t" // Load c117 into quad and increment by rs_c. + " \n\t" + " fmul v8.4s, v8.4s, v7.s[0] \n\t" // Scale by beta + " fmul v9.4s, v9.4s, v7.s[0] \n\t" // Scale by beta + " fmul v10.4s,v10.4s,v7.s[0] \n\t" // Scale by beta + " fmul v11.4s,v11.4s,v7.s[0] \n\t" // Scale by beta + " fmul v12.4s,v12.4s,v7.s[0] \n\t" // Scale by beta + " fmul v13.4s,v13.4s,v7.s[0] \n\t" // Scale by beta + " \n\t" + LABEL(SBETAZEROGENSTOREDS4) + " \n\t" + " prfm pldl2keep,[x0] \n\t" + " prfm pldl2keep,[x1] \n\t" + " \n\t" + " fmla v8.4s, v26.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v9.4s, v27.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v10.4s,v28.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v11.4s,v29.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v12.4s,v30.4s,v6.s[0] \n\t" // Scale by alpha + " fmla v13.4s,v31.4s,v6.s[0] \n\t" // Scale by alpha + " \n\t" + " mov x5, x25 \n\t" + " \n\t" + " st1 {v8.s}[0],[x5],x14 \n\t" // Store c90 into quad and increment by rs_c. + " st1 {v8.s}[1],[x5],x14 \n\t" // Store c91 into quad and increment by rs_c. + " st1 {v8.s}[2],[x5],x14 \n\t" // Store c92 into quad and increment by rs_c. + " st1 {v8.s}[3],[x5],x14 \n\t" // Store c93 into quad and increment by rs_c. + " st1 {v9.s}[0],[x5],x14 \n\t" // Store c94 into quad and increment by rs_c. + " st1 {v9.s}[1],[x5],x14 \n\t" // Store c95 into quad and increment by rs_c. + " st1 {v9.s}[2],[x5],x14 \n\t" // Store c96 into quad and increment by rs_c. + " st1 {v9.s}[3],[x5],x14 \n\t" // Store c97 into quad and increment by rs_c. + " \n\t" + " mov x5, x26 \n\t" + " \n\t" + " st1 {v10.s}[0],[x5],x14 \n\t" // Store c100 into quad and increment by rs_c. + " st1 {v10.s}[1],[x5],x14 \n\t" // Store c101 into quad and increment by rs_c. + " st1 {v10.s}[2],[x5],x14 \n\t" // Store c102 into quad and increment by rs_c. + " st1 {v10.s}[3],[x5],x14 \n\t" // Store c103 into quad and increment by rs_c. + " st1 {v11.s}[0],[x5],x14 \n\t" // Store c104 into quad and increment by rs_c. + " st1 {v11.s}[1],[x5],x14 \n\t" // Store c105 into quad and increment by rs_c. + " st1 {v11.s}[2],[x5],x14 \n\t" // Store c106 into quad and increment by rs_c. + " st1 {v11.s}[3],[x5],x14 \n\t" // Store c107 into quad and increment by rs_c. + " \n\t" + " mov x5, x27 \n\t" + " \n\t" + " st1 {v12.s}[0],[x5],x14 \n\t" // Store c110 into quad and increment by rs_c. + " st1 {v12.s}[1],[x5],x14 \n\t" // Store c111 into quad and increment by rs_c. + " st1 {v12.s}[2],[x5],x14 \n\t" // Store c112 into quad and increment by rs_c. + " st1 {v12.s}[3],[x5],x14 \n\t" // Store c113 into quad and increment by rs_c. + " st1 {v13.s}[0],[x5],x14 \n\t" // Store c114 into quad and increment by rs_c. + " st1 {v13.s}[1],[x5],x14 \n\t" // Store c115 into quad and increment by rs_c. + " st1 {v13.s}[2],[x5],x14 \n\t" // Store c116 into quad and increment by rs_c. + " st1 {v13.s}[3],[x5],x14 \n\t" // Store c147 into quad and increment by rs_c. + " \n\t" + LABEL(SEND) // Done! + " \n\t" + :// output operands (none) + :// input operands + [aaddr] "m" (a), // 0 + [baddr] "m" (b), // 1 + [caddr] "m" (c), // 2 + [k_iter] "m" (k_iter), // 3 + [k_left] "m" (k_left), // 4 + [alpha] "m" (alpha), // 5 + [beta] "m" (beta), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [a_next] "m" (a_next), // 9 + [b_next] "m" (b_next) // 10 + :// Register clobber list + "x0", "x1", "x2", + "x5", "x6", "x10","x14", + "x16","x17","x19","x20", + "x21","x22","x23","x24", + "x25","x26","x27", + "v0", "v1", "v2", "v3", + "v4", "v5", "v6", "v7", + "v8", "v9", "v10","v11", + "v12","v13","v14","v15", + "v16","v17","v18","v19", + "v20","v21","v22","v23", + "v24","v25","v26","v27", + "v28","v29","v30","v31" + ); - GEMM_UKR_FLUSH_CT( s ); + GEMM_UKR_FLUSH_CT( s ); } @@ -1133,956 +1133,956 @@ void bli_dgemm_armv8a_asm_6x8 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT( d, 6, 8, false ); + GEMM_UKR_SETUP_CT( d, 6, 8, false ); - __asm__ volatile - ( - " \n\t" - " ldr x0,%[aaddr] \n\t" // Load address of A - " ldr x1,%[baddr] \n\t" // Load address of B - " ldr x2,%[caddr] \n\t" // Load address of C - " \n\t" - " ldr x5,%[k_iter] \n\t" // Init guard (k_iter) - " ldr x6,%[k_left] \n\t" // Init guard (k_iter) - " \n\t" - " ldr x10,%[cs_c] \n\t" // Load cs_c - " lsl x10,x10,#3 \n\t" // cs_c * sizeof(double) - " \n\t" - " ldr x14,%[rs_c] \n\t" // Load rs_c. - " lsl x14,x14,#3 \n\t" // rs_c * sizeof(double). - " \n\t" - " add x20,x2,x10 \n\t" //Load address Column 1 of C - " add x21,x20,x10 \n\t" //Load address Column 2 of C - " add x22,x21,x10 \n\t" //Load address Column 3 of C - " add x23,x22,x10 \n\t" //Load address Column 4 of C - " add x24,x23,x10 \n\t" //Load address Column 5 of C - " add x25,x24,x10 \n\t" //Load address Column 6 of C - " add x26,x25,x10 \n\t" //Load address Column 7 of C - " \n\t" - " prfm pldl1keep,[x2] \n\t" // Prefetch c. - " prfm pldl1keep,[x20] \n\t" // Prefetch c. - " prfm pldl1keep,[x21] \n\t" // Prefetch c. - " prfm pldl1keep,[x22] \n\t" // Prefetch c. - " prfm pldl1keep,[x23] \n\t" // Prefetch c. - " prfm pldl1keep,[x24] \n\t" // Prefetch c. - " prfm pldl1keep,[x25] \n\t" // Prefetch c. - " prfm pldl1keep,[x26] \n\t" // Prefetch c. - " \n\t" - " dup v8.2d, xzr \n\t" // Vector for accummulating column 0 - " prfm PLDL1KEEP, [x1, #256] \n\t" - " dup v9.2d, xzr \n\t" // Vector for accummulating column 0 - " prfm PLDL1KEEP, [x1, #320] \n\t" - " dup v10.2d, xzr \n\t" // Vector for accummulating column 0 - " prfm PLDL1KEEP, [x1, #384] \n\t" - " dup v11.2d, xzr \n\t" // Vector for accummulating column 1 - " prfm PLDL1KEEP, [x1, #448] \n\t" - " dup v12.2d, xzr \n\t" // Vector for accummulating column 1 - " dup v13.2d, xzr \n\t" // Vector for accummulating column 1 - " \n\t" - " dup v14.2d, xzr \n\t" // Vector for accummulating column 2 - " prfm PLDL1KEEP, [x0, #192] \n\t" - " dup v15.2d, xzr \n\t" // Vector for accummulating column 2 - " prfm PLDL1KEEP, [x0, #256] \n\t" - " dup v16.2d, xzr \n\t" // Vector for accummulating column 2 - " prfm PLDL1KEEP, [x0, #320] \n\t" - " dup v17.2d, xzr \n\t" // Vector for accummulating column 3 - " dup v18.2d, xzr \n\t" // Vector for accummulating column 3 - " dup v19.2d, xzr \n\t" // Vector for accummulating column 3 - " \n\t" - " dup v20.2d, xzr \n\t" // Vector for accummulating column 4 - " dup v21.2d, xzr \n\t" // Vector for accummulating column 4 - " dup v22.2d, xzr \n\t" // Vector for accummulating column 4 - " dup v23.2d, xzr \n\t" // Vector for accummulating column 5 - " dup v24.2d, xzr \n\t" // Vector for accummulating column 5 - " dup v25.2d, xzr \n\t" // Vector for accummulating column 5 - " \n\t" - " dup v26.2d, xzr \n\t" // Vector for accummulating column 6 - " dup v27.2d, xzr \n\t" // Vector for accummulating column 6 - " dup v28.2d, xzr \n\t" // Vector for accummulating column 6 - " dup v29.2d, xzr \n\t" // Vector for accummulating column 7 - " dup v30.2d, xzr \n\t" // Vector for accummulating column 7 - " dup v31.2d, xzr \n\t" // Vector for accummulating column 7 - " \n\t" - " \n\t" - " cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. - BEQ(DCONSIDERKLEFT) - " \n\t" - " ldr q0, [x0] \n\t" // Load a - " ldr q1, [x0, #16] \n\t" - " ldr q2, [x0, #32] \n\t" - " \n\t" - " ldr q3, [x1] \n\t" // Load b - " ldr q4, [x1, #16] \n\t" - " ldr q5, [x1, #32] \n\t" - " ldr q6, [x1, #48] \n\t" - " \n\t" - " add x0, x0, #48 \n\t" //update address of A - " add x1, x1, #64 \n\t" //update address of B - " \n\t" - " cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. - BEQ(DLASTITER) // (as loop is do-while-like). - " \n\t" - LABEL(DLOOP) // Body - " \n\t" - " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate - " prfm PLDL1KEEP, [x1, #448] \n\t" //512-64=448 - " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate - " prfm PLDL1KEEP, [x1, #512] \n\t" - " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate - " prfm PLDL1KEEP, [x1, #576] \n\t" - " \n\t" - " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate - " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate - " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate - " \n\t" - " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate - " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate - " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate - " ldr q3, [x1] \n\t" - " \n\t" - " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate - " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate - " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate - " ldr q7, [x0, #32] \n\t" - " \n\t" - " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate - " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate - " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate - " ldr q4, [x1, #16] \n\t" - " \n\t" - " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate - " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate - " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate - " ldr q5, [x1, #32] \n\t" - " \n\t" - " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate - " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate - " ldr q0, [x0] \n\t" - " \n\t" - " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate - " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate - " ldr q1, [x0, #16] \n\t" - " \n\t" - " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate - " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate - " ldr q6, [x1, #48] \n\t" - " \n\t" // End it 1 - " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate - " prfm PLDL1KEEP, [x1, #640] \n\t" - " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate - " prfm PLDL1KEEP, [x0, #336] \n\t" - " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate - " prfm PLDL1KEEP, [x0, #400] \n\t" - " \n\t" - " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate - " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate - " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate - " \n\t" - " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate - " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate - " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate - " ldr q3, [x1, #64] \n\t" - " \n\t" - " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate - " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate - " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate - " ldr q2, [x0, #80] \n\t" - " \n\t" - " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate - " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate - " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate - " ldr q4, [x1, #80] \n\t" - " \n\t" - " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate - " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate - " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate - " ldr q5, [x1, #96] \n\t" - " \n\t" - " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate - " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate - " ldr q0, [x0, #48] \n\t" - " \n\t" - " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate - " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate - " ldr q1, [x0, #64] \n\t" - " \n\t" - " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate - " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate - " ldr q6, [x1, #112] \n\t" - " \n\t" //End it 2 - " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate - " prfm PLDL1KEEP, [x0, #464] \n\t" - " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate - " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate - " \n\t" - " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate - " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate - " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate - " \n\t" - " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate - " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate - " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate - " ldr q3, [x1, #128] \n\t" - " \n\t" - " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate - " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate - " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate - " ldr q7, [x0, #128] \n\t" - " \n\t" - " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate - " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate - " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate - " ldr q4, [x1, #144] \n\t" - " \n\t" - " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate - " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate - " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate - " ldr q5, [x1, #160] \n\t" - " \n\t" - " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate - " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate - " ldr q0, [x0, #96] \n\t" - " \n\t" - " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate - " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate - " ldr q1, [x0, #112] \n\t" - " \n\t" - " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate - " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate - " ldr q6, [x1, #176] \n\t" - " \n\t" // End it 3 - " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate - " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate - " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate - " \n\t" - " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate - " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate - " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate - " ldr q3, [x1, #192] \n\t" - " \n\t" - " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate - " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate - " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate - " ldr q2, [x0, #176] \n\t" - " \n\t" - " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate - " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate - " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate - " ldr q4, [x1, #208] \n\t" - " \n\t" - " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate - " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate - " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate - " \n\t" - " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate - " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate - " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate - " ldr q5, [x1, #224] \n\t" - " \n\t" - " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate - " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate - " ldr q0, [x0, #144] \n\t" - " \n\t" - " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate - " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate - " ldr q1, [x0, #160] \n\t" - " \n\t" - " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate - " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate - " ldr q6, [x1, #240] \n\t" - " \n\t" //End it 4 - " add x0, x0, #192 \n\t" - " add x1, x1, #256 \n\t" - " \n\t" - " sub x5,x5,1 \n\t" // i-=1 - " cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. - BNE(DLOOP) - " \n\t" - LABEL(DLASTITER) - " \n\t" - " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate - " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate - " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate - " \n\t" - " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate - " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate - " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate - " ldr q3, [x1] \n\t" - " \n\t" - " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate - " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate - " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate - " ldr q7, [x0, #32] \n\t" - " \n\t" - " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate - " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate - " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate - " ldr q4, [x1, #16] \n\t" - " \n\t" - " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate - " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate - " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate - " \n\t" - " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate - " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate - " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate - " ldr q5, [x1, #32] \n\t" - " \n\t" - " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate - " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate - " ldr q0, [x0] \n\t" - " \n\t" - " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate - " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate - " ldr q1, [x0, #16] \n\t" - " \n\t" - " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate - " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate - " ldr q6, [x1, #48] \n\t" - " \n\t" // End it 1 - " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate - " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate - " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate - " \n\t" - " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate - " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate - " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate - " ldr q3, [x1, #64] \n\t" - " \n\t" - " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate - " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate - " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate - " ldr q2, [x0, #80] \n\t" - " \n\t" - " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate - " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate - " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate - " ldr q4, [x1, #80] \n\t" - " \n\t" - " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate - " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate - " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate - " \n\t" - " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate - " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate - " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate - " ldr q5, [x1, #96] \n\t" - " \n\t" - " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate - " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate - " ldr q0, [x0, #48] \n\t" - " \n\t" - " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate - " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate - " ldr q1, [x0, #64] \n\t" - " \n\t" - " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate - " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate - " ldr q6, [x1, #112] \n\t" - " \n\t" //End it 2 - " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate - " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate - " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate - " \n\t" - " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate - " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate - " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate - " ldr q3, [x1, #128] \n\t" - " \n\t" - " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate - " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate - " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate - " ldr q7, [x0, #128] \n\t" - " \n\t" - " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate - " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate - " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate - " ldr q4, [x1, #144] \n\t" - " \n\t" - " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate - " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate - " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate - " \n\t" - " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate - " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate - " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate - " ldr q5, [x1, #160] \n\t" - " \n\t" - " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate - " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate - " ldr q0, [x0, #96] \n\t" - " \n\t" - " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate - " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate - " ldr q1, [x0, #112] \n\t" - " \n\t" - " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate - " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate - " ldr q6, [x1, #176] \n\t" - " \n\t" // End it 3 - " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate - " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate - " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate - " \n\t" - " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate - " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate - " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate - " \n\t" - " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate - " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate - " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate - " \n\t" - " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate - " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate - " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate - " \n\t" - " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate - " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate - " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate - " \n\t" - " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate - " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate - " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate - " \n\t" - " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate - " add x1, x1, #192 \n\t" - " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate - " \n\t" - " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate - " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate - " \n\t" - " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate - " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate - " \n\t" //End it 4 - " add x0, x0, #144 \n\t" - " \n\t" - LABEL(DCONSIDERKLEFT) - " cmp x6,0 \n\t" // If k_left == 0, we are done. - BEQ(DPOSTACCUM) // else, we enter the k_left loop. - " \n\t" - LABEL(DLOOPKLEFT) - " \n\t" - " ldr q0, [x0],#16 \n\t" - " ldr q1, [x0],#16 \n\t" // Load a - " ldr q2, [x0],#16 \n\t" - " \n\t" - " ldr q3, [x1],#16 \n\t" // Load b - " ldr q4, [x1],#16 \n\t" - " ldr q5, [x1],#16 \n\t" - " ldr q6, [x1],#16 \n\t" - " \n\t" - " sub x6,x6,1 \n\t" - " \n\t" - " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate - " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate - " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate - " \n\t" - " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate - " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate - " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate - " \n\t" - " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate - " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate - " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate - " \n\t" - " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate - " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate - " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate - " \n\t" - " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate - " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate - " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate - " \n\t" - " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate - " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate - " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate - " \n\t" - " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate - " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate - " \n\t" - " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate - " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate - " \n\t" - " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate - " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate - " \n\t" - " cmp x6,0 \n\t" // Iterate again. - BNE(DLOOPKLEFT) // if i!=0. - " \n\t" - LABEL(DPOSTACCUM) - " \n\t" - " ldr x0,%[alpha] \n\t" // Alpha address - " ldr x1,%[beta] \n\t" // Beta address - " \n\t" - " ld1r {v6.2d},[x0] \n\t" // Load alpha. - " ld1r {v7.2d},[x1] \n\t" // Load beta - " \n\t" - " ldr x0,%[a_next] \n\t" // Next A address for later use. - " ldr x1,%[b_next] \n\t" // Next B address for later use. - " \n\t" -" cmp x14,#8 \n\t" // If rs_c != 1 (column-major) -BNE(DGENSTORED) -" \n\t" - LABEL(DCOLSTORED) // C is column-major. - " \n\t" - " dup v0.2d, xzr \n\t" - " dup v1.2d, xzr \n\t" - " dup v2.2d, xzr \n\t" - " dup v3.2d, xzr \n\t" - " dup v4.2d, xzr \n\t" - " dup v5.2d, xzr \n\t" - " \n\t" - " fcmp d7,#0.0 \n\t" - BEQ(DBETAZEROCOLSTOREDS1) // Taking care of the beta==0 case. - " \n\t" - " ldr q0, [x2] \n\t" //Load column 0 of C - " ldr q1, [x2, #16] \n\t" - " ldr q2, [x2, #32] \n\t" - " \n\t" - " ldr q3, [x20] \n\t" //Load column 1 of C - " ldr q4, [x20, #16] \n\t" - " ldr q5, [x20, #32] \n\t" - " \n\t" - " fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta - " fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta - " fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta - " fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta - " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta - " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta - " \n\t" - LABEL(DBETAZEROCOLSTOREDS1) - " \n\t" - " fmla v0.2d,v8.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v1.2d,v9.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v2.2d,v10.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v3.2d,v11.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v4.2d,v12.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v5.2d,v13.2d,v6.d[0] \n\t" // Scale by alpha - " \n\t" - " str q0, [x2] \n\t" //Store column 0 of C - " str q1, [x2, #16] \n\t" - " str q2, [x2, #32] \n\t" - " \n\t" - " str q3, [x20] \n\t" //Store column 1 of C - " str q4, [x20, #16] \n\t" - " str q5, [x20, #32] \n\t" - " \n\t" - " dup v8.2d, xzr \n\t" - " dup v9.2d, xzr \n\t" - " dup v10.2d, xzr \n\t" - " dup v11.2d, xzr \n\t" - " dup v12.2d, xzr \n\t" - " dup v13.2d, xzr \n\t" - " \n\t" - " fcmp d7,#0.0 \n\t" - BEQ(DBETAZEROCOLSTOREDS2) // Taking care of the beta==0 case. - " \n\t" - " ldr q8, [x21] \n\t" //Load column 2 of C - " ldr q9, [x21, #16] \n\t" - " ldr q10, [x21, #32] \n\t" - " \n\t" - " ldr q11, [x22] \n\t" //Load column 3 of C - " ldr q12, [x22, #16] \n\t" - " ldr q13, [x22, #32] \n\t" - " \n\t" - " fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta - " fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta - " fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta - " fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta - " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta - " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta - " \n\t" - LABEL(DBETAZEROCOLSTOREDS2) - " \n\t" - " fmla v8.2d, v14.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v9.2d, v15.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v10.2d,v16.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v11.2d,v17.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v12.2d,v18.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v13.2d,v19.2d,v6.d[0] \n\t" // Scale by alpha - " \n\t" - " str q8, [x21] \n\t" //Store column 2 of C - " str q9, [x21, #16] \n\t" - " str q10, [x21, #32] \n\t" - " \n\t" - " str q11, [x22] \n\t" //Store column 3 of C - " str q12, [x22, #16] \n\t" - " str q13, [x22, #32] \n\t" - " \n\t" - " dup v0.2d, xzr \n\t" - " dup v1.2d, xzr \n\t" - " dup v2.2d, xzr \n\t" - " dup v3.2d, xzr \n\t" - " dup v4.2d, xzr \n\t" - " dup v5.2d, xzr \n\t" - " \n\t" - " fcmp d7,#0.0 \n\t" - BEQ(DBETAZEROCOLSTOREDS3) // Taking care of the beta==0 case. - " \n\t" - " ldr q0, [x23] \n\t" //Load column 4 of C - " ldr q1, [x23, #16] \n\t" - " ldr q2, [x23, #32] \n\t" - " \n\t" - " ldr q3, [x24] \n\t" //Load column 5 of C - " ldr q4, [x24, #16] \n\t" - " ldr q5, [x24, #32] \n\t" - " \n\t" - " fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta - " fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta - " fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta - " fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta - " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta - " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta - " \n\t" - LABEL(DBETAZEROCOLSTOREDS3) - " \n\t" - " fmla v0.2d,v20.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v1.2d,v21.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v2.2d,v22.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v3.2d,v23.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v4.2d,v24.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v5.2d,v25.2d,v6.d[0] \n\t" // Scale by alpha - " \n\t" - " str q0, [x23] \n\t" //Store column 4 of C - " str q1, [x23, #16] \n\t" - " str q2, [x23, #32] \n\t" - " \n\t" - " str q3, [x24] \n\t" //Store column 5 of C - " str q4, [x24, #16] \n\t" - " str q5, [x24, #32] \n\t" - " \n\t" - " dup v8.2d, xzr \n\t" - " dup v9.2d, xzr \n\t" - " dup v10.2d, xzr \n\t" - " dup v11.2d, xzr \n\t" - " dup v12.2d, xzr \n\t" - " dup v13.2d, xzr \n\t" - " \n\t" - " fcmp d7,#0.0 \n\t" - BEQ(DBETAZEROCOLSTOREDS4) // Taking care of the beta==0 case. - " \n\t" - " ldr q8, [x25] \n\t" //Load column 6 of C - " ldr q9, [x25, #16] \n\t" - " ldr q10, [x25, #32] \n\t" - " \n\t" - " ldr q11, [x26] \n\t" //Load column 7 of C - " ldr q12, [x26, #16] \n\t" - " ldr q13, [x26, #32] \n\t" - " \n\t" - " fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta - " fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta - " fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta - " fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta - " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta - " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta - " \n\t" - LABEL(DBETAZEROCOLSTOREDS4) - " \n\t" - " prfm pldl2keep,[x0] \n\t" - " prfm pldl2keep,[x1] \n\t" - " \n\t" - " fmla v8.2d, v26.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v9.2d, v27.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v10.2d,v28.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v11.2d,v29.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v12.2d,v30.2d,v6.d[0] \n\t" // Scale by alpha - " fmla v13.2d,v31.2d,v6.d[0] \n\t" // Scale by alpha - " \n\t" - " str q8, [x25] \n\t" //Store column 6 of C - " str q9, [x25, #16] \n\t" - " str q10, [x25, #32] \n\t" - " \n\t" - " str q11, [x26] \n\t" //Store column 7 of C - " str q12, [x26, #16] \n\t" - " str q13, [x26, #32] \n\t" - " \n\t" -BRANCH(DEND) -" \n\t" -LABEL(DGENSTORED) // C is general-stride stored. -" \n\t" -" dup v0.2d, xzr \n\t" -" dup v1.2d, xzr \n\t" -" dup v2.2d, xzr \n\t" -" dup v3.2d, xzr \n\t" -" dup v4.2d, xzr \n\t" -" dup v5.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -BEQ(DBETAZEROGENSTOREDS1) // Taking care of the beta==0 case. -" \n\t" -" mov x27, x2 \n\t" -" \n\t" // Load address of C. -" ld1 {v0.d}[0],[x27],x14 \n\t" // Load c00 into quad and increment by rs_c. -" ld1 {v0.d}[1],[x27],x14 \n\t" // Load c01 into quad and increment by rs_c. -" ld1 {v1.d}[0],[x27],x14 \n\t" // Load c02 into quad and increment by rs_c. -" ld1 {v1.d}[1],[x27],x14 \n\t" // Load c03 into quad and increment by rs_c. -" ld1 {v2.d}[0],[x27],x14 \n\t" // Load c04 into quad and increment by rs_c. -" ld1 {v2.d}[1],[x27],x14 \n\t" // Load c05 into quad and increment by rs_c. -" \n\t" -" mov x27, x20 \n\t" // Load address of C. -" \n\t" -" ld1 {v3.d}[0],[x27],x14 \n\t" // Load c10 into quad and increment by rs_c. -" ld1 {v3.d}[1],[x27],x14 \n\t" // Load c11 into quad and increment by rs_c. -" ld1 {v4.d}[0],[x27],x14 \n\t" // Load c12 into quad and increment by rs_c. -" ld1 {v4.d}[1],[x27],x14 \n\t" // Load c13 into quad and increment by rs_c. -" ld1 {v5.d}[0],[x27],x14 \n\t" // Load c14 into quad and increment by rs_c. -" ld1 {v5.d}[1],[x27],x14 \n\t" // Load c15 into quad and increment by rs_c. -" \n\t" -" fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta -" fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta -" fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta -" fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta -" fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta -" fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -LABEL(DBETAZEROGENSTOREDS1) -" \n\t" -" fmla v0.2d,v8.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v1.2d,v9.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v2.2d,v10.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v3.2d,v11.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v4.2d,v12.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v5.2d,v13.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" mov x27, x2 \n\t" // Load address of C. -" \n\t" -" st1 {v0.d}[0],[x27],x14 \n\t" // Store c00 into quad and increment by rs_c. -" st1 {v0.d}[1],[x27],x14 \n\t" // Store c01 into quad and increment by rs_c. -" st1 {v1.d}[0],[x27],x14 \n\t" // Store c02 into quad and increment by rs_c. -" st1 {v1.d}[1],[x27],x14 \n\t" // Store c03 into quad and increment by rs_c. -" st1 {v2.d}[0],[x27],x14 \n\t" // Store c04 into quad and increment by rs_c. -" st1 {v2.d}[1],[x27],x14 \n\t" // Store c05 into quad and increment by rs_c. -" \n\t" -" mov x27, x20 \n\t" // Load address of C. -" \n\t" -" st1 {v3.d}[0],[x27],x14 \n\t" // Store c10 into quad and increment by rs_c. -" st1 {v3.d}[1],[x27],x14 \n\t" // Store c11 into quad and increment by rs_c. -" st1 {v4.d}[0],[x27],x14 \n\t" // Store c12 into quad and increment by rs_c. -" st1 {v4.d}[1],[x27],x14 \n\t" // Store c13 into quad and increment by rs_c. -" st1 {v5.d}[0],[x27],x14 \n\t" // Store c14 into quad and increment by rs_c. -" st1 {v5.d}[1],[x27],x14 \n\t" // Store c15 into quad and increment by rs_c. -" \n\t" -" dup v8.2d, xzr \n\t" -" dup v9.2d, xzr \n\t" -" dup v10.2d, xzr \n\t" -" dup v11.2d, xzr \n\t" -" dup v12.2d, xzr \n\t" -" dup v13.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -BEQ(DBETAZEROGENSTOREDS2) // Taking care of the beta==0 case. -" \n\t" -" mov x27, x21 \n\t" // Load address of C. -" \n\t" -" ld1 {v8.d}[0], [x27],x14 \n\t" // Load c20 into quad and increment by rs_c. -" ld1 {v8.d}[1], [x27],x14 \n\t" // Load c21 into quad and increment by rs_c. -" ld1 {v9.d}[0], [x27],x14 \n\t" // Load c22 into quad and increment by rs_c. -" ld1 {v9.d}[1], [x27],x14 \n\t" // Load c23 into quad and increment by rs_c. -" ld1 {v10.d}[0],[x27],x14 \n\t" // Load c24 into quad and increment by rs_c. -" ld1 {v10.d}[1],[x27],x14 \n\t" // Load c25 into quad and increment by rs_c. -" \n\t" -" mov x27, x22 \n\t" // Load address of C. -" \n\t" -" ld1 {v11.d}[0],[x27],x14 \n\t" // Load c30 into quad and increment by rs_c. -" ld1 {v11.d}[1],[x27],x14 \n\t" // Load c31 into quad and increment by rs_c. -" ld1 {v12.d}[0],[x27],x14 \n\t" // Load c32 into quad and increment by rs_c. -" ld1 {v12.d}[1],[x27],x14 \n\t" // Load c33 into quad and increment by rs_c. -" ld1 {v13.d}[0],[x27],x14 \n\t" // Load c34 into quad and increment by rs_c. -" ld1 {v13.d}[1],[x27],x14 \n\t" // Load c35 into quad and increment by rs_c. -" \n\t" -" fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta -" fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta -" fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta -" fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta -" fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta -" fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -LABEL(DBETAZEROGENSTOREDS2) -" \n\t" -" fmla v8.2d, v14.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v9.2d, v15.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v10.2d,v16.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v11.2d,v17.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v12.2d,v18.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v13.2d,v19.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" mov x27, x21 \n\t" // Load address of C. -" \n\t" -" st1 {v8.d}[0], [x27],x14 \n\t" // Store c20 into quad and increment by rs_c. -" st1 {v8.d}[1], [x27],x14 \n\t" // Store c21 into quad and increment by rs_c. -" st1 {v9.d}[0], [x27],x14 \n\t" // Store c22 into quad and increment by rs_c. -" st1 {v9.d}[1], [x27],x14 \n\t" // Store c23 into quad and increment by rs_c. -" st1 {v10.d}[0],[x27],x14 \n\t" // Store c24 into quad and increment by rs_c. -" st1 {v10.d}[1],[x27],x14 \n\t" // Store c25 into quad and increment by rs_c. -" \n\t" -" mov x27, x22 \n\t" // Load address of C. -" \n\t" -" st1 {v11.d}[0],[x27],x14 \n\t" // Store c30 into quad and increment by rs_c. -" st1 {v11.d}[1],[x27],x14 \n\t" // Store c31 into quad and increment by rs_c. -" st1 {v12.d}[0],[x27],x14 \n\t" // Store c32 into quad and increment by rs_c. -" st1 {v12.d}[1],[x27],x14 \n\t" // Store c33 into quad and increment by rs_c. -" st1 {v13.d}[0],[x27],x14 \n\t" // Store c34 into quad and increment by rs_c. -" st1 {v13.d}[1],[x27],x14 \n\t" // Store c35 into quad and increment by rs_c. -" \n\t" -" dup v0.2d, xzr \n\t" -" dup v1.2d, xzr \n\t" -" dup v2.2d, xzr \n\t" -" dup v3.2d, xzr \n\t" -" dup v4.2d, xzr \n\t" -" dup v5.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -BEQ(DBETAZEROGENSTOREDS3) // Taking care of the beta==0 case. -" \n\t" -" mov x27, x23 \n\t" // Load address of C. -" \n\t" -" ld1 {v0.d}[0],[x27],x14 \n\t" // Load c40 into quad and increment by rs_c. -" ld1 {v0.d}[1],[x27],x14 \n\t" // Load c41 into quad and increment by rs_c. -" ld1 {v1.d}[0],[x27],x14 \n\t" // Load c42 into quad and increment by rs_c. -" ld1 {v1.d}[1],[x27],x14 \n\t" // Load c43 into quad and increment by rs_c. -" ld1 {v2.d}[0],[x27],x14 \n\t" // Load c44 into quad and increment by rs_c. -" ld1 {v2.d}[1],[x27],x14 \n\t" // Load c45 into quad and increment by rs_c. -" \n\t" -" mov x27, x24 \n\t" // Load address of C. -" \n\t" -" ld1 {v3.d}[0],[x27],x14 \n\t" // Load c50 into quad and increment by rs_c. -" ld1 {v3.d}[1],[x27],x14 \n\t" // Load c51 into quad and increment by rs_c. -" ld1 {v4.d}[0],[x27],x14 \n\t" // Load c52 into quad and increment by rs_c. -" ld1 {v4.d}[1],[x27],x14 \n\t" // Load c53 into quad and increment by rs_c. -" ld1 {v5.d}[0],[x27],x14 \n\t" // Load c54 into quad and increment by rs_c. -" ld1 {v5.d}[1],[x27],x14 \n\t" // Load c55 into quad and increment by rs_c. -" \n\t" -" fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta -" fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta -" fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta -" fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta -" fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta -" fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -LABEL(DBETAZEROGENSTOREDS3) -" \n\t" -" fmla v0.2d,v20.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v1.2d,v21.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v2.2d,v22.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v3.2d,v23.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v4.2d,v24.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v5.2d,v25.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" mov x27, x23 \n\t" // Load address of C. -" \n\t" -" st1 {v0.d}[0],[x27],x14 \n\t" // Store c40 into quad and increment by rs_c. -" st1 {v0.d}[1],[x27],x14 \n\t" // Store c41 into quad and increment by rs_c. -" st1 {v1.d}[0],[x27],x14 \n\t" // Store c42 into quad and increment by rs_c. -" st1 {v1.d}[1],[x27],x14 \n\t" // Store c43 into quad and increment by rs_c. -" st1 {v2.d}[0],[x27],x14 \n\t" // Store c44 into quad and increment by rs_c. -" st1 {v2.d}[1],[x27],x14 \n\t" // Store c45 into quad and increment by rs_c. -" \n\t" -" mov x27, x24 \n\t" // Load address of C. -" \n\t" -" st1 {v3.d}[0],[x27],x14 \n\t" // Store c50 into quad and increment by rs_c. -" st1 {v3.d}[1],[x27],x14 \n\t" // Store c51 into quad and increment by rs_c. -" st1 {v4.d}[0],[x27],x14 \n\t" // Store c52 into quad and increment by rs_c. -" st1 {v4.d}[1],[x27],x14 \n\t" // Store c53 into quad and increment by rs_c. -" st1 {v5.d}[0],[x27],x14 \n\t" // Store c54 into quad and increment by rs_c. -" st1 {v5.d}[1],[x27],x14 \n\t" // Store c55 into quad and increment by rs_c. -" \n\t" -" dup v8.2d, xzr \n\t" -" dup v9.2d, xzr \n\t" -" dup v10.2d, xzr \n\t" -" dup v11.2d, xzr \n\t" -" dup v12.2d, xzr \n\t" -" dup v13.2d, xzr \n\t" -" \n\t" -" fcmp d7,#0.0 \n\t" -BEQ(DBETAZEROGENSTOREDS4) // Taking care of the beta==0 case. -" \n\t" -" mov x27, x25 \n\t" -" \n\t" -" ld1 {v8.d}[0], [x27],x14 \n\t" // Load c60 into quad and increment by rs_c. -" ld1 {v8.d}[1], [x27],x14 \n\t" // Load c61 into quad and increment by rs_c. -" ld1 {v9.d}[0], [x27],x14 \n\t" // Load c62 into quad and increment by rs_c. -" ld1 {v9.d}[1], [x27],x14 \n\t" // Load c63 into quad and increment by rs_c. -" ld1 {v10.d}[0],[x27],x14 \n\t" // Load c64 into quad and increment by rs_c. -" ld1 {v10.d}[1],[x27],x14 \n\t" // Load c65 into quad and increment by rs_c. -" \n\t" -" mov x27, x26 \n\t" // Load address of C. -" \n\t" -" ld1 {v11.d}[0],[x27],x14 \n\t" // Load c70 into quad and increment by rs_c. -" ld1 {v11.d}[1],[x27],x14 \n\t" // Load c71 into quad and increment by rs_c. -" ld1 {v12.d}[0],[x27],x14 \n\t" // Load c72 into quad and increment by rs_c. -" ld1 {v12.d}[1],[x27],x14 \n\t" // Load c73 into quad and increment by rs_c. -" ld1 {v13.d}[0],[x27],x14 \n\t" // Load c74 into quad and increment by rs_c. -" ld1 {v13.d}[1],[x27],x14 \n\t" // Load c75 into quad and increment by rs_c. -" \n\t" -" fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta -" fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta -" fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta -" fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta -" fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta -" fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta -" \n\t" -LABEL(DBETAZEROGENSTOREDS4) -" \n\t" -" prfm pldl2keep,[x0] \n\t" -" prfm pldl2keep,[x1] \n\t" -" \n\t" -" fmla v8.2d, v26.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v9.2d, v27.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v10.2d,v28.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v11.2d,v29.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v12.2d,v30.2d,v6.d[0] \n\t" // Scale by alpha -" fmla v13.2d,v31.2d,v6.d[0] \n\t" // Scale by alpha -" \n\t" -" mov x27, x25 \n\t" // Load address of C. -" \n\t" -" st1 {v8.d}[0], [x27],x14 \n\t" // Store c60 into quad and increment by rs_c. -" st1 {v8.d}[1], [x27],x14 \n\t" // Store c61 into quad and increment by rs_c. -" st1 {v9.d}[0], [x27],x14 \n\t" // Store c62 into quad and increment by rs_c. -" st1 {v9.d}[1], [x27],x14 \n\t" // Store c63 into quad and increment by rs_c. -" st1 {v10.d}[0],[x27],x14 \n\t" // Store c64 into quad and increment by rs_c. -" st1 {v10.d}[1],[x27],x14 \n\t" // Store c65 into quad and increment by rs_c. -" \n\t" -" mov x27, x26 \n\t" // Load address of C. -" \n\t" -" st1 {v11.d}[0],[x27],x14 \n\t" // Store c70 into quad and increment by rs_c. -" st1 {v11.d}[1],[x27],x14 \n\t" // Store c71 into quad and increment by rs_c. -" st1 {v12.d}[0],[x27],x14 \n\t" // Store c72 into quad and increment by rs_c. -" st1 {v12.d}[1],[x27],x14 \n\t" // Store c73 into quad and increment by rs_c. -" st1 {v13.d}[0],[x27],x14 \n\t" // Store c74 into quad and increment by rs_c. -" st1 {v13.d}[1],[x27],x14 \n\t" // Store c75 into quad and increment by rs_c. -" \n\t" - LABEL(DEND) // Done! - " \n\t" - :// output operands (none) - :// input operands - [aaddr] "m" (a), // 0 - [baddr] "m" (b), // 1 - [caddr] "m" (c), // 2 - [k_iter] "m" (k_iter), // 3 - [k_left] "m" (k_left), // 4 - [alpha] "m" (alpha), // 5 - [beta] "m" (beta), // 6 - [rs_c] "m" (rs_c), // 6 - [cs_c] "m" (cs_c), // 7 - [a_next] "m" (a_next), // 8 - [b_next] "m" (b_next) // 9 - :// Register clobber list - "x0","x1","x2", - "x5","x6","x10", - "x14","x16","x17", - "x20","x21","x22","x23","x24","x25","x26","x27", - "v0","v1","v2", - "v3","v4","v5", - "v6","v7","v8", - "v9","v10","v11", - "v12","v13","v14", - "v15","v16","v17","v18","v19", - "v20","v21","v22","v23", - "v24","v25","v26","v27", - "v28","v29","v30","v31" - ); + __asm__ volatile + ( + " \n\t" + " ldr x0,%[aaddr] \n\t" // Load address of A + " ldr x1,%[baddr] \n\t" // Load address of B + " ldr x2,%[caddr] \n\t" // Load address of C + " \n\t" + " ldr x5,%[k_iter] \n\t" // Init guard (k_iter) + " ldr x6,%[k_left] \n\t" // Init guard (k_iter) + " \n\t" + " ldr x10,%[cs_c] \n\t" // Load cs_c + " lsl x10,x10,#3 \n\t" // cs_c * sizeof(double) + " \n\t" + " ldr x14,%[rs_c] \n\t" // Load rs_c. + " lsl x14,x14,#3 \n\t" // rs_c * sizeof(double). + " \n\t" + " add x20,x2,x10 \n\t" //Load address Column 1 of C + " add x21,x20,x10 \n\t" //Load address Column 2 of C + " add x22,x21,x10 \n\t" //Load address Column 3 of C + " add x23,x22,x10 \n\t" //Load address Column 4 of C + " add x24,x23,x10 \n\t" //Load address Column 5 of C + " add x25,x24,x10 \n\t" //Load address Column 6 of C + " add x26,x25,x10 \n\t" //Load address Column 7 of C + " \n\t" + " prfm pldl1keep,[x2] \n\t" // Prefetch c. + " prfm pldl1keep,[x20] \n\t" // Prefetch c. + " prfm pldl1keep,[x21] \n\t" // Prefetch c. + " prfm pldl1keep,[x22] \n\t" // Prefetch c. + " prfm pldl1keep,[x23] \n\t" // Prefetch c. + " prfm pldl1keep,[x24] \n\t" // Prefetch c. + " prfm pldl1keep,[x25] \n\t" // Prefetch c. + " prfm pldl1keep,[x26] \n\t" // Prefetch c. + " \n\t" + " dup v8.2d, xzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #256] \n\t" + " dup v9.2d, xzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #320] \n\t" + " dup v10.2d, xzr \n\t" // Vector for accummulating column 0 + " prfm PLDL1KEEP, [x1, #384] \n\t" + " dup v11.2d, xzr \n\t" // Vector for accummulating column 1 + " prfm PLDL1KEEP, [x1, #448] \n\t" + " dup v12.2d, xzr \n\t" // Vector for accummulating column 1 + " dup v13.2d, xzr \n\t" // Vector for accummulating column 1 + " \n\t" + " dup v14.2d, xzr \n\t" // Vector for accummulating column 2 + " prfm PLDL1KEEP, [x0, #192] \n\t" + " dup v15.2d, xzr \n\t" // Vector for accummulating column 2 + " prfm PLDL1KEEP, [x0, #256] \n\t" + " dup v16.2d, xzr \n\t" // Vector for accummulating column 2 + " prfm PLDL1KEEP, [x0, #320] \n\t" + " dup v17.2d, xzr \n\t" // Vector for accummulating column 3 + " dup v18.2d, xzr \n\t" // Vector for accummulating column 3 + " dup v19.2d, xzr \n\t" // Vector for accummulating column 3 + " \n\t" + " dup v20.2d, xzr \n\t" // Vector for accummulating column 4 + " dup v21.2d, xzr \n\t" // Vector for accummulating column 4 + " dup v22.2d, xzr \n\t" // Vector for accummulating column 4 + " dup v23.2d, xzr \n\t" // Vector for accummulating column 5 + " dup v24.2d, xzr \n\t" // Vector for accummulating column 5 + " dup v25.2d, xzr \n\t" // Vector for accummulating column 5 + " \n\t" + " dup v26.2d, xzr \n\t" // Vector for accummulating column 6 + " dup v27.2d, xzr \n\t" // Vector for accummulating column 6 + " dup v28.2d, xzr \n\t" // Vector for accummulating column 6 + " dup v29.2d, xzr \n\t" // Vector for accummulating column 7 + " dup v30.2d, xzr \n\t" // Vector for accummulating column 7 + " dup v31.2d, xzr \n\t" // Vector for accummulating column 7 + " \n\t" + " \n\t" + " cmp x5,#0 \n\t" // If k_iter == 0, jump to k_left. + BEQ(DCONSIDERKLEFT) + " \n\t" + " ldr q0, [x0] \n\t" // Load a + " ldr q1, [x0, #16] \n\t" + " ldr q2, [x0, #32] \n\t" + " \n\t" + " ldr q3, [x1] \n\t" // Load b + " ldr q4, [x1, #16] \n\t" + " ldr q5, [x1, #32] \n\t" + " ldr q6, [x1, #48] \n\t" + " \n\t" + " add x0, x0, #48 \n\t" //update address of A + " add x1, x1, #64 \n\t" //update address of B + " \n\t" + " cmp x5,1 \n\t" // If there is just one k_iter, jump to that one. + BEQ(DLASTITER) // (as loop is do-while-like). + " \n\t" + LABEL(DLOOP) // Body + " \n\t" + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x1, #448] \n\t" //512-64=448 + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x1, #512] \n\t" + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x1, #576] \n\t" + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " ldr q3, [x1] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " ldr q7, [x0, #32] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " ldr q4, [x1, #16] \n\t" + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #32] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #16] \n\t" + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #48] \n\t" + " \n\t" // End it 1 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x1, #640] \n\t" + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x0, #336] \n\t" + " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x0, #400] \n\t" + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate + " ldr q3, [x1, #64] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate + " ldr q2, [x0, #80] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate + " ldr q4, [x1, #80] \n\t" + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #96] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #48] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #64] \n\t" + " \n\t" + " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #112] \n\t" + " \n\t" //End it 2 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " prfm PLDL1KEEP, [x0, #464] \n\t" + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " ldr q3, [x1, #128] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " ldr q7, [x0, #128] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " ldr q4, [x1, #144] \n\t" + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #160] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #96] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #112] \n\t" + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #176] \n\t" + " \n\t" // End it 3 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate + " ldr q3, [x1, #192] \n\t" + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate + " ldr q2, [x0, #176] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate + " ldr q4, [x1, #208] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #224] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #144] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #160] \n\t" + " \n\t" + " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #240] \n\t" + " \n\t" //End it 4 + " add x0, x0, #192 \n\t" + " add x1, x1, #256 \n\t" + " \n\t" + " sub x5,x5,1 \n\t" // i-=1 + " cmp x5,1 \n\t" // Iterate again if we are not in k_iter == 1. + BNE(DLOOP) + " \n\t" + LABEL(DLASTITER) + " \n\t" + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " ldr q3, [x1] \n\t" + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " ldr q7, [x0, #32] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " ldr q4, [x1, #16] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #32] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #16] \n\t" + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #48] \n\t" + " \n\t" // End it 1 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate + " ldr q3, [x1, #64] \n\t" + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate + " ldr q2, [x0, #80] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate + " ldr q4, [x1, #80] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #96] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #48] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #64] \n\t" + " \n\t" + " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #112] \n\t" + " \n\t" //End it 2 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " ldr q3, [x1, #128] \n\t" + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " ldr q7, [x0, #128] \n\t" + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " ldr q4, [x1, #144] \n\t" + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " ldr q5, [x1, #160] \n\t" + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " ldr q0, [x0, #96] \n\t" + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " ldr q1, [x0, #112] \n\t" + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " ldr q6, [x1, #176] \n\t" + " \n\t" // End it 3 + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v7.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v7.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v7.2d,v4.d[0] \n\t" // Accummulate + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v7.2d,v4.d[1] \n\t" // Accummulate + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v7.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v7.2d,v5.d[1] \n\t" // Accummulate + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " add x1, x1, #192 \n\t" + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " fmla v28.2d,v7.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v7.2d,v6.d[1] \n\t" // Accummulate + " \n\t" //End it 4 + " add x0, x0, #144 \n\t" + " \n\t" + LABEL(DCONSIDERKLEFT) + " cmp x6,0 \n\t" // If k_left == 0, we are done. + BEQ(DPOSTACCUM) // else, we enter the k_left loop. + " \n\t" + LABEL(DLOOPKLEFT) + " \n\t" + " ldr q0, [x0],#16 \n\t" + " ldr q1, [x0],#16 \n\t" // Load a + " ldr q2, [x0],#16 \n\t" + " \n\t" + " ldr q3, [x1],#16 \n\t" // Load b + " ldr q4, [x1],#16 \n\t" + " ldr q5, [x1],#16 \n\t" + " ldr q6, [x1],#16 \n\t" + " \n\t" + " sub x6,x6,1 \n\t" + " \n\t" + " fmla v8.2d ,v0.2d,v3.d[0] \n\t" // Accummulate + " fmla v9.2d ,v1.2d,v3.d[0] \n\t" // Accummulate + " fmla v10.2d,v2.2d,v3.d[0] \n\t" // Accummulate + " \n\t" + " fmla v11.2d,v0.2d,v3.d[1] \n\t" // Accummulate + " fmla v12.2d,v1.2d,v3.d[1] \n\t" // Accummulate + " fmla v13.2d,v2.2d,v3.d[1] \n\t" // Accummulate + " \n\t" + " fmla v14.2d,v0.2d,v4.d[0] \n\t" // Accummulate + " fmla v15.2d,v1.2d,v4.d[0] \n\t" // Accummulate + " fmla v16.2d,v2.2d,v4.d[0] \n\t" // Accummulate + " \n\t" + " fmla v17.2d,v0.2d,v4.d[1] \n\t" // Accummulate + " fmla v18.2d,v1.2d,v4.d[1] \n\t" // Accummulate + " fmla v19.2d,v2.2d,v4.d[1] \n\t" // Accummulate + " \n\t" + " fmla v20.2d,v0.2d,v5.d[0] \n\t" // Accummulate + " fmla v21.2d,v1.2d,v5.d[0] \n\t" // Accummulate + " fmla v22.2d,v2.2d,v5.d[0] \n\t" // Accummulate + " \n\t" + " fmla v23.2d,v0.2d,v5.d[1] \n\t" // Accummulate + " fmla v24.2d,v1.2d,v5.d[1] \n\t" // Accummulate + " fmla v25.2d,v2.2d,v5.d[1] \n\t" // Accummulate + " \n\t" + " fmla v26.2d,v0.2d,v6.d[0] \n\t" // Accummulate + " fmla v29.2d,v0.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " fmla v27.2d,v1.2d,v6.d[0] \n\t" // Accummulate + " fmla v30.2d,v1.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " fmla v28.2d,v2.2d,v6.d[0] \n\t" // Accummulate + " fmla v31.2d,v2.2d,v6.d[1] \n\t" // Accummulate + " \n\t" + " cmp x6,0 \n\t" // Iterate again. + BNE(DLOOPKLEFT) // if i!=0. + " \n\t" + LABEL(DPOSTACCUM) + " \n\t" + " ldr x0,%[alpha] \n\t" // Alpha address + " ldr x1,%[beta] \n\t" // Beta address + " \n\t" + " ld1r {v6.2d},[x0] \n\t" // Load alpha. + " ld1r {v7.2d},[x1] \n\t" // Load beta + " \n\t" + " ldr x0,%[a_next] \n\t" // Next A address for later use. + " ldr x1,%[b_next] \n\t" // Next B address for later use. + " \n\t" + " cmp x14,#8 \n\t" // If rs_c != 1 (column-major) + BNE(DGENSTORED) + " \n\t" + LABEL(DCOLSTORED) // C is column-major. + " \n\t" + " dup v0.2d, xzr \n\t" + " dup v1.2d, xzr \n\t" + " dup v2.2d, xzr \n\t" + " dup v3.2d, xzr \n\t" + " dup v4.2d, xzr \n\t" + " dup v5.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROCOLSTOREDS1) // Taking care of the beta==0 case. + " \n\t" + " ldr q0, [x2] \n\t" //Load column 0 of C + " ldr q1, [x2, #16] \n\t" + " ldr q2, [x2, #32] \n\t" + " \n\t" + " ldr q3, [x20] \n\t" //Load column 1 of C + " ldr q4, [x20, #16] \n\t" + " ldr q5, [x20, #32] \n\t" + " \n\t" + " fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta + " fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta + " fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta + " fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta + " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta + " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROCOLSTOREDS1) + " \n\t" + " fmla v0.2d,v8.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v1.2d,v9.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v2.2d,v10.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v3.2d,v11.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v4.2d,v12.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v5.2d,v13.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " str q0, [x2] \n\t" //Store column 0 of C + " str q1, [x2, #16] \n\t" + " str q2, [x2, #32] \n\t" + " \n\t" + " str q3, [x20] \n\t" //Store column 1 of C + " str q4, [x20, #16] \n\t" + " str q5, [x20, #32] \n\t" + " \n\t" + " dup v8.2d, xzr \n\t" + " dup v9.2d, xzr \n\t" + " dup v10.2d, xzr \n\t" + " dup v11.2d, xzr \n\t" + " dup v12.2d, xzr \n\t" + " dup v13.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROCOLSTOREDS2) // Taking care of the beta==0 case. + " \n\t" + " ldr q8, [x21] \n\t" //Load column 2 of C + " ldr q9, [x21, #16] \n\t" + " ldr q10, [x21, #32] \n\t" + " \n\t" + " ldr q11, [x22] \n\t" //Load column 3 of C + " ldr q12, [x22, #16] \n\t" + " ldr q13, [x22, #32] \n\t" + " \n\t" + " fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta + " fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta + " fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta + " fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta + " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta + " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROCOLSTOREDS2) + " \n\t" + " fmla v8.2d, v14.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v9.2d, v15.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v10.2d,v16.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v11.2d,v17.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v12.2d,v18.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v13.2d,v19.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " str q8, [x21] \n\t" //Store column 2 of C + " str q9, [x21, #16] \n\t" + " str q10, [x21, #32] \n\t" + " \n\t" + " str q11, [x22] \n\t" //Store column 3 of C + " str q12, [x22, #16] \n\t" + " str q13, [x22, #32] \n\t" + " \n\t" + " dup v0.2d, xzr \n\t" + " dup v1.2d, xzr \n\t" + " dup v2.2d, xzr \n\t" + " dup v3.2d, xzr \n\t" + " dup v4.2d, xzr \n\t" + " dup v5.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROCOLSTOREDS3) // Taking care of the beta==0 case. + " \n\t" + " ldr q0, [x23] \n\t" //Load column 4 of C + " ldr q1, [x23, #16] \n\t" + " ldr q2, [x23, #32] \n\t" + " \n\t" + " ldr q3, [x24] \n\t" //Load column 5 of C + " ldr q4, [x24, #16] \n\t" + " ldr q5, [x24, #32] \n\t" + " \n\t" + " fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta + " fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta + " fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta + " fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta + " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta + " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROCOLSTOREDS3) + " \n\t" + " fmla v0.2d,v20.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v1.2d,v21.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v2.2d,v22.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v3.2d,v23.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v4.2d,v24.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v5.2d,v25.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " str q0, [x23] \n\t" //Store column 4 of C + " str q1, [x23, #16] \n\t" + " str q2, [x23, #32] \n\t" + " \n\t" + " str q3, [x24] \n\t" //Store column 5 of C + " str q4, [x24, #16] \n\t" + " str q5, [x24, #32] \n\t" + " \n\t" + " dup v8.2d, xzr \n\t" + " dup v9.2d, xzr \n\t" + " dup v10.2d, xzr \n\t" + " dup v11.2d, xzr \n\t" + " dup v12.2d, xzr \n\t" + " dup v13.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROCOLSTOREDS4) // Taking care of the beta==0 case. + " \n\t" + " ldr q8, [x25] \n\t" //Load column 6 of C + " ldr q9, [x25, #16] \n\t" + " ldr q10, [x25, #32] \n\t" + " \n\t" + " ldr q11, [x26] \n\t" //Load column 7 of C + " ldr q12, [x26, #16] \n\t" + " ldr q13, [x26, #32] \n\t" + " \n\t" + " fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta + " fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta + " fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta + " fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta + " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta + " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROCOLSTOREDS4) + " \n\t" + " prfm pldl2keep,[x0] \n\t" + " prfm pldl2keep,[x1] \n\t" + " \n\t" + " fmla v8.2d, v26.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v9.2d, v27.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v10.2d,v28.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v11.2d,v29.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v12.2d,v30.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v13.2d,v31.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " str q8, [x25] \n\t" //Store column 6 of C + " str q9, [x25, #16] \n\t" + " str q10, [x25, #32] \n\t" + " \n\t" + " str q11, [x26] \n\t" //Store column 7 of C + " str q12, [x26, #16] \n\t" + " str q13, [x26, #32] \n\t" + " \n\t" + BRANCH(DEND) + " \n\t" + LABEL(DGENSTORED) // C is general-stride stored. + " \n\t" + " dup v0.2d, xzr \n\t" + " dup v1.2d, xzr \n\t" + " dup v2.2d, xzr \n\t" + " dup v3.2d, xzr \n\t" + " dup v4.2d, xzr \n\t" + " dup v5.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROGENSTOREDS1) // Taking care of the beta==0 case. + " \n\t" + " mov x27, x2 \n\t" + " \n\t" // Load address of C. + " ld1 {v0.d}[0],[x27],x14 \n\t" // Load c00 into quad and increment by rs_c. + " ld1 {v0.d}[1],[x27],x14 \n\t" // Load c01 into quad and increment by rs_c. + " ld1 {v1.d}[0],[x27],x14 \n\t" // Load c02 into quad and increment by rs_c. + " ld1 {v1.d}[1],[x27],x14 \n\t" // Load c03 into quad and increment by rs_c. + " ld1 {v2.d}[0],[x27],x14 \n\t" // Load c04 into quad and increment by rs_c. + " ld1 {v2.d}[1],[x27],x14 \n\t" // Load c05 into quad and increment by rs_c. + " \n\t" + " mov x27, x20 \n\t" // Load address of C. + " \n\t" + " ld1 {v3.d}[0],[x27],x14 \n\t" // Load c10 into quad and increment by rs_c. + " ld1 {v3.d}[1],[x27],x14 \n\t" // Load c11 into quad and increment by rs_c. + " ld1 {v4.d}[0],[x27],x14 \n\t" // Load c12 into quad and increment by rs_c. + " ld1 {v4.d}[1],[x27],x14 \n\t" // Load c13 into quad and increment by rs_c. + " ld1 {v5.d}[0],[x27],x14 \n\t" // Load c14 into quad and increment by rs_c. + " ld1 {v5.d}[1],[x27],x14 \n\t" // Load c15 into quad and increment by rs_c. + " \n\t" + " fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta + " fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta + " fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta + " fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta + " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta + " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROGENSTOREDS1) + " \n\t" + " fmla v0.2d,v8.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v1.2d,v9.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v2.2d,v10.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v3.2d,v11.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v4.2d,v12.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v5.2d,v13.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " mov x27, x2 \n\t" // Load address of C. + " \n\t" + " st1 {v0.d}[0],[x27],x14 \n\t" // Store c00 into quad and increment by rs_c. + " st1 {v0.d}[1],[x27],x14 \n\t" // Store c01 into quad and increment by rs_c. + " st1 {v1.d}[0],[x27],x14 \n\t" // Store c02 into quad and increment by rs_c. + " st1 {v1.d}[1],[x27],x14 \n\t" // Store c03 into quad and increment by rs_c. + " st1 {v2.d}[0],[x27],x14 \n\t" // Store c04 into quad and increment by rs_c. + " st1 {v2.d}[1],[x27],x14 \n\t" // Store c05 into quad and increment by rs_c. + " \n\t" + " mov x27, x20 \n\t" // Load address of C. + " \n\t" + " st1 {v3.d}[0],[x27],x14 \n\t" // Store c10 into quad and increment by rs_c. + " st1 {v3.d}[1],[x27],x14 \n\t" // Store c11 into quad and increment by rs_c. + " st1 {v4.d}[0],[x27],x14 \n\t" // Store c12 into quad and increment by rs_c. + " st1 {v4.d}[1],[x27],x14 \n\t" // Store c13 into quad and increment by rs_c. + " st1 {v5.d}[0],[x27],x14 \n\t" // Store c14 into quad and increment by rs_c. + " st1 {v5.d}[1],[x27],x14 \n\t" // Store c15 into quad and increment by rs_c. + " \n\t" + " dup v8.2d, xzr \n\t" + " dup v9.2d, xzr \n\t" + " dup v10.2d, xzr \n\t" + " dup v11.2d, xzr \n\t" + " dup v12.2d, xzr \n\t" + " dup v13.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROGENSTOREDS2) // Taking care of the beta==0 case. + " \n\t" + " mov x27, x21 \n\t" // Load address of C. + " \n\t" + " ld1 {v8.d}[0], [x27],x14 \n\t" // Load c20 into quad and increment by rs_c. + " ld1 {v8.d}[1], [x27],x14 \n\t" // Load c21 into quad and increment by rs_c. + " ld1 {v9.d}[0], [x27],x14 \n\t" // Load c22 into quad and increment by rs_c. + " ld1 {v9.d}[1], [x27],x14 \n\t" // Load c23 into quad and increment by rs_c. + " ld1 {v10.d}[0],[x27],x14 \n\t" // Load c24 into quad and increment by rs_c. + " ld1 {v10.d}[1],[x27],x14 \n\t" // Load c25 into quad and increment by rs_c. + " \n\t" + " mov x27, x22 \n\t" // Load address of C. + " \n\t" + " ld1 {v11.d}[0],[x27],x14 \n\t" // Load c30 into quad and increment by rs_c. + " ld1 {v11.d}[1],[x27],x14 \n\t" // Load c31 into quad and increment by rs_c. + " ld1 {v12.d}[0],[x27],x14 \n\t" // Load c32 into quad and increment by rs_c. + " ld1 {v12.d}[1],[x27],x14 \n\t" // Load c33 into quad and increment by rs_c. + " ld1 {v13.d}[0],[x27],x14 \n\t" // Load c34 into quad and increment by rs_c. + " ld1 {v13.d}[1],[x27],x14 \n\t" // Load c35 into quad and increment by rs_c. + " \n\t" + " fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta + " fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta + " fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta + " fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta + " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta + " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROGENSTOREDS2) + " \n\t" + " fmla v8.2d, v14.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v9.2d, v15.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v10.2d,v16.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v11.2d,v17.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v12.2d,v18.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v13.2d,v19.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " mov x27, x21 \n\t" // Load address of C. + " \n\t" + " st1 {v8.d}[0], [x27],x14 \n\t" // Store c20 into quad and increment by rs_c. + " st1 {v8.d}[1], [x27],x14 \n\t" // Store c21 into quad and increment by rs_c. + " st1 {v9.d}[0], [x27],x14 \n\t" // Store c22 into quad and increment by rs_c. + " st1 {v9.d}[1], [x27],x14 \n\t" // Store c23 into quad and increment by rs_c. + " st1 {v10.d}[0],[x27],x14 \n\t" // Store c24 into quad and increment by rs_c. + " st1 {v10.d}[1],[x27],x14 \n\t" // Store c25 into quad and increment by rs_c. + " \n\t" + " mov x27, x22 \n\t" // Load address of C. + " \n\t" + " st1 {v11.d}[0],[x27],x14 \n\t" // Store c30 into quad and increment by rs_c. + " st1 {v11.d}[1],[x27],x14 \n\t" // Store c31 into quad and increment by rs_c. + " st1 {v12.d}[0],[x27],x14 \n\t" // Store c32 into quad and increment by rs_c. + " st1 {v12.d}[1],[x27],x14 \n\t" // Store c33 into quad and increment by rs_c. + " st1 {v13.d}[0],[x27],x14 \n\t" // Store c34 into quad and increment by rs_c. + " st1 {v13.d}[1],[x27],x14 \n\t" // Store c35 into quad and increment by rs_c. + " \n\t" + " dup v0.2d, xzr \n\t" + " dup v1.2d, xzr \n\t" + " dup v2.2d, xzr \n\t" + " dup v3.2d, xzr \n\t" + " dup v4.2d, xzr \n\t" + " dup v5.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROGENSTOREDS3) // Taking care of the beta==0 case. + " \n\t" + " mov x27, x23 \n\t" // Load address of C. + " \n\t" + " ld1 {v0.d}[0],[x27],x14 \n\t" // Load c40 into quad and increment by rs_c. + " ld1 {v0.d}[1],[x27],x14 \n\t" // Load c41 into quad and increment by rs_c. + " ld1 {v1.d}[0],[x27],x14 \n\t" // Load c42 into quad and increment by rs_c. + " ld1 {v1.d}[1],[x27],x14 \n\t" // Load c43 into quad and increment by rs_c. + " ld1 {v2.d}[0],[x27],x14 \n\t" // Load c44 into quad and increment by rs_c. + " ld1 {v2.d}[1],[x27],x14 \n\t" // Load c45 into quad and increment by rs_c. + " \n\t" + " mov x27, x24 \n\t" // Load address of C. + " \n\t" + " ld1 {v3.d}[0],[x27],x14 \n\t" // Load c50 into quad and increment by rs_c. + " ld1 {v3.d}[1],[x27],x14 \n\t" // Load c51 into quad and increment by rs_c. + " ld1 {v4.d}[0],[x27],x14 \n\t" // Load c52 into quad and increment by rs_c. + " ld1 {v4.d}[1],[x27],x14 \n\t" // Load c53 into quad and increment by rs_c. + " ld1 {v5.d}[0],[x27],x14 \n\t" // Load c54 into quad and increment by rs_c. + " ld1 {v5.d}[1],[x27],x14 \n\t" // Load c55 into quad and increment by rs_c. + " \n\t" + " fmul v0.2d,v0.2d,v7.d[0] \n\t" // Scale by beta + " fmul v1.2d,v1.2d,v7.d[0] \n\t" // Scale by beta + " fmul v2.2d,v2.2d,v7.d[0] \n\t" // Scale by beta + " fmul v3.2d,v3.2d,v7.d[0] \n\t" // Scale by beta + " fmul v4.2d,v4.2d,v7.d[0] \n\t" // Scale by beta + " fmul v5.2d,v5.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROGENSTOREDS3) + " \n\t" + " fmla v0.2d,v20.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v1.2d,v21.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v2.2d,v22.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v3.2d,v23.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v4.2d,v24.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v5.2d,v25.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " mov x27, x23 \n\t" // Load address of C. + " \n\t" + " st1 {v0.d}[0],[x27],x14 \n\t" // Store c40 into quad and increment by rs_c. + " st1 {v0.d}[1],[x27],x14 \n\t" // Store c41 into quad and increment by rs_c. + " st1 {v1.d}[0],[x27],x14 \n\t" // Store c42 into quad and increment by rs_c. + " st1 {v1.d}[1],[x27],x14 \n\t" // Store c43 into quad and increment by rs_c. + " st1 {v2.d}[0],[x27],x14 \n\t" // Store c44 into quad and increment by rs_c. + " st1 {v2.d}[1],[x27],x14 \n\t" // Store c45 into quad and increment by rs_c. + " \n\t" + " mov x27, x24 \n\t" // Load address of C. + " \n\t" + " st1 {v3.d}[0],[x27],x14 \n\t" // Store c50 into quad and increment by rs_c. + " st1 {v3.d}[1],[x27],x14 \n\t" // Store c51 into quad and increment by rs_c. + " st1 {v4.d}[0],[x27],x14 \n\t" // Store c52 into quad and increment by rs_c. + " st1 {v4.d}[1],[x27],x14 \n\t" // Store c53 into quad and increment by rs_c. + " st1 {v5.d}[0],[x27],x14 \n\t" // Store c54 into quad and increment by rs_c. + " st1 {v5.d}[1],[x27],x14 \n\t" // Store c55 into quad and increment by rs_c. + " \n\t" + " dup v8.2d, xzr \n\t" + " dup v9.2d, xzr \n\t" + " dup v10.2d, xzr \n\t" + " dup v11.2d, xzr \n\t" + " dup v12.2d, xzr \n\t" + " dup v13.2d, xzr \n\t" + " \n\t" + " fcmp d7,#0.0 \n\t" + BEQ(DBETAZEROGENSTOREDS4) // Taking care of the beta==0 case. + " \n\t" + " mov x27, x25 \n\t" + " \n\t" + " ld1 {v8.d}[0], [x27],x14 \n\t" // Load c60 into quad and increment by rs_c. + " ld1 {v8.d}[1], [x27],x14 \n\t" // Load c61 into quad and increment by rs_c. + " ld1 {v9.d}[0], [x27],x14 \n\t" // Load c62 into quad and increment by rs_c. + " ld1 {v9.d}[1], [x27],x14 \n\t" // Load c63 into quad and increment by rs_c. + " ld1 {v10.d}[0],[x27],x14 \n\t" // Load c64 into quad and increment by rs_c. + " ld1 {v10.d}[1],[x27],x14 \n\t" // Load c65 into quad and increment by rs_c. + " \n\t" + " mov x27, x26 \n\t" // Load address of C. + " \n\t" + " ld1 {v11.d}[0],[x27],x14 \n\t" // Load c70 into quad and increment by rs_c. + " ld1 {v11.d}[1],[x27],x14 \n\t" // Load c71 into quad and increment by rs_c. + " ld1 {v12.d}[0],[x27],x14 \n\t" // Load c72 into quad and increment by rs_c. + " ld1 {v12.d}[1],[x27],x14 \n\t" // Load c73 into quad and increment by rs_c. + " ld1 {v13.d}[0],[x27],x14 \n\t" // Load c74 into quad and increment by rs_c. + " ld1 {v13.d}[1],[x27],x14 \n\t" // Load c75 into quad and increment by rs_c. + " \n\t" + " fmul v8.2d, v8.2d, v7.d[0] \n\t" // Scale by beta + " fmul v9.2d, v9.2d, v7.d[0] \n\t" // Scale by beta + " fmul v10.2d,v10.2d,v7.d[0] \n\t" // Scale by beta + " fmul v11.2d,v11.2d,v7.d[0] \n\t" // Scale by beta + " fmul v12.2d,v12.2d,v7.d[0] \n\t" // Scale by beta + " fmul v13.2d,v13.2d,v7.d[0] \n\t" // Scale by beta + " \n\t" + LABEL(DBETAZEROGENSTOREDS4) + " \n\t" + " prfm pldl2keep,[x0] \n\t" + " prfm pldl2keep,[x1] \n\t" + " \n\t" + " fmla v8.2d, v26.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v9.2d, v27.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v10.2d,v28.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v11.2d,v29.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v12.2d,v30.2d,v6.d[0] \n\t" // Scale by alpha + " fmla v13.2d,v31.2d,v6.d[0] \n\t" // Scale by alpha + " \n\t" + " mov x27, x25 \n\t" // Load address of C. + " \n\t" + " st1 {v8.d}[0], [x27],x14 \n\t" // Store c60 into quad and increment by rs_c. + " st1 {v8.d}[1], [x27],x14 \n\t" // Store c61 into quad and increment by rs_c. + " st1 {v9.d}[0], [x27],x14 \n\t" // Store c62 into quad and increment by rs_c. + " st1 {v9.d}[1], [x27],x14 \n\t" // Store c63 into quad and increment by rs_c. + " st1 {v10.d}[0],[x27],x14 \n\t" // Store c64 into quad and increment by rs_c. + " st1 {v10.d}[1],[x27],x14 \n\t" // Store c65 into quad and increment by rs_c. + " \n\t" + " mov x27, x26 \n\t" // Load address of C. + " \n\t" + " st1 {v11.d}[0],[x27],x14 \n\t" // Store c70 into quad and increment by rs_c. + " st1 {v11.d}[1],[x27],x14 \n\t" // Store c71 into quad and increment by rs_c. + " st1 {v12.d}[0],[x27],x14 \n\t" // Store c72 into quad and increment by rs_c. + " st1 {v12.d}[1],[x27],x14 \n\t" // Store c73 into quad and increment by rs_c. + " st1 {v13.d}[0],[x27],x14 \n\t" // Store c74 into quad and increment by rs_c. + " st1 {v13.d}[1],[x27],x14 \n\t" // Store c75 into quad and increment by rs_c. + " \n\t" + LABEL(DEND) // Done! + " \n\t" + :// output operands (none) + :// input operands + [aaddr] "m" (a), // 0 + [baddr] "m" (b), // 1 + [caddr] "m" (c), // 2 + [k_iter] "m" (k_iter), // 3 + [k_left] "m" (k_left), // 4 + [alpha] "m" (alpha), // 5 + [beta] "m" (beta), // 6 + [rs_c] "m" (rs_c), // 6 + [cs_c] "m" (cs_c), // 7 + [a_next] "m" (a_next), // 8 + [b_next] "m" (b_next) // 9 + :// Register clobber list + "x0","x1","x2", + "x5","x6","x10", + "x14","x16","x17", + "x20","x21","x22","x23","x24","x25","x26","x27", + "v0","v1","v2", + "v3","v4","v5", + "v6","v7","v8", + "v9","v10","v11", + "v12","v13","v14", + "v15","v16","v17","v18","v19", + "v20","v21","v22","v23", + "v24","v25","v26","v27", + "v28","v29","v30","v31" + ); - GEMM_UKR_FLUSH_CT( d ); + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c b/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c index 7eba05a138..3a75d61d73 100644 --- a/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c +++ b/kernels/bulldozer/3/bli_gemm_bulldozer_asm_d4x6_fma4.c @@ -109,7 +109,7 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT_ALIGNED( s, 8, 8, false, 32 ); + GEMM_UKR_SETUP_CT_ALIGNED( s, 8, 8, false, 32 ); begin_asm() @@ -274,16 +274,16 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 vfmaddps(ymm11, ymm0, ymm4, ymm11) vfmaddps(ymm9, ymm0, ymm5, ymm9) - vfmaddps(ymm14, ymm0, ymm2, ymm14) + vfmaddps(ymm14, ymm0, ymm2, ymm14) vperm2f128(imm(0x3), ymm2, ymm2, ymm4) vmovsldup(mem(rbx, 1*32), ymm2) add(imm(8*1*4), rbx) // b += 8 (1 x nr) - vfmaddps(ymm12, ymm0, ymm3, ymm12) + vfmaddps(ymm12, ymm0, ymm3, ymm12) vperm2f128(imm(0x3), ymm3, ymm3, ymm5) vpermilps(imm(0x4e), ymm2, ymm3) - vfmaddps(ymm10, ymm0, ymm4, ymm10) - vfmaddps(ymm8, ymm0, ymm5, ymm8) + vfmaddps(ymm10, ymm0, ymm4, ymm10) + vfmaddps(ymm8, ymm0, ymm5, ymm8) vmovaps(ymm1, ymm0) @@ -372,105 +372,105 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 vucomiss(xmm0, xmm4) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - vmovaps(mem(rcx), ymm0) // load c00:c70, - // vmulps(ymm4, ymm0, ymm0) // scale by beta, - // vaddps(ymm15, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm15, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - vmovaps(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm1) // load c01:c71, - // vmulps(ymm4, ymm1, ymm1) // scale by beta, - // vaddps(ymm14, ymm1, ymm1) // add the gemm result, - vfmaddps(ymm14, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, - vmovaps(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm0) // load c02:c72, - // vmulps(ymm4, ymm0, ymm0) // scale by beta, - // vaddps(ymm13, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm13, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - vmovaps(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm1) // load c03:c73, - // vmulps(ymm4, ymm1, ymm1) // scale by beta, - // vaddps(ymm12, ymm1, ymm1) // add the gemm result, - vfmaddps(ymm12, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, - vmovaps(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm0) // load c04:c74, - // vmulps(ymm4, ymm0, ymm0) // scale by beta, - // vaddps(ymm11, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm11, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - vmovaps(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm1) // load c05:c75, - // vmulps(ymm4, ymm1, ymm1) // scale by beta, - // vaddps(ymm10, ymm1, ymm1) // add the gemm result, - vfmaddps(ymm10, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, - vmovaps(ymm1, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm0) // load c06:c76, - // vmulps(ymm4, ymm0, ymm0) // scale by beta, - // vaddps(ymm9, ymm0, ymm0) // add the gemm result, - vfmaddps(ymm9, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, - vmovaps(ymm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; - - vmovaps(mem(rcx), ymm1) // load c07:c77, - // vmulps(ymm4, ymm1, ymm1) // scale by beta, - // vaddps(ymm8, ymm1, ymm1) // add the gemm result, - vfmaddps(ymm8, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, - vmovaps(ymm1, mem(rcx)) // and store back to memory. - - jmp(.SDONE) // jump to end. + vmovaps(mem(rcx), ymm0) // load c00:c70, + //vmulps(ymm4, ymm0, ymm0) // scale by beta, + //vaddps(ymm15, ymm0, ymm0) // add the gemm result, + vfmaddps(ymm15, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, + vmovaps(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm1) // load c01:c71, + //vmulps(ymm4, ymm1, ymm1) // scale by beta, + //vaddps(ymm14, ymm1, ymm1) // add the gemm result, + vfmaddps(ymm14, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, + vmovaps(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm0) // load c02:c72, + //vmulps(ymm4, ymm0, ymm0) // scale by beta, + //vaddps(ymm13, ymm0, ymm0) // add the gemm result, + vfmaddps(ymm13, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, + vmovaps(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm1) // load c03:c73, + //vmulps(ymm4, ymm1, ymm1) // scale by beta, + //vaddps(ymm12, ymm1, ymm1) // add the gemm result, + vfmaddps(ymm12, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, + vmovaps(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm0) // load c04:c74, + //vmulps(ymm4, ymm0, ymm0) // scale by beta, + //vaddps(ymm11, ymm0, ymm0) // add the gemm result, + vfmaddps(ymm11, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, + vmovaps(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm1) // load c05:c75, + //vmulps(ymm4, ymm1, ymm1) // scale by beta, + //vaddps(ymm10, ymm1, ymm1) // add the gemm result, + vfmaddps(ymm10, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, + vmovaps(ymm1, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm0) // load c06:c76, + //vmulps(ymm4, ymm0, ymm0) // scale by beta, + //vaddps(ymm9, ymm0, ymm0) // add the gemm result, + vfmaddps(ymm9, ymm0, ymm4, ymm0) // scale by beta and add the gemm result, + vmovaps(ymm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; + + vmovaps(mem(rcx), ymm1) // load c07:c77, + //vmulps(ymm4, ymm1, ymm1) // scale by beta, + //vaddps(ymm8, ymm1, ymm1) // add the gemm result, + vfmaddps(ymm8, ymm1, ymm4, ymm1) // scale by beta and add the gemm result, + vmovaps(ymm1, mem(rcx)) // and store back to memory. + + jmp(.SDONE) // jump to end. label(.SBETAZERO) - vmovaps(ymm15, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; + vmovaps(ymm15, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; - vmovaps(ymm14, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; + vmovaps(ymm14, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; - vmovaps(ymm13, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; + vmovaps(ymm13, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; - vmovaps(ymm12, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; + vmovaps(ymm12, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; - vmovaps(ymm11, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; + vmovaps(ymm11, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; - vmovaps(ymm10, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; + vmovaps(ymm10, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; - vmovaps(ymm9, mem(rcx)) // and store back to memory. - add(rdi, rcx) // c += cs_c; + vmovaps(ymm9, mem(rcx)) // and store back to memory. + add(rdi, rcx) // c += cs_c; - vmovaps(ymm8, mem(rcx)) // and store back to memory. + vmovaps(ymm8, mem(rcx)) // and store back to memory. label(.SDONE) - end_asm( + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -481,7 +481,7 @@ void bli_sgemm_bulldozer_asm_8x8_fma4 "memory" ) - GEMM_UKR_FLUSH_CT( s ); + GEMM_UKR_FLUSH_CT( s ); } #undef KERNEL4x6_1 @@ -601,7 +601,7 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT_ANY( d, 4, 6, false ); + GEMM_UKR_SETUP_CT_ANY( d, 4, 6, false ); begin_asm() @@ -625,18 +625,18 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 ALIGN32 label(.LOOPKITER) // MAIN LOOP - KERNEL4x6_1(xx) - KERNEL4x6_2(xx) - KERNEL4x6_3(xx) - KERNEL4x6_4(xx) - KERNEL4x6_1(xx) - KERNEL4x6_2(xx) - KERNEL4x6_3(xx) - KERNEL4x6_4(xx) - KERNEL4x6_1(xx) - KERNEL4x6_2(xx) - KERNEL4x6_3(xx) - KERNEL4x6_4(xx) + KERNEL4x6_1(xx) + KERNEL4x6_2(xx) + KERNEL4x6_3(xx) + KERNEL4x6_4(xx) + KERNEL4x6_1(xx) + KERNEL4x6_2(xx) + KERNEL4x6_3(xx) + KERNEL4x6_4(xx) + KERNEL4x6_1(xx) + KERNEL4x6_2(xx) + KERNEL4x6_3(xx) + KERNEL4x6_4(xx) dec(rsi) jne(.LOOPKITER) @@ -648,7 +648,7 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 label(.LOOPKLEFT) je(.POSTACCUM) - KERNEL4x6_1(xx) + KERNEL4x6_1(xx) add(imm(6*8), rbx) add(imm(4*8), rax) @@ -757,30 +757,30 @@ void bli_dgemm_bulldozer_asm_4x6_fma4 vmovhpd(xmm14, mem(rdx, rdi, 1)) vmovhpd(xmm15, mem(r8, rdi, 1)) - end_asm( - : // output operands (none) - : // input operands - [k_iter] "r" (k_iter), // 0 - [k_left] "r" (k_left), // 1 - [a] "r" (a), // 2 - [b] "r" (b), // 3 - [alpha] "r" (alpha), // 4 - [beta] "r" (beta), // 5 - [c] "r" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 - : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", - "xmm0", "xmm1", "xmm2", "xmm3", - "xmm4", "xmm5", "xmm6", "xmm7", - "xmm8", "xmm9", "xmm10", "xmm11", - "xmm12", "xmm13", "xmm14", "xmm15", - "memory" + end_asm( + : // output operands (none) + : // input operands + [k_iter] "r" (k_iter), // 0 + [k_left] "r" (k_left), // 1 + [a] "r" (a), // 2 + [b] "r" (b), // 3 + [alpha] "r" (alpha), // 4 + [beta] "r" (beta), // 5 + [c] "r" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" ) - GEMM_UKR_FLUSH_CT( d ); + GEMM_UKR_FLUSH_CT( d ); } //The parameter "i" is the iteration number, i.e. the B values to read #define MADD_TO_YMM(i) \ @@ -824,7 +824,7 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT_ALIGNED( c, 8, 4, false, 32 ); + GEMM_UKR_SETUP_CT_ALIGNED( c, 8, 4, false, 32 ); begin_asm() @@ -1170,123 +1170,123 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 0, jump to beta == 0 case - // update c00:c70 - - vmovaps(mem(rcx), ymm0) // load c00:c70 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx)) // store c00:c70 - - // update c80:cf0 - - vmovaps(mem(rcx,32), ymm0) // load c80:f0 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx,32)) // store c80:cf0 - add(rdi, rcx) // c += cs_c; - - // update c00:c70 - - vmovaps(mem(rcx), ymm0) // load c01:c71 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx)) // store c01:c71 - - // update c81:cf1 - - vmovaps(mem(rcx,32), ymm0) // load c81:f1 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx,32)) // store c81:cf1 - add(rdi, rcx) // c += cs_c; - - // update c02:c72 - vmovaps(mem(rcx), ymm0) // load c02:c72 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx)) // store c02:c72 - - // update c82:cf2 - vmovaps(mem(rcx,32), ymm0) // load c82:f2 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx,32)) // store c82:cf2 - add(rdi, rcx) // c += cs_c; - - // update c03:c73 - vmovaps(mem(rcx), ymm0) // load c03:c73 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx)) // store c03:c73 - - // update c83:cf3 - vmovaps(mem(rcx,32), ymm0) // load c83:f3 into ymm0 - vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta - vmulps(ymm7, ymm0, ymm0) - vmulps(ymm6, ymm2, ymm2) - vaddsubps(ymm2, ymm0, ymm0) - vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vmovaps(ymm0, mem(rcx,32)) // store c83:cf3 - //add(rdi, rcx) // c += cs_c; - - jmp(.CDONE) // jump to end. + // update c00:c70 + + vmovaps(mem(rcx), ymm0) // load c00:c70 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm15, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx)) // store c00:c70 + + // update c80:cf0 + + vmovaps(mem(rcx,32), ymm0) // load c80:f0 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm14, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx,32)) // store c80:cf0 + add(rdi, rcx) // c += cs_c; + + // update c00:c70 + + vmovaps(mem(rcx), ymm0) // load c01:c71 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm13, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx)) // store c01:c71 + + // update c81:cf1 + + vmovaps(mem(rcx,32), ymm0) // load c81:f1 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm12, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx,32)) // store c81:cf1 + add(rdi, rcx) // c += cs_c; + + // update c02:c72 + vmovaps(mem(rcx), ymm0) // load c02:c72 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm11, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx)) // store c02:c72 + + // update c82:cf2 + vmovaps(mem(rcx,32), ymm0) // load c82:f2 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm10, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx,32)) // store c82:cf2 + add(rdi, rcx) // c += cs_c; + + // update c03:c73 + vmovaps(mem(rcx), ymm0) // load c03:c73 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm9, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx)) // store c03:c73 + + // update c83:cf3 + vmovaps(mem(rcx,32), ymm0) // load c83:f3 into ymm0 + vpermilps(imm(0xb1), ymm0, ymm2) // scale ymm0 by beta + vmulps(ymm7, ymm0, ymm0) + vmulps(ymm6, ymm2, ymm2) + vaddsubps(ymm2, ymm0, ymm0) + vaddps(ymm8, ymm0, ymm0) // add the gemm result to ymm0 + vmovaps(ymm0, mem(rcx,32)) // store c83:cf3 + //add(rdi, rcx) // c += cs_c; + + jmp(.CDONE) // jump to end. label(.CBETAZERO) - vmovaps(ymm15, mem(rcx)) // store c00:c70 - vmovaps(ymm14, mem(rcx,32)) // store c80:cf0 - add(rdi, rcx) // c += cs_c; + vmovaps(ymm15, mem(rcx)) // store c00:c70 + vmovaps(ymm14, mem(rcx,32)) // store c80:cf0 + add(rdi, rcx) // c += cs_c; - vmovaps(ymm13, mem(rcx)) // store c01:c71 - vmovaps(ymm12, mem(rcx,32)) // store c81:cf1 - add(rdi, rcx) // c += cs_c; + vmovaps(ymm13, mem(rcx)) // store c01:c71 + vmovaps(ymm12, mem(rcx,32)) // store c81:cf1 + add(rdi, rcx) // c += cs_c; - vmovaps(ymm11, mem(rcx)) // store c02:c72 - vmovaps(ymm10, mem(rcx,32)) // store c82:cf2 - add(rdi, rcx) // c += cs_c; + vmovaps(ymm11, mem(rcx)) // store c02:c72 + vmovaps(ymm10, mem(rcx,32)) // store c82:cf2 + add(rdi, rcx) // c += cs_c; - vmovaps(ymm9, mem(rcx)) // store c03:c73 - vmovaps(ymm8, mem(rcx,32)) // store c83:cf3 - add(rdi, rcx) // c += cs_c; + vmovaps(ymm9, mem(rcx)) // store c03:c73 + vmovaps(ymm8, mem(rcx,32)) // store c83:cf3 + add(rdi, rcx) // c += cs_c; label(.CDONE) - end_asm( + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next)/*, // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next)/*, // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -1297,7 +1297,7 @@ void bli_cgemm_bulldozer_asm_8x4_fma4 "memory" ) - GEMM_UKR_FLUSH_CT( c ); + GEMM_UKR_FLUSH_CT( c ); } #define MADDSUBPD_TO_YMM \ @@ -1348,7 +1348,7 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT_ALIGNED( z, 4, 4, false, 32 ); + GEMM_UKR_SETUP_CT_ALIGNED( z, 4, 4, false, 32 ); begin_asm() @@ -1649,90 +1649,90 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 0, jump to beta == 0 case - // update c00:c30 + // update c00:c30 - vmovapd(mem(rcx), ymm0) // load c00:c30 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx)) // store c00:c30 + vmovapd(mem(rcx), ymm0) // load c00:c30 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm15, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx)) // store c00:c30 - // update c40:c70 + // update c40:c70 - vmovapd(mem(rcx,32), ymm0) // load c40:c70 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx,32)) // store c40:c70 - add(rdi, rcx) // c += cs_c; + vmovapd(mem(rcx,32), ymm0) // load c40:c70 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm14, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx,32)) // store c40:c70 + add(rdi, rcx) // c += cs_c; - // update c01:c31 + // update c01:c31 - vmovapd(mem(rcx), ymm0) // load c01:c31 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx)) // store c01:c31 + vmovapd(mem(rcx), ymm0) // load c01:c31 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm13, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx)) // store c01:c31 - // update c41:c71 + // update c41:c71 - vmovapd(mem(rcx,32), ymm0) // load c41:c71 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx,32)) // store c41:c71 - add(rdi, rcx) // c += cs_c; + vmovapd(mem(rcx,32), ymm0) // load c41:c71 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm12, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx,32)) // store c41:c71 + add(rdi, rcx) // c += cs_c; - // update c02:c32 + // update c02:c32 - vmovapd(mem(rcx), ymm0) // load c02:c32 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx)) // store c02:c32 + vmovapd(mem(rcx), ymm0) // load c02:c32 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm11, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx)) // store c02:c32 - // update c42:c72 + // update c42:c72 - vmovapd(mem(rcx,32), ymm0) // load c42:c72 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx,32)) // store c42:c72 - add(rdi, rcx) // c += cs_c; + vmovapd(mem(rcx,32), ymm0) // load c42:c72 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm10, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx,32)) // store c42:c72 + add(rdi, rcx) // c += cs_c; - // update c03:c33 + // update c03:c33 - vmovapd(mem(rcx), ymm0) // load c03:c33 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx)) // store c03:c33 + vmovapd(mem(rcx), ymm0) // load c03:c33 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm9, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx)) // store c03:c33 - // update c43:c73 + // update c43:c73 - vmovapd(mem(rcx,32), ymm0) // load c43:c73 into ymm0 - Z_ALPHA(0, 2) // scale ymm0 by beta - vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 - vmovapd(ymm0, mem(rcx,32)) // store c43:c73 - add(rdi, rcx) // c += cs_c; + vmovapd(mem(rcx,32), ymm0) // load c43:c73 into ymm0 + Z_ALPHA(0, 2) // scale ymm0 by beta + vaddpd(ymm8, ymm0, ymm0) // add the gemm result to ymm0 + vmovapd(ymm0, mem(rcx,32)) // store c43:c73 + add(rdi, rcx) // c += cs_c; - jmp(.ZDONE) // jump to end. + jmp(.ZDONE) // jump to end. label(.ZBETAZERO) - vmovapd(ymm15, mem(rcx)) // store c00:c30 - vmovapd(ymm14, mem(rcx,32)) // store c40:c70 - add(rdi, rcx) // c += cs_c; + vmovapd(ymm15, mem(rcx)) // store c00:c30 + vmovapd(ymm14, mem(rcx,32)) // store c40:c70 + add(rdi, rcx) // c += cs_c; - vmovapd(ymm13, mem(rcx)) // store c01:c31 - vmovapd(ymm12, mem(rcx,32)) // store c41:c71 - add(rdi, rcx) // c += cs_c; + vmovapd(ymm13, mem(rcx)) // store c01:c31 + vmovapd(ymm12, mem(rcx,32)) // store c41:c71 + add(rdi, rcx) // c += cs_c; - vmovapd(ymm11, mem(rcx)) // store c02:c32 - vmovapd(ymm10, mem(rcx,32)) // store c42:c72 - add(rdi, rcx) // c += cs_c; + vmovapd(ymm11, mem(rcx)) // store c02:c32 + vmovapd(ymm10, mem(rcx,32)) // store c42:c72 + add(rdi, rcx) // c += cs_c; - vmovapd(ymm9, mem(rcx)) // store c03:c33 - vmovapd(ymm8, mem(rcx,32)) // store c43:c73 - //add(rdi, rcx) // c += cs_c; + vmovapd(ymm9, mem(rcx)) // store c03:c33 + vmovapd(ymm8, mem(rcx,32)) // store c43:c73 + //add(rdi, rcx) // c += cs_c; label(.ZDONE) - end_asm( + end_asm( : // output operands (none) : // input operands [k_iter] "m" (k_iter), // 0 @@ -1756,6 +1756,6 @@ void bli_zgemm_bulldozer_asm_4x4_fma4 "memory" ) - GEMM_UKR_FLUSH_CT( z ); + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c index 897f6fd778..d0e7938678 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c @@ -101,7 +101,7 @@ void bli_sgemm_haswell_asm_6x16 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT_AMBI( s, 6, 16, true ); + GEMM_UKR_SETUP_CT_AMBI( s, 6, 16, true ); begin_asm() @@ -113,7 +113,7 @@ void bli_sgemm_haswell_asm_6x16 //mov(%9, r15) // load address of b_next. add(imm(32*4), rbx) - // initialize loop by pre-loading + // initialize loop by pre-loading vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) @@ -123,55 +123,55 @@ void bli_sgemm_haswell_asm_6x16 mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 4), rsi) // cs_c *= sizeof(float) - cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. - jz(.SCOLPREFETCH) // jump to column prefetch case - - lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; - lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c - prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c - prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c - prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c - prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c - - jmp(.SPREFETCHDONE) - - label(.SCOLPREFETCH) - - lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; - lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c - prefetch(0, mem(rcx, rsi, 1, 7*8)) // prefetch c + 1*cs_c - prefetch(0, mem(rcx, rsi, 2, 7*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c - prefetch(0, mem(rdx, rsi, 1, 7*8)) // prefetch c + 4*cs_c - prefetch(0, mem(rdx, rsi, 2, 7*8)) // prefetch c + 5*cs_c - prefetch(0, mem(rdx, r13, 1, 7*8)) // prefetch c + 6*cs_c - prefetch(0, mem(rdx, rsi, 4, 7*8)) // prefetch c + 7*cs_c - lea(mem(rcx, rsi, 8), r14) // r14 = c + 8*cs_c; - lea(mem(r14, r13, 1), rdx) // rdx = c + 11*cs_c; - prefetch(0, mem(r14, 7*8)) // prefetch c + 8*cs_c - prefetch(0, mem(r14, rsi, 1, 7*8)) // prefetch c + 9*cs_c - prefetch(0, mem(r14, rsi, 2, 7*8)) // prefetch c + 10*cs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 11*cs_c - prefetch(0, mem(rdx, rsi, 1, 7*8)) // prefetch c + 12*cs_c - prefetch(0, mem(rdx, rsi, 2, 7*8)) // prefetch c + 13*cs_c - prefetch(0, mem(rdx, r13, 1, 7*8)) // prefetch c + 14*cs_c - prefetch(0, mem(rdx, rsi, 4, 7*8)) // prefetch c + 15*cs_c - - label(.SPREFETCHDONE) + cmp(imm(4), rsi) // set ZF if (4*cs_c) == 4. + jz(.SCOLPREFETCH) // jump to column prefetch case + + lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; + lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.SPREFETCHDONE) + + label(.SCOLPREFETCH) + + lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; + lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 7*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 7*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 7*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 7*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, r13, 1, 7*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 4, 7*8)) // prefetch c + 7*cs_c + lea(mem(rcx, rsi, 8), r14) // r14 = c + 8*cs_c; + lea(mem(r14, r13, 1), rdx) // rdx = c + 11*cs_c; + prefetch(0, mem(r14, 7*8)) // prefetch c + 8*cs_c + prefetch(0, mem(r14, rsi, 1, 7*8)) // prefetch c + 9*cs_c + prefetch(0, mem(r14, rsi, 2, 7*8)) // prefetch c + 10*cs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 11*cs_c + prefetch(0, mem(rdx, rsi, 1, 7*8)) // prefetch c + 12*cs_c + prefetch(0, mem(rdx, rsi, 2, 7*8)) // prefetch c + 13*cs_c + prefetch(0, mem(rdx, r13, 1, 7*8)) // prefetch c + 14*cs_c + prefetch(0, mem(rdx, rsi, 4, 7*8)) // prefetch c + 15*cs_c + + label(.SPREFETCHDONE) mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. + // contains the k_left loop. label(.SLOOPKITER) // MAIN LOOP - // iteration 0 + // iteration 0 prefetch(0, mem(rax, 64*4)) vbroadcastss(mem(rax, 0*4), ymm2) @@ -198,7 +198,7 @@ void bli_sgemm_haswell_asm_6x16 vmovaps(mem(rbx, -2*32), ymm0) vmovaps(mem(rbx, -1*32), ymm1) - // iteration 1 + // iteration 1 vbroadcastss(mem(rax, 6*4), ymm2) vbroadcastss(mem(rax, 7*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) @@ -223,7 +223,7 @@ void bli_sgemm_haswell_asm_6x16 vmovaps(mem(rbx, 0*32), ymm0) vmovaps(mem(rbx, 1*32), ymm1) - // iteration 2 + // iteration 2 prefetch(0, mem(rax, 76*4)) vbroadcastss(mem(rax, 12*4), ymm2) @@ -250,7 +250,7 @@ void bli_sgemm_haswell_asm_6x16 vmovaps(mem(rbx, 2*32), ymm0) vmovaps(mem(rbx, 3*32), ymm1) - // iteration 3 + // iteration 3 vbroadcastss(mem(rax, 18*4), ymm2) vbroadcastss(mem(rax, 19*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) @@ -292,7 +292,7 @@ void bli_sgemm_haswell_asm_6x16 mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.SPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. + // else, we prepare to enter k_left loop. label(.SLOOPKLEFT) // EDGE LOOP @@ -371,330 +371,330 @@ void bli_sgemm_haswell_asm_6x16 lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c; - // now avoid loading C if beta == 0 + // now avoid loading C if beta == 0 vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm3) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. - jz(.SCOLSTORED) // jump to column storage case - - vfmadd231ps(mem(rcx), ymm3, ymm4) - vmovups(ymm4, mem(rcx)) - vfmadd231ps(mem(rcx,32), ymm3, ymm5) - vmovups(ymm5, mem(rcx,32)) - add(rdi, rcx) - - - vfmadd231ps(mem(rcx), ymm3, ymm6) - vmovups(ymm6, mem(rcx)) - vfmadd231ps(mem(rcx,32), ymm3, ymm7) - vmovups(ymm7, mem(rcx,32)) - add(rdi, rcx) - - - vfmadd231ps(mem(rcx), ymm3, ymm8) - vmovups(ymm8, mem(rcx)) - vfmadd231ps(mem(rcx,32), ymm3, ymm9) - vmovups(ymm9, mem(rcx,32)) - add(rdi, rcx) - - - vfmadd231ps(mem(rcx), ymm3, ymm10) - vmovups(ymm10, mem(rcx)) - vfmadd231ps(mem(rcx,32), ymm3, ymm11) - vmovups(ymm11, mem(rcx,32)) - add(rdi, rcx) - - - vfmadd231ps(mem(rcx), ymm3, ymm12) - vmovups(ymm12, mem(rcx)) - vfmadd231ps(mem(rcx,32), ymm3, ymm13) - vmovups(ymm13, mem(rcx,32)) - add(rdi, rcx) - - - vfmadd231ps(mem(rcx), ymm3, ymm14) - vmovups(ymm14, mem(rcx)) - vfmadd231ps(mem(rcx,32), ymm3, ymm15) - vmovups(ymm15, mem(rcx,32)) - //add(rdi, rcx) - - jmp(.SDONE) // jump to end. - - label(.SCOLSTORED) - - vunpcklps(ymm6, ymm4, ymm0) - vunpcklps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm0) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) - vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) - vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - - vunpckhps(ymm6, ymm4, ymm0) - vunpckhps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) - vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) - vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) - vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) - vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - - lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - - vunpcklps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(r14), xmm1, xmm1) - vmovhpd(mem(r14, rsi, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) - vmovlpd(mem(r14, rsi, 4), xmm1, xmm1) - vmovhpd(mem(r14, r15, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - - vunpckhps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) - vmovhpd(mem(r14, r13, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) - vmovlpd(mem(r14, r13, 2), xmm1, xmm1) - vmovhpd(mem(r14, r10, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - - lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - - - vunpcklps(ymm7, ymm5, ymm0) - vunpcklps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx), xmm3, xmm0) - vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) - vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) - vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - - - vunpckhps(ymm7, ymm5, ymm0) - vunpckhps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) - - vextractf128(imm(0x1), ymm0, xmm2) - vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) - vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) - vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - - vextractf128(imm(0x1), ymm1, xmm2) - vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) - vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) - vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - - //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - - vunpcklps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(r14), xmm1, xmm1) - vmovhpd(mem(r14, rsi, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) - vmovlpd(mem(r14, rsi, 4), xmm1, xmm1) - vmovhpd(mem(r14, r15, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - - vunpckhps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) - vmovhpd(mem(r14, r13, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm0) - vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) - vmovlpd(mem(r14, r13, 2), xmm1, xmm1) - vmovhpd(mem(r14, r10, 1), xmm1, xmm1) - vfmadd231ps(xmm1, xmm3, xmm2) - vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - - //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - - jmp(.SDONE) // jump to end. + cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. + jz(.SCOLSTORED) // jump to column storage case + + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm5) + vmovups(ymm5, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm7) + vmovups(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm9) + vmovups(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm11) + vmovups(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm12) + vmovups(ymm12, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm13) + vmovups(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231ps(mem(rcx), ymm3, ymm14) + vmovups(ymm14, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm15) + vmovups(ymm15, mem(rcx,32)) + //add(rdi, rcx) + + jmp(.SDONE) // jump to end. + + label(.SCOLSTORED) + + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) + + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(r14), xmm1, xmm1) + vmovhpd(mem(r14, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(r14, rsi, 4), xmm1, xmm1) + vmovhpd(mem(r14, r15, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) + vmovhpd(mem(r14, r13, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(r14, r13, 2), xmm1, xmm1) + vmovhpd(mem(r14, r10, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) + + lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c + + + + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx), xmm3, xmm0) + vfmadd231ps(mem(rcx, rsi, 4), xmm3, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, rsi, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, r15, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) + + + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) + + vextractf128(imm(0x1), ymm0, xmm2) + vfmadd231ps(mem(rcx, rsi, 2), xmm3, xmm0) + vfmadd231ps(mem(rcx, r13, 2), xmm3, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) + + vextractf128(imm(0x1), ymm1, xmm2) + vfmadd231ps(mem(rcx, r13, 1), xmm3, xmm1) + vfmadd231ps(mem(rcx, r10, 1), xmm3, xmm2) + vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) + + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(r14), xmm1, xmm1) + vmovhpd(mem(r14, rsi, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(mem(r14, rsi, 4), xmm1, xmm1) + vmovhpd(mem(r14, r15, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) + + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(mem(r14, rsi, 2), xmm1, xmm1) + vmovhpd(mem(r14, r13, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm0) + vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) + vmovlpd(mem(r14, r13, 2), xmm1, xmm1) + vmovhpd(mem(r14, r10, 1), xmm1, xmm1) + vfmadd231ps(xmm1, xmm3, xmm2) + vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) + + //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c + + jmp(.SDONE) // jump to end. label(.SBETAZERO) - cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. - jz(.SCOLSTORBZ) // jump to column storage case + cmp(imm(4), rdi) // set ZF if (4*cs_c) == 4. + jz(.SCOLSTORBZ) // jump to column storage case - vmovups(ymm4, mem(rcx)) - vmovups(ymm5, mem(rcx,32)) - add(rdi, rcx) + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx,32)) + add(rdi, rcx) - vmovups(ymm6, mem(rcx)) - vmovups(ymm7, mem(rcx,32)) - add(rdi, rcx) + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx,32)) + add(rdi, rcx) - vmovups(ymm8, mem(rcx)) - vmovups(ymm9, mem(rcx,32)) - add(rdi, rcx) + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx,32)) + add(rdi, rcx) - vmovups(ymm10, mem(rcx)) - vmovups(ymm11, mem(rcx,32)) - add(rdi, rcx) + vmovups(ymm10, mem(rcx)) + vmovups(ymm11, mem(rcx,32)) + add(rdi, rcx) - vmovups(ymm12, mem(rcx)) - vmovups(ymm13, mem(rcx,32)) - add(rdi, rcx) + vmovups(ymm12, mem(rcx)) + vmovups(ymm13, mem(rcx,32)) + add(rdi, rcx) - vmovups(ymm14, mem(rcx)) - vmovups(ymm15, mem(rcx,32)) - //add(rdi, rcx) + vmovups(ymm14, mem(rcx)) + vmovups(ymm15, mem(rcx,32)) + //add(rdi, rcx) - jmp(.SDONE) // jump to end. + jmp(.SDONE) // jump to end. - label(.SCOLSTORBZ) + label(.SCOLSTORBZ) - vunpcklps(ymm6, ymm4, ymm0) - vunpcklps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) + vunpcklps(ymm6, ymm4, ymm0) + vunpcklps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - vunpckhps(ymm6, ymm4, ymm0) - vunpckhps(ymm10, ymm8, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) + vunpckhps(ymm6, ymm4, ymm0) + vunpckhps(ymm10, ymm8, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - vunpcklps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) - vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) + vunpcklps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - vunpckhps(ymm14, ymm12, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) - vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) + vunpckhps(ymm14, ymm12, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c + lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_c - vunpcklps(ymm7, ymm5, ymm0) - vunpcklps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) + vunpcklps(ymm7, ymm5, ymm0) + vunpcklps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) - vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx)) // store ( gamma00..gamma30 ) + vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma04..gamma34 ) - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) - vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma31 ) + vmovups(xmm2, mem(rcx, r15, 1)) // store ( gamma05..gamma35 ) - vunpckhps(ymm7, ymm5, ymm0) - vunpckhps(ymm11, ymm9, ymm1) - vshufps(imm(0x4e), ymm1, ymm0, ymm2) - vblendps(imm(0xcc), ymm2, ymm0, ymm0) - vblendps(imm(0x33), ymm2, ymm1, ymm1) + vunpckhps(ymm7, ymm5, ymm0) + vunpckhps(ymm11, ymm9, ymm1) + vshufps(imm(0x4e), ymm1, ymm0, ymm2) + vblendps(imm(0xcc), ymm2, ymm0, ymm0) + vblendps(imm(0x33), ymm2, ymm1, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) - vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) + vextractf128(imm(0x1), ymm0, xmm2) + vmovups(xmm0, mem(rcx, rsi, 2)) // store ( gamma02..gamma32 ) + vmovups(xmm2, mem(rcx, r13, 2)) // store ( gamma06..gamma36 ) - vextractf128(imm(0x1), ymm1, xmm2) - vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) - vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) + vextractf128(imm(0x1), ymm1, xmm2) + vmovups(xmm1, mem(rcx, r13, 1)) // store ( gamma03..gamma33 ) + vmovups(xmm2, mem(rcx, r10, 1)) // store ( gamma07..gamma37 ) - //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c + //lea(mem(rcx, rsi, 8), rcx) // rcx += 8*cs_c - vunpcklps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) - vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) - vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) - vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) + vunpcklps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(r14)) // store ( gamma40..gamma50 ) + vmovhpd(xmm0, mem(r14, rsi, 1)) // store ( gamma41..gamma51 ) + vmovlpd(xmm2, mem(r14, rsi, 4)) // store ( gamma44..gamma54 ) + vmovhpd(xmm2, mem(r14, r15, 1)) // store ( gamma45..gamma55 ) - vunpckhps(ymm15, ymm13, ymm0) - vextractf128(imm(0x1), ymm0, xmm2) - vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) - vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) - vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) - vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) + vunpckhps(ymm15, ymm13, ymm0) + vextractf128(imm(0x1), ymm0, xmm2) + vmovlpd(xmm0, mem(r14, rsi, 2)) // store ( gamma42..gamma52 ) + vmovhpd(xmm0, mem(r14, r13, 1)) // store ( gamma43..gamma53 ) + vmovlpd(xmm2, mem(r14, r13, 2)) // store ( gamma46..gamma56 ) + vmovhpd(xmm2, mem(r14, r10, 1)) // store ( gamma47..gamma57 ) - //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_ + //lea(mem(r14, rsi, 8), r14) // r14 += 8*cs_ label(.SDONE) @@ -727,7 +727,7 @@ void bli_sgemm_haswell_asm_6x16 "memory" ) - GEMM_UKR_FLUSH_CT( s ); + GEMM_UKR_FLUSH_CT( s ); } @@ -781,7 +781,7 @@ void bli_dgemm_haswell_asm_6x8 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT_AMBI( d, 6, 8, true ); + GEMM_UKR_SETUP_CT_AMBI( d, 6, 8, true ); begin_asm() @@ -793,7 +793,7 @@ void bli_dgemm_haswell_asm_6x8 //mov(%9, r15) // load address of b_next. add(imm(32*4), rbx) - // initialize loop by pre-loading + // initialize loop by pre-loading vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) @@ -803,46 +803,46 @@ void bli_dgemm_haswell_asm_6x8 mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) - cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. - jz(.SCOLPREFETCH) // jump to column prefetch case + cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. + jz(.SCOLPREFETCH) // jump to column prefetch case - lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; - lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c - prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c - prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c - prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c - prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + lea(mem(rdi, rdi, 2), r13) // r13 = 3*rs_c; + lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c - jmp(.SPREFETCHDONE) + jmp(.SPREFETCHDONE) - label(.SCOLPREFETCH) + label(.SCOLPREFETCH) - lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; - lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c - prefetch(0, mem(rcx, rsi, 1, 7*8)) // prefetch c + 1*cs_c - prefetch(0, mem(rcx, rsi, 2, 7*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c - prefetch(0, mem(rdx, rsi, 1, 7*8)) // prefetch c + 4*cs_c - prefetch(0, mem(rdx, rsi, 2, 7*8)) // prefetch c + 5*cs_c - prefetch(0, mem(rdx, r13, 1, 7*8)) // prefetch c + 6*cs_c - prefetch(0, mem(rdx, rsi, 4, 7*8)) // prefetch c + 7*cs_c + lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; + lea(mem(rcx, r13, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(rcx, 7*8)) // prefetch c + 0*cs_c + prefetch(0, mem(rcx, rsi, 1, 7*8)) // prefetch c + 1*cs_c + prefetch(0, mem(rcx, rsi, 2, 7*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 7*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 7*8)) // prefetch c + 5*cs_c + prefetch(0, mem(rdx, r13, 1, 7*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 4, 7*8)) // prefetch c + 7*cs_c - label(.SPREFETCHDONE) + label(.SPREFETCHDONE) mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. + // contains the k_left loop. label(.DLOOPKITER) // MAIN LOOP - // iteration 0 + // iteration 0 prefetch(0, mem(rax, 64*8)) vbroadcastsd(mem(rax, 0*8), ymm2) @@ -869,7 +869,7 @@ void bli_dgemm_haswell_asm_6x8 vmovapd(mem(rbx, -2*32), ymm0) vmovapd(mem(rbx, -1*32), ymm1) - // iteration 1 + // iteration 1 prefetch(0, mem(rax, 72*8)) vbroadcastsd(mem(rax, 6*8), ymm2) @@ -896,7 +896,7 @@ void bli_dgemm_haswell_asm_6x8 vmovapd(mem(rbx, 0*32), ymm0) vmovapd(mem(rbx, 1*32), ymm1) - // iteration 2 + // iteration 2 prefetch(0, mem(rax, 80*8)) vbroadcastsd(mem(rax, 12*8), ymm2) @@ -923,7 +923,7 @@ void bli_dgemm_haswell_asm_6x8 vmovapd(mem(rbx, 2*32), ymm0) vmovapd(mem(rbx, 3*32), ymm1) - // iteration 3 + // iteration 3 vbroadcastsd(mem(rax, 18*8), ymm2) vbroadcastsd(mem(rax, 19*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) @@ -965,7 +965,7 @@ void bli_dgemm_haswell_asm_6x8 mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. + // else, we prepare to enter k_left loop. label(.DLOOPKLEFT) // EDGE LOOP @@ -1044,232 +1044,232 @@ void bli_dgemm_haswell_asm_6x8 //lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c; - // now avoid loading C if beta == 0 + // now avoid loading C if beta == 0 vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. - jz(.DCOLSTORED) // jump to column storage case - - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) - vfmadd231pd(mem(rcx,32), ymm3, ymm5) - vmovupd(ymm5, mem(rcx,32)) - add(rdi, rcx) + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx,32)) + add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) - vfmadd231pd(mem(rcx,32), ymm3, ymm7) - vmovupd(ymm7, mem(rcx,32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx), ymm3, ymm8) - vmovupd(ymm8, mem(rcx)) - vfmadd231pd(mem(rcx,32), ymm3, ymm9) - vmovupd(ymm9, mem(rcx,32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx), ymm3, ymm10) - vmovupd(ymm10, mem(rcx)) - vfmadd231pd(mem(rcx,32), ymm3, ymm11) - vmovupd(ymm11, mem(rcx,32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx), ymm3, ymm12) - vmovupd(ymm12, mem(rcx)) - vfmadd231pd(mem(rcx,32), ymm3, ymm13) - vmovupd(ymm13, mem(rcx,32)) - add(rdi, rcx) - - - vfmadd231pd(mem(rcx), ymm3, ymm14) - vmovupd(ymm14, mem(rcx)) - vfmadd231pd(mem(rcx,32), ymm3, ymm15) - vmovupd(ymm15, mem(rcx,32)) - //add(rdi, rcx) - - jmp(.DDONE) // jump to end. - - label(.DCOLSTORED) - - vunpcklpd(ymm6, ymm4, ymm0) - vunpckhpd(ymm6, ymm4, ymm1) - vunpcklpd(ymm10, ymm8, ymm2) - vunpckhpd(ymm10, ymm8, ymm3) - vinsertf128(imm(0x1), xmm2, ymm0, ymm4) - vinsertf128(imm(0x1), xmm3, ymm1, ymm6) - vperm2f128(imm(0x31), ymm2, ymm0, ymm8) - vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - - vbroadcastsd(mem(rbx), ymm3) - - vfmadd231pd(mem(rcx), ymm3, ymm4) - vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) - vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) - vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm10) - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm6, mem(rcx, rsi, 1)) - vmovupd(ymm8, mem(rcx, rsi, 2)) - vmovupd(ymm10, mem(rcx, r13, 1)) - - lea(mem(rcx, rsi, 4), rcx) - - vunpcklpd(ymm14, ymm12, ymm0) - vunpckhpd(ymm14, ymm12, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vfmadd231pd(mem(r14), xmm3, xmm0) - vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) - vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) - vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) - - lea(mem(r14, rsi, 4), r14) - - - vunpcklpd(ymm7, ymm5, ymm0) - vunpckhpd(ymm7, ymm5, ymm1) - vunpcklpd(ymm11, ymm9, ymm2) - vunpckhpd(ymm11, ymm9, ymm3) - vinsertf128(imm(0x1), xmm2, ymm0, ymm5) - vinsertf128(imm(0x1), xmm3, ymm1, ymm7) - vperm2f128(imm(0x31), ymm2, ymm0, ymm9) - vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - - vbroadcastsd(mem(rbx), ymm3) - - vfmadd231pd(mem(rcx), ymm3, ymm5) - vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) - vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) - vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm11) - vmovupd(ymm5, mem(rcx)) - vmovupd(ymm7, mem(rcx, rsi, 1)) - vmovupd(ymm9, mem(rcx, rsi, 2)) - vmovupd(ymm11, mem(rcx, r13, 1)) - - //lea(mem(rcx, rsi, 4), rcx) + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx,32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx,32)) + //add(rdi, rcx) + + jmp(.DDONE) // jump to end. + + label(.DCOLSTORED) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, r13, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(r14), xmm3, xmm0) + vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) + vmovupd(xmm0, mem(r14)) + vmovupd(xmm1, mem(r14, rsi, 1)) + vmovupd(xmm2, mem(r14, rsi, 2)) + vmovupd(xmm4, mem(r14, r13, 1)) + + lea(mem(r14, rsi, 4), r14) + + + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, r13, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, r13, 1)) + + //lea(mem(rcx, rsi, 4), rcx) - vunpcklpd(ymm15, ymm13, ymm0) - vunpckhpd(ymm15, ymm13, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) - - vfmadd231pd(mem(r14), xmm3, xmm0) - vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) - vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) - vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) - - //lea(mem(r14, rsi, 4), r14) - - jmp(.DDONE) // jump to end. + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(r14), xmm3, xmm0) + vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(r14, r13, 1), xmm3, xmm4) + vmovupd(xmm0, mem(r14)) + vmovupd(xmm1, mem(r14, rsi, 1)) + vmovupd(xmm2, mem(r14, rsi, 2)) + vmovupd(xmm4, mem(r14, r13, 1)) + + //lea(mem(r14, rsi, 4), r14) + + jmp(.DDONE) // jump to end. label(.DBETAZERO) - cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. - jz(.DCOLSTORBZ) // jump to column storage case + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx,32)) - add(rdi, rcx) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx,32)) + add(rdi, rcx) - vmovupd(ymm6, mem(rcx)) - vmovupd(ymm7, mem(rcx,32)) - add(rdi, rcx) + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx,32)) + add(rdi, rcx) - vmovupd(ymm8, mem(rcx)) - vmovupd(ymm9, mem(rcx,32)) - add(rdi, rcx) + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx,32)) + add(rdi, rcx) - vmovupd(ymm10, mem(rcx)) - vmovupd(ymm11, mem(rcx,32)) - add(rdi, rcx) + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx,32)) + add(rdi, rcx) - vmovupd(ymm12, mem(rcx)) - vmovupd(ymm13, mem(rcx,32)) - add(rdi, rcx) + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx,32)) + add(rdi, rcx) - vmovupd(ymm14, mem(rcx)) - vmovupd(ymm15, mem(rcx,32)) - //add(rdi, rcx) + vmovupd(ymm14, mem(rcx)) + vmovupd(ymm15, mem(rcx,32)) + //add(rdi, rcx) - jmp(.DDONE) // jump to end. + jmp(.DDONE) // jump to end. - label(.DCOLSTORBZ) + label(.DCOLSTORBZ) - vunpcklpd(ymm6, ymm4, ymm0) - vunpckhpd(ymm6, ymm4, ymm1) - vunpcklpd(ymm10, ymm8, ymm2) - vunpckhpd(ymm10, ymm8, ymm3) - vinsertf128(imm(0x1), xmm2, ymm0, ymm4) - vinsertf128(imm(0x1), xmm3, ymm1, ymm6) - vperm2f128(imm(0x31), ymm2, ymm0, ymm8) - vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm6, mem(rcx, rsi, 1)) - vmovupd(ymm8, mem(rcx, rsi, 2)) - vmovupd(ymm10, mem(rcx, r13, 1)) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, r13, 1)) - lea(mem(rcx, rsi, 4), rcx) + lea(mem(rcx, rsi, 4), rcx) - vunpcklpd(ymm14, ymm12, ymm0) - vunpckhpd(ymm14, ymm12, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) + vmovupd(xmm0, mem(r14)) + vmovupd(xmm1, mem(r14, rsi, 1)) + vmovupd(xmm2, mem(r14, rsi, 2)) + vmovupd(xmm4, mem(r14, r13, 1)) - lea(mem(r14, rsi, 4), r14) + lea(mem(r14, rsi, 4), r14) - vunpcklpd(ymm7, ymm5, ymm0) - vunpckhpd(ymm7, ymm5, ymm1) - vunpcklpd(ymm11, ymm9, ymm2) - vunpckhpd(ymm11, ymm9, ymm3) - vinsertf128(imm(0x1), xmm2, ymm0, ymm5) - vinsertf128(imm(0x1), xmm3, ymm1, ymm7) - vperm2f128(imm(0x31), ymm2, ymm0, ymm9) - vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - vmovupd(ymm5, mem(rcx)) - vmovupd(ymm7, mem(rcx, rsi, 1)) - vmovupd(ymm9, mem(rcx, rsi, 2)) - vmovupd(ymm11, mem(rcx, r13, 1)) + vmovupd(ymm5, mem(rcx)) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, r13, 1)) - //lea(mem(rcx, rsi, 4), rcx) + //lea(mem(rcx, rsi, 4), rcx) - vunpcklpd(ymm15, ymm13, ymm0) - vunpckhpd(ymm15, ymm13, ymm1) - vextractf128(imm(0x1), ymm0, xmm2) - vextractf128(imm(0x1), ymm1, xmm4) + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) - vmovupd(xmm0, mem(r14)) - vmovupd(xmm1, mem(r14, rsi, 1)) - vmovupd(xmm2, mem(r14, rsi, 2)) - vmovupd(xmm4, mem(r14, r13, 1)) + vmovupd(xmm0, mem(r14)) + vmovupd(xmm1, mem(r14, rsi, 1)) + vmovupd(xmm2, mem(r14, rsi, 2)) + vmovupd(xmm4, mem(r14, r13, 1)) - //lea(mem(r14, rsi, 4), r14) + //lea(mem(r14, rsi, 4), r14) label(.DDONE) @@ -1302,7 +1302,7 @@ void bli_dgemm_haswell_asm_6x8 "memory" ) - GEMM_UKR_FLUSH_CT( d ); + GEMM_UKR_FLUSH_CT( d ); } @@ -1340,7 +1340,7 @@ void bli_cgemm_haswell_asm_3x8 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT( c, 3, 8, true ); + GEMM_UKR_SETUP_CT( c, 3, 8, true ); begin_asm() @@ -1352,7 +1352,7 @@ void bli_cgemm_haswell_asm_3x8 //mov(%9, r15) // load address of b_next. add(imm(32*4), rbx) - // initialize loop by pre-loading + // initialize loop by pre-loading vmovaps(mem(rbx, -4*32), ymm0) vmovaps(mem(rbx, -3*32), ymm1) @@ -1373,13 +1373,13 @@ void bli_cgemm_haswell_asm_3x8 mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.CCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. + // contains the k_left loop. label(.CLOOPKITER) // MAIN LOOP - // iteration 0 + // iteration 0 prefetch(0, mem(rax, 32*8)) vbroadcastss(mem(rax, 0*4), ymm2) @@ -1406,7 +1406,7 @@ void bli_cgemm_haswell_asm_3x8 vmovaps(mem(rbx, -2*32), ymm0) vmovaps(mem(rbx, -1*32), ymm1) - // iteration 1 + // iteration 1 vbroadcastss(mem(rax, 6*4), ymm2) vbroadcastss(mem(rax, 7*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) @@ -1431,7 +1431,7 @@ void bli_cgemm_haswell_asm_3x8 vmovaps(mem(rbx, 0*32), ymm0) vmovaps(mem(rbx, 1*32), ymm1) - // iteration 2 + // iteration 2 prefetch(0, mem(rax, 38*8)) vbroadcastss(mem(rax, 12*4), ymm2) @@ -1458,7 +1458,7 @@ void bli_cgemm_haswell_asm_3x8 vmovaps(mem(rbx, 2*32), ymm0) vmovaps(mem(rbx, 3*32), ymm1) - // iteration 3 + // iteration 3 vbroadcastss(mem(rax, 18*4), ymm2) vbroadcastss(mem(rax, 19*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) @@ -1500,7 +1500,7 @@ void bli_cgemm_haswell_asm_3x8 mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.CPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. + // else, we prepare to enter k_left loop. label(.CLOOPKLEFT) // EDGE LOOP @@ -1543,8 +1543,8 @@ void bli_cgemm_haswell_asm_3x8 label(.CPOSTACCUM) - // permute even and odd elements - // of ymm6/7, ymm10/11, ymm/14/15 + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 vpermilps(imm(0xb1), ymm6, ymm6) vpermilps(imm(0xb1), ymm7, ymm7) vpermilps(imm(0xb1), ymm10, ymm10) @@ -1553,7 +1553,7 @@ void bli_cgemm_haswell_asm_3x8 vpermilps(imm(0xb1), ymm15, ymm15) - // subtract/add even/odd elements + // subtract/add even/odd elements vaddsubps(ymm6, ymm4, ymm4) vaddsubps(ymm7, ymm5, ymm5) @@ -1612,7 +1612,7 @@ void bli_cgemm_haswell_asm_3x8 vbroadcastss(mem(rbx, 4), ymm2) // load beta_i and duplicate - // now avoid loading C if beta == 0 + // now avoid loading C if beta == 0 vxorps(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomiss(xmm0, xmm1) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -1621,49 +1621,49 @@ void bli_cgemm_haswell_asm_3x8 and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 1, jump to beta == 0 case - CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx)) - vaddps(ymm4, ymm0, ymm0) - vmovups(ymm0, mem(rcx)) + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx)) + vaddps(ymm4, ymm0, ymm0) + vmovups(ymm0, mem(rcx)) - CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx,32)) - vaddps(ymm5, ymm0, ymm0) - vmovups(ymm0, mem(rcx,32)) + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx,32)) + vaddps(ymm5, ymm0, ymm0) + vmovups(ymm0, mem(rcx,32)) - CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11)) - vaddps(ymm8, ymm0, ymm0) - vmovups(ymm0, mem(r11)) + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11)) + vaddps(ymm8, ymm0, ymm0) + vmovups(ymm0, mem(r11)) - CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11,32)) - vaddps(ymm9, ymm0, ymm0) - vmovups(ymm0, mem(r11,32)) + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11,32)) + vaddps(ymm9, ymm0, ymm0) + vmovups(ymm0, mem(r11,32)) - CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12)) - vaddps(ymm12, ymm0, ymm0) - vmovups(ymm0, mem(r12)) + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12)) + vaddps(ymm12, ymm0, ymm0) + vmovups(ymm0, mem(r12)) - CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12,32)) - vaddps(ymm13, ymm0, ymm0) - vmovups(ymm0, mem(r12,32)) + CGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12,32)) + vaddps(ymm13, ymm0, ymm0) + vmovups(ymm0, mem(r12,32)) - jmp(.CDONE) // jump to end. + jmp(.CDONE) // jump to end. label(.CBETAZERO) - vmovups(ymm4, mem(rcx)) - vmovups(ymm5, mem(rcx,32)) + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx,32)) - vmovups(ymm8, mem(r11)) - vmovups(ymm9, mem(r11,32)) + vmovups(ymm8, mem(r11)) + vmovups(ymm9, mem(r11,32)) - vmovups(ymm12, mem(r12)) - vmovups(ymm13, mem(r12,32)) + vmovups(ymm12, mem(r12)) + vmovups(ymm13, mem(r12,32)) label(.CDONE) @@ -1696,7 +1696,7 @@ void bli_cgemm_haswell_asm_3x8 "memory" ) - GEMM_UKR_FLUSH_CT( c ); + GEMM_UKR_FLUSH_CT( c ); } @@ -1733,7 +1733,7 @@ void bli_zgemm_haswell_asm_3x4 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT( z, 3, 4, true ); + GEMM_UKR_SETUP_CT( z, 3, 4, true ); begin_asm() @@ -1745,7 +1745,7 @@ void bli_zgemm_haswell_asm_3x4 //mov(%9, r15) // load address of b_next. add(imm(32*4), rbx) - // initialize loop by pre-loading + // initialize loop by pre-loading vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) @@ -1767,13 +1767,13 @@ void bli_zgemm_haswell_asm_3x4 mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.ZCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. + // contains the k_left loop. label(.ZLOOPKITER) // MAIN LOOP - // iteration 0 + // iteration 0 prefetch(0, mem(rax, 32*16)) vbroadcastsd(mem(rax, 0*8), ymm2) @@ -1800,7 +1800,7 @@ void bli_zgemm_haswell_asm_3x4 vmovapd(mem(rbx, -2*32), ymm0) vmovapd(mem(rbx, -1*32), ymm1) - // iteration 1 + // iteration 1 prefetch(0, mem(rax, 36*16)) vbroadcastsd(mem(rax, 6*8), ymm2) @@ -1827,7 +1827,7 @@ void bli_zgemm_haswell_asm_3x4 vmovapd(mem(rbx, 0*32), ymm0) vmovapd(mem(rbx, 1*32), ymm1) - // iteration 2 + // iteration 2 prefetch(0, mem(rax, 40*16)) vbroadcastsd(mem(rax, 12*8), ymm2) @@ -1854,7 +1854,7 @@ void bli_zgemm_haswell_asm_3x4 vmovapd(mem(rbx, 2*32), ymm0) vmovapd(mem(rbx, 3*32), ymm1) - // iteration 3 + // iteration 3 vbroadcastsd(mem(rax, 18*8), ymm2) vbroadcastsd(mem(rax, 19*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) @@ -1896,7 +1896,7 @@ void bli_zgemm_haswell_asm_3x4 mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. + // else, we prepare to enter k_left loop. label(.ZLOOPKLEFT) // EDGE LOOP @@ -1938,8 +1938,8 @@ void bli_zgemm_haswell_asm_3x4 label(.ZPOSTACCUM) - // permute even and odd elements - // of ymm6/7, ymm10/11, ymm/14/15 + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 vpermilpd(imm(0x5), ymm6, ymm6) vpermilpd(imm(0x5), ymm7, ymm7) vpermilpd(imm(0x5), ymm10, ymm10) @@ -1948,7 +1948,7 @@ void bli_zgemm_haswell_asm_3x4 vpermilpd(imm(0x5), ymm15, ymm15) - // subtract/add even/odd elements + // subtract/add even/odd elements vaddsubpd(ymm6, ymm4, ymm4) vaddsubpd(ymm7, ymm5, ymm5) @@ -2008,7 +2008,7 @@ void bli_zgemm_haswell_asm_3x4 - // now avoid loading C if beta == 0 + // now avoid loading C if beta == 0 vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. sete(r8b) // r8b = ( ZF == 1 ? 1 : 0 ); @@ -2017,49 +2017,49 @@ void bli_zgemm_haswell_asm_3x4 and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case - ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx)) - vaddpd(ymm4, ymm0, ymm0) - vmovupd(ymm0, mem(rcx)) + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx)) + vaddpd(ymm4, ymm0, ymm0) + vmovupd(ymm0, mem(rcx)) - ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx,32)) - vaddpd(ymm5, ymm0, ymm0) - vmovupd(ymm0, mem(rcx,32)) + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(rcx,32)) + vaddpd(ymm5, ymm0, ymm0) + vmovupd(ymm0, mem(rcx,32)) - ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11)) - vaddpd(ymm8, ymm0, ymm0) - vmovupd(ymm0, mem(r11)) + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11)) + vaddpd(ymm8, ymm0, ymm0) + vmovupd(ymm0, mem(r11)) - ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11,32)) - vaddpd(ymm9, ymm0, ymm0) - vmovupd(ymm0, mem(r11,32)) + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r11,32)) + vaddpd(ymm9, ymm0, ymm0) + vmovupd(ymm0, mem(r11,32)) - ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12)) - vaddpd(ymm12, ymm0, ymm0) - vmovupd(ymm0, mem(r12)) + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12)) + vaddpd(ymm12, ymm0, ymm0) + vmovupd(ymm0, mem(r12)) - ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12,32)) - vaddpd(ymm13, ymm0, ymm0) - vmovupd(ymm0, mem(r12,32)) + ZGEMM_INPUT_SCALE_RS_BETA_NZ(mem(r12,32)) + vaddpd(ymm13, ymm0, ymm0) + vmovupd(ymm0, mem(r12,32)) - jmp(.ZDONE) // jump to end. + jmp(.ZDONE) // jump to end. label(.ZBETAZERO) - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx,32)) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx,32)) - vmovupd(ymm8, mem(r11)) - vmovupd(ymm9, mem(r11,32)) + vmovupd(ymm8, mem(r11)) + vmovupd(ymm9, mem(r11,32)) - vmovupd(ymm12, mem(r12)) - vmovupd(ymm13, mem(r12,32)) + vmovupd(ymm12, mem(r12)) + vmovupd(ymm13, mem(r12,32)) label(.ZDONE) @@ -2092,7 +2092,7 @@ void bli_zgemm_haswell_asm_3x4 "memory" ) - GEMM_UKR_FLUSH_CT( z ); + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c b/kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c index 27dc99f52a..a3a8b0b09f 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d8x6.c @@ -100,7 +100,7 @@ void bli_sgemm_haswell_asm_16x6 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT( s, 16, 6, true ); + GEMM_UKR_SETUP_CT( s, 16, 6, true ); begin_asm() @@ -332,78 +332,78 @@ void bli_sgemm_haswell_asm_16x6 vucomiss(xmm0, xmm3) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - vfmadd231ps(mem(rcx), ymm3, ymm4) - vmovups(ymm4, mem(rcx)) - vfmadd231ps(mem(rcx,32), ymm3, ymm5) - vmovups(ymm5, mem(rcx,32)) - add(rdi, rcx) + vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(ymm4, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm5) + vmovups(ymm5, mem(rcx,32)) + add(rdi, rcx) - vfmadd231ps(mem(rcx), ymm3, ymm6) - vmovups(ymm6, mem(rcx)) - vfmadd231ps(mem(rcx,32), ymm3, ymm7) - vmovups(ymm7, mem(rcx,32)) - add(rdi, rcx) + vfmadd231ps(mem(rcx), ymm3, ymm6) + vmovups(ymm6, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm7) + vmovups(ymm7, mem(rcx,32)) + add(rdi, rcx) - vfmadd231ps(mem(rcx), ymm3, ymm8) - vmovups(ymm8, mem(rcx)) - vfmadd231ps(mem(rcx,32), ymm3, ymm9) - vmovups(ymm9, mem(rcx,32)) - add(rdi, rcx) + vfmadd231ps(mem(rcx), ymm3, ymm8) + vmovups(ymm8, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm9) + vmovups(ymm9, mem(rcx,32)) + add(rdi, rcx) - vfmadd231ps(mem(rcx), ymm3, ymm10) - vmovups(ymm10, mem(rcx)) - vfmadd231ps(mem(rcx,32), ymm3, ymm11) - vmovups(ymm11, mem(rcx,32)) - add(rdi, rcx) + vfmadd231ps(mem(rcx), ymm3, ymm10) + vmovups(ymm10, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm11) + vmovups(ymm11, mem(rcx,32)) + add(rdi, rcx) - vfmadd231ps(mem(rcx), ymm3, ymm12) - vmovups(ymm12, mem(rcx)) - vfmadd231ps(mem(rcx,32), ymm3, ymm13) - vmovups(ymm13, mem(rcx,32)) - add(rdi, rcx) + vfmadd231ps(mem(rcx), ymm3, ymm12) + vmovups(ymm12, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm13) + vmovups(ymm13, mem(rcx,32)) + add(rdi, rcx) - vfmadd231ps(mem(rcx), ymm3, ymm14) - vmovups(ymm14, mem(rcx)) - vfmadd231ps(mem(rcx,32), ymm3, ymm15) - vmovups(ymm15, mem(rcx,32)) - //add(rdi, rcx) + vfmadd231ps(mem(rcx), ymm3, ymm14) + vmovups(ymm14, mem(rcx)) + vfmadd231ps(mem(rcx,32), ymm3, ymm15) + vmovups(ymm15, mem(rcx,32)) + //add(rdi, rcx) jmp(.SDONE) // jump to end. label(.SBETAZERO) - vmovups(ymm4, mem(rcx)) - vmovups(ymm5, mem(rcx,32)) - add(rdi, rcx) + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx,32)) + add(rdi, rcx) - vmovups(ymm6, mem(rcx)) - vmovups(ymm7, mem(rcx,32)) - add(rdi, rcx) + vmovups(ymm6, mem(rcx)) + vmovups(ymm7, mem(rcx,32)) + add(rdi, rcx) - vmovups(ymm8, mem(rcx)) - vmovups(ymm9, mem(rcx,32)) - add(rdi, rcx) + vmovups(ymm8, mem(rcx)) + vmovups(ymm9, mem(rcx,32)) + add(rdi, rcx) - vmovups(ymm10, mem(rcx)) - vmovups(ymm11, mem(rcx,32)) - add(rdi, rcx) + vmovups(ymm10, mem(rcx)) + vmovups(ymm11, mem(rcx,32)) + add(rdi, rcx) - vmovups(ymm12, mem(rcx)) - vmovups(ymm13, mem(rcx,32)) - add(rdi, rcx) + vmovups(ymm12, mem(rcx)) + vmovups(ymm13, mem(rcx,32)) + add(rdi, rcx) - vmovups(ymm14, mem(rcx)) - vmovups(ymm15, mem(rcx,32)) - //add(rdi, rcx) + vmovups(ymm14, mem(rcx)) + vmovups(ymm15, mem(rcx,32)) + //add(rdi, rcx) label(.SDONE) @@ -412,17 +412,17 @@ void bli_sgemm_haswell_asm_16x6 end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -433,7 +433,7 @@ void bli_sgemm_haswell_asm_16x6 "memory" ) - GEMM_UKR_FLUSH_CT( s ); + GEMM_UKR_FLUSH_CT( s ); } #define DGEMM_INPUT_GS_BETA_NZ \ @@ -484,7 +484,7 @@ void bli_dgemm_haswell_asm_8x6 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT( d, 8, 6, false ); + GEMM_UKR_SETUP_CT( d, 8, 6, false ); begin_asm() @@ -716,97 +716,97 @@ void bli_dgemm_haswell_asm_8x6 vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - vfmadd231pd(mem(rcx), ymm3, ymm4) - vmovupd(ymm4, mem(rcx)) - vfmadd231pd(mem(rcx,32), ymm3, ymm5) - vmovupd(ymm5, mem(rcx,32)) - add(rdi, rcx) + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx,32)) + add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm6) - vmovupd(ymm6, mem(rcx)) - vfmadd231pd(mem(rcx,32), ymm3, ymm7) - vmovupd(ymm7, mem(rcx,32)) - add(rdi, rcx) + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx,32)) + add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm8) - vmovupd(ymm8, mem(rcx)) - vfmadd231pd(mem(rcx,32), ymm3, ymm9) - vmovupd(ymm9, mem(rcx,32)) - add(rdi, rcx) + vfmadd231pd(mem(rcx), ymm3, ymm8) + vmovupd(ymm8, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx,32)) + add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm10) - vmovupd(ymm10, mem(rcx)) - vfmadd231pd(mem(rcx,32), ymm3, ymm11) - vmovupd(ymm11, mem(rcx,32)) - add(rdi, rcx) + vfmadd231pd(mem(rcx), ymm3, ymm10) + vmovupd(ymm10, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx,32)) + add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm12) - vmovupd(ymm12, mem(rcx)) - vfmadd231pd(mem(rcx,32), ymm3, ymm13) - vmovupd(ymm13, mem(rcx,32)) - add(rdi, rcx) + vfmadd231pd(mem(rcx), ymm3, ymm12) + vmovupd(ymm12, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx,32)) + add(rdi, rcx) - vfmadd231pd(mem(rcx), ymm3, ymm14) - vmovupd(ymm14, mem(rcx)) - vfmadd231pd(mem(rcx,32), ymm3, ymm15) - vmovupd(ymm15, mem(rcx,32)) - //add(rdi, rcx) + vfmadd231pd(mem(rcx), ymm3, ymm14) + vmovupd(ymm14, mem(rcx)) + vfmadd231pd(mem(rcx,32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx,32)) + //add(rdi, rcx) - jmp(.DDONE) // jump to end. + jmp(.DDONE) // jump to end. label(.DBETAZERO) - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx,32)) - add(rdi, rcx) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx,32)) + add(rdi, rcx) - vmovupd(ymm6, mem(rcx)) - vmovupd(ymm7, mem(rcx,32)) - add(rdi, rcx) + vmovupd(ymm6, mem(rcx)) + vmovupd(ymm7, mem(rcx,32)) + add(rdi, rcx) - vmovupd(ymm8, mem(rcx)) - vmovupd(ymm9, mem(rcx,32)) - add(rdi, rcx) + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx,32)) + add(rdi, rcx) - vmovupd(ymm10, mem(rcx)) - vmovupd(ymm11, mem(rcx,32)) - add(rdi, rcx) + vmovupd(ymm10, mem(rcx)) + vmovupd(ymm11, mem(rcx,32)) + add(rdi, rcx) - vmovupd(ymm12, mem(rcx)) - vmovupd(ymm13, mem(rcx,32)) - add(rdi, rcx) + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx,32)) + add(rdi, rcx) - vmovupd(ymm14, mem(rcx)) - vmovupd(ymm15, mem(rcx,32)) - //add(rdi, rcx) + vmovupd(ymm14, mem(rcx)) + vmovupd(ymm15, mem(rcx,32)) + //add(rdi, rcx) label(.DDONE) - end_asm( + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -817,7 +817,7 @@ void bli_dgemm_haswell_asm_8x6 "memory" ) - GEMM_UKR_FLUSH_CT( d ); + GEMM_UKR_FLUSH_CT( d ); } @@ -854,7 +854,7 @@ void bli_cgemm_haswell_asm_8x3 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT( c, 8, 3, false ); + GEMM_UKR_SETUP_CT( c, 8, 3, false ); begin_asm() @@ -1136,68 +1136,68 @@ void bli_cgemm_haswell_asm_8x3 and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 1, jump to beta == 0 case - CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx)) - vaddps(ymm4, ymm0, ymm0) - vmovups(ymm0, mem(rcx)) + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx)) + vaddps(ymm4, ymm0, ymm0) + vmovups(ymm0, mem(rcx)) - CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx,32)) - vaddps(ymm5, ymm0, ymm0) - vmovups(ymm0, mem(rcx,32)) + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx,32)) + vaddps(ymm5, ymm0, ymm0) + vmovups(ymm0, mem(rcx,32)) - CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11)) - vaddps(ymm8, ymm0, ymm0) - vmovups(ymm0, mem(r11)) + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11)) + vaddps(ymm8, ymm0, ymm0) + vmovups(ymm0, mem(r11)) - CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11,32)) - vaddps(ymm9, ymm0, ymm0) - vmovups(ymm0, mem(r11,32)) + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11,32)) + vaddps(ymm9, ymm0, ymm0) + vmovups(ymm0, mem(r11,32)) - CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12)) - vaddps(ymm12, ymm0, ymm0) - vmovups(ymm0, mem(r12)) + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12)) + vaddps(ymm12, ymm0, ymm0) + vmovups(ymm0, mem(r12)) - CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12,32)) - vaddps(ymm13, ymm0, ymm0) - vmovups(ymm0, mem(r12,32)) + CGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12,32)) + vaddps(ymm13, ymm0, ymm0) + vmovups(ymm0, mem(r12,32)) - jmp(.CDONE) // jump to end. + jmp(.CDONE) // jump to end. label(.CBETAZERO) - vmovups(ymm4, mem(rcx)) - vmovups(ymm5, mem(rcx,32)) + vmovups(ymm4, mem(rcx)) + vmovups(ymm5, mem(rcx,32)) - vmovups(ymm8, mem(r11)) - vmovups(ymm9, mem(r11,32)) + vmovups(ymm8, mem(r11)) + vmovups(ymm9, mem(r11,32)) - vmovups(ymm12, mem(r12)) - vmovups(ymm13, mem(r12,32)) + vmovups(ymm12, mem(r12)) + vmovups(ymm13, mem(r12,32)) label(.CDONE) - end_asm( + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c)/*, // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c)/*, // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -1208,7 +1208,7 @@ void bli_cgemm_haswell_asm_8x3 "memory" ) - GEMM_UKR_FLUSH_CT( c ); + GEMM_UKR_FLUSH_CT( c ); } @@ -1245,7 +1245,7 @@ void bli_zgemm_haswell_asm_4x3 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT( z, 4, 3, false ); + GEMM_UKR_SETUP_CT( z, 4, 3, false ); begin_asm() @@ -1528,55 +1528,55 @@ void bli_zgemm_haswell_asm_4x3 and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case - ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx)) - vaddpd(ymm4, ymm0, ymm0) - vmovupd(ymm0, mem(rcx)) + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx)) + vaddpd(ymm4, ymm0, ymm0) + vmovupd(ymm0, mem(rcx)) - ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx,32)) - vaddpd(ymm5, ymm0, ymm0) - vmovupd(ymm0, mem(rcx,32)) + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(rcx,32)) + vaddpd(ymm5, ymm0, ymm0) + vmovupd(ymm0, mem(rcx,32)) - ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11)) - vaddpd(ymm8, ymm0, ymm0) - vmovupd(ymm0, mem(r11)) + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11)) + vaddpd(ymm8, ymm0, ymm0) + vmovupd(ymm0, mem(r11)) - ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11,32)) - vaddpd(ymm9, ymm0, ymm0) - vmovupd(ymm0, mem(r11,32)) + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r11,32)) + vaddpd(ymm9, ymm0, ymm0) + vmovupd(ymm0, mem(r11,32)) - ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12)) - vaddpd(ymm12, ymm0, ymm0) - vmovupd(ymm0, mem(r12)) + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12)) + vaddpd(ymm12, ymm0, ymm0) + vmovupd(ymm0, mem(r12)) - ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12,32)) - vaddpd(ymm13, ymm0, ymm0) - vmovupd(ymm0, mem(r12,32)) + ZGEMM_INPUT_SCALE_CS_BETA_NZ(mem(r12,32)) + vaddpd(ymm13, ymm0, ymm0) + vmovupd(ymm0, mem(r12,32)) - jmp(.ZDONE) // jump to end. + jmp(.ZDONE) // jump to end. label(.ZBETAZERO) - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx,32)) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx,32)) - vmovupd(ymm8, mem(r11)) - vmovupd(ymm9, mem(r11,32)) + vmovupd(ymm8, mem(r11)) + vmovupd(ymm9, mem(r11,32)) - vmovupd(ymm12, mem(r12)) - vmovupd(ymm13, mem(r12,32)) + vmovupd(ymm12, mem(r12)) + vmovupd(ymm13, mem(r12,32)) label(.ZDONE) - end_asm( + end_asm( : // output operands (none) : // input operands [k_iter] "m" (k_iter), // 0 @@ -1600,7 +1600,7 @@ void bli_zgemm_haswell_asm_4x3 "memory" ) - GEMM_UKR_FLUSH_CT( z ); + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c b/kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c index 5d24f6e86e..a3e39c3ac1 100644 --- a/kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c +++ b/kernels/penryn/3/bli_gemm_penryn_asm_d4x4.c @@ -61,7 +61,7 @@ void bli_sgemm_penryn_asm_8x4 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT_ALIGNED( s, 8, 4, false, 16 ); + GEMM_UKR_SETUP_CT_ALIGNED( s, 8, 4, false, 16 ); begin_asm() @@ -380,126 +380,126 @@ void bli_sgemm_penryn_asm_8x4 ucomisd(xmm0, xmm7) // check if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - movaps(mem(rcx), xmm0) // load c00 ~ c30, - mulps(xmm6, xmm8) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm8, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) + movaps(mem(rcx), xmm0) // load c00 ~ c30, + mulps(xmm6, xmm8) // scale by alpha, + mulps(xmm7, xmm0) // scale by beta, + addps(xmm8, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) - movaps(mem(rdx), xmm1) // load c40 ~ c70, - mulps(xmm6, xmm12) // scale by alpha, - mulps(xmm7, xmm1) // scale by beta, - addps(xmm12, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) + movaps(mem(rdx), xmm1) // load c40 ~ c70, + mulps(xmm6, xmm12) // scale by alpha, + mulps(xmm7, xmm1) // scale by beta, + addps(xmm12, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) - movaps(mem(rcx), xmm0) // load c01 ~ c31, - mulps(xmm6, xmm9) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm9, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) + movaps(mem(rcx), xmm0) // load c01 ~ c31, + mulps(xmm6, xmm9) // scale by alpha, + mulps(xmm7, xmm0) // scale by beta, + addps(xmm9, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) - movaps(mem(rdx), xmm1) // load c41 ~ c71, - mulps(xmm6, xmm13) // scale by alpha, - mulps(xmm7, xmm1) // scale by beta, - addps(xmm13, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) + movaps(mem(rdx), xmm1) // load c41 ~ c71, + mulps(xmm6, xmm13) // scale by alpha, + mulps(xmm7, xmm1) // scale by beta, + addps(xmm13, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) - movaps(mem(rcx), xmm0) // load c02 ~ c32, - mulps(xmm6, xmm10) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm10, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) + movaps(mem(rcx), xmm0) // load c02 ~ c32, + mulps(xmm6, xmm10) // scale by alpha, + mulps(xmm7, xmm0) // scale by beta, + addps(xmm10, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) - movaps(mem(rdx), xmm1) // load c42 ~ c72, - mulps(xmm6, xmm14) // scale by alpha, - mulps(xmm7, xmm1) // scale by beta, - addps(xmm14, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) + movaps(mem(rdx), xmm1) // load c42 ~ c72, + mulps(xmm6, xmm14) // scale by alpha, + mulps(xmm7, xmm1) // scale by beta, + addps(xmm14, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) - movaps(mem(rcx), xmm0) // load c03 ~ c33, - mulps(xmm6, xmm11) // scale by alpha, - mulps(xmm7, xmm0) // scale by beta, - addps(xmm11, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. + movaps(mem(rcx), xmm0) // load c03 ~ c33, + mulps(xmm6, xmm11) // scale by alpha, + mulps(xmm7, xmm0) // scale by beta, + addps(xmm11, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. - movaps(mem(rdx), xmm1) // load c43 ~ c73, - mulps(xmm6, xmm15) // scale by alpha, - mulps(xmm7, xmm1) // scale by beta, - addps(xmm15, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. + movaps(mem(rdx), xmm1) // load c43 ~ c73, + mulps(xmm6, xmm15) // scale by alpha, + mulps(xmm7, xmm1) // scale by beta, + addps(xmm15, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. - jmp(.SDONE) // jump to end. + jmp(.SDONE) // jump to end. label(.SBETAZERO) - // skip loading c00 ~ c30, - mulps(xmm6, xmm8) // scale by alpha, - movaps(xmm8, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c40 ~ c70, - mulps(xmm6, xmm12) // scale by alpha, - movaps(xmm12, mem(rdx)) // and store back to memory. - add(rdi, rdx) + // skip loading c00 ~ c30, + mulps(xmm6, xmm8) // scale by alpha, + movaps(xmm8, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c40 ~ c70, + mulps(xmm6, xmm12) // scale by alpha, + movaps(xmm12, mem(rdx)) // and store back to memory. + add(rdi, rdx) - // skip loading c01 ~ c31, - mulps(xmm6, xmm9) // scale by alpha, - movaps(xmm9, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c41 ~ c71, - mulps(xmm6, xmm13) // scale by alpha, - movaps(xmm13, mem(rdx)) // and store back to memory. - add(rdi, rdx) + // skip loading c01 ~ c31, + mulps(xmm6, xmm9) // scale by alpha, + movaps(xmm9, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c41 ~ c71, + mulps(xmm6, xmm13) // scale by alpha, + movaps(xmm13, mem(rdx)) // and store back to memory. + add(rdi, rdx) - // skip loading c02 ~ c32, - mulps(xmm6, xmm10) // scale by alpha, - movaps(xmm10, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c42 ~ c72, - mulps(xmm6, xmm14) // scale by alpha, - movaps(xmm14, mem(rdx)) // and store back to memory. - add(rdi, rdx) + // skip loading c02 ~ c32, + mulps(xmm6, xmm10) // scale by alpha, + movaps(xmm10, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c42 ~ c72, + mulps(xmm6, xmm14) // scale by alpha, + movaps(xmm14, mem(rdx)) // and store back to memory. + add(rdi, rdx) - // skip loading c03 ~ c33, - mulps(xmm6, xmm11) // scale by alpha, - movaps(xmm11, mem(rcx)) // and store back to memory. + // skip loading c03 ~ c33, + mulps(xmm6, xmm11) // scale by alpha, + movaps(xmm11, mem(rcx)) // and store back to memory. - // skip loading c43 ~ c73, - mulps(xmm6, xmm15) // scale by alpha, - movaps(xmm15, mem(rdx)) // and store back to memory. + // skip loading c43 ~ c73, + mulps(xmm6, xmm15) // scale by alpha, + movaps(xmm15, mem(rdx)) // and store back to memory. label(.SDONE) - end_asm( + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next)/*, // 9 - [a_next] "m" (a_next)*/ // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next)/*, // 9 + [a_next] "m" (a_next)*/ // 10 : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "xmm0", "xmm1", "xmm2", "xmm3", @@ -509,7 +509,7 @@ void bli_sgemm_penryn_asm_8x4 "memory" ) - GEMM_UKR_FLUSH_CT( s ); + GEMM_UKR_FLUSH_CT( s ); } void bli_dgemm_penryn_asm_4x4 @@ -536,7 +536,7 @@ void bli_dgemm_penryn_asm_4x4 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT_ALIGNED( d, 4, 4, false, 16 ); + GEMM_UKR_SETUP_CT_ALIGNED( d, 4, 4, false, 16 ); begin_asm() @@ -853,126 +853,126 @@ void bli_dgemm_penryn_asm_4x4 ucomisd(xmm0, xmm7) // check if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - movaps(mem(rcx), xmm0) // load c00 and c10, - mulpd(xmm6, xmm8) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm8, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) + movaps(mem(rcx), xmm0) // load c00 and c10, + mulpd(xmm6, xmm8) // scale by alpha, + mulpd(xmm7, xmm0) // scale by beta, + addpd(xmm8, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) - movaps(mem(rdx), xmm1) // load c20 and c30, - mulpd(xmm6, xmm12) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm12, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) + movaps(mem(rdx), xmm1) // load c20 and c30, + mulpd(xmm6, xmm12) // scale by alpha, + mulpd(xmm7, xmm1) // scale by beta, + addpd(xmm12, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) - movaps(mem(rcx), xmm0) // load c01 and c11, - mulpd(xmm6, xmm9) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm9, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) + movaps(mem(rcx), xmm0) // load c01 and c11, + mulpd(xmm6, xmm9) // scale by alpha, + mulpd(xmm7, xmm0) // scale by beta, + addpd(xmm9, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) - movaps(mem(rdx), xmm1) // load c21 and c31, - mulpd(xmm6, xmm13) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm13, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) + movaps(mem(rdx), xmm1) // load c21 and c31, + mulpd(xmm6, xmm13) // scale by alpha, + mulpd(xmm7, xmm1) // scale by beta, + addpd(xmm13, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) - movaps(mem(rcx), xmm0) // load c02 and c12, - mulpd(xmm6, xmm10) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm10, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. - add(rdi, rcx) + movaps(mem(rcx), xmm0) // load c02 and c12, + mulpd(xmm6, xmm10) // scale by alpha, + mulpd(xmm7, xmm0) // scale by beta, + addpd(xmm10, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. + add(rdi, rcx) - movaps(mem(rdx), xmm1) // load c22 and c32, - mulpd(xmm6, xmm14) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm14, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. - add(rdi, rdx) + movaps(mem(rdx), xmm1) // load c22 and c32, + mulpd(xmm6, xmm14) // scale by alpha, + mulpd(xmm7, xmm1) // scale by beta, + addpd(xmm14, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. + add(rdi, rdx) - movaps(mem(rcx), xmm0) // load c03 and c13, - mulpd(xmm6, xmm11) // scale by alpha, - mulpd(xmm7, xmm0) // scale by beta, - addpd(xmm11, xmm0) // add the gemm result, - movaps(xmm0, mem(rcx)) // and store back to memory. + movaps(mem(rcx), xmm0) // load c03 and c13, + mulpd(xmm6, xmm11) // scale by alpha, + mulpd(xmm7, xmm0) // scale by beta, + addpd(xmm11, xmm0) // add the gemm result, + movaps(xmm0, mem(rcx)) // and store back to memory. - movaps(mem(rdx), xmm1) // load c23 and c33, - mulpd(xmm6, xmm15) // scale by alpha, - mulpd(xmm7, xmm1) // scale by beta, - addpd(xmm15, xmm1) // add the gemm result, - movaps(xmm1, mem(rdx)) // and store back to memory. + movaps(mem(rdx), xmm1) // load c23 and c33, + mulpd(xmm6, xmm15) // scale by alpha, + mulpd(xmm7, xmm1) // scale by beta, + addpd(xmm15, xmm1) // add the gemm result, + movaps(xmm1, mem(rdx)) // and store back to memory. - jmp(.DDONE) // jump to end. + jmp(.DDONE) // jump to end. label(.DBETAZERO) - // skip loading c00 and c10, - mulpd(xmm6, xmm8) // scale by alpha, - movaps(xmm8, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c20 and c30, - mulpd(xmm6, xmm12) // scale by alpha, - movaps(xmm12, mem(rdx)) // and store back to memory. - add(rdi, rdx) + // skip loading c00 and c10, + mulpd(xmm6, xmm8) // scale by alpha, + movaps(xmm8, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c20 and c30, + mulpd(xmm6, xmm12) // scale by alpha, + movaps(xmm12, mem(rdx)) // and store back to memory. + add(rdi, rdx) - // skip loading c01 and c11, - mulpd(xmm6, xmm9) // scale by alpha, - movaps(xmm9, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c21 and c31, - mulpd(xmm6, xmm13) // scale by alpha, - movaps(xmm13, mem(rdx)) // and store back to memory. - add(rdi, rdx) + // skip loading c01 and c11, + mulpd(xmm6, xmm9) // scale by alpha, + movaps(xmm9, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c21 and c31, + mulpd(xmm6, xmm13) // scale by alpha, + movaps(xmm13, mem(rdx)) // and store back to memory. + add(rdi, rdx) - // skip loading c02 and c12, - mulpd(xmm6, xmm10) // scale by alpha, - movaps(xmm10, mem(rcx)) // and store back to memory. - add(rdi, rcx) - // skip loading c22 and c32, - mulpd(xmm6, xmm14) // scale by alpha, - movaps(xmm14, mem(rdx)) // and store back to memory. - add(rdi, rdx) + // skip loading c02 and c12, + mulpd(xmm6, xmm10) // scale by alpha, + movaps(xmm10, mem(rcx)) // and store back to memory. + add(rdi, rcx) + // skip loading c22 and c32, + mulpd(xmm6, xmm14) // scale by alpha, + movaps(xmm14, mem(rdx)) // and store back to memory. + add(rdi, rdx) - // skip loading c03 and c13, - mulpd(xmm6, xmm11) // scale by alpha, - movaps(xmm11, mem(rcx)) // and store back to memory. + // skip loading c03 and c13, + mulpd(xmm6, xmm11) // scale by alpha, + movaps(xmm11, mem(rcx)) // and store back to memory. - // skip loading c23 and c33, - mulpd(xmm6, xmm15) // scale by alpha, - movaps(xmm15, mem(rdx)) // and store back to memory. + // skip loading c23 and c33, + mulpd(xmm6, xmm15) // scale by alpha, + movaps(xmm15, mem(rdx)) // and store back to memory. label(.DDONE) - end_asm( + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next) // 10 : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "xmm0", "xmm1", "xmm2", "xmm3", @@ -982,7 +982,7 @@ void bli_dgemm_penryn_asm_4x4 "memory" ) - GEMM_UKR_FLUSH_CT( d ); + GEMM_UKR_FLUSH_CT( d ); } diff --git a/kernels/piledriver/3/bli_gemm_piledriver_asm_d8x3.c b/kernels/piledriver/3/bli_gemm_piledriver_asm_d8x3.c index 709c2b40b2..e65ce7178a 100644 --- a/kernels/piledriver/3/bli_gemm_piledriver_asm_d8x3.c +++ b/kernels/piledriver/3/bli_gemm_piledriver_asm_d8x3.c @@ -64,7 +64,7 @@ void bli_sgemm_piledriver_asm_16x3 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT( s, 16, 3, false ); + GEMM_UKR_SETUP_CT( s, 16, 3, false ); begin_asm() @@ -457,57 +457,57 @@ void bli_sgemm_piledriver_asm_16x3 vucomiss(xmm0, xmm2) // set ZF if beta == 0. je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case - vfmadd231ps(mem(rcx, 0*16), xmm2, xmm4) - vfmadd231ps(mem(rcx, 1*16), xmm2, xmm7) - vfmadd231ps(mem(rcx, 2*16), xmm2, xmm10) - vfmadd231ps(mem(rcx, 3*16), xmm2, xmm13) + vfmadd231ps(mem(rcx, 0*16), xmm2, xmm4) + vfmadd231ps(mem(rcx, 1*16), xmm2, xmm7) + vfmadd231ps(mem(rcx, 2*16), xmm2, xmm10) + vfmadd231ps(mem(rcx, 3*16), xmm2, xmm13) - vfmadd231ps(mem(r10, 0*16), xmm2, xmm5) - vfmadd231ps(mem(r10, 1*16), xmm2, xmm8) - vfmadd231ps(mem(r10, 2*16), xmm2, xmm11) - vfmadd231ps(mem(r10, 3*16), xmm2, xmm14) + vfmadd231ps(mem(r10, 0*16), xmm2, xmm5) + vfmadd231ps(mem(r10, 1*16), xmm2, xmm8) + vfmadd231ps(mem(r10, 2*16), xmm2, xmm11) + vfmadd231ps(mem(r10, 3*16), xmm2, xmm14) - vfmadd231ps(mem(r11, 0*16), xmm2, xmm6) - vfmadd231ps(mem(r11, 1*16), xmm2, xmm9) - vfmadd231ps(mem(r11, 2*16), xmm2, xmm12) - vfmadd231ps(mem(r11, 3*16), xmm2, xmm15) + vfmadd231ps(mem(r11, 0*16), xmm2, xmm6) + vfmadd231ps(mem(r11, 1*16), xmm2, xmm9) + vfmadd231ps(mem(r11, 2*16), xmm2, xmm12) + vfmadd231ps(mem(r11, 3*16), xmm2, xmm15) - // fall through + // fall through label(.SBETAZERO) - vmovups(xmm4, mem(rcx, 0*16)) - vmovups(xmm7, mem(rcx, 1*16)) - vmovups(xmm10, mem(rcx, 2*16)) - vmovups(xmm13, mem(rcx, 3*16)) + vmovups(xmm4, mem(rcx, 0*16)) + vmovups(xmm7, mem(rcx, 1*16)) + vmovups(xmm10, mem(rcx, 2*16)) + vmovups(xmm13, mem(rcx, 3*16)) - vmovups(xmm5, mem(r10, 0*16)) - vmovups(xmm8, mem(r10, 1*16)) - vmovups(xmm11, mem(r10, 2*16)) - vmovups(xmm14, mem(r10, 3*16)) + vmovups(xmm5, mem(r10, 0*16)) + vmovups(xmm8, mem(r10, 1*16)) + vmovups(xmm11, mem(r10, 2*16)) + vmovups(xmm14, mem(r10, 3*16)) - vmovups(xmm6, mem(r11, 0*16)) - vmovups(xmm9, mem(r11, 1*16)) - vmovups(xmm12, mem(r11, 2*16)) - vmovups(xmm15, mem(r11, 3*16)) + vmovups(xmm6, mem(r11, 0*16)) + vmovups(xmm9, mem(r11, 1*16)) + vmovups(xmm12, mem(r11, 2*16)) + vmovups(xmm15, mem(r11, 3*16)) label(.SDONE) - end_asm( + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next) // 10 : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -518,7 +518,7 @@ void bli_sgemm_piledriver_asm_16x3 "memory" ) - GEMM_UKR_FLUSH_CT( s ); + GEMM_UKR_FLUSH_CT( s ); } void bli_dgemm_piledriver_asm_8x3 @@ -545,7 +545,7 @@ void bli_dgemm_piledriver_asm_8x3 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT( d, 8, 3, false ); + GEMM_UKR_SETUP_CT( d, 8, 3, false ); begin_asm() @@ -897,73 +897,73 @@ void bli_dgemm_piledriver_asm_8x3 vucomisd(xmm0, xmm2) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - // xmm4: xmm5: xmm6: - // ( ab00 ( ab01 ( ab02 - // ab10 ) ab11 ) ab12 ) - // - // xmm7: xmm8: xmm9: - // ( ab20 ( ab21 ( ab22 - // ab30 ) ab31 ) ab32 ) - // - // xmm10: xmm11: xmm12: - // ( ab40 ( ab41 ( ab42 - // ab50 ) ab51 ) ab52 ) - // - // xmm13: xmm14: xmm15: - // ( ab60 ( ab61 ( ab62 - // ab70 ) ab71 ) ab72 ) - - vfmadd231pd(mem(rcx, 0*16), xmm2, xmm4) - vfmadd231pd(mem(rcx, 1*16), xmm2, xmm7) - vfmadd231pd(mem(rcx, 2*16), xmm2, xmm10) - vfmadd231pd(mem(rcx, 3*16), xmm2, xmm13) - - vfmadd231pd(mem(r10, 0*16), xmm2, xmm5) - vfmadd231pd(mem(r10, 1*16), xmm2, xmm8) - vfmadd231pd(mem(r10, 2*16), xmm2, xmm11) - vfmadd231pd(mem(r10, 3*16), xmm2, xmm14) - - vfmadd231pd(mem(r11, 0*16), xmm2, xmm6) - vfmadd231pd(mem(r11, 1*16), xmm2, xmm9) - vfmadd231pd(mem(r11, 2*16), xmm2, xmm12) - vfmadd231pd(mem(r11, 3*16), xmm2, xmm15) - - // fall through + // xmm4: xmm5: xmm6: + // ( ab00 ( ab01 ( ab02 + // ab10 ) ab11 ) ab12 ) + // + // xmm7: xmm8: xmm9: + // ( ab20 ( ab21 ( ab22 + // ab30 ) ab31 ) ab32 ) + // + // xmm10: xmm11: xmm12: + // ( ab40 ( ab41 ( ab42 + // ab50 ) ab51 ) ab52 ) + // + // xmm13: xmm14: xmm15: + // ( ab60 ( ab61 ( ab62 + // ab70 ) ab71 ) ab72 ) + + vfmadd231pd(mem(rcx, 0*16), xmm2, xmm4) + vfmadd231pd(mem(rcx, 1*16), xmm2, xmm7) + vfmadd231pd(mem(rcx, 2*16), xmm2, xmm10) + vfmadd231pd(mem(rcx, 3*16), xmm2, xmm13) + + vfmadd231pd(mem(r10, 0*16), xmm2, xmm5) + vfmadd231pd(mem(r10, 1*16), xmm2, xmm8) + vfmadd231pd(mem(r10, 2*16), xmm2, xmm11) + vfmadd231pd(mem(r10, 3*16), xmm2, xmm14) + + vfmadd231pd(mem(r11, 0*16), xmm2, xmm6) + vfmadd231pd(mem(r11, 1*16), xmm2, xmm9) + vfmadd231pd(mem(r11, 2*16), xmm2, xmm12) + vfmadd231pd(mem(r11, 3*16), xmm2, xmm15) + + // fall through label(.DBETAZERO) - vmovups(xmm4, mem(rcx, 0*16)) - vmovups(xmm7, mem(rcx, 1*16)) - vmovups(xmm10, mem(rcx, 2*16)) - vmovups(xmm13, mem(rcx, 3*16)) + vmovups(xmm4, mem(rcx, 0*16)) + vmovups(xmm7, mem(rcx, 1*16)) + vmovups(xmm10, mem(rcx, 2*16)) + vmovups(xmm13, mem(rcx, 3*16)) - vmovups(xmm5, mem(r10, 0*16)) - vmovups(xmm8, mem(r10, 1*16)) - vmovups(xmm11, mem(r10, 2*16)) - vmovups(xmm14, mem(r10, 3*16)) + vmovups(xmm5, mem(r10, 0*16)) + vmovups(xmm8, mem(r10, 1*16)) + vmovups(xmm11, mem(r10, 2*16)) + vmovups(xmm14, mem(r10, 3*16)) - vmovups(xmm6, mem(r11, 0*16)) - vmovups(xmm9, mem(r11, 1*16)) - vmovups(xmm12, mem(r11, 2*16)) - vmovups(xmm15, mem(r11, 3*16)) + vmovups(xmm6, mem(r11, 0*16)) + vmovups(xmm9, mem(r11, 1*16)) + vmovups(xmm12, mem(r11, 2*16)) + vmovups(xmm15, mem(r11, 3*16)) label(.DDONE) - end_asm( + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next) // 10 : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -974,7 +974,7 @@ void bli_dgemm_piledriver_asm_8x3 "memory" ) - GEMM_UKR_FLUSH_CT( d ); + GEMM_UKR_FLUSH_CT( d ); } void bli_cgemm_piledriver_asm_4x2 @@ -1001,7 +1001,7 @@ void bli_cgemm_piledriver_asm_4x2 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT( c, 4, 2, false ); + GEMM_UKR_SETUP_CT( c, 4, 2, false ); begin_asm() @@ -1317,63 +1317,63 @@ void bli_cgemm_piledriver_asm_4x2 and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.CBETAZERO) // if ZF = 0, jump to beta == 0 case - vmovups(mem(rcx), xmm0) // load c00:c10 - vmovups(mem(rcx, 16), xmm2) // load c20:c30 - vpermilps(imm(0xb1), xmm0, xmm1) - vpermilps(imm(0xb1), xmm2, xmm3) + vmovups(mem(rcx), xmm0) // load c00:c10 + vmovups(mem(rcx, 16), xmm2) // load c20:c30 + vpermilps(imm(0xb1), xmm0, xmm1) + vpermilps(imm(0xb1), xmm2, xmm3) - vmulps(xmm6, xmm0, xmm0) - vmulps(xmm7, xmm1, xmm1) - vaddsubps(xmm1, xmm0, xmm0) - vaddps(xmm8, xmm0, xmm0) + vmulps(xmm6, xmm0, xmm0) + vmulps(xmm7, xmm1, xmm1) + vaddsubps(xmm1, xmm0, xmm0) + vaddps(xmm8, xmm0, xmm0) - vmulps(xmm6, xmm2, xmm2) - vmulps(xmm7, xmm3, xmm3) - vaddsubps(xmm3, xmm2, xmm2) - vaddps(xmm12, xmm2, xmm2) + vmulps(xmm6, xmm2, xmm2) + vmulps(xmm7, xmm3, xmm3) + vaddsubps(xmm3, xmm2, xmm2) + vaddps(xmm12, xmm2, xmm2) - vmovups(mem(r10), xmm0) // load c01:c11 - vmovups(mem(r10, 16), xmm2) // load c21:c31 - vpermilps(imm(0xb1), xmm0, xmm1) - vpermilps(imm(0xb1), xmm2, xmm3) + vmovups(mem(r10), xmm0) // load c01:c11 + vmovups(mem(r10, 16), xmm2) // load c21:c31 + vpermilps(imm(0xb1), xmm0, xmm1) + vpermilps(imm(0xb1), xmm2, xmm3) - vmulps(xmm6, xmm0, xmm0) - vmulps(xmm7, xmm1, xmm1) - vaddsubps(xmm1, xmm0, xmm0) - vaddps(xmm10, xmm0, xmm0) + vmulps(xmm6, xmm0, xmm0) + vmulps(xmm7, xmm1, xmm1) + vaddsubps(xmm1, xmm0, xmm0) + vaddps(xmm10, xmm0, xmm0) - vmulps(xmm6, xmm2, xmm2) - vmulps(xmm7, xmm3, xmm3) - vaddsubps(xmm3, xmm2, xmm2) - vaddps(xmm14, xmm2, xmm2) + vmulps(xmm6, xmm2, xmm2) + vmulps(xmm7, xmm3, xmm3) + vaddsubps(xmm3, xmm2, xmm2) + vaddps(xmm14, xmm2, xmm2) - // fall through + // fall through label(.CBETAZERO) - vmovups(xmm8, mem(rcx)) // store c00:c10 - vmovups(xmm12, mem(rcx, 16)) // store c20:c30 + vmovups(xmm8, mem(rcx)) // store c00:c10 + vmovups(xmm12, mem(rcx, 16)) // store c20:c30 - vmovups(xmm10, mem(r10)) // store c01:c11 - vmovups(xmm14, mem(r10, 16)) // store c21:c31 + vmovups(xmm10, mem(r10)) // store c01:c11 + vmovups(xmm14, mem(r10, 16)) // store c21:c31 label(.CDONE) - end_asm( + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next) // 10 : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -1384,7 +1384,7 @@ void bli_cgemm_piledriver_asm_4x2 "memory" ) - GEMM_UKR_FLUSH_CT( c ); + GEMM_UKR_FLUSH_CT( c ); } void bli_zgemm_piledriver_asm_2x2 @@ -1411,7 +1411,7 @@ void bli_zgemm_piledriver_asm_2x2 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - GEMM_UKR_SETUP_CT( z, 2, 2, false ); + GEMM_UKR_SETUP_CT( z, 2, 2, false ); begin_asm() @@ -1729,63 +1729,63 @@ void bli_zgemm_piledriver_asm_2x2 and(r8b, r9b) // set ZF if r8b & r9b == 1. jne(.ZBETAZERO) // if ZF = 0, jump to beta == 0 case - vmovups(mem(rcx), xmm0) // load c00 - vmovups(mem(rcx, 16), xmm2) // load c10 - vpermilpd(imm(0x1), xmm0, xmm1) - vpermilpd(imm(0x1), xmm2, xmm3) + vmovups(mem(rcx), xmm0) // load c00 + vmovups(mem(rcx, 16), xmm2) // load c10 + vpermilpd(imm(0x1), xmm0, xmm1) + vpermilpd(imm(0x1), xmm2, xmm3) - vmulpd(xmm6, xmm0, xmm0) - vmulpd(xmm7, xmm1, xmm1) - vaddsubpd(xmm1, xmm0, xmm0) - vaddpd(xmm8, xmm0, xmm0) + vmulpd(xmm6, xmm0, xmm0) + vmulpd(xmm7, xmm1, xmm1) + vaddsubpd(xmm1, xmm0, xmm0) + vaddpd(xmm8, xmm0, xmm0) - vmulpd(xmm6, xmm2, xmm2) - vmulpd(xmm7, xmm3, xmm3) - vaddsubpd(xmm3, xmm2, xmm2) - vaddpd(xmm12, xmm2, xmm2) + vmulpd(xmm6, xmm2, xmm2) + vmulpd(xmm7, xmm3, xmm3) + vaddsubpd(xmm3, xmm2, xmm2) + vaddpd(xmm12, xmm2, xmm2) - vmovups(mem(r10), xmm0) // load c01 - vmovups(mem(r10, 16), xmm2) // load c11 - vpermilpd(imm(0x1), xmm0, xmm1) - vpermilpd(imm(0x1), xmm2, xmm3) + vmovups(mem(r10), xmm0) // load c01 + vmovups(mem(r10, 16), xmm2) // load c11 + vpermilpd(imm(0x1), xmm0, xmm1) + vpermilpd(imm(0x1), xmm2, xmm3) - vmulpd(xmm6, xmm0, xmm0) - vmulpd(xmm7, xmm1, xmm1) - vaddsubpd(xmm1, xmm0, xmm0) - vaddpd(xmm10, xmm0, xmm0) + vmulpd(xmm6, xmm0, xmm0) + vmulpd(xmm7, xmm1, xmm1) + vaddsubpd(xmm1, xmm0, xmm0) + vaddpd(xmm10, xmm0, xmm0) - vmulpd(xmm6, xmm2, xmm2) - vmulpd(xmm7, xmm3, xmm3) - vaddsubpd(xmm3, xmm2, xmm2) - vaddpd(xmm14, xmm2, xmm2) + vmulpd(xmm6, xmm2, xmm2) + vmulpd(xmm7, xmm3, xmm3) + vaddsubpd(xmm3, xmm2, xmm2) + vaddpd(xmm14, xmm2, xmm2) - // fall through + // fall through label(.ZBETAZERO) - vmovups(xmm8, mem(rcx)) // store c00 - vmovups(xmm12, mem(rcx, 16)) // store c10 + vmovups(xmm8, mem(rcx)) // store c00 + vmovups(xmm12, mem(rcx, 16)) // store c10 - vmovups(xmm10, mem(r10)) // store c01 - vmovups(xmm14, mem(r10, 16)) // store c11 + vmovups(xmm10, mem(r10)) // store c01 + vmovups(xmm14, mem(r10, 16)) // store c11 label(.ZDONE) - end_asm( + end_asm( : // output operands (none) : // input operands - [k_iter] "m" (k_iter), // 0 - [k_left] "m" (k_left), // 1 - [a] "m" (a), // 2 - [b] "m" (b), // 3 - [alpha] "m" (alpha), // 4 - [beta] "m" (beta), // 5 - [c] "m" (c), // 6 - [rs_c] "m" (rs_c), // 7 - [cs_c] "m" (cs_c), // 8 - [b_next] "m" (b_next), // 9 - [a_next] "m" (a_next) // 10 + [k_iter] "m" (k_iter), // 0 + [k_left] "m" (k_left), // 1 + [a] "m" (a), // 2 + [b] "m" (b), // 3 + [alpha] "m" (alpha), // 4 + [beta] "m" (beta), // 5 + [c] "m" (c), // 6 + [rs_c] "m" (rs_c), // 7 + [cs_c] "m" (cs_c), // 8 + [b_next] "m" (b_next), // 9 + [a_next] "m" (a_next) // 10 : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", @@ -1796,7 +1796,7 @@ void bli_zgemm_piledriver_asm_2x2 "memory" ) - GEMM_UKR_FLUSH_CT( z ); + GEMM_UKR_FLUSH_CT( z ); } diff --git a/kernels/power10/3/bli_i4gemm_power10_mma.c b/kernels/power10/3/bli_i4gemm_power10_mma.c index 6cfedafd08..7527f271ff 100644 --- a/kernels/power10/3/bli_i4gemm_power10_mma.c +++ b/kernels/power10/3/bli_i4gemm_power10_mma.c @@ -69,7 +69,7 @@ void bli_i4gemm_power10_mma_8x16 { uint64_t k_iter = (k-1) / 4; - uint64_t k_left = (k-1) % 4; + uint64_t k_left = (k-1) % 4; uint64_t rs_c = rs_c0; @@ -102,19 +102,19 @@ void bli_i4gemm_power10_mma_8x16 I4_INCREMENT // k loop (unrolled by 4) - for (int k = 0; kd; \ - inc_t incd = params_cast->incd; \ - ctype kappa_cast = *( ctype* )kappa; \ + packm_diag_params_t* params_cast = params; \ + ctype* restrict a_cast = a; \ + ctype* restrict p_cast = p; \ + ctype* restrict d_cast = params_cast->d; \ + inc_t incd = params_cast->incd; \ + ctype kappa_cast = *( ctype* )kappa; \ \ - if ( schema != BLIS_PACKED_ROW_PANELS && \ - schema != BLIS_PACKED_COL_PANELS ) \ - bli_abort(); \ + if ( schema != BLIS_PACKED_ROW_PANELS && \ + schema != BLIS_PACKED_COL_PANELS ) \ + bli_abort(); \ \ - /* Apply the offset */ \ - d_cast += panel_len_off * incd; \ + /* Apply the offset */ \ + d_cast += panel_len_off * incd; \ \ - if ( conja ) \ - { \ - for ( dim_t j = 0; j < panel_len; j++ ) \ - { \ - ctype kappa_d; \ - PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ + if ( conja ) \ + { \ + for ( dim_t j = 0; j < panel_len; j++ ) \ + { \ + ctype kappa_d; \ + PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ \ - for (dim_t i = 0;i < panel_dim;i++) \ - PASTEMAC(ch,scal2js)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ + for (dim_t i = 0;i < panel_dim;i++) \ + PASTEMAC(ch,scal2js)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ \ - for (dim_t i = panel_dim;i < panel_dim_max;i++) \ - PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ - } \ - } \ - else \ - { \ - for ( dim_t j = 0; j < panel_len; j++ ) \ - { \ - ctype kappa_d; \ - PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ + for (dim_t i = panel_dim;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + } \ + } \ + else \ + { \ + for ( dim_t j = 0; j < panel_len; j++ ) \ + { \ + ctype kappa_d; \ + PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ \ - for (dim_t i = 0;i < panel_dim;i++) \ - PASTEMAC(ch,scal2s)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ + for (dim_t i = 0;i < panel_dim;i++) \ + PASTEMAC(ch,scal2s)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ \ - for (dim_t i = panel_dim;i < panel_dim_max;i++) \ - PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ - } \ - } \ + for (dim_t i = panel_dim;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + } \ + } \ \ - for (dim_t j = panel_len;j < panel_len_max;j++) \ - for (dim_t i = 0;i < panel_dim_max;i++) \ - PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + for (dim_t j = panel_len;j < panel_len_max;j++) \ + for (dim_t i = 0;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ } INSERT_GENTFUNC_BASIC0(packm_diag_ukr); @@ -100,17 +100,17 @@ static packm_ker_vft GENARRAY( packm_diag_ukrs, packm_diag_ukr ); */ void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a ) { - memset( params, 0, sizeof(*params) ); + memset( params, 0, sizeof(*params) ); - // Assumes D is a column vector - params->d = bli_obj_buffer_at_off( d ); - params->incd = bli_obj_row_stride( d ); + // Assumes D is a column vector + params->d = bli_obj_buffer_at_off( d ); + params->incd = bli_obj_row_stride( d ); - for ( int i = BLIS_DT_LO; i <= BLIS_DT_HI; i++ ) - params->super.ukr_fn[i][i] = packm_diag_ukrs[i]; + for ( int i = BLIS_DT_LO; i <= BLIS_DT_HI; i++ ) + params->super.ukr_fn[i][i] = packm_diag_ukrs[i]; - // Attach the parameters to the A object. - bli_obj_set_pack_params( params, a ); + // Attach the parameters to the A object. + bli_obj_set_pack_params( params, a ); } /* @@ -120,67 +120,67 @@ void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a ) */ void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ) { - obj_t ad; // this is (D * A^T) - packm_diag_params_t params; + obj_t ad; // this is (D * A^T) + packm_diag_params_t params; - bli_obj_alias_to( a, &ad ); - bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T - attach_diagonal_factor( ¶ms, d, &ad ); + bli_obj_alias_to( a, &ad ); + bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T + attach_diagonal_factor( ¶ms, d, &ad ); - // Does C := alpha * A * B + beta * C using B = (D + A^T) - bli_gemmtnat( alpha, a, &ad, beta, c, NULL, NULL ); + // Does C := alpha * A * B + beta * C using B = (D + A^T) + bli_gemmtnat( alpha, a, &ad, beta, c, NULL, NULL ); } -int main() +int main( void ) { - obj_t a; - obj_t d; - obj_t c; - obj_t c_copy; - obj_t norm; - - dim_t m = 10; - dim_t k = 10; - - for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) - for ( int upper = 0; upper <= 1; upper++ ) - for ( int transa = 0; transa <= 1; transa++ ) - for ( int transc = 0; transc <= 1; transc++ ) - { - num_t dt = dt_; - uplo_t uplo = upper ? BLIS_UPPER : BLIS_LOWER; - - bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a ); - bli_obj_create( dt, k, 1, 1, 1, &d ); - bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c ); - bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy ); - bli_obj_set_struc( BLIS_SYMMETRIC , &c ); - bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy ); - bli_obj_set_uplo( uplo , &c ); - bli_obj_set_uplo( uplo , &c_copy ); - bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); - - bli_randm( &a ); - bli_randm( &d ); - bli_randm( &c ); - bli_copym( &c, &c_copy ); - - syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c ); - syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy ); - - bli_subm( &c_copy, &c ); - bli_normfm( &c, &norm ); - - double normr, normi; - bli_getsc( &norm, &normr, &normi ); - - printf("dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n", - dt, upper, transa, transc, normr); - - bli_obj_free( &a ); - bli_obj_free( &d ); - bli_obj_free( &c ); - bli_obj_free( &c_copy ); - bli_obj_free( &norm ); - } + obj_t a; + obj_t d; + obj_t c; + obj_t c_copy; + obj_t norm; + + dim_t m = 10; + dim_t k = 10; + + for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) + for ( int upper = 0; upper <= 1; upper++ ) + for ( int transa = 0; transa <= 1; transa++ ) + for ( int transc = 0; transc <= 1; transc++ ) + { + num_t dt = dt_; + uplo_t uplo = upper ? BLIS_UPPER : BLIS_LOWER; + + bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a ); + bli_obj_create( dt, k, 1, 1, 1, &d ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy ); + bli_obj_set_uplo( uplo , &c ); + bli_obj_set_uplo( uplo , &c_copy ); + bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); + + bli_randm( &a ); + bli_randm( &d ); + bli_randm( &c ); + bli_copym( &c, &c_copy ); + + syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c ); + syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy ); + + bli_subm( &c_copy, &c ); + bli_normfm( &c, &norm ); + + double normr, normi; + bli_getsc( &norm, &normr, &normi ); + + printf( "dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n", + dt, upper, transa, transc, normr ); + + bli_obj_free( &a ); + bli_obj_free( &d ); + bli_obj_free( &c ); + bli_obj_free( &c_copy ); + bli_obj_free( &norm ); + } } diff --git a/test/syrk_diagonal/syrk_diagonal_example2.c b/test/syrk_diagonal/syrk_diagonal_example2.c index 4937f45138..92371f48b0 100644 --- a/test/syrk_diagonal/syrk_diagonal_example2.c +++ b/test/syrk_diagonal/syrk_diagonal_example2.c @@ -44,43 +44,43 @@ void PASTEMAC(ch,op) \ void* restrict p, inc_t ldp \ ) \ { \ - ctype* restrict a_cast = a; \ - ctype* restrict p_cast = p; \ - ctype* restrict d_cast = d; \ - ctype kappa_cast = *( ctype* )kappa; \ + ctype* restrict a_cast = a; \ + ctype* restrict p_cast = p; \ + ctype* restrict d_cast = d; \ + ctype kappa_cast = *( ctype* )kappa; \ \ - if ( conja ) \ - { \ - for ( dim_t j = 0; j < panel_len; j++ ) \ - { \ - ctype kappa_d; \ - PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ + if ( conja ) \ + { \ + for ( dim_t j = 0; j < panel_len; j++ ) \ + { \ + ctype kappa_d; \ + PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ \ - for (dim_t i = 0;i < panel_dim;i++) \ - PASTEMAC(ch,scal2js)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ + for (dim_t i = 0;i < panel_dim;i++) \ + PASTEMAC(ch,scal2js)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ \ - for (dim_t i = panel_dim;i < panel_dim_max;i++) \ - PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ - } \ - } \ - else \ - { \ - for ( dim_t j = 0; j < panel_len; j++ ) \ - { \ - ctype kappa_d; \ - PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ + for (dim_t i = panel_dim;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + } \ + } \ + else \ + { \ + for ( dim_t j = 0; j < panel_len; j++ ) \ + { \ + ctype kappa_d; \ + PASTEMAC(ch,scal2s)( kappa_cast, d_cast[ j*incd ], kappa_d ); \ \ - for (dim_t i = 0;i < panel_dim;i++) \ - PASTEMAC(ch,scal2s)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ + for (dim_t i = 0;i < panel_dim;i++) \ + PASTEMAC(ch,scal2s)( kappa_d, a_cast[ i*inca + j*lda ], p_cast[ i + j*ldp ] ); \ \ - for (dim_t i = panel_dim;i < panel_dim_max;i++) \ - PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ - } \ - } \ + for (dim_t i = panel_dim;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + } \ + } \ \ - for (dim_t j = panel_len;j < panel_len_max;j++) \ - for (dim_t i = 0;i < panel_dim_max;i++) \ - PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ + for (dim_t j = panel_len;j < panel_len_max;j++) \ + for (dim_t i = 0;i < panel_dim_max;i++) \ + PASTEMAC(ch,set0s)( p_cast[ i + j*ldp ] ); \ } INSERT_GENTFUNC_BASIC0(packm_diag_ukr); @@ -109,7 +109,7 @@ void packm_diag dim_t dt_size = bli_dt_size( dt ); if ( dt_scalar != dt || dt_tar != dt ) - bli_abort(); + bli_abort(); // Extract various fields from the control tree. bszid_t bmult_id_m = bli_cntl_packm_params_bmid_m( cntl ); @@ -119,9 +119,9 @@ void packm_diag dim_t bmult_m_pack = bli_cntx_get_blksz_max_dt( dt_tar, bmult_id_m, cntx ); dim_t bmult_n_def = bli_cntx_get_blksz_def_dt( dt_tar, bmult_id_n, cntx ); - if ( schema != BLIS_PACKED_ROW_PANELS && - schema != BLIS_PACKED_COL_PANELS ) - bli_abort(); + if ( schema != BLIS_PACKED_ROW_PANELS && + schema != BLIS_PACKED_COL_PANELS ) + bli_abort(); // Store the pack schema to the object. bli_obj_set_pack_schema( schema, p ); @@ -215,9 +215,9 @@ void packm_diag dim_t panel_len_off = bli_obj_col_off( a ); conj_t conja = bli_obj_conj_status( a ); - packm_diag_params_t* params = bli_obj_pack_params( a ); - char* d_cast = params->d; - inc_t incd = params->incd; + packm_diag_params_t* params = bli_obj_pack_params( a ); + char* d_cast = params->d; + inc_t incd = params->incd; obj_t kappa_local; char* kappa_cast = bli_packm_scalar( &kappa_local, p ); @@ -241,22 +241,25 @@ void packm_diag { dim_t panel_dim_i = bli_min( bmult_m_def, m_p - it*bmult_m_def ); - char* d_begin = d_cast + panel_len_off*incd*dt_size; + char* d_begin = d_cast + panel_len_off*incd*dt_size; char* a_begin = a_cast + it* bmult_m_def*inca*dt_size; - char* p_begin = p_cast + it* ps_p*dt_size; + char* p_begin = p_cast + it* ps_p*dt_size; if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) { - packm_ker_cast( conja, - panel_dim_i, - n_p, - bmult_m_def, - n_p_pad, - kappa_cast, - d_begin, incd, - a_begin, inca, lda, - p_begin, bmult_m_pack ); - } + packm_ker_cast + ( + conja, + panel_dim_i, + n_p, + bmult_m_def, + n_p_pad, + kappa_cast, + d_begin, incd, + a_begin, inca, lda, + p_begin, bmult_m_pack + ); + } } } @@ -267,15 +270,15 @@ void packm_diag */ void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a ) { - // Assumes D is a column vector - params->d = bli_obj_buffer_at_off( d ); - params->incd = bli_obj_row_stride( d ); + // Assumes D is a column vector + params->d = bli_obj_buffer_at_off( d ); + params->incd = bli_obj_row_stride( d ); - // Set the custom pack function. - bli_obj_set_pack_fn( packm_diag, a ); + // Set the custom pack function. + bli_obj_set_pack_fn( packm_diag, a ); - // Attach the parameters to the A object. - bli_obj_set_pack_params( params, a ); + // Attach the parameters to the A object. + bli_obj_set_pack_params( params, a ); } /* @@ -285,67 +288,67 @@ void attach_diagonal_factor( packm_diag_params_t* params, obj_t* d, obj_t* a ) */ void syrk_diag( obj_t* alpha, obj_t* a, obj_t* d, obj_t* beta, obj_t* c ) { - obj_t ad; // this is (D * A^T) - packm_diag_params_t params; + obj_t ad; // this is (D * A^T) + packm_diag_params_t params; - bli_obj_alias_to( a, &ad ); - bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T - attach_diagonal_factor( ¶ms, d, &ad ); + bli_obj_alias_to( a, &ad ); + bli_obj_toggle_trans( &ad ); // because gemmt is A*B instead of A*B^T + attach_diagonal_factor( ¶ms, d, &ad ); - // Does C := alpha * A * B + beta * C using B = (D + A^T) - bli_gemmt( alpha, a, &ad, beta, c ); + // Does C := alpha * A * B + beta * C using B = (D + A^T) + bli_gemmt( alpha, a, &ad, beta, c ); } -int main() +int main( void ) { - obj_t a; - obj_t d; - obj_t c; - obj_t c_copy; - obj_t norm; - - dim_t m = 10; - dim_t k = 10; - - for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) - for ( int upper = 0; upper <= 1; upper++ ) - for ( int transa = 0; transa <= 1; transa++ ) - for ( int transc = 0; transc <= 1; transc++ ) - { - num_t dt = dt_; - uplo_t uplo = upper ? BLIS_UPPER : BLIS_LOWER; - - bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a ); - bli_obj_create( dt, k, 1, 1, 1, &d ); - bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c ); - bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy ); - bli_obj_set_struc( BLIS_SYMMETRIC , &c ); - bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy ); - bli_obj_set_uplo( uplo , &c ); - bli_obj_set_uplo( uplo , &c_copy ); - bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); - - bli_randm( &a ); - bli_randm( &d ); - bli_randm( &c ); - bli_copym( &c, &c_copy ); - - syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c ); - syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy ); - - bli_subm( &c_copy, &c ); - bli_normfm( &c, &norm ); - - double normr, normi; - bli_getsc( &norm, &normr, &normi ); - - printf("dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n", - dt, upper, transa, transc, normr); - - bli_obj_free( &a ); - bli_obj_free( &d ); - bli_obj_free( &c ); - bli_obj_free( &c_copy ); - bli_obj_free( &norm ); - } + obj_t a; + obj_t d; + obj_t c; + obj_t c_copy; + obj_t norm; + + dim_t m = 10; + dim_t k = 10; + + for ( int dt_ = BLIS_DT_LO; dt_ <= BLIS_DT_HI; dt_++ ) + for ( int upper = 0; upper <= 1; upper++ ) + for ( int transa = 0; transa <= 1; transa++ ) + for ( int transc = 0; transc <= 1; transc++ ) + { + num_t dt = dt_; + uplo_t uplo = upper ? BLIS_UPPER : BLIS_LOWER; + + bli_obj_create( dt, m, k, transa ? k : 1, transa ? 1 : m, &a ); + bli_obj_create( dt, k, 1, 1, 1, &d ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c ); + bli_obj_create( dt, m, m, transc ? m : 1, transc ? 1 : m, &c_copy ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c ); + bli_obj_set_struc( BLIS_SYMMETRIC , &c_copy ); + bli_obj_set_uplo( uplo , &c ); + bli_obj_set_uplo( uplo , &c_copy ); + bli_obj_create_1x1( bli_dt_proj_to_real( dt ), &norm ); + + bli_randm( &a ); + bli_randm( &d ); + bli_randm( &c ); + bli_copym( &c, &c_copy ); + + syrk_diag( &BLIS_ONE, &a, &d, &BLIS_ONE, &c ); + syrk_diag_ref( &BLIS_ONE, &a, &d, &BLIS_ONE, &c_copy ); + + bli_subm( &c_copy, &c ); + bli_normfm( &c, &norm ); + + double normr, normi; + bli_getsc( &norm, &normr, &normi ); + + printf( "dt: %d, upper: %d, transa: %d, transc: %d, norm: %g\n", + dt, upper, transa, transc, normr ); + + bli_obj_free( &a ); + bli_obj_free( &d ); + bli_obj_free( &c ); + bli_obj_free( &c_copy ); + bli_obj_free( &norm ); + } } From e8ce74adfe89525dceaa4f7810c57b3700dec973 Mon Sep 17 00:00:00 2001 From: "Field G. Van Zee" Date: Tue, 21 Dec 2021 12:28:10 -0600 Subject: [PATCH 11/12] Relocate edge-case gemm ukr macros to new file. --- frame/include/bli_edge_case_macro_defs.h | 109 +++++++++++++++++++++++ frame/include/bli_macro_defs.h | 1 + frame/include/bli_misc_macro_defs.h | 69 -------------- 3 files changed, 110 insertions(+), 69 deletions(-) create mode 100644 frame/include/bli_edge_case_macro_defs.h diff --git a/frame/include/bli_edge_case_macro_defs.h b/frame/include/bli_edge_case_macro_defs.h new file mode 100644 index 0000000000..242045a029 --- /dev/null +++ b/frame/include/bli_edge_case_macro_defs.h @@ -0,0 +1,109 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_EDGE_CASE_MACRO_DEFS_H +#define BLIS_EDGE_CASE_MACRO_DEFS_H + + +// Helper macros for edge-case handling within gemm microkernels. + +#define GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major) \ +\ + PASTEMAC(ch,ctype)* restrict _beta = beta; \ + PASTEMAC(ch,ctype)* restrict _c = c; \ + const inc_t _rs_c = rs_c; \ + const inc_t _cs_c = cs_c; \ + PASTEMAC(ch,ctype) _ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( PASTEMAC(ch,type) ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const inc_t _rs_ct = row_major ? nr : 1; \ + const inc_t _cs_ct = row_major ? 1 : mr; + +#define GEMM_UKR_SETUP_CT_POST(ch) \ +\ + PASTEMAC(ch,ctype) _zero; \ + PASTEMAC(ch,set0s)( _zero ); \ + \ + if ( _use_ct ) \ + { \ + c = _ct; \ + rs_c = _rs_ct; \ + cs_c = _cs_ct; \ + beta = &_zero; \ + } + +#define GEMM_UKR_SETUP_CT(ch,mr,nr,row_major) \ +\ + GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ + const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \ + m != mr || n != nr; \ + GEMM_UKR_SETUP_CT_POST(ch); + +#define GEMM_UKR_SETUP_CT_AMBI(ch,mr,nr,row_major) \ +\ + GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ + const bool _use_ct = ( cs_c != 1 && rs_c != 1 ) || \ + m != mr || n != nr; \ + GEMM_UKR_SETUP_CT_POST(ch); + +#define GEMM_UKR_SETUP_CT_ANY(ch,mr,nr,row_major) \ +\ + GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ + const bool _use_ct = m != mr || n != nr; \ + GEMM_UKR_SETUP_CT_POST(ch); + +#define GEMM_UKR_SETUP_CT_ALIGNED(ch,mr,nr,row_major,alignment) \ +\ + GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ + const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \ + m != mr || n != nr || \ + ( (uintptr_t)_c % alignment ) || \ + ( ( ( row_major ? _rs_c : _cs_c )*sizeof( PASTEMAC(ch,ctype) ) ) % alignment ); \ + GEMM_UKR_SETUP_CT_POST(ch); + +#define GEMM_UKR_FLUSH_CT(ch) \ +\ + if ( _use_ct ) \ + { \ + PASTEMAC(ch,xpbys_mxn) \ + ( \ + m, n, \ + _ct, _rs_ct, _cs_ct, \ + _beta, \ + _c, _rs_c, _cs_c \ + ); \ + } \ + + +#endif + diff --git a/frame/include/bli_macro_defs.h b/frame/include/bli_macro_defs.h index 03451d4407..be45a12e3f 100644 --- a/frame/include/bli_macro_defs.h +++ b/frame/include/bli_macro_defs.h @@ -98,6 +98,7 @@ #include "bli_gentprot_macro_defs.h" #include "bli_misc_macro_defs.h" +#include "bli_edge_case_macro_defs.h" #include "bli_param_macro_defs.h" #include "bli_obj_macro_defs.h" #include "bli_complex_macro_defs.h" diff --git a/frame/include/bli_misc_macro_defs.h b/frame/include/bli_misc_macro_defs.h index 7e1b93b944..120338beba 100644 --- a/frame/include/bli_misc_macro_defs.h +++ b/frame/include/bli_misc_macro_defs.h @@ -164,74 +164,5 @@ BLIS_INLINE void bli_toggle_bool( bool* b ) #define bli_iformatspec() "%6d" -// helper macros for gemm microkernels - -#define GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major) \ -\ - PASTEMAC(ch,ctype)* restrict _beta = beta; \ - PASTEMAC(ch,ctype)* restrict _c = c; \ - const inc_t _rs_c = rs_c; \ - const inc_t _cs_c = cs_c; \ - PASTEMAC(ch,ctype) _ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( PASTEMAC(ch,type) ) ] \ - __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ - const inc_t _rs_ct = row_major ? nr : 1; \ - const inc_t _cs_ct = row_major ? 1 : mr; - -#define GEMM_UKR_SETUP_CT_POST(ch) \ -\ - PASTEMAC(ch,ctype) _zero; \ - PASTEMAC(ch,set0s)( _zero ); \ - \ - if ( _use_ct ) \ - { \ - c = _ct; \ - rs_c = _rs_ct; \ - cs_c = _cs_ct; \ - beta = &_zero; \ - } - -#define GEMM_UKR_SETUP_CT(ch,mr,nr,row_major) \ -\ - GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ - const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \ - m != mr || n != nr; \ - GEMM_UKR_SETUP_CT_POST(ch); - -#define GEMM_UKR_SETUP_CT_AMBI(ch,mr,nr,row_major) \ -\ - GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ - const bool _use_ct = ( cs_c != 1 && rs_c != 1 ) || \ - m != mr || n != nr; \ - GEMM_UKR_SETUP_CT_POST(ch); - -#define GEMM_UKR_SETUP_CT_ANY(ch,mr,nr,row_major) \ -\ - GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ - const bool _use_ct = m != mr || n != nr; \ - GEMM_UKR_SETUP_CT_POST(ch); - -#define GEMM_UKR_SETUP_CT_ALIGNED(ch,mr,nr,row_major,alignment) \ -\ - GEMM_UKR_SETUP_CT_PRE(ch,mr,nr,row_major); \ - const bool _use_ct = ( row_major ? cs_c != 1 : rs_c != 1 ) || \ - m != mr || n != nr || \ - ( (uintptr_t)_c % alignment ) || \ - ( ( ( row_major ? _rs_c : _cs_c )*sizeof( PASTEMAC(ch,ctype) ) ) % alignment ); \ - GEMM_UKR_SETUP_CT_POST(ch); - -#define GEMM_UKR_FLUSH_CT(ch) \ -\ - if ( _use_ct ) \ - { \ - PASTEMAC(ch,xpbys_mxn) \ - ( \ - m, n, \ - _ct, _rs_ct, _cs_ct, \ - _beta, \ - _c, _rs_c, _cs_c \ - ); \ - } \ - - #endif From 72e649959e5c0a04081ae11a59ea9b787982cdd6 Mon Sep 17 00:00:00 2001 From: "Field G. Van Zee" Date: Tue, 21 Dec 2021 15:42:06 -0600 Subject: [PATCH 12/12] Trivial comment/whitespace changes. --- frame/1m/packm/bli_packm_alloc.c | 2 +- ref_kernels/ind/bli_gemm1m_ref.c | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/frame/1m/packm/bli_packm_alloc.c b/frame/1m/packm/bli_packm_alloc.c index c316932145..b12a93ddc0 100644 --- a/frame/1m/packm/bli_packm_alloc.c +++ b/frame/1m/packm/bli_packm_alloc.c @@ -114,6 +114,6 @@ void* bli_packm_alloc_ex bli_thread_barrier( thread ); } - return bli_mem_buffer( cntl_mem_p ); + return bli_mem_buffer( cntl_mem_p ); } diff --git a/ref_kernels/ind/bli_gemm1m_ref.c b/ref_kernels/ind/bli_gemm1m_ref.c index 51ff28c41a..fbd15d695b 100644 --- a/ref_kernels/ind/bli_gemm1m_ref.c +++ b/ref_kernels/ind/bli_gemm1m_ref.c @@ -124,7 +124,7 @@ void PASTEMAC3(ch,opname,arch,suf) \ else using_ct = FALSE; \ \ \ - /* If we are not packing a full micro-tile, then we must write to + /* If we are not computing a full micro-tile, then we must write to ct and then accumulate to c afterwards. */ \ if ( mr != m || nr != n ) using_ct = TRUE; \ \