#!/usr/bin/env python3
#
# Purpose: To convert tdlpack data into NetCDF for the AWC_DATA job
# History:  
#          Jun. 2022   Schnapp    -  newly created for LMP 
# Usage:
#          python lmp_tdlp2nc.py $tdlpack_filename
#
import datetime
import sys
import pandas as pd
import pytdlpack
import xarray as xr
from dataclasses import dataclass

#arrgv1 is the PDY of the date for which data is collected; It is on filed lebeled with PDY_p1
#PDY = sys.argv[1]
#dt = datetime.datetime.strptime(PDY, '%Y%m%d')
#PDY_p1 = (dt + datetime.timedelta(days=1)).strftime('%Y%m%d')
tdlp_file = sys.argv[1]
print(tdlp_file)

@dataclass
class TdlpId:
    word1: int
    word2: int
    word3: int
    word4: int

    @property
    def ccc(self):
        return  self.word1 // 1000000
    @property
    def fff(self):
        return  self.word1 % 1_000_000 // 1000
    @property
    def b(self):
        return  self.word1 % 1000 // 100
    @property
    def dd(self):
        return  self.word1 % 100
    @property
    def rr(self):
        return self.word3 % 100_000_000 // 1_000_000
    @property
    def ttt(self):
        return self.word3 % 1_000
    @property
    def w(self):
        return  self.word4 // 1_000_000_000
    @property
    def xxxx(self):
        return  self.word4 % 1_000_000_000 // 100_000
    @property
    def yy(self):
        return self.word4 % 100_000 // 1000
    @property
    def isg(self):
        return self.word4 % 1000
    @property
    def i(self):
         return self.word4 % 1000 // 100
    @property
    def s(self):
        return self.word4 % 100 // 10
    @property
    def g(self):
        return self.word4 % 10
    @property
    def thresh(self):
        return f'{self.w}.{self.xxxx:04}E{self.yy:02}'


def south_to_negative(lat):
    if 'N' in lat:
       lat = lat.strip('N')
       return float(lat)
    elif 'S' in lat:
      lat = lat.strip('S')
      return float(lat) * -1

def west_to_negative(lon):
    if 'E' in lon:
        lon = lon.strip('E')
        return float(lon)
    elif 'W' in lon:
       lon = lon.strip('W')
       return float(lon) * -1


def from_mos2ktbl(tbl):
    df = pd.read_csv(tbl, sep=':', usecols=[0,1,2,3,5,6,7,18],
            names=['call', 'link', 'name', 'state', 'elev', 'lat', 'lon', 'comment'], quoting=3)
            # quoting = 3 prevents unclosed quotes from blocking parse on sep and \n
    df['call'] = df['call'].str.strip()
    df['lat'] = df['lat'].apply(south_to_negative)
    df['lon'] = df['lon'].apply(west_to_negative)
    return df

long_name = dict()
long_name[228161095] = 'visibility'
long_name[228081095] = 'cloud_ceiling_AGL'
long_name[204225005] = 'wind_direction'
long_name[204325005] = 'wind_speed'
long_name[204355005] = 'wind_gust'
long_name[208381035] = 'sky_cover'

unit_dict = dict()
unit_dict[228161095] = 'miles'
unit_dict[228081095] = '100 ft'
unit_dict[204225005] = 'degree'
unit_dict[204325005] = 'knots'
unit_dict[204355005] = 'knots'
unit_dict[208381035] = 'category'

#names = ['visibility', 'cloud_ceiling_AGL', 'wind speed']
names = list(long_name.values())
variables = dict()
for name in names:
    variables[name] = list()

station_tbl = 'lmp_station.tbl'
station_directory = from_mos2ktbl(station_tbl)
with pytdlpack.open(tdlp_file) as f:
    while True:
        rec = f.read()
        if rec is None:  # end of file
            break
        if isinstance(rec, pytdlpack.TdlpackStationRecord):
            rec.unpack()
            stations = station_directory.set_index('call').loc[rec.stations].reset_index()
            continue

        # inspect id and map to standard_name
        rec_id = TdlpId(*rec.id)
        if rec_id.word1 not in long_name:
           continue
        name = long_name[rec_id.word1] # to do!  pass if don't care about this record
        rec.unpack(data=True)

        # make DataArray
        da = xr.DataArray(rec.data, dims='station')

        # assign names
        da.name = name
        da.attrs['long_name'] = name

        # assign units
        units = unit_dict[rec_id.word1]
       # units = function_to_return_units_from_this_id()
        da.attrs['units'] = units

        # assign lat and lon
        lats, lons = stations.lat, stations.lon
        lat = xr.DataArray(lats, dims='station')
        lat.attrs['standard_name'] = 'latitude'
        lon = xr.DataArray(lons, dims='station')
        lon.attrs['standard_name'] = 'longitude'
        station_call = xr.DataArray(stations.call, dims='station')
        station_call.attrs['standard_name'] = 'platform_id'
        da = da.assign_coords({'longitude':lon, 'latitude':lat, 'station_call':station_call})

        # assign some additional metadata as attrs
        rec_id = TdlpId(*rec.id)
        attrs = dict(mos2k_word1=str(rec_id.word1),
                mos2k_word2=str(rec_id.word2),
                plain=str(rec.plain))
           #     grid_def=str(rec.grid_def),
           #     proj_string=str(rec.proj_string))
        da.attrs.update(attrs)

        # if variable to contain only observations/00-hr analysis; then include only time, otherwise include cf forecast_reference_time(initialization/cycle_time) and forecast_period(tau/forecast_projection/lead_time)
        dt = datetime.datetime.strptime(str(rec.reference_date),"%Y%m%d%H")
        dt = dt - datetime.timedelta(hours=int(rec_id.rr))
        ttt = datetime.timedelta(hours=int(rec_id.ttt))
        da = da.expand_dims('lead_time')
        lead_time = xr.DataArray([ttt], dims='lead_time')
        lead_time.attrs['standard_name'] = 'forecast_period'
        da = da.assign_coords({'lead_time':lead_time})
        da = da.expand_dims('reference_time')
        reference_time = xr.DataArray([dt], dims='reference_time')
        reference_time.attrs['standard_name'] = 'forecast_reference_time'
        da = da.assign_coords({'reference_time':reference_time})

        variables[name].append(da)


for name, list_of_array in variables.items():
    # make a single dataarray from list of arrays
    variables[name] = xr.merge(variables[name])[name]

ds = xr.Dataset(variables)

# update variable encoding to include compression
#encoding = dict(zlib=True, shuffle=True, complevel=4)
encoding = dict(zlib=False, shuffle=True, complevel=4)
for var in ds.variables:
    ds[var].encoding.update(encoding)

ds.to_netcdf(f'output_file.nc')
print(ds)