module sumfln_stochy_mod

      implicit none

      contains

      subroutine sumfln_stochy(flnev,flnod,lat1s,plnev,plnod,
     &                         nvars,ls_node,latl2,
     &                         workdim,nvarsdim,four_gr,
     &                         ls_nodes,max_ls_nodes,
     &                         lats_nodes,global_lats,
     &                         lats_node,ipt_lats_node,
     &                         lons_lat,londi,latl,nvars_0)
!
      use stochy_resol_def , only : jcap,latgd
      use spectral_layout_mod   , only : len_trie_ls,len_trio_ls,
     &                               ls_dim,ls_max_node,me,nodes
      use machine
      use spectral_layout_mod, only : num_parthds_stochy => ompthreads
      !or : use fv_mp_mod ?
      use mpp_mod, only: mpp_pe,mpp_npes, mpp_alltoall,
     &                   mpp_get_current_pelist

      implicit none
!
      external esmf_dgemm
!
      integer lat1s(0:jcap),latl2
!
      integer              nvars,nvars_0
      integer,                allocatable :: pelist(:)
      integer :: npes
      real(kind=kind_dbl_prec) flnev(len_trie_ls,2*nvars)
      real(kind=kind_dbl_prec) flnod(len_trio_ls,2*nvars)
!
      real(kind=kind_dbl_prec) plnev(len_trie_ls,latl2)
      real(kind=kind_dbl_prec) plnod(len_trio_ls,latl2)
!
      integer              ls_node(ls_dim,3)
!
!cmr  ls_node(1,1) ... ls_node(ls_max_node,1) : values of L
!cmr  ls_node(1,2) ... ls_node(ls_max_node,2) : values of jbasev
!cmr  ls_node(1,3) ... ls_node(ls_max_node,3) : values of jbasod
!
!    local scalars
!    -------------
!
      integer              j, k, l, lat, lat1, n, kn, n2,indev,indod
!
!    local arrays
!    ------------
!
      real(kind=kind_dbl_prec), dimension(nvars*2,latl2) ::  apev, apod
      integer               num_threads, nvar_thread_max, nvar_1, nvar_2
     &,                     thread
! xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
!
      integer              nvarsdim,  latl, workdim, londi
     &,                    lats_node, ipt_lats_node
!
      real(kind=kind_dbl_prec) four_gr(londi,nvarsdim,workdim)
!
      integer              ls_nodes(ls_dim,nodes)
      integer, dimension(nodes) :: max_ls_nodes, lats_nodes
      integer, dimension(latl)  :: global_lats, lons_lat

!jfe  integer              global_lats(latg+2*jintmx+2*nypt*(nodes-1))
!
      real(kind=4),target,dimension(2,nvars,ls_dim*workdim,nodes)::
     &                               workr,works
!      real(kind=4),dimension(2*nvars*ls_dim*workdim*nodes)::
!     &                               work1dr,work1ds
      real(kind=4),pointer:: work1dr(:),work1ds(:)
      integer, dimension(nodes) :: kpts, kptr, sendcounts, recvcounts,
     &                              sdispls
!
      integer              ierr,ilat,ipt_ls, lmax,lval,i,jj,lonl,nv
      integer              node,nvar,arrsz,my_pe
      integer              ilat_list(nodes)              !    for OMP buffer copy
!
!    statement functions
!    -------------------
!
      integer              indlsev, jbasev, indlsod, jbasod
!
      include 'function_indlsev'
      include 'function_indlsod'
!
      real(kind=kind_dbl_prec), parameter ::  cons0=0.0d0, cons1=1.0d0
!
! xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
!
      arrsz=2*nvars*ls_dim*workdim*nodes
      num_threads     = min(num_parthds_stochy,nvars)
      nvar_thread_max = (nvars+num_threads-1)/num_threads
      npes = mpp_npes()
      my_pe=mpp_pe()
      allocate(pelist(0:npes-1))
      call mpp_get_current_pelist(pelist)
      kpts   = 0
!     write(0,*)' londi=',londi,'nvarsdim=',nvarsdim,'workdim=',workdim
!
      do j = 1, ls_max_node   ! start of do j loop #####################
!
             l = ls_node(j,1)
        jbasev = ls_node(j,2)
        jbasod = ls_node(j,3)

        indev  = indlsev(l,l)
        indod  = indlsod(l+1,l)
!
        lat1 = lat1s(l)
        if ( kind_dbl_prec == 8 ) then !------------------------------------

!$omp parallel do private(thread,nvar_1,nvar_2,n2)
          do thread=1,num_threads   ! start of thread loop ..............
            nvar_1 = (thread-1)*nvar_thread_max + 1
            nvar_2 = min(nvar_1+nvar_thread_max-1,nvars)

            if (nvar_2 >= nvar_1) then
              n2 = 2*(nvar_2-nvar_1+1)

!           compute the even and odd components of the fourier coefficients
!
!           compute the sum of the even real      terms for each level
!           compute the sum of the even imaginary terms for each level
!
!           call dgemm('t','n',latl2-lat1+1, 2*(nvar_2-nvar_1+1),
!     &                 (jcap+2-l)/2,cons1,     !constant
!     &                 plnev(indev,lat1), len_trio_ls,
!     &                 flnev(indev,2*nvar_1-1),len_trio_ls,cons0,
!     &                 apev(2*nvar_1-1,lat1),latl2)
             call esmf_dgemm(
     &                   't',
     &                   'n',
     &                    n2,
     &                   latl2-lat1+1,
     &                   (jcap+3-l)/2,
     &                   cons1,
     &                   flnev(indev,2*nvar_1-1),
     &                   len_trie_ls,
     &                   plnev(indev,lat1),
     &                   len_trie_ls,
     &                   cons0,
     &                   apev(2*nvar_1-1,lat1),
     &                   2*nvars
     &                   )
!
!           compute the sum of the odd real      terms for each level
!           compute the sum of the odd imaginary terms for each level
!
!           call dgemm('t','n',latl2-lat1+1, 2*(nvar_2-nvar_1+1),
!     &                 (jcap+2-l)/2,cons1,     !constant
!     &                 plnod(indod,lat1), len_trio_ls,
!     &                 flnod(indod,2*nvar_1-1),len_trio_ls,cons0,
!     &                 apod(2*nvar_1-1,lat1), latl2)
              call esmf_dgemm(
     &                   't',
     &                   'n',
     &                   n2,
     &                   latl2-lat1+1,
     &                  (jcap+2-l)/2,
     &                   cons1,
     &                   flnod(indod,2*nvar_1-1),
     &                   len_trio_ls,
     &                   plnod(indod,lat1),
     &                   len_trio_ls,
     &                   cons0,
     &                   apod(2*nvar_1-1,lat1),
     &                   2*nvars
     &                   )
!
            endif
          enddo   ! end of thread loop ..................................
        else !------------------------------------------------------------
!$omp parallel do private(thread,nvar_1,nvar_2)
          do thread=1,num_threads   ! start of thread loop ..............
            nvar_1 = (thread-1)*nvar_thread_max + 1
            nvar_2 = min(nvar_1+nvar_thread_max-1,nvars)
          enddo   ! end of thread loop ..................................
        endif !-----------------------------------------------------------
!
ccxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
!
!       compute the fourier coefficients for each level
!       -----------------------------------------------
!
        ilat_list(1) = 0
        do node = 1, nodes - 1
          ilat_list(node+1) = ilat_list(node) + lats_nodes(node)
        end do
!$omp parallel do private(node,jj,ilat,lat,ipt_ls,nvar,kn,n2)
        do node=1,nodes
          do jj=1,lats_nodes(node)
            ilat  = ilat_list(node) + jj
            lat    = global_lats(ilat)
            ipt_ls = min(lat,latl-lat+1)
            if ( ipt_ls >= lat1s(ls_nodes(j,me+1)) ) then
              kpts(node) = kpts(node) + 1
              kn = kpts(node)
!
              if ( lat <= latl2 ) then
!                                                northern hemisphere
                do nvar=1,nvars
                  n2 = nvar + nvar
                  works(1,nvar,kn,node) = apev(n2-1,ipt_ls)
     &                                  + apod(n2-1,ipt_ls)
                  works(2,nvar,kn,node) = apev(n2,  ipt_ls)
     &                                  + apod(n2,  ipt_ls)
                enddo
              else
!                                                southern hemisphere
                do nvar=1,nvars
                  n2 = nvar + nvar
                  works(1,nvar,kn,node) = apev(n2-1,ipt_ls)
     &                                  - apod(n2-1,ipt_ls)
                  works(2,nvar,kn,node) = apev(n2,  ipt_ls)
     &                                  - apod(n2,  ipt_ls)
                enddo
              endif
            endif
          enddo
        enddo
!
      enddo   ! end of do j loop #######################################
!
      kptr = 0
      do node=1,nodes
         do l=1,max_ls_nodes(node)
            lval = ls_nodes(l,node)+1
            do j=1,lats_node
               lat = global_lats(ipt_lats_node-1+j)
               if ( min(lat,latl-lat+1) >= lat1s(lval-1) ) then
                  kptr(node) = kptr(node) + 1
               endif
            enddo
         enddo
      enddo
!
!
      n2 = nvars + nvars
!$omp parallel do private(node)
      do node=1,nodes
         sendcounts(node) = kpts(node) * n2
         recvcounts(node) = kptr(node) * n2
         sdispls(node)    = (node-1)   * n2 * ls_dim * workdim
      end do
      work1dr(1:arrsz)=>workr
      work1ds(1:arrsz)=>works
      call mpp_alltoall(work1ds, sendcounts, sdispls,
     &                  work1dr,recvcounts,sdispls,pelist)
      nullify(work1dr)
      nullify(work1ds)
!$omp parallel do private(j,lat,lmax,nvar,lval,n2,lonl,nv)
      do j=1,lats_node
         lat  = global_lats(ipt_lats_node-1+j)
         lonl = lons_lat(lat)
         lmax = min(jcap,lonl/2)
         n2   = lmax + lmax + 3
!     write(0,*)' j=',j,' lat=',lat,' lmax=',lmax,' n2=',n2
!    &,' nvars=',nvars,' lonl=',lonl
         if ( n2 <= lonl+2 ) then
           do nvar=1,nvars
             nv = nvars_0 + nvar
             do lval = n2, lonl+2
!        write(0,*)' lval=',lval,' nvar=',nvar,nvars_0
!    &,' n2=',n2,' lonl=',lonl,' nv=',nv,' j=',j
!    &,'size=',size(four_gr,1),size(four_gr,2),size(four_gr,3)
               four_gr(lval,nv,j) = cons0
             enddo
           enddo
         endif
      enddo
!
      kptr = 0
!     write(0,*)' kptr=',kptr(1)
!!
!$omp parallel do private(node,l,lval,j,lat,nvar,kn,n2)
      do node=1,nodes
        do l=1,max_ls_nodes(node)
          lval = ls_nodes(l,node)+1
          n2   = lval + lval
          do j=1,lats_node
            lat = global_lats(ipt_lats_node-1+j)
            if ( min(lat,latl-lat+1) >= lat1s(lval-1) ) then
              kptr(node) = kptr(node) + 1
              kn = kptr(node)

              do nvar=1,nvars
                four_gr(n2-1,nvars_0+nvar,j) = workr(1,nvar,kn,node)
                four_gr(n2,  nvars_0+nvar,j) = workr(2,nvar,kn,node)
              enddo
            endif
          enddo
        enddo
      enddo
!
      return
      end subroutine sumfln_stochy

      end module sumfln_stochy_mod