#!/usr/bin/env python3 ''' Name: subseasonal_plots_performance_diagram.py Contact(s): Shannon Shields Abstract: This script is run by subseasonal_plots.py in ush/subseasonal. This script generates a performance diagram plot. ''' import sys import os import logging import datetime import glob import subprocess import pandas as pd pd.plotting.deregister_matplotlib_converters() import numpy as np import matplotlib matplotlib.use('agg') import matplotlib.pyplot as plt import matplotlib.dates as md import subseasonal_util as sub_util from subseasonal_plots_specs import PlotSpecs class PerformanceDiagram: """ Create a performance diagram graphic """ def __init__(self, logger, input_dir, output_dir, model_info_dict, date_info_dict, plot_info_dict, met_info_dict, logo_dir): """! Initialize PerformanceDiagram class Args: logger - logger object input_dir - path to input directory (string) output_dir - path to output directory (string) model_info_dict - model information dictionary (strings) plot_info_dict - plot information dictionary (strings) date_info_dict - date information dictionary (strings) met_info_dict - MET information dictionary (strings) logo_dir - directory with logo images (string) Returns: """ self.logger = logger self.input_dir = input_dir self.output_dir = output_dir self.model_info_dict = model_info_dict self.date_info_dict = date_info_dict self.plot_info_dict = plot_info_dict self.met_info_dict = met_info_dict self.logo_dir = logo_dir def make_performance_diagram(self): """! Create the performance diagram graphic Args: Returns: """ self.logger.info(f"Creating performance diagrams...") self.logger.debug(f"Input directory: {self.input_dir}") self.logger.debug(f"Output directory: {self.output_dir}") self.logger.debug(f"Model information dictionary: " +f"{self.model_info_dict}") self.logger.debug(f"Date information dictionary: " +f"{self.date_info_dict}") self.logger.debug(f"Plot information dictionary: " +f"{self.plot_info_dict}") # Check stat if self.plot_info_dict['stat'] != 'PERF_DIA': self.logger.warning("Cannot make performance diagram for stat " +f"{self.plot_info_dict['stat']}") sys.exit(0) # Make job image directory output_image_dir = os.path.join(self.output_dir, 'images') if not os.path.exists(output_image_dir): os.makedirs(output_image_dir) self.logger.info(f"Plots will be in: {output_image_dir}") # Set stats to calculate for diagram perf_dia_stat_list = ['SRATIO', 'POD', 'CSI'] # Get dates to plot self.logger.info("Creating valid and init date arrays") valid_dates, init_dates = sub_util.get_plot_dates( self.logger, self.date_info_dict['date_type'], self.date_info_dict['start_date'], self.date_info_dict['end_date'], self.date_info_dict['valid_hr_start'], self.date_info_dict['valid_hr_end'], self.date_info_dict['valid_hr_inc'], self.date_info_dict['init_hr_start'], self.date_info_dict['init_hr_end'], self.date_info_dict['init_hr_inc'], self.date_info_dict['forecast_hour'] ) format_valid_dates = [valid_dates[d].strftime('%Y%m%d_%H%M%S') \ for d in range(len(valid_dates))] format_init_dates = [init_dates[d].strftime('%Y%m%d_%H%M%S') \ for d in range(len(init_dates))] if self.date_info_dict['date_type'] == 'VALID': self.logger.debug("Based on date information, plot will display " +"valid dates "+', '.join(format_valid_dates)+" " +"for forecast hour " +f"{self.date_info_dict['forecast_hour']} " +"with initialization dates " +', '.join(format_init_dates)) plot_dates = valid_dates elif self.date_info_dict['date_type'] == 'INIT': self.logger.debug("Based on date information, plot will display " +"initialization dates " +', '.join(format_init_dates)+" " +"for forecast hour " +f"{self.date_info_dict['forecast_hour']} " +"with valid dates " +', '.join(format_valid_dates)) plot_dates = init_dates # Read in data self.logger.info(f"Reading in model stat files from {self.input_dir}") # Create dataframe for all thresholds self.logger.info("Building dataframe for all thresholds") fcst_units = [] for fcst_var_thresh in self.plot_info_dict['fcst_var_threshs']: self.logger.debug("Building data for forecast threshold " +f"{fcst_var_thresh}") fcst_var_thresh_idx = (self.plot_info_dict['fcst_var_threshs']\ .index(fcst_var_thresh)) obs_var_thresh = (self.plot_info_dict['obs_var_threshs']\ [fcst_var_thresh_idx]) all_model_df = sub_util.build_df( self.logger, self.input_dir, self.output_dir, self.model_info_dict, self.met_info_dict, self.plot_info_dict['fcst_var_name'], self.plot_info_dict['fcst_var_level'], fcst_var_thresh, self.plot_info_dict['obs_var_name'], self.plot_info_dict['obs_var_level'], obs_var_thresh, self.plot_info_dict['line_type'], self.plot_info_dict['grid'], self.plot_info_dict['vx_mask'], self.plot_info_dict['interp_method'], self.plot_info_dict['interp_points'], self.date_info_dict['date_type'], plot_dates, format_valid_dates, self.date_info_dict['forecast_hour'] ) fcst_units.extend( all_model_df['FCST_UNITS'].values.astype('str').tolist() ) model_idx_list = ( all_model_df.index.get_level_values(0).unique().tolist() ) if fcst_var_thresh == self.plot_info_dict['fcst_var_threshs'][0]: perf_dia_stat_avg_df = pd.DataFrame( np.nan, pd.MultiIndex.from_product( [model_idx_list, perf_dia_stat_list], names=['model', 'stat'] ), columns=self.plot_info_dict['fcst_var_threshs'] ) # Calculate statistics mean for stat in perf_dia_stat_list: self.logger.info(f"Calculating statistic {stat} from line type " +f"{self.plot_info_dict['line_type']}") stat_df, stat_array = sub_util.calculate_stat( self.logger, all_model_df, self.plot_info_dict['line_type'], stat ) model_idx_list = ( stat_df.index.get_level_values(0).unique().tolist() ) if self.plot_info_dict['event_equalization'] == 'YES': self.logger.debug("Doing event equalization") masked_stat_array = np.ma.masked_invalid(stat_array) stat_array = np.ma.mask_cols(masked_stat_array) stat_array = stat_array.filled(fill_value=np.nan) for model_idx in model_idx_list: model_idx_num = model_idx_list.index(model_idx) stat_df.loc[model_idx] = stat_array[model_idx_num,:] all_model_df.loc[model_idx] = ( all_model_df.loc[model_idx].where( stat_df.loc[model_idx].notna() ).values) for model_idx in model_idx_list: model_idx_num = model_idx_list.index(model_idx) if self.plot_info_dict['line_type'] in ['CNT', 'GRAD', 'CTS', 'NBRCTS', 'NBRCNT', 'VCNT']: avg_method = 'mean' calc_avg_df = stat_df.loc[model_idx] else: avg_method = 'aggregation' calc_avg_df = all_model_df.loc[model_idx] model_idx_fcst_var_thresh_avg = sub_util.calculate_average( self.logger, avg_method, self.plot_info_dict['line_type'], stat, calc_avg_df ) if not np.isnan(model_idx_fcst_var_thresh_avg): perf_dia_stat_avg_df.loc[(model_idx,stat), fcst_var_thresh] = ( model_idx_fcst_var_thresh_avg ) # Set up plot self.logger.info(f"Doing plot set up") plot_specs_pd = PlotSpecs(self.logger, 'performance_diagram') plot_specs_pd.set_up_plot() csi_colors = ['#ffffff', '#f5f5f5', '#ececec', '#dfdfdf', '#cbcbcb', '#b2b2b2','#8e8e8e', '#6f6f6f', '#545454', '#3f3f3f'] cmap_csi = matplotlib.colors.ListedColormap(csi_colors) pd_ticks = np.arange(0.001, 1.001, 0.001) pd_sr, pd_pod = np.meshgrid(pd_ticks, pd_ticks) pd_bias = pd_pod / pd_sr pd_csi = 1.0 / (1.0 / pd_sr + 1.0 / pd_pod - 1.0) pd_bias_clevs = [0.1, 0.2, 0.4, 0.6, 0.8, 1., 1.2, 1.5, 2., 3., 5., 10.] stat_plot_name = plot_specs_pd.get_stat_plot_name( self.plot_info_dict['stat'] ) POD_plot_name = plot_specs_pd.get_stat_plot_name('POD') SRATIO_plot_name = plot_specs_pd.get_stat_plot_name('SRATIO') CSI_plot_name = plot_specs_pd.get_stat_plot_name('CSI') fcst_units = np.unique(fcst_units) fcst_units = np.delete(fcst_units, np.where(fcst_units == 'nan')) if len(fcst_units) > 1: self.logger.error("FATAL ERROR, DIFFERING UNITS") sys.exit(1) elif len(fcst_units) == 0: self.logger.warning("Empty dataframe") fcst_units = [''] plot_title = plot_specs_pd.get_plot_title( self.plot_info_dict, self.date_info_dict, fcst_units[0] ) plot_left_logo = False plot_left_logo_path = os.path.join(self.logo_dir, 'noaa.png') if os.path.exists(plot_left_logo_path): plot_left_logo = True left_logo_img_array = matplotlib.image.imread( plot_left_logo_path ) left_logo_xpixel_loc, left_logo_ypixel_loc, left_logo_alpha = ( plot_specs_pd.get_logo_location( 'left', plot_specs_pd.fig_size[0], plot_specs_pd.fig_size[1], plt.rcParams['figure.dpi'] ) ) plot_right_logo = False plot_right_logo_path = os.path.join(self.logo_dir, 'nws.png') if os.path.exists(plot_right_logo_path): plot_right_logo = True right_logo_img_array = matplotlib.image.imread( plot_right_logo_path ) right_logo_xpixel_loc, right_logo_ypixel_loc, right_logo_alpha = ( plot_specs_pd.get_logo_location( 'right', plot_specs_pd.fig_size[0], plot_specs_pd.fig_size[1], plt.rcParams['figure.dpi'] ) ) image_name = plot_specs_pd.get_savefig_name( output_image_dir, self.plot_info_dict, self.date_info_dict ) self.logger.info(f"Creating performance diagram") fig, ax = plt.subplots(1,1, figsize=(plot_specs_pd.fig_size[0], plot_specs_pd.fig_size[1])) fig.suptitle(plot_title) ax.grid(False) ax.set_xlabel(SRATIO_plot_name) ax.set_xlim([0,1]) ax.set_xticks(np.arange(0,1.1,0.1)) ax.set_ylabel(POD_plot_name) ax.set_ylim([0,1]) ax.set_yticks(np.arange(0,1.1,0.1)) if plot_left_logo: left_logo_img = fig.figimage( left_logo_img_array, left_logo_xpixel_loc, left_logo_ypixel_loc, zorder=1, alpha=right_logo_alpha ) left_logo_img.set_visible(True) if plot_right_logo: right_logo_img = fig.figimage( right_logo_img_array, right_logo_xpixel_loc, right_logo_ypixel_loc, zorder=1, alpha=right_logo_alpha ) CBIAS = plt.contour(pd_sr, pd_pod, pd_bias, pd_bias_clevs, colors='gray', linestyles='dashed') radius = 0.75 CBIAS_label_loc = [] for bias_val in pd_bias_clevs: x = np.sqrt(np.power(radius, 2)/(np.power(bias_val, 2)+1)) y = np.sqrt(np.power(radius, 2) - np.power(x, 2)) CBIAS_label_loc.append((x,y)) plt.clabel(CBIAS, fmt='%1.1f', manual=CBIAS_label_loc) CFCSI = plt.contourf(pd_sr, pd_pod, pd_csi, np.arange(0., 1.1, 0.1), cmap=cmap_csi, extend='neither') cbar_left = ax.get_position().x1 + 0.05 cbar_bottom = ax.get_position().y0 cbar_width = 0.01 cbar_height = ax.get_position().y1 - ax.get_position().y0 cbar_ax = fig.add_axes( [cbar_left, cbar_bottom, cbar_width, cbar_height] ) cbar = plt.colorbar(CFCSI, orientation='vertical', cax=cbar_ax, ticks=CFCSI.levels) cbar.set_label(CSI_plot_name) f = lambda m,c,ls,lw,ms,mec: plt.plot( [], [], marker=m, mec=mec, mew=2., c=c, ls=ls, lw=lw, ms=ms)[0] thresh_marker_plot_settings_dict = ( plot_specs_pd.get_marker_plot_settings() ) if len(self.plot_info_dict['fcst_var_threshs']) > \ len(list(thresh_marker_plot_settings_dict.keys())): self.logger.error("FATAL ERROR, REQUESTED NUMBER OF THRESHOLDS (" +f"{len(self.plot_info_dict['fcst_var_threshs'])} " +", " +','.join(self.plot_info_dict['fcst_var_threshs']) +") EXCEEDS PRESET MARKER SETTING, REDUCE NUMBER " +"OF THRESHOLDS TO <= " +f"{len(list(thresh_marker_plot_settings_dict.keys()))}") sys.exit(1) thresh_legend_handles = [] thresh_mark_dict = {} for fcst_var_thresh in self.plot_info_dict['fcst_var_threshs']: fcst_var_thresh_num = ( self.plot_info_dict['fcst_var_threshs'].index(fcst_var_thresh) + 1 ) if fcst_var_thresh in list(thresh_marker_plot_settings_dict.keys()): fcst_var_thresh_marker_dict = ( thresh_marker_plot_settings_dict[fcst_var_thresh] ) else: fcst_var_thresh_marker_dict = ( thresh_marker_plot_settings_dict\ ['marker'+str(fcst_var_thresh_num)] ) thresh_legend_handles.append( f(fcst_var_thresh_marker_dict['marker'], 'white', 'solid', 0, fcst_var_thresh_marker_dict['markersize'], 'black') ) thresh_mark_dict[fcst_var_thresh] = fcst_var_thresh_marker_dict thresh_legend_labels = [ f'{t} {fcst_units[0]}' for t in self.plot_info_dict['fcst_var_threshs'] ] thresh_legend = ax.legend( thresh_legend_handles, thresh_legend_labels, bbox_to_anchor=(0.5, -0.075), loc = 'upper center', ncol = plot_specs_pd.legend_ncol, fontsize = plot_specs_pd.legend_font_size ) plt.draw() ax.add_artist(thresh_legend) model_legend_handles = [] model_legend_labels = [] model_plot_settings_dict = plot_specs_pd.get_model_plot_settings() for model_idx in model_idx_list: model_num = model_idx.split('/')[0] model_num_name = model_idx.split('/')[1] model_num_plot_name = model_idx.split('/')[2] model_num_obs_name = self.model_info_dict[model_num]['obs_name'] model_num_data = perf_dia_stat_avg_df.loc[model_idx] if model_num_name in list(model_plot_settings_dict.keys()): model_num_plot_settings_dict = ( model_plot_settings_dict[model_num_name] ) else: model_num_plot_settings_dict = ( model_plot_settings_dict[model_num] ) model_num_SRATIO = model_num_data.loc['SRATIO'] masked_model_num_SRATIO = np.ma.masked_invalid(model_num_SRATIO) model_num_npts_SRATIO = ( len(masked_model_num_SRATIO) - np.ma.count_masked(masked_model_num_SRATIO) ) model_num_POD = model_num_data.loc['POD'] masked_model_num_POD = np.ma.masked_invalid(model_num_POD) model_num_npts_POD = ( len(masked_model_num_POD) - np.ma.count_masked(masked_model_num_POD) ) if model_num_npts_SRATIO != 0 and model_num_npts_POD != 0: self.logger.debug(f"Plotting {model_num} - {model_num_name} " +f"- {model_num_plot_name}") ax.plot( masked_model_num_SRATIO, masked_model_num_POD, color = model_num_plot_settings_dict['color'], linestyle = model_num_plot_settings_dict['linestyle'], linewidth = 2*model_num_plot_settings_dict['linewidth'], marker = 'None', markersize = 0, zorder = (len(list(self.model_info_dict.keys())) - model_idx_list.index(model_idx) + 4) ) model_legend_labels.append(model_num_plot_name) model_legend_handles.append( f('', model_num_plot_settings_dict['color'], model_num_plot_settings_dict['linestyle'], 8, 0, 'white') ) for fcst_var_thresh in self.plot_info_dict['fcst_var_threshs']: ax.scatter( model_num_SRATIO[fcst_var_thresh], model_num_POD[fcst_var_thresh], c = model_num_plot_settings_dict['color'], linewidth = 2, edgecolors='white', marker = thresh_mark_dict[fcst_var_thresh]['marker'], s = thresh_mark_dict[fcst_var_thresh]['markersize']**2, zorder=40 ) inv = ax.transData.inverted() legend_box = thresh_legend.get_frame().get_bbox() legend_box_inv = inv.transform([(legend_box.x0,legend_box.y0), (legend_box.x1,legend_box.y1)]) model_legend = ax.legend( model_legend_handles, model_legend_labels, bbox_to_anchor=(0.5, legend_box_inv[0][1]*1.1), loc = 'upper center', ncol = plot_specs_pd.legend_ncol, fontsize = plot_specs_pd.legend_font_size ) self.logger.info("Saving image as "+image_name) plt.savefig(image_name) plt.clf() plt.close('all')