@@ -9132,34 +9132,39 @@ def polyfit(
9132
9132
variables [sing .name ] = sing
9133
9133
9134
9134
# If we have a coordinate get its underlying dimension.
9135
- true_dim = self .coords [dim ].dims [ 0 ]
9135
+ ( true_dim ,) = self .coords [dim ].dims
9136
9136
9137
- for name , da in self .data_vars .items ():
9138
- if true_dim not in da .dims :
9137
+ other_coords = {
9138
+ dim : self ._variables [dim ]
9139
+ for dim in set (self .dims ) - {true_dim }
9140
+ if dim in self ._variables
9141
+ }
9142
+ present_dims = set ()
9143
+ for name , var in self ._variables .items ():
9144
+ if name in self ._coord_names or name in self .dims :
9145
+ continue
9146
+ if true_dim not in var .dims :
9139
9147
continue
9140
9148
9141
- if is_duck_dask_array (da . data ) and (
9149
+ if is_duck_dask_array (var . _data ) and (
9142
9150
rank != order or full or skipna is None
9143
9151
):
9144
9152
# Current algorithm with dask and skipna=False neither supports
9145
9153
# deficient ranks nor does it output the "full" info (issue dask/dask#6516)
9146
9154
skipna_da = True
9147
9155
elif skipna is None :
9148
- skipna_da = bool (np .any (da .isnull ()))
9149
-
9150
- dims_to_stack = [dimname for dimname in da .dims if dimname != true_dim ]
9151
- stacked_coords : dict [Hashable , DataArray ] = {}
9152
- if dims_to_stack :
9153
- stacked_dim = utils .get_temp_dimname (dims_to_stack , "stacked" )
9154
- rhs = da .transpose (true_dim , * dims_to_stack ).stack (
9155
- {stacked_dim : dims_to_stack }
9156
- )
9157
- stacked_coords = {stacked_dim : rhs [stacked_dim ]}
9158
- scale_da = scale [:, np .newaxis ]
9156
+ skipna_da = bool (np .any (var .isnull ()))
9157
+
9158
+ if var .ndim > 1 :
9159
+ rhs = var .transpose (true_dim , ...)
9160
+ other_dims = rhs .dims [1 :]
9161
+ scale_da = scale .reshape (- 1 , * ((1 ,) * len (other_dims )))
9159
9162
else :
9160
- rhs = da
9163
+ rhs = var
9161
9164
scale_da = scale
9165
+ other_dims = ()
9162
9166
9167
+ present_dims .update (* other_dims )
9163
9168
if w is not None :
9164
9169
rhs = rhs * w [:, np .newaxis ]
9165
9170
@@ -9179,26 +9184,15 @@ def polyfit(
9179
9184
# Thus a ReprObject => polyfit was called on a DataArray
9180
9185
name = ""
9181
9186
9182
- coeffs = DataArray (
9183
- coeffs / scale_da ,
9184
- dims = [degree_dim ] + list (stacked_coords .keys ()),
9185
- coords = {degree_dim : np .arange (order )[::- 1 ], ** stacked_coords },
9186
- name = name + "polyfit_coefficients" ,
9187
- )
9188
- if dims_to_stack :
9189
- coeffs = coeffs .unstack (stacked_dim )
9190
- variables [coeffs .name ] = coeffs
9187
+ coeffs = Variable (data = coeffs / scale_da , dims = (degree_dim ,) + other_dims )
9188
+ variables [name + "polyfit_coefficients" ] = coeffs
9191
9189
9192
9190
if full or (cov is True ):
9193
- residuals = DataArray (
9194
- residuals if dims_to_stack else residuals .squeeze (),
9195
- dims = list (stacked_coords .keys ()),
9196
- coords = stacked_coords ,
9197
- name = name + "polyfit_residuals" ,
9191
+ residuals = Variable (
9192
+ data = residuals if var .ndim > 1 else residuals .squeeze (),
9193
+ dims = other_dims ,
9198
9194
)
9199
- if dims_to_stack :
9200
- residuals = residuals .unstack (stacked_dim )
9201
- variables [residuals .name ] = residuals
9195
+ variables [name + "polyfit_residuals" ] = residuals
9202
9196
9203
9197
if cov :
9204
9198
Vbase = np .linalg .inv (np .dot (lhs .T , lhs ))
@@ -9214,7 +9208,18 @@ def polyfit(
9214
9208
covariance = DataArray (Vbase , dims = ("cov_i" , "cov_j" )) * fac
9215
9209
variables [name + "polyfit_covariance" ] = covariance
9216
9210
9217
- return type (self )(data_vars = variables , attrs = self .attrs .copy ())
9211
+ return type (self )(
9212
+ data_vars = variables ,
9213
+ coords = {
9214
+ degree_dim : np .arange (order )[::- 1 ],
9215
+ ** {
9216
+ name : coord
9217
+ for name , coord in other_coords .items ()
9218
+ if name in present_dims
9219
+ },
9220
+ },
9221
+ attrs = self .attrs .copy (),
9222
+ )
9218
9223
9219
9224
def pad (
9220
9225
self ,
0 commit comments