#! /usr/bin/env python

##@namespace scripts.exhwrf_products
# Runs regribbing operations on the output of exhwrf_products, and
# runs the GFDL vortex tracker on the regribbed output.  Note that this
# script is restartable: if it fails, and you call it again, it will
# pick up where it left off.  To force a reprocessing of the entire
# post-processing system, call exhwrf_unpost first.

import logging, os, sys
import produtil.sigsafety, produtil.cd, produtil.run, produtil.setup
import produtil.log, produtil.ecflow
import hwrf_expt
import hwrf_alerts
import hwrf_wcoss

from produtil.ecflow import set_ecflow_meter
from produtil.log import jlogger
from produtil.cd import NamedDir
from produtil.run import mpi, mpirun, checkrun

def gribber(ecflow_meter=None):
    """!Runs the hwrf.gribtask.GRIBTask on one thread."""
    jlogger.info(hwrf_expt.conf.strinterp('config',
            '{stormlabel}: starting regribbing job for {out_prefix}'))
    with NamedDir(hwrf_expt.WORKhwrf,logger=logging.getLogger()) as t:
        hwrf_expt.gribber.uncomplete()
        #hwrf_expt.gribber.unrun()
        os.environ['FORT_BUFFERED']='TRUE'
        os.environ['OMP_NUM_THREADS']='1'
        os.environ['MKL_NUM_THREADS']='1'
        hwrf_expt.gribber.run(raiseall=True,ecflow_meter=ecflow_meter)
        if ecflow_meter:
            # Gribber never sets ecflow meter to full
            set_ecflow_meter(ecflow_meter,126)
    jlogger.info(hwrf_expt.conf.strinterp('config',
            '{stormlabel}: completed regribbing job for {out_prefix}'))

def tracker(n):
    """!Runs the hwrf.tracker.TrackerTask on one thread
    @param n the domain of interest: 1 2 or 3"""
    jlogger.info(hwrf_expt.conf.strinterp('config',
            '{stormlabel}: starting domain {dom} tracker job for {out_prefix}',
            dom=n))
    with NamedDir(hwrf_expt.WORKhwrf,logger=logging.getLogger()) as t:
        if n==3:
            hwrf_expt.tracker.run()
        elif n==2:
            hwrf_expt.trackerd02.run()
        elif n==1:
            hwrf_expt.trackerd01.run()
    jlogger.info(hwrf_expt.conf.strinterp(
            'config','{stormlabel}: completed domain {dom} tracker job '
            'for {out_prefix}',dom=n))

def copier():
    """!Runs the hwrf.copywrf.WRFCopyTask to copy native WRF input and
    output files to COM, and then runs the gribber().  Does this on one thread."""
    post_runs_copier=hwrf_expt.conf.getbool(
        'config','post_runs_wrfcopier',False)
    if not post_runs_copier:
        jlogger.info(hwrf_expt.conf.strinterp('config',
            '{stormlabel}: starting wrfcopier job for {out_prefix}'))
        with NamedDir(hwrf_expt.WORKhwrf,logger=logging.getLogger()) as t:
            hwrf_expt.wrfcopier.run(check_all=True,raise_all=True)
        jlogger.info(hwrf_expt.conf.strinterp('config',
            '{stormlabel}: completed wrfcopier job for {out_prefix}'))
    else:
        jlogger.info('Products job will not run wrfcopier, post will do it.')
    gribber()

def products():
    """!Runs the hwrf.nhc_products.NHCProducts on multiple threads."""
    jlogger.info(hwrf_expt.conf.strinterp('config',
            '{stormlabel}: starting nhc_products job for {out_prefix}'))
    with NamedDir(hwrf_expt.WORKhwrf,logger=logging.getLogger()) as t:
        hwrf_expt.nhcp.run()
    jlogger.info(hwrf_expt.conf.strinterp('config',
            '{stormlabel}: completed nhc_products job for {out_prefix}'))

def starter(dryrun):
    """!Main program for subprocesses.  Calls tracker() for the last
    one or three threads (depending on whether extra_trackers=yes in
    the [config] section).  Calls copier() for the last non-tracker
    rank.  Calls the gribber() on all other ranks."""
    conf=hwrf_expt.conf
    myrank=int(os.environ['SCR_COMM_RANK'])
    count=int(os.environ['SCR_COMM_SIZE'])
    logger=conf.log('exhwrf_products')
    extra_trackers=conf.getbool('config','extra_trackers',False)
    ngribbers=0
    ncopiers=0
    run=None
    for rank in range(count):
        if rank==0: 
            if rank==myrank:
                if dryrun: logger.info('Rank %d runs d03 tracker'%rank)
                run=lambda: tracker(3)
                whoami='tracker'
        elif rank==1 and extra_trackers:
            if rank==myrank:
                if dryrun: logger.info('Rank %d runs d02 tracker'%rank)
                run=lambda: tracker(2)
                whoami='d02tracker'
        elif rank==2 and extra_trackers:
            if rank==myrank:
                if dryrun: logger.info('Rank %d runs d01 tracker'%rank)
                run=lambda: tracker(1)
                whoami='d01tracker'
        elif rank==count-1:
            ncopiers+=1
            if rank==myrank:
                if dryrun: logger.info('Rank %d runs wrfcopier'%rank)
                run=lambda: copier()
                whoami='copier%d'%ncopiers
        else:
            ngribbers+=1
            if rank==myrank: 
                if dryrun: logger.info('Rank %d runs gribber'%rank)
                if ngribbers==1:
                    run=lambda: gribber(ecflow_meter='gribber')
                else:
                    run=lambda: gribber()
                whoami='gribber%d'%ngribbers
    if ncopiers<1 or ngribbers<1:
        need=2+1+1
        if extra_trackers: need+=2
        msg='Cannot run products job with %d processors with these settings.'\
            ' I require at least %d.'%(count,need)
        logger.critical(msg)
        sys.exit(2)
    if dryrun:
        return whoami
    else:
        run()

def slave_main():
    """!This is run multiple times in parallel, once in each
    subprocess.  It sets up the environment and logging settings and
    then runs the starter() function."""
    rank=int(os.environ['SCR_COMM_RANK'])
    count=int(os.environ['SCR_COMM_SIZE'])
    print('MPI communicator: rank=%d size=%d'%(rank,count))
    hwrf_expt.init_module(preload=hwrf_expt.argv_preload)
    hwrf_expt.conf.add_fallback_callback(hwrf_alerts.fallback_callback)
    hwrf_alerts.add_regrib_alerts()
    hwrf_alerts.add_tracker_alerts()
    subdict={ 'RANK':rank, 'COUNT':count, 'WHO':'regribber', 
              'jobid':produtil.batchsystem.jobid(), 
              'WORKhwrf':hwrf_expt.conf.getdir('WORKhwrf') }

    whoami=starter(dryrun=True)
    subdict['THREAD_WHOAMI']=whoami

    if whoami.find('tracker')>=0:
        # Redirect stdout and stderr to one stream for tracker job:
        if 'TRACKER_LOGS' in os.environ:
            r=os.environ.get('TRACKER_LOGS')
        else:
            r=hwrf_expt.conf.strinterp(
                'config','%(WORKhwrf)s/%(jobid)s-%(THREAD_WHOAMI)s.log')
        rstdout=r % dict(subdict, WHO='tracker', STREAM='out')
        rstderr=r % dict(subdict, WHO='tracker', STREAM='err')
        produtil.log.mpi_redirect(stdoutfile=rstdout,stderrfile=None,
                                  threadname='tracker')
    else:
        # Regribber and copier have one file per stream (out, err).
        if 'REGRIBBER_LOGS' in os.environ:
            r=os.environ['REGRIBBER_LOGS']
        else:
            r=hwrf_expt.conf.strinterp(
                'config',
                '%(WORKhwrf)s/%(jobid)s-%(THREAD_WHOAMI)s.log',
                threadwhoami=whoami)
        rstdout=r % dict(subdict, WHO='regribber', STREAM='out')
        rstderr=r % dict(subdict, WHO='regribber', STREAM='err')
        logging.getLogger('hwrf').warning(
            'Redirecting regribber %d to: stderr=%s stdout=%s'%
            ( rank, rstderr, rstdout ))
        produtil.log.mpi_redirect(stdoutfile=rstdout,stderrfile=rstderr,
                                  threadname='regrib%d'%(rank,))
    whoami=starter(dryrun=False)

def launchself():
    """!Launches an MPI program that will call this script in multiple
    threads using the mpiserial program."""
    # Instruct further processes not to re-launch scripts via mpirun:
    os.environ['LAUNCH_SELF']='no'

    # Launch multiple copies of myself.  We must use mpiserial for this
    # because we need the SCR_COMM_RANK variable:

    logger=logging.getLogger('exhwrf_products')

    if produtil.cluster.name() in ['surge','luna']:
        nodesize=int(os.environ.get('PRODUTIL_RUN_NODESIZE','24'))
        assert(nodesize>=1)
        totaltasks=int(os.environ['TOTAL_TASKS'])
        assert(totaltasks>=1)
        newnodesize=8
        totaltasks=int(totaltasks+nodesize-1)//nodesize*newnodesize
        assert(totaltasks>=4)

        logger.info('Running on Cray.  Will run with %d tasks and %d per node.'%(
            totaltasks,nodesize))

        os.environ['TOTAL_TASKS']='%d'%totaltasks
        os.environ['PRODUTIL_RUN_NODESIZE']='%d'%newnodesize

        assert(totaltasks<24) # safety check in case above code has logic errors
    os.environ["SCR_IMMEDIATE_EXIT"]='yes' # kill processes on first error
    os.environ["SCR_VERBOSITY"]='2' # verb 2 = major MPI calls & all significant events
    checkrun(mpirun(mpi(hwrf_expt.conf.getexe('mpiserial','mpiserial'))
                    [os.path.realpath(__file__)][sys.argv[1:]],allranks=True),logger=logger)
    # Calling checkrun ensures the program exits abnormally if
    # mpirun.lsf (or whatever you use) exits abnormally.

def doit(): 
    """!Main entry point.  Slave processes (launched by mpiserial)
    just call slave_main to pass control on to tracker(), gribber() or
    copier().  The main process (which calls mpiserial) will wait for
    mpiserial to exit, and then run the products() function."""
    produtil.setup.setup(masterdomain='exhwrf_products')
    if 'SCR_COMM_RANK' not in os.environ \
            and os.environ.get('LAUNCH_SELF','yes')=='yes':
        # This is the top level of the job: we are NOT inside an
        # mpi_serial call.

        # Initialize the hwrf_expt and re-call any callbacks for completed products:
        hwrf_expt.init_module(preload=hwrf_expt.argv_preload)
        hwrf_expt.conf.add_fallback_callback(hwrf_alerts.fallback_callback)
        logger=logging.getLogger('exhwrf_products')
        hwrf_wcoss.set_vars_for_products(logger)
        logger.info('Ensure incomplete products are marked as such...')
        hwrf_expt.gribber.uncomplete()
        logger.info('Add alerts and delveries...')
        hwrf_alerts.add_nhc_alerts()
        hwrf_alerts.add_regrib_alerts()
        hwrf_alerts.add_wave_alerts()
        logger.warning('''Rerunning dbn_alert for prior jobs' posted files.''')
        hwrf_expt.gribber.call_completed_callbacks()

        # We're in the top-level job.  Launch copies of ourself to run the
        # gribber and tracker:
        logger.warning('---------------------------------------------------')
        logger.warning('LAUNCH PARALLEL PORTION OF SCRIPT------------------')
        logger.warning('---------------------------------------------------')
        launchself()
        logger.warning('---------------------------------------------------')
        logger.warning('PARALLEL PORTION OF SCRIPT HAS ENDED---------------')
        logger.warning('---------------------------------------------------')

        # Gribber and tracker succeeded.  Run the products job:
        if hwrf_expt.fcstlen == 126:
            products()
        else:
            logger.info('Forecast length is: %d ; Not running the products job.'%hwrf_expt.fcstlen)
    else:
        # We're in a subprocess.  Just run the gribber and tracker and return:
        slave_main()

#cProfile.run('doit()')
doit()