import argparse from ast import parse from contextlib import contextmanager import time import ESMF import netCDF4 import numpy as np time_step = 3600 missing = -9999. @contextmanager def tasklabel(task_name): import builtins old_print = builtins.print did_print = False def child_print(*args, **kwargs): nonlocal did_print if not did_print: old_print() did_print = True if kwargs.get('end', None) is not None: old_print(f" ", end='') old_print(*args, **kwargs) builtins.print = child_print start = time.monotonic() try: old_print(f"{task_name}... ", end='', flush=True) builtins.print = child_print yield finally: end = time.monotonic() old_print(f"complete, in {end-start:.03g} seconds") builtins.print = old_print def main(psurge_file, hgrid_file, nc_out): with tasklabel("Generating PSURGE grid"): psurge_grid = ESMF.Grid(filename=psurge_file, filetype=ESMF.FileFormat.GRIDSPEC) # with tasklabel("Reading PSURGE data"): psurge_data = ESMF.Field(psurge_grid) # psurge_data.read(psurge_file, 'SURGE_surface', 1) with tasklabel("Generating HGRID open-boundary locstream"): # hgrid_mesh = ESMF.Mesh(filename=hgrid_file, filetype=ESMF.FileFormat.ESMFMESH) # hgrid_data = ESMF.Field(hgrid_mesh) # print(f"HGRID output mesh has size: {hgrid_data.data.size}", flush=True) with netCDF4.Dataset(hgrid_file) as hgrid_ds: coords = hgrid_ds['nodeCoords'][:] valid_indices = hgrid_ds['openBndNodes'][:] coords = [coords[i].tolist() for i in valid_indices] ob_lons, ob_lats = [[c[n] for c in coords] for n in (0, 1)] # ob_lons = [lon + 360 for lon in ob_lons] ob_locstream = ESMF.LocStream(len(ob_lons), coord_sys=ESMF.CoordSys.SPH_DEG) ob_locstream["ESMF:Lon"] = ob_lons ob_locstream["ESMF:Lat"] = ob_lats hgrid_data = ESMF.Field(ob_locstream) with tasklabel("Generating Regridder object"): regridder = ESMF.Regrid(psurge_data, hgrid_data, rh_filename="psurge-to-hgrid.rh", regrid_method=ESMF.RegridMethod.NEAREST_STOD) # write output with tasklabel("Writing regridded PSURGE data to SCHISM elevation time/height"): # filter output to only open-boundary nodes with netCDF4.Dataset(nc_out, "w", format="NETCDF4") as elev_out: with netCDF4.Dataset(psurge_file) as psurge_ds: _ = elev_out.createDimension("time", None) _ = elev_out.createDimension("nOpenBndNodes", len(coords)) _ = elev_out.createDimension("nLevels", 1) _ = elev_out.createDimension("nComponents", 1) _ = elev_out.createDimension("one", 1) time_step_var = elev_out.createVariable("time_step", "f8", ("one",)) time_var = elev_out.createVariable("time", "f8", ("time",)) time_series_var = elev_out.createVariable("time_series", "f8", ("time", "nOpenBndNodes", "nLevels", "nComponents"), fill_value=missing) time_var[:] = psurge_ds['time'][:] time_var.setncatts(psurge_ds['time'].__dict__) time_step_var[:] = np.array([psurge_ds['time'].time_step]) for t in range(len(time_var)): with tasklabel(f"Performing Regridding for timestep {t}"): surge_var = [v for v in psurge_ds.variables.keys() if v.startswith('SURGE')] if len(surge_var) == 0: raise KeyError(f"SURGE variable not found in {psurge_file}") psurge_data.data[...] = psurge_ds[surge_var[0]][t, :, :].filled(missing).T hgrid_data = regridder(psurge_data, hgrid_data) time_series_var[t, :, 0, 0] = hgrid_data.data[...] if __name__ == '__main__': with tasklabel("Regridding PSURGE to SCHISM"): parser = argparse.ArgumentParser() parser.add_argument('psurge_file', type=str, help='Input PSURGE netCDF file (from wgrib2)') parser.add_argument('hgrid_file', type=str, help='open_bnds_hgrid.nc file containing SCHISM mesh data') parser.add_argument('nc_out', type=str, help='Output elev2d.th.nc file') args = parser.parse_args() main(**vars(args))