diff --git a/yatsm/cli/pixel.py b/yatsm/cli/pixel.py index 56c2d018..99d01a1e 100644 --- a/yatsm/cli/pixel.py +++ b/yatsm/cli/pixel.py @@ -57,6 +57,7 @@ @click.option('--algo_kw', multiple=True, callback=options.callback_dict, help='Algorithm parameter overrides') @click.option('--result_prefix', type=str, default='', show_default=True, + multiple=True, help='Plot coef/rmse from refit that used this prefix') @click.pass_context def pixel(ctx, config, px, py, band, plot, ylim, style, cmap, @@ -67,8 +68,11 @@ def pixel(ctx, config, px, py, band, plot, ylim, style, cmap, band -= 1 # Format result prefix if result_prefix: - result_prefix = (result_prefix if result_prefix[-1] == '_' else - result_prefix + '_') + result_prefix = set((_pref if _pref[-1] == '_' else _pref + '_') + for _pref in result_prefix) + result_prefix.add('') # add in no prefix to show original fit + else: + result_prefix = ('') # Get colormap if cmap not in mpl.cm.cmap_d: @@ -157,17 +161,17 @@ def pixel(ctx, config, px, py, band, plot, ylim, style, cmap, plot_DOY(dt_dates, Y[band, :], cmap) elif _plot == 'VAL': plot_VAL(dt_dates, Y[band, :], cmap) - + if ylim: plt.ylim(ylim) plt.title('Timeseries: px={px} py={py}'.format(px=px, py=py)) plt.ylabel('Band {b}'.format(b=band + 1)) - - for _prefix in set((result_prefix, '')): + + for _prefix in set(result_prefix): plot_results(band, cfg, yatsm, design_info, result_prefix=_prefix, plot_type=_plot) - + plt.tight_layout() plt.show() @@ -266,7 +270,8 @@ def plot_results(band, cfg, model, design_info, coef_k = result_prefix + 'coef' rmse_k = result_prefix + 'rmse' if coef_k not in result_k or rmse_k not in result_k: - raise KeyError('Cannot find result prefix in results') + raise KeyError('Cannot find result prefix "{}" in results' + .format(result_prefix)) if result_prefix: click.echo('Using "{}" re-fitted results'.format(result_prefix)) @@ -283,7 +288,7 @@ def plot_results(band, cfg, model, design_info, i_coef.append(v) i_coef = np.asarray(i_coef) - _prefix = result_prefix or cfg['YATSM']['prediction'] + _prefix = result_prefix or cfg['YATSM']['prediction'] for i, r in enumerate(model.record): label = 'Model {i} ({prefix})'.format(i=i, prefix=_prefix) if plot_type == 'TS': @@ -311,7 +316,6 @@ def plot_results(band, cfg, model, design_info, mx_date = np.array([dt.datetime.fromordinal(d).timetuple().tm_yday for d in mx]) - label = 'Model {i} - {yr} ({prefix})'.format(i=i, yr=yr_mid, prefix=_prefix)