#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Apr  8 19:47:37 2021

@author: Camaron.George
"""
import os
import pygrib
import numpy as np
import netCDF4 as nc
from scipy.interpolate import griddata

#path to location of data (will need a path for the grid file, the psurge file, the estofs file, and the location where the simulation will take place)
path = '/scratch2/NCEPDEV/ohd/Camaron.George/Data/Laura/'
#name of psurge file for each simulation
file = 'psurge.t2020082600z.al132020_e10_inc_dat.h102.conus_625m.grib2'
#name of file with interpolated estofs data from NCAR
est = 'estofs.t00z.fields.cwl.regrid.nc'
#name of output file
out = 'elev2D.th.nc'

#check if psurge file exits; interpolate and apply psurge data if it does, or move/rename interpolated estofs file created by NCAR to apporopriate folder
if os.path.exists(path+file):
    #read in grid file to get list of boundary nodes and their lon/lat locations
    gridFile = '/scratch2/NCEPDEV/ohd/Camaron.George/Scripts/ModelScripts/hgrid.gr3'
    lon = []
    lat = []
    elems = []
    bnodes = []
    with open(gridFile) as f:
        next(f)
        line = f.readline()
        ne = int(line.split()[0])
        nn = int(line.split()[1])
        for i in range(nn):
            line = f.readline()
            lon.append(float(line.split()[1]))
            lat.append(float(line.split()[2]))
        for i in range(ne):
            line = f.readline()
            elems.append(line)
        next(f)
        next(f)
        line = f.readline()
        nb = int(line.split()[0])
        for j in range(nb):
            bnodes.append(int(f.readline()))
    
    lo = [lon[b-1] for b in bnodes]
    la = [lat[b-1] for b in bnodes]
    
    #define array for water elevation values that will last 10 days
    Z = np.zeros((241,len(lo),1,1))
    
    #open and read interpolated estofs data
    data = nc.Dataset(path+est,'r')
    t = data.variables['time'][:]
    tStep = data.variables['time_step'][:]
    z = data.variables['time_series'][:]
    
    #assign estofs water level data to the first 7.5 days of the final dataset
    Z[:z.shape[0],:,:,:] = z
    
    #read in psurge data, interpolate onto the boundary nodes where it exists
    f = pygrib.open(path+file)
    msg = f.read()
    lt,ln = msg[0].latlons()
    for i in range(len(msg)):
        z1 = msg[i].values
        I = np.where(z1.mask == False)
        x1 = ln[I[0],I[1]]
        y1 = lt[I[0],I[1]]
        z1 = z1[I[0],I[1]]
        I = np.where((lo < np.max(x1)) & (lo > np.min(x1)) & (la < np.max(y1)) & (la > np.min(y1)))
        x2 = [lo[I[0][a]] for a in range(len(I[0]))]
        y2 = [la[I[0][a]] for a in range(len(I[0]))]
        z2 = griddata((x1,y1),z1,(x2,y2),method='nearest')
        for j in range(len(I[0])):
            Z[i+1,I[0][j],:,:] = z2[j]
    
    # open a netCDF file to write
    ncout = nc.Dataset(path+out,'w',format='NETCDF4')
    
    # define axis size
    ncout.createDimension('time',None)
    ncout.createDimension('nOpenBndNodes',len(lo))
    ncout.createDimension('nLevels',1)
    ncout.createDimension('nComponents',1)
    ncout.createDimension('one',1)
    
    # create time step variable
    nctstep = ncout.createVariable('time_step','f8',('one',))
    
    # create time axis
    nctime = ncout.createVariable('time','f8',('time',))
    
    # create water level time series
    ncwl = ncout.createVariable('time_series','f8',('time','nOpenBndNodes','nLevels','nComponents',))
    
    # copy axis from original dataset
    nctstep[:] = tStep
    nctime[:] = t
    ncwl[:] = Z
    
    ncout.close()
else:
    #copy interpolated ESTOFS file to folder where SCHISM runs will be done and rename it to elev2D.th.nc