Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Oct 27, 2023
1 parent 788c7dc commit aa54d13
Showing 1 changed file with 4 additions and 11 deletions.
15 changes: 4 additions & 11 deletions metrics/crps/crps.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,13 @@ def _info(self):
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(self._get_feature_types()),
reference_urls=[
"https://www.lokad.com/continuous-ranked-probability-score/"
],
reference_urls=["https://www.lokad.com/continuous-ranked-probability-score/"],
)

def _get_feature_types(self):
if self.config_name == "multilist":
return {
"predictions": datasets.Sequence(
datasets.Sequence(datasets.Value("float"))
),
"predictions": datasets.Sequence(datasets.Sequence(datasets.Value("float"))),
"references": datasets.Sequence(datasets.Value("float")),
}
else:
Expand Down Expand Up @@ -125,16 +121,13 @@ def _compute(
weighted_quantile_loss = []
for q in quantiles:
forecast_quantile = np.quantile(predictions, q, axis=0)
weighted_quantile_loss.append(
self.quantile_loss(references, forecast_quantile, q) / abs_target_sum
)
weighted_quantile_loss.append(self.quantile_loss(references, forecast_quantile, q) / abs_target_sum)

if multioutput == "raw_values":
return {"crps": weighted_quantile_loss}
elif multioutput == "uniform_average":
return {"crps": np.average(weighted_quantile_loss)}
else:
raise ValueError(
"The multioutput parameter should be one of the following: "
+ "'raw_values', 'uniform_average'"
"The multioutput parameter should be one of the following: " + "'raw_values', 'uniform_average'"
)

0 comments on commit aa54d13

Please sign in to comment.