#!/usr/bin/env python

from datetime import datetime, timedelta
import os
from time import monotonic
import argparse

import ESMF
import netCDF4
from cftime import num2date
import numpy as np

from mpi4py import MPI
comm = MPI.COMM_WORLD
comm.Set_errhandler(MPI.ERRORS_ARE_FATAL)

# ESMF.Manager(debug=True)

"""
Utility to use ESMF to regrid ESTOFS data files to the SCHISM elev2D.th.nc (time history) format

NOTE: input schism hgrid boundary file created using gr3_2_esmf.py with the --filter_open_bnds flag
./gr3_2_esmf.py --filter_open_bnds hgrid.gr3 schism.hgrid.nc

Usage: ./regrid_estofs.py  /glade/scratch/rcabell/coastal/estofs.t00z.fields.cwl.nc /glade/scratch/bpetzke/ForcingEngine/gr3_2_esmf/prvi.schism.hgrid.nc /glade/scratch/bpetzke/ForcingEngine/regrid_estofs/prvi.estofs.t00z.fields.cwl.regrid.nc
"""


local_pet = ESMF.local_pet()
pet_count = ESMF.pet_count()

schism_coord_name = "nodeCoords"
schism_open_boundary_name = "openBndNodes"
x_name = "x"
y_name = "y"
time_name = "time"
elem_name = "element"
time_step = 3600
missing = -9999.

TEST = False
BND_BUFFER = 0.25


def create_locstream(lons, lats, name=""):
    # https://earthsystemmodeling.org/esmpy_doc/release/ESMF_8_1_0/html/examples.html#locstream-create

    assert len(lons) == len(lats)

    count = len(lons) // pet_count
    if count * pet_count < len(lons):
        if local_pet == (pet_count - 1):
            count += (len(lons) - (count * pet_count))

    locstream = ESMF.LocStream(count, coord_sys=ESMF.CoordSys.SPH_DEG)

    lbounds = locstream.lower_bounds[0]
    ubounds = locstream.upper_bounds[0]

    locstream["ESMF:Lon"] = lons[lbounds: ubounds]
    locstream["ESMF:Lat"] = lats[lbounds: ubounds]

    return locstream


def regrid_chunk(beg_t, end_t, x, y, nt, in_field, coords):
    # print("Starting regrid_chunk kernel", flush=True)

    bnd_lons, bnd_lats = [[c[n] for c in coords] for n in (0, 1)]  # coords[:, 0], coords[:, 1]
    regrid = None

    mesh_in = create_locstream(x, y, "mesh_in")
    locstream_out = create_locstream(bnd_lons, bnd_lats, "bnd_out")

    field_from = ESMF.Field(mesh_in, name="EstofsIn")
    field_to = ESMF.Field(locstream_out, name="OpenBoundary")

    output = np.empty((nt, len(coords)))
    output[:] = 0

    i_lbounds = mesh_in.lower_bounds[0]
    i_ubounds = mesh_in.upper_bounds[0]

    o_lbounds = locstream_out.lower_bounds[0]
    o_ubounds = locstream_out.upper_bounds[0]

    for t in range(beg_t, end_t):
        start = monotonic()
        if local_pet == 0:
            print(f"Regridding ESTOFS, t = {t-beg_t} ", flush=True, end='')

        data = in_field[t][i_lbounds:i_ubounds]          # only read once
        field_from.data[...] = data
        mesh_in["ESMF:Mask"] = data.mask.astype('i4')    # 0 is unmasked, 1 is masked

        field_to.data[...] = missing

        method = ESMF.RegridMethod.NEAREST_STOD
        regrid = ESMF.Regrid(srcfield=field_from, dstfield=field_to, regrid_method=method,
                             unmapped_action=ESMF.UnmappedAction.IGNORE, src_mask_values=[1])

        field_regridded = regrid(field_from, field_to)
        output[t-beg_t][o_lbounds:o_ubounds] = field_regridded.data[...]
        if local_pet == 0:
            print(f"in {monotonic()-start:.2f} sec")
    # mesh_in.destroy()
    # locstream_out.destroy()

    return output


def regrid(nc_in, nc_grid, nc_out, regrid_field):
    FORECAST_START = 5

    if local_pet == 0:
        print(f"Reading SCHISM open boundary points...", flush=True)
    with netCDF4.Dataset(nc_grid) as f_in:
        coords = f_in[schism_coord_name][:]
        valid_indices = f_in[schism_open_boundary_name][:]

    if local_pet == 0:
        print(f"Reading ESTOFS input...", flush=True)
    with netCDF4.Dataset(nc_in) as f_in:
        start = FORECAST_START      # ignore spinup

        etime_var = f_in[time_name]
        estart = num2date(etime_var[FORECAST_START], units=etime_var.units)

        # round up non-hourly ESTOFS start times
        if estart.minute != 0:
            estart = datetime(estart.year, estart.month, estart.day, estart.hour)
            estart += timedelta(hours=1)
        if local_pet == 0:
            print(f"ESTOFS date is {estart}")

        cdate = os.environ["CYCLE_DATE"]
        ctime = os.environ["CYCLE_TIME"]
        total_hours = int(os.environ.get('LENGTH_HRS', 180)) + 1
        fdate = datetime.strptime(cdate+ctime, "%Y%m%d%H%M")
        if local_pet == 0:
            print(f"Forecast date is {fdate}")

        dt_h = int((fdate-estart).total_seconds() / 3600)
        start += dt_h
        if local_pet == 0:
            print(f"ESTOFS lags forecast by {dt_h} hours, forecast length is {total_hours-1} hours")

        x = f_in[x_name][:]
        y = f_in[y_name][:]
        times = f_in[time_name][start:start+total_hours]
        t = len(times)
        input_time_atts = f_in[time_name].__dict__
        in_field = f_in[regrid_field] 

        coords = [coords[i].tolist() for i in valid_indices]
        output_local = regrid_chunk(start, t+start, x, y, t, in_field, coords)

    if local_pet == 0:
        output = np.zeros((t, len(coords)))
    else:
        output = None
    comm.Reduce(output_local, output, op=MPI.SUM)

    if local_pet == 0:
        with netCDF4.Dataset(nc_out, "w", format="NETCDF4") as f_out:
            f_out.createDimension("time", None)
            f_out.createDimension("nOpenBndNodes", len(coords))
            f_out.createDimension("nLevels", 1)
            f_out.createDimension("nComponents", 1)
            f_out.createDimension("one", 1)

            time_step_var = f_out.createVariable("time_step", "f8", ("one",))
            time_var = f_out.createVariable("time", "f8", ("time",))
            time_var.setncatts(input_time_atts)
            time_var.start_time = times[0]
            time_series_var = f_out.createVariable("time_series", "f8", ("time", "nOpenBndNodes", "nLevels", "nComponents"),
                                                   fill_value=missing, zlib=True)

            time_step_var[:] = np.array([time_step])
            time_var[:] = np.arange(0, len(times)*time_step, time_step)
            # print(time_var)
            print(f"Writing SCHISM elevation file...", flush=True)
            # for chunk in regridded_chunks:
            for t in range(len(times)):
                # t, data = output
                data = output[t]
                data = np.where(data > missing, data, 0)
                # print(f"Outputting t = {t}")
                time_series_var[t, :, 0, 0] = data
            print(f"Regridding complete")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('estofs_input', type=str, help='Input .nc file')
    parser.add_argument('schism_grid', type=str, help='.nc file containing schism coordinates')
    parser.add_argument('regrid_output', type=str, help='Output .nc file')
    args = parser.parse_args()

    if local_pet == 0:
        print(f"Starting {parser.prog}", flush=True)
    regrid(args.estofs_input, args.schism_grid, args.regrid_output, "zeta")
    if local_pet == 0:
        print(f"Ending {parser.prog}", flush=True)


if __name__ == '__main__':
    main()