__all__=['NHCProducts']

import math, os, re, glob
import produtil.fileop, produtil.run, produtil.tempdir, produtil.datastore
import hwrf.hwrftask, hwrf.namelist, hwrf.constants, hwrf.storminfo

from math import pi,sqrt
from hwrf.constants import Rearth
from hwrf.namelist import to_fortnml
from produtil.fileop import make_symlink, deliver_file
from produtil.run import checkrun, exe, openmp
from produtil.tempdir import TempDir
from produtil.datastore import COMPLETED, RUNNING, UpstreamFile, FileProduct

class NHCProducts(hwrf.hwrftask.HWRFTask):
    """This is a wrapper around the hwrf_nhc_products program."""
    def __init__(self,dstore,conf,section,wrftask,track,domains,
                 vitals,stream='auxhist1',**kwargs):
        super(NHCProducts,self).__init__(dstore,conf,section,**kwargs)
        self._wrftask=wrftask
        self._track=track # final track product
        assert(track is not None)
        self._domains=domains
        self._stream=stream
        self._vitals=vitals
        simend=wrftask.wrf().simend()
        self._wrfdiag=[x for x in wrftask.products(stream=stream,time=simend)]
        self._products=self._make_products()
        self._name_for_ext=dict()
    def _make_products(self): 
        pre=os.path.join(self.getdir('com'),self.confstrinterp('{nidymdh}.').lower())
        PRE=os.path.join(self.getdir('com'),self.confstrinterp('{nidymdh}.').upper())
        deliverme=dict( wind10m='wind10m.ascii', rainfall='rainfall.ascii',
                        wind10hrly='wind10hrly.ascii')
        DELIVERME=dict( stats='stats.tpc', htcf='hwrf_d03.htcf', afos='afos')
        with self.dstore.transaction() as t:
            prods=dict()
            for pname,psub in deliverme.iteritems():
                prod=UpstreamFile(self.dstore,prodname=pname,category=self.taskname)
                prod.location=pre+psub
                prods[pname]=prod
            for pname,psub in DELIVERME.iteritems():
                prod=UpstreamFile(self.dstore,prodname=pname,category=self.taskname)
                prod.location=PRE+psub
                prods[pname]=prod
            for name,prod in prods.iteritems():
                prod['minsize']=0
                prod['minage']=-120
            for domain in self._domains:
                diagname='wrfdiag_d%02d'%(domain.get_grid_id(),)
                prods[diagname] = FileProduct(dstore=self.dstore, prodname=diagname, 
                    category=self.taskname, location=pre+diagname)
        return prods
    def canrun(self,silent=True):
        """Determines if the hwrf_nhc_products program can be run yet."""
        with self.dstore.transaction():
            # Check for track file:
            if not self._track.available:
                if not silent: 
                    self.log().info('Cannot run: track not yet available.  Location: %s available: %d'% (
                            str(self._track.location), int(self._track.available)))
                return False
            # Check for wrfdiag files.  Run the check() method on each
            # if it is not already available just in case it is an
            # upstream file.
            for w in self._wrfdiag:
                if not w.available:
                    w.check()
                    if not w.available:
                        if not silent: 
                            self.log().info('Cannot run: wrfdiag file %s not available'%(str(w),))
                        return False
        return True
    def get_res_cutoff(self,wrf,fudge_factor=0.8):
        """Calculates the outermost radius from the domain center at which the
storm center can be, while still considered to be resolved by that domain.
This is used to detect failures of the nest motion, and report which
nests actually contain the storm.  Iterates over radii for each nest, 
yielding a radius in km."""
        first=True
        for domain in wrf:
            if first:
                first=False
            else:
                sn=domain.nl.nl_get('domains','e_sn')
                yield domain.dy*Rearth*pi/180.*sn*sqrt(2.)*fudge_factor
    def nesting_level(self,moad,nest):
        """Determines the nesting level of the specified nest relative to the
given moad.  If nest=moad, the result is 0, if nest is the direct
child of moad, then 1, if it is the grandchild, then 2, and so on."""
        level=0
        moadid=moad.get_grid_id()
        dom=nest
        while level<100 and dom.get_grid_id()!=moadid: # cap at 100 just in case
            level+=1
            dom=dom.parent
        return level
    def write_namelist(self,f,wrf,moad,inner):
        """This is an internal implementation function; do not call it directly.
Writes the products.nml namelist to file object f."""
        wrftask=self._wrftask
        g=moad.nl.nl_get
        basin1=self.confstr('basin1',section='config').upper()
        stnum=self.confint('stnum',section='config')
        assert(isinstance(wrftask.location,basestring))
        hifreq=inner.hifreq_file()
        assert(isinstance(hifreq,basestring))

        # Construct a dict of replacement strings to substitute into the namelist:
        repl={ 'dx':g('domains','dx'),
               'dy':g('domains','dy'),
               'ide':g('domains','e_we'),
               'jde':g('domains','e_sn'),
               'YMDH':self.confint('YMDH',section='config'),
               'inhifreq':os.path.join(str(wrftask.location),inner.hifreq_file()),
               'inatcf':self._track.location,
               'timestep':moad.dt,
               'domlat':self.conffloat('domlat',section='config'),
               'domlon':self.conffloat('domlon',section='config'),
               'STORM':self.confstr('STORM',section='config'),
               'ATCFID':"%02d%s"%(stnum,basin1),
               'TierI_model':to_fortnml(self.confstr('TierI_model','HWRF')),
               'TierI_submodel':to_fortnml(self.confstr('TierI_submodel','PARA')),
               'TierI_realtime':to_fortnml(self.confbool('TierI_realtime',True)),
               'swathres':to_fortnml(self.conffloat('swathres',0.05)),
               'swathpad':to_fortnml(self.conffloat('swathpad',0.3)),
               'grads_byteswap':to_fortnml(self.confbool('grads_byteswap',True)),
               'nestlev':to_fortnml(self.nesting_level(moad,inner)),
               'rescut':to_fortnml([float(x) for x in self.get_res_cutoff(wrf)])
             }
        # Guess the forecast center from the basin:
        if(basin1=='L' or basin1=='E' or basin1=='C' or basin1=='Q'):
            repl['centername']='"NHC"'
        else:
            repl['centername']='"JTWC"'
        # Now generate the actual namelist:
        f.write('''
&nhc_products
    intcvitals='tmpvit'
    inatcf='{inatcf}'
    inhifreq='{inhifreq}'
    inwrfdiag='wrfdiag_d<DOMAIN>'
    outpre='{STORM}{ATCFID}.{YMDH}.'
    mdstatus='MDstatus'
    resolution_cutoffs = {rescut}
    want_ymdh={YMDH}
    want_stid={ATCFID}
    want_centername={centername}
    coupler_dt=540.
    fcst_len=126
    ide_moad={ide}
    jde_moad={jde}
    dlmd_moad={dx}
    dphd_moad={dy}
    clat={domlat}
    clon={domlon}
    nesting_level={nestlev}
    grads_byteswap={grads_byteswap}
    time_step={timestep}
    model={TierI_model}
    submodel={TierI_submodel}
    realtime={TierI_realtime}
    swath_latres=0.05
    swath_lonres=0.05
    swath_latpad=0.3
    swath_lonpad=0.3
/
'''.format(**repl))
    def product(self,name):
        """Convenience function that returns the product with the given name, or
raises KeyError if none is found."""
        return self._products[name]
    def wrfdiag_products(self,what=None):
        for name,prod in self._products.iteritems():
            assert(isinstance(name,basestring))
            part=name[0:7]
            if part=='wrfdiag':
                yield prod
    def products(self,what=None):
        """Returns Product objects describing files produced by this Task.
NOTE: This subroutine returns nothing.  There are no products produced by
this task that are consumed by later tasks in the same workflow.  If
that changes, we will need to implement the NHCProducts.products()
subroutine."""
        if what is not None:
            if what in self._products:
                yield self._products[what]
        else:
            for product in self._products.itervalues():
                yield product

    def run(self):
        self.state=RUNNING
        wrf=self._wrftask.wrf()
        moad=wrf.get(self._domains[0])
        inner=wrf.get(self._domains[-1])
        logger=logger=self.log()
        with TempDir(prefix='%s.'%(self.taskname,),dir=self.getdir("WORKhwrf"),logger=logger,keep=True):
            runme=self.getexe('hwrf_nhc_products')
            # Write the namelist:
            with open('products.nml','wt') as f:
                self.write_namelist(f,wrf,moad,inner)
            # Write the tcvitals:
            with open('tmpvit','wt') as f:
                if(isinstance(self._vitals,hwrf.storminfo.StormInfo)):
                    f.write("%s\n"%(self._vitals.as_tcvitals(),))
                else:
                    for vital in self._vitals:
                        f.write("%s\n"%(self._vitals.as_tcvitals(),))

            # Link to the three wrfdiag files:
            for domain in self._domains:
                (start,end,interval)=domain.get_output_range(self._stream) # get the last wrfdiag time
                orig=[x for x in self._wrftask.products(stream=self._stream,domains=[domain],time=end)] # get the wrfdiag file
                orig=orig[-1].location # get the path of the last wrfdiag file in that list
                here='wrfdiag_d%02d'%(domain.get_grid_id(),) # local filename needed by program
                make_symlink(orig,here,force=True,logger=logger) # make the symlink
                
            # Link to the coupling status file:
            make_symlink(os.path.join(self._wrftask.location,'MDstatus'),'MDstatus')
            checkrun(openmp(exe(runme),threads=self.confint('threads',int(os.environ.get('NHC_PRODUCTS_NTHREADS','1')))))
            self.deliver_outlist()
            self.state=COMPLETED
    def rewrite_swath_ctl(self,ctlfile):
        """Modifies the swath.ctl file to point to a lower-case swath.dat filename."""
        newfile='%s.lowerdat'%(ctlfile,)
        with open(ctlfile,'rt') as fi:
            with open(newfile,'wt') as fo:
                for line in fi:
                    print 'LINE: %s'%(line.rstrip(),)
                    m=re.match('^(.*DSET +\^)(.*)$',line)
                    if m:
                        print ' - contains DSET'
                        line="%s%s\n"%(m.group(1),m.group(2).lower())
                    else:
                        print ' - does not contain DSET'
                    fo.write(line)
        return newfile
    def deliver_outlist(self):
        """Reads the "outlist" output file from hwrf_nhc_products and delivers
the listed files to the com directory."""
        logger=self.log()
        outfiles=list()
        with open('outlist','rt') as outlist:
            for outfile in outlist:
                outfile=outfile.rstrip() # remove end-of-line character
                outfiles.append(outfile)
        for outfile in outfiles:
            bn=os.path.basename(outfile)
            if(re.search('\.swath.ctl',bn)):
                # Change swath.ctl file to lower-case, and change
                # the swath.dat filename inside to lower-case:
                newctl=self.rewrite_swath_ctl(outfile)
                with open(newctl,'rt') as f:
                    for line in f: 
                        line.rstrip()
                        logger.info('NEWCTL: '+repr(line))
                deliver_file(newctl,os.path.join(self.getdir('com'),outfile.lower()),keep=False,logger=logger)
            elif(re.search('\.(afos|stats.tpc|htcf|resolution|htcfstats)$',bn)):
                # Deliver these twice: once in original (upper)
                # case, and once in lower-case:
                assert(outfile.find('swath')<0)
                logger.info('%s: deliver twice: as upper- and lower-case'%(outfile,))
                deliver_file(outfile,os.path.join(self.getdir('com'),outfile.lower()),keep=True,logger=logger)
                deliver_file(outfile,os.path.join(self.getdir('com'),outfile),keep=False,logger=logger)
            elif(re.search('^a.*\.dat$',bn)):
                # Deliver these files in original case
                assert(outfile.find('swath')<0)
                deliver_file(outfile,os.path.join(self.getdir('com'),outfile),keep=False,logger=logger)
                logger.info('%s: deliver as upper-case'%(outfile,))
            else:
                # Deliver remaining files in lower-case:
                logger.info('%s: deliver with original case'%(bn,))
                deliver_file(outfile,os.path.join(self.getdir('com'),outfile.lower()),keep=False,logger=logger)
        for (name,prod) in self._products.iteritems():
            logger.warning('%s: checking for product at %s'%(prod.did,prod.location))
            prod.check(logger=logger) 
            # Note: check instead of deliver because this
            # is an UpstreamFile object.
        for prod in self.wrfdiag_products():
            dest=self.confstrinterp('{com}/{nidymdh}.{prodname}',prodname=prod.prodname)
            logger.info("%s: deliver to %s"%(prod.prodname,prod.location))
            prod.deliver(frominfo=prod.prodname,location=dest)
