Skip to content

Commit

Permalink
clean up little more
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Nov 11, 2024
1 parent ee7b6d3 commit 9295f17
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9120,9 +9120,10 @@ def polyfit(
rank = np.linalg.matrix_rank(lhs)

if full:
rank = DataArray(rank, name=xname + "matrix_rank")
variables[rank.name] = rank
rank = Variable(dims=(), data=rank)
variables[xname + "matrix_rank"] = rank
_sing = np.linalg.svd(lhs, compute_uv=False)
# Using a DataArray here because `degree_dim` coordinate values need not
sing = DataArray(
_sing,
dims=(degree_dim,),
Expand Down Expand Up @@ -9184,15 +9185,15 @@ def polyfit(
# Thus a ReprObject => polyfit was called on a DataArray
name = ""

coeffs = Variable(data=coeffs / scale_da, dims=(degree_dim,) + other_dims)
variables[name + "polyfit_coefficients"] = coeffs
variables[name + "polyfit_coefficients"] = Variable(
data=coeffs / scale_da, dims=(degree_dim,) + other_dims
)

if full or (cov is True):
residuals = Variable(
variables[name + "polyfit_residuals"] = Variable(
data=residuals if var.ndim > 1 else residuals.squeeze(),
dims=other_dims,
)
variables[name + "polyfit_residuals"] = residuals

if cov:
Vbase = np.linalg.inv(np.dot(lhs.T, lhs))
Expand All @@ -9204,9 +9205,10 @@ def polyfit(
raise ValueError(
"The number of data points must exceed order to scale the covariance matrix."
)
fac = residuals / (x.shape[0] - order)
covariance = DataArray(Vbase, dims=("cov_i", "cov_j")) * fac
variables[name + "polyfit_covariance"] = covariance
fac = variables[name + "polyfit_residuals"] / (x.shape[0] - order)
variables[name + "polyfit_covariance"] = (
Variable(data=Vbase, dims=("cov_i", "cov_j")) * fac
)

return type(self)(
data_vars=variables,
Expand Down

0 comments on commit 9295f17

Please sign in to comment.