import argparse
from ast import parse
from contextlib import contextmanager
import time

import ESMF
import netCDF4
import numpy as np

time_step = 3600
missing = -9999.


@contextmanager
def tasklabel(task_name):
    import builtins
    old_print = builtins.print
    did_print = False

    def child_print(*args, **kwargs):
        nonlocal did_print
        if not did_print:
            old_print()
            did_print = True
        if kwargs.get('end', None) is not None:
            old_print(f"    ", end='')
        old_print(*args, **kwargs)
    builtins.print = child_print

    start = time.monotonic()
    try:
        old_print(f"{task_name}... ", end='', flush=True)
        builtins.print = child_print
        yield
    finally:
        end = time.monotonic()
        old_print(f"complete, in {end-start:.03g} seconds")
        builtins.print = old_print


def main(psurge_file, hgrid_file, nc_out):
    with tasklabel("Generating PSURGE grid"):
        psurge_grid = ESMF.Grid(filename=psurge_file, filetype=ESMF.FileFormat.GRIDSPEC)

    # with tasklabel("Reading PSURGE data"):
        psurge_data = ESMF.Field(psurge_grid)
    #    psurge_data.read(psurge_file, 'SURGE_surface', 1)

    with tasklabel("Generating HGRID open-boundary locstream"):
        # hgrid_mesh = ESMF.Mesh(filename=hgrid_file, filetype=ESMF.FileFormat.ESMFMESH)
        # hgrid_data = ESMF.Field(hgrid_mesh)
        # print(f"HGRID output mesh has size: {hgrid_data.data.size}", flush=True)

        with netCDF4.Dataset(hgrid_file) as hgrid_ds:
            coords = hgrid_ds['nodeCoords'][:]
            valid_indices = hgrid_ds['openBndNodes'][:]

        coords = [coords[i].tolist() for i in valid_indices]
        ob_lons, ob_lats = [[c[n] for c in coords] for n in (0, 1)]
        # ob_lons = [lon + 360 for lon in ob_lons]
        ob_locstream = ESMF.LocStream(len(ob_lons), coord_sys=ESMF.CoordSys.SPH_DEG)
        ob_locstream["ESMF:Lon"] = ob_lons
        ob_locstream["ESMF:Lat"] = ob_lats

        hgrid_data = ESMF.Field(ob_locstream)

    with tasklabel("Generating Regridder object"):
        regridder = ESMF.Regrid(psurge_data, hgrid_data, rh_filename="psurge-to-hgrid.rh",
                                regrid_method=ESMF.RegridMethod.NEAREST_STOD)

    # write output
    with tasklabel("Writing regridded PSURGE data to SCHISM elevation time/height"):
        # filter output to only open-boundary nodes
        with netCDF4.Dataset(nc_out, "w", format="NETCDF4") as elev_out:
            with netCDF4.Dataset(psurge_file) as psurge_ds:
                _ = elev_out.createDimension("time", None)
                _ = elev_out.createDimension("nOpenBndNodes", len(coords))
                _ = elev_out.createDimension("nLevels", 1)
                _ = elev_out.createDimension("nComponents", 1)
                _ = elev_out.createDimension("one", 1)

                time_step_var = elev_out.createVariable("time_step", "f8", ("one",))
                time_var = elev_out.createVariable("time", "f8", ("time",))
                time_series_var = elev_out.createVariable("time_series", "f8", ("time", "nOpenBndNodes", "nLevels", "nComponents"),
                                                          fill_value=missing)

                time_var[:] = psurge_ds['time'][:]
                time_var.setncatts(psurge_ds['time'].__dict__)

                time_step_var[:] = np.array([psurge_ds['time'].time_step])

                for t in range(len(time_var)):
                    with tasklabel(f"Performing Regridding for timestep {t}"):
                        surge_var = [v for v in psurge_ds.variables.keys() if v.startswith('SURGE')]
                        if len(surge_var) == 0:
                            raise KeyError(f"SURGE variable not found in {psurge_file}")
                        psurge_data.data[...] = psurge_ds[surge_var[0]][t, :, :].filled(missing).T
                        hgrid_data = regridder(psurge_data, hgrid_data)
                        time_series_var[t, :, 0, 0] = hgrid_data.data[...]


if __name__ == '__main__':
    with tasklabel("Regridding PSURGE to SCHISM"):
        parser = argparse.ArgumentParser()
        parser.add_argument('psurge_file', type=str, help='Input PSURGE netCDF file (from wgrib2)')
        parser.add_argument('hgrid_file', type=str, help='open_bnds_hgrid.nc file containing SCHISM mesh data')
        parser.add_argument('nc_out', type=str, help='Output elev2d.th.nc file')
        args = parser.parse_args()

        main(**vars(args))