import argparse
from datetime import datetime, timedelta

from netCDF4 import Dataset
from cftime import num2date, date2num
import numpy as np


missing = -9999.


def time_align_psurge_estofs_elevs(estofs_file, psurge_file, final_elev_out):
    pds = Dataset(psurge_file)
    eds = Dataset(estofs_file)

    ptime_var = pds['time']
    pstart = num2date(ptime_var[0], units=ptime_var.units)

    etime_var = eds['time']
    estart = num2date(etime_var.start_time, units=etime_var.units)
    # round up non-hourly ESTOFS start times
    if estart.minute != 0:
        estart = datetime(estart.year, estart.month, estart.day, estart.hour)
        estart += timedelta(hours=1)

    # print(f"estofs start= {estart}, psurge start = {pstart}")
    # find psurge offset
    if pstart > estart:
        psurge_offset = (pstart - estart).seconds
    else:
        psurge_offset = -1 * (estart - pstart).seconds
    # print(f"psurge_offset={psurge_offset}")

    with Dataset(final_elev_out, "w", format="NETCDF4") as elev_out:
        elev_out.createDimension("time", None)
        elev_out.createDimension("nOpenBndNodes", len(eds.dimensions['nOpenBndNodes']))
        elev_out.createDimension("nLevels", 1)
        elev_out.createDimension("nComponents", 1)
        elev_out.createDimension("one", 1)

        time_step_var = elev_out.createVariable("time_step", "f8", ("one",))
        time_var = elev_out.createVariable("time", "f8", ("time",))
        time_series_var = elev_out.createVariable("time_series", "f8", ("time", "nOpenBndNodes", "nLevels", "nComponents"),
                                                  fill_value=missing)

        time_var[:] = eds['time'][:]
        time_var.setncatts(etime_var.__dict__)

        time_step_var[:] = np.array(ptime_var.time_step)
        psurge_offset_idx = int(psurge_offset / ptime_var.time_step)
        print(f"psurge_offset_idx = {psurge_offset_idx}")

        base_data = eds['time_series']
        psurge_data = pds['time_series']
        for t in range(len(time_var)):
            base = base_data[t]
            ps_idx = t - psurge_offset_idx
            merged = base
            if ps_idx >= 0 and ps_idx < len(psurge_data):
                merged = psurge_data[ps_idx, :].filled(base)
                print(f"overlaying estofs timestep {t} with psurge timestep {ps_idx}")

            time_series_var[t, :, 0, 0] = merged


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('estofs_file', type=str, help='Input ESTOFS-derived elev2d.th.nc file')
    parser.add_argument('psurge_file', type=str, help='Input PSURGE-derived elev2d.th.nc file')
    parser.add_argument('final_elev_out', type=str, help='Output elev2d.th.nc file')
    args = parser.parse_args()

    time_align_psurge_estofs_elevs(**vars(args))