Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 75 additions & 71 deletions src/multiply_module.f90
Original file line number Diff line number Diff line change
Expand Up @@ -148,34 +148,28 @@ subroutine mat_mult(myid,a,lena,b,lenb,c,lenc,a_b_c,debug)
integer :: icall,n_cont(2),kpart_next,ind_partN,k_off(2)
integer :: stat,ilen2(2),lenb_rem(2)
! Remote variables to be allocated
integer(integ),allocatable :: ibpart_rem(:,:)
integer(integ),allocatable :: ibpart_rem(:)
type jagged_array_r
real(double), allocatable :: values(:)
end type jagged_array_r
type(jagged_array_r) :: b_rem(2)
! Remote variables which will point to part_array
type jagged_pointer_array_i
integer(integ),pointer :: values(:)
end type jagged_pointer_array_i
type(jagged_pointer_array_i) :: nbnab_rem(2)
type(jagged_pointer_array_i) :: ibseq_rem(2)
type(jagged_pointer_array_i) :: ibind_rem(2)
type(jagged_pointer_array_i) :: ib_nd_acc_rem(2)
type(jagged_pointer_array_i) :: ibndimj_rem(2)
type(jagged_pointer_array_i) :: npxyz_rem(2)
integer(integ), dimension(:), pointer :: nbnab_rem
integer(integ), dimension(:), pointer :: ibseq_rem
integer(integ), dimension(:), pointer :: ibind_rem
integer(integ), dimension(:), pointer :: ib_nd_acc_rem
integer(integ), dimension(:), pointer :: ibndimj_rem
integer(integ), dimension(:), pointer :: npxyz_rem
! Arrays for remote variables to point to
integer, target :: part_array(3*a_b_c%parts%mx_mem_grp+ &
5*a_b_c%parts%mx_mem_grp*a_b_c%bmat(1)%mx_abs, 2)
integer :: offset
integer, dimension(:), allocatable :: nreqs
integer :: sends,i,j
integer, dimension(MPI_STATUS_SIZE) :: mpi_stat
type jagged_array_i
integer, allocatable :: values(:)
end type jagged_array_i
type(jagged_array_i) :: recv_part(2)
integer, allocatable :: recv_part(:,:)
real(double) :: t0,t1
integer :: request(2,2), index_rec, index_wait
integer :: request(2,2), index_rec, index_compute

logical :: new_partition(2)

Expand All @@ -184,16 +178,13 @@ subroutine mat_mult(myid,a,lena,b,lenb,c,lenc,a_b_c,debug)
call start_timer(tmr_std_allocation)
if(iprint_mat>3.AND.myid==0) t0 = mtime()
! Allocate memory for the elements
allocate(ibpart_rem(a_b_c%parts%mx_mem_grp*a_b_c%bmat(1)%mx_abs,2),STAT=stat)
allocate(ibpart_rem(a_b_c%parts%mx_mem_grp*a_b_c%bmat(1)%mx_abs),STAT=stat)
if(stat/=0) call cq_abort('mat_mult: error allocating ibpart_rem')
!allocate(atrans(a_b_c%amat(1)%length),STAT=stat)
allocate(atrans(lena),STAT=stat)
if(stat/=0) call cq_abort('mat_mult: error allocating atrans')
allocate(recv_part(1)%values(0:a_b_c%comms%inode),STAT=stat)
allocate(recv_part(2)%values(0:a_b_c%comms%inode),STAT=stat)
allocate(recv_part(a_b_c%comms%inode + 1, 2),STAT=stat)
if(stat/=0) call cq_abort('mat_mult: error allocating recv_part')
recv_part(1)%values = zero
recv_part(2)%values = zero
recv_part = zero
call stop_timer(tmr_std_allocation)
!write(io_lun,*) 'Sizes: ',a_b_c%comms%mx_dim3*a_b_c%comms%mx_dim2*a_b_c%parts%mx_mem_grp*a_b_c%bmat(1)%mx_abs,&
! a_b_c%parts%mx_mem_grp*a_b_c%bmat(1)%mx_abs,a_b_c%comms%mx_dim3*a_b_c%comms%mx_dim1* &
Expand Down Expand Up @@ -227,7 +218,7 @@ subroutine mat_mult(myid,a,lena,b,lenb,c,lenc,a_b_c,debug)
ncover_yz=a_b_c%gcs%ncovery*a_b_c%gcs%ncoverz

! Receive the data from the first partition - blocking
call do_comms(k_off(2), 1, part_array(:,2), n_cont(2), ilen2(2), a_b_c, b, recv_part(2)%values, &
call do_comms(k_off(2), 1, part_array(:,2), n_cont(2), ilen2(2), a_b_c, b, recv_part(:,2), &
b_rem(2)%values, lenb_rem(2), myid, ncover_yz, new_partition(2))

request = MPI_REQUEST_NULL
Expand All @@ -239,51 +230,55 @@ subroutine mat_mult(myid,a,lena,b,lenb,c,lenc,a_b_c,debug)

! These indices point to elements of all the 2-element vectors of the variables needed
! for the do_comms and m_kern_min/max calls. They alternate between the values of
! (index_rec,index_wait)=(1,2) and (2,1) from iteration to iteration.
! (index_rec,index_compute)=(1,2) and (2,1) from iteration to iteration.
! index_rec points to the values being received in the current iteration in do_comms,
! and index_wait points to the values received in the previous iteration, thus computation
! and index_compute points to the values received in the previous iteration, thus computation
! can start on them in m_kern_min/max
! These indices are also used to point to elements of the 2x2-element request() array,
! that contains the MPI request numbers for the non-blocking data receives. There are 2
! MPI_Irecv calls per call of do_comms, and request() keeps track of 2 sets of those calls,
! thus it's of size 2x2.
! request(:,index_rec) points to the requests from the current iteration MPI_Irecv,
! and request(:,index_wait) points to the requests from the previous iteration, that have
! and request(:,index_compute) points to the requests from the previous iteration, that have
! to complete in order for the computation to start (thus the MPI_Wait).
index_rec = mod(kpart,2) + 1
index_wait = mod(kpart+1,2) + 1
index_compute = mod(kpart+1,2) + 1

! Receive the data from the current partition - non-blocking
call do_comms(k_off(index_rec), kpart, part_array(:,index_rec), n_cont(index_rec), ilen2(index_rec), &
a_b_c, b, recv_part(index_rec)%values, b_rem(index_rec)%values, &
a_b_c, b, recv_part(:,index_rec), b_rem(index_rec)%values, &
lenb_rem(index_rec), myid, ncover_yz, new_partition(index_rec), .true., request(:,index_rec))

! Check that previous partition data have been received before starting computation
if (kpart.gt.2 .and. all(request(:,index_wait).ne.[MPI_REQUEST_NULL,MPI_REQUEST_NULL])) &
call MPI_Waitall(2,request(:,index_wait),MPI_STATUSES_IGNORE,ierr)
if(new_partition(index_wait)) then
if (kpart.gt.2 .and. all(request(:,index_compute).ne.[MPI_REQUEST_NULL,MPI_REQUEST_NULL])) &
call MPI_Waitall(2,request(:,index_compute),MPI_STATUSES_IGNORE,ierr)

if(new_partition(index_compute)) then
! Now point the _rem variables at the appropriate parts of
! the array where we will receive the data
offset = 0
nbnab_rem(index_wait)%values => part_array(offset+1:offset+n_cont(index_wait),index_wait)
offset = offset+n_cont(index_wait)
ibind_rem(index_wait)%values => part_array(offset+1:offset+n_cont(index_wait),index_wait)
offset = offset+n_cont(index_wait)
ib_nd_acc_rem(index_wait)%values => part_array(offset+1:offset+n_cont(index_wait),index_wait)
offset = offset+n_cont(index_wait)
ibseq_rem(index_wait)%values => part_array(offset+1:offset+ilen2(index_wait),index_wait)
offset = offset+ilen2(index_wait)
npxyz_rem(index_wait)%values => part_array(offset+1:offset+3*ilen2(index_wait),index_wait)
offset = offset+3*ilen2(index_wait)
ibndimj_rem(index_wait)%values => part_array(offset+1:offset+ilen2(index_wait),index_wait)
if(offset+ilen2(index_wait)>3*a_b_c%parts%mx_mem_grp+ &
nbnab_rem => part_array(offset+1:offset+n_cont(index_compute),index_compute)
offset = offset+n_cont(index_compute)
ibind_rem => part_array(offset+1:offset+n_cont(index_compute),index_compute)
offset = offset+n_cont(index_compute)
ib_nd_acc_rem => part_array(offset+1:offset+n_cont(index_compute),index_compute)
offset = offset+n_cont(index_compute)
ibseq_rem => part_array(offset+1:offset+ilen2(index_compute),index_compute)
offset = offset+ilen2(index_compute)
npxyz_rem => part_array(offset+1:offset+3*ilen2(index_compute),index_compute)
offset = offset+3*ilen2(index_compute)
ibndimj_rem => part_array(offset+1:offset+ilen2(index_compute),index_compute)
if(offset+ilen2(index_compute)>3*a_b_c%parts%mx_mem_grp+ &
5*a_b_c%parts%mx_mem_grp*a_b_c%bmat(1)%mx_abs) then
call cq_abort('mat_mult: error pointing to part_array ',kpart-1)
end if
! Create ibpart_rem
call end_part_comms(myid,n_cont(index_wait),nbnab_rem(index_wait)%values, &
ibind_rem(index_wait)%values,npxyz_rem(index_wait)%values,&
ibpart_rem(:,index_wait),ncover_yz,a_b_c%gcs%ncoverz)
call end_part_comms(myid,n_cont(index_compute), &
nbnab_rem, &
ibind_rem, &
npxyz_rem, &
ibpart_rem, &
ncover_yz,a_b_c%gcs%ncoverz)
end if

! Omp master doesn't include an implicit barrier. We want master
Expand All @@ -294,17 +289,31 @@ subroutine mat_mult(myid,a,lena,b,lenb,c,lenc,a_b_c,debug)

! Call the computation kernel on the previous partition
if(a_b_c%mult_type.eq.1) then ! C is full mult
call m_kern_max( k_off(index_wait),kpart,ib_nd_acc_rem(index_wait)%values, ibind_rem(index_wait)%values, &
nbnab_rem(index_wait)%values,ibpart_rem(:,index_wait),ibseq_rem(index_wait)%values, &
ibndimj_rem(index_wait)%values, atrans,b_rem(index_wait)%values,c,a_b_c%ahalo,a_b_c%chalo, &
call m_kern_max( k_off(index_compute),kpart, &
ib_nd_acc_rem, &
ibind_rem, &
nbnab_rem, &
ibpart_rem, &
ibseq_rem, &
ibndimj_rem, &
atrans, &
b_rem(index_compute)%values, &
c,a_b_c%ahalo,a_b_c%chalo, &
a_b_c%ltrans,a_b_c%bmat(1)%mx_abs,a_b_c%parts%mx_mem_grp, &
a_b_c%prim%mx_iprim, lena, lenb_rem(index_wait), lenc)
a_b_c%prim%mx_iprim, lena, lenb_rem(index_compute), lenc)
else if(a_b_c%mult_type.eq.2) then ! A is partial mult
call m_kern_min( k_off(index_wait),kpart,ib_nd_acc_rem(index_wait)%values, ibind_rem(index_wait)%values, &
nbnab_rem(index_wait)%values,ibpart_rem(:,index_wait),ibseq_rem(index_wait)%values, &
ibndimj_rem(index_wait)%values, atrans,b_rem(index_wait)%values,c,a_b_c%ahalo,a_b_c%chalo, &
call m_kern_min( k_off(index_compute),kpart,&
ib_nd_acc_rem, &
ibind_rem, &
nbnab_rem, &
ibpart_rem, &
ibseq_rem, &
ibndimj_rem, &
atrans, &
b_rem(index_compute)%values, &
c,a_b_c%ahalo,a_b_c%chalo, &
a_b_c%ltrans,a_b_c%bmat(1)%mx_abs,a_b_c%parts%mx_mem_grp, &
a_b_c%prim%mx_iprim, lena, lenb_rem(index_wait), lenc)
a_b_c%prim%mx_iprim, lena, lenb_rem(index_compute), lenc)
end if
!$omp barrier
end do main_loop
Expand Down Expand Up @@ -345,9 +354,7 @@ subroutine mat_mult(myid,a,lena,b,lenb,c,lenc,a_b_c,debug)
call start_timer(tmr_std_allocation)
deallocate(ibpart_rem,STAT=stat)
if(stat/=0) call cq_abort('mat_mult: error deallocating ibpart_rem')
deallocate(recv_part(1)%values,STAT=stat)
if(stat/=0) call cq_abort('mat_mult: error deallocating recv_part')
deallocate(recv_part(2)%values,STAT=stat)
if(allocated(recv_part)) deallocate(recv_part,STAT=stat)
if(stat/=0) call cq_abort('mat_mult: error deallocating recv_part')
call stop_timer(tmr_std_allocation)
call my_barrier
Expand Down Expand Up @@ -578,7 +585,7 @@ subroutine do_comms(k_off, kpart, part_array, n_cont, ilen2, a_b_c, b, recv_part
integer, intent(in) :: kpart
type(matrix_mult), intent(in) :: a_b_c
real(double), intent(in) :: b(:)
integer, allocatable, dimension(:), intent(inout) :: recv_part
integer, dimension(:), intent(inout) :: recv_part
real(double), allocatable, intent(inout) :: b_rem(:)
integer, intent(out) :: lenb_rem
integer, intent(in) :: myid, ncover_yz
Expand All @@ -597,10 +604,8 @@ subroutine do_comms(k_off, kpart, part_array, n_cont, ilen2, a_b_c, b, recv_part
! Set non-blocking receive flag
do_nonb_local = .false.
if (present(do_nonb)) do_nonb_local = do_nonb

if(.not.allocated(recv_part)) allocate(recv_part(0:a_b_c%comms%inode))

icall=1
icall = 1
ind_part = a_b_c%ahalo%lab_hcell(kpart)
new_partition = .true.

Expand All @@ -615,7 +620,7 @@ subroutine do_comms(k_off, kpart, part_array, n_cont, ilen2, a_b_c, b, recv_part
! Get the data
ipart = a_b_c%parts%i_cc2seq(ind_part)
nnode = a_b_c%comms%neigh_node_list(kpart)
recv_part(nnode) = recv_part(nnode)+1
recv_part(nnode+1) = recv_part(nnode)+1
if(allocated(b_rem)) deallocate(b_rem)
if(a_b_c%parts%i_cc2node(ind_part)==myid+1) then
lenb_rem = a_b_c%bmat(ipart)%part_nd_nabs
Expand All @@ -625,11 +630,11 @@ subroutine do_comms(k_off, kpart, part_array, n_cont, ilen2, a_b_c, b, recv_part
allocate(b_rem(lenb_rem))
call prefetch(kpart,a_b_c%ahalo,a_b_c%comms,a_b_c%bmat,icall,&
n_cont,part_array,a_b_c%bindex,b_rem,lenb_rem,b,myid,ilen2,&
mx_msg_per_part,a_b_c%parts,a_b_c%prim,a_b_c%gcs,(recv_part(nnode)-1)*2,do_nonb,request)
mx_msg_per_part,a_b_c%parts,a_b_c%prim,a_b_c%gcs,(recv_part(nnode))*2,do_nonb,request)
end if

k_off=a_b_c%ahalo%lab_hcover(kpart) ! --- offset for pbcs
end subroutine do_comms
k_off=a_b_c%ahalo%lab_hcover(kpart) ! --- offset for pbcs
end subroutine do_comms

!!****f* multiply_module/prefetch *
!!
Expand Down Expand Up @@ -689,7 +694,6 @@ subroutine prefetch(this_part,ahalo,a_b_c,bmat,icall,&
integer :: ncover_yz,ind_part,iskip,ind_last
integer :: inode,ipart,nnode
logical :: do_nonb_local

! Set non-blocking receive flag
do_nonb_local = .false.
if (present(do_nonb)) do_nonb_local = do_nonb
Expand All @@ -709,23 +713,23 @@ subroutine prefetch(this_part,ahalo,a_b_c,bmat,icall,&
end if
if(icall.eq.1) then ! Else fetch the data
ilen2 = a_b_c%ilen2rec(ipart,nnode)
if(.not.do_nonb_local) then ! Use blocking receive
call Mquest_get( prim%mx_ngonn, &
if (do_nonb_local) then ! Use blocking receive
if (.not.present(request)) call cq_abort('Need to provide MPI request argument for non-blocking receive.')
call Mquest_get_nonb( prim%mx_ngonn, &
a_b_c%ilen2rec(ipart,nnode),&
a_b_c%ilen3rec(ipart,nnode),&
n_cont,inode,ipart,myid,&
bind_rem,b_rem,lenb_rem,bind,&
a_b_c%istart(ipart,nnode), &
bmat(1)%mx_abs,parts%mx_mem_grp,tag)
bmat(1)%mx_abs,parts%mx_mem_grp,tag,request)
else ! Use non-blocking receive
if (.not.present(request)) call cq_abort('Need to provide MPI request argument for non-blocking receive.')
call Mquest_get_nonb( prim%mx_ngonn, &
call Mquest_get( prim%mx_ngonn, &
a_b_c%ilen2rec(ipart,nnode),&
a_b_c%ilen3rec(ipart,nnode),&
n_cont,inode,ipart,myid,&
bind_rem,b_rem,lenb_rem,bind,&
a_b_c%istart(ipart,nnode), &
bmat(1)%mx_abs,parts%mx_mem_grp,tag,request)
bmat(1)%mx_abs,parts%mx_mem_grp,tag)
end if
end if
return
Expand Down