-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Fix in vectorized item assignment #1746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
9eafe36
2aff4ce
22c1295
dc135cd
b1dd0f0
c74c828
8e6e2e0
aef3d56
f10ecf4
d482f80
c2b5ac3
84e5e6f
2011140
6906eeb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -301,3 +301,21 @@ def __getitem__(self, key): | |
|
||
def __unicode__(self): | ||
return formatting.indexes_repr(self) | ||
|
||
|
||
def assert_coordinate_consistent(obj, coords): | ||
""" Maeke sure the dimension coordinate of obj is | ||
consistent with coords. | ||
|
||
obj: DataArray or Dataset | ||
coords: Dict-like of variables | ||
""" | ||
for k in obj.dims: | ||
# make sure there are no conflict in dimension coordinates | ||
if (k in coords and k in obj.coords): | ||
coord = getattr(coords[k], 'variable', coords[k]) # Variable | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be better to insist that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's really reasonable. Updated. |
||
if not coord.equals(obj[k].variable): | ||
raise IndexError( | ||
'dimension coordinate {!r} conflicts between ' | ||
'indexed and indexing objects:\n{}\nvs.\n{}' | ||
.format(k, obj[k], coords[k])) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ | |
from .alignment import align, reindex_like_indexers | ||
from .common import AbstractArray, BaseDataObject | ||
from .coordinates import (DataArrayCoordinates, LevelCoordinatesSource, | ||
Indexes) | ||
Indexes, assert_coordinate_consistent) | ||
from .dataset import Dataset, merge_indexes, split_indexes | ||
from .pycompat import iteritems, basestring, OrderedDict, zip, range | ||
from .variable import (as_variable, Variable, as_compatible_data, | ||
|
@@ -484,7 +484,13 @@ def __setitem__(self, key, value): | |
if isinstance(key, basestring): | ||
self.coords[key] = value | ||
else: | ||
# xarray-style array indexing | ||
# Coordinates in key, value and self[key] should be consistent. | ||
obj = self[key] | ||
if isinstance(value, DataArray): | ||
assert_coordinate_consistent(value, obj.coords) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was actually thinking of checking the consistency of coords on each DataArray argument in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if you use |
||
# DataArray key -> Variable key | ||
key = {k: v.variable if isinstance(v, DataArray) else v | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to enforce consistency for coordinates here? My inclination would be that we should support exactly the same keys in setitem as are valid in getitem. Ideally we should also reuse the same code. That means we should raise errors if there are multiple indexers with inconsistent alignment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Reasonable. I will add a validation. |
||
for k, v in self._item_key_to_dict(key).items()} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to check that coordinates are consistent on the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is done a few lines above, But I am wondering this unnecessary indexing, though I think this implementation is the simplest. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm. This could indeed be a significant performance hit. That said, I'm OK leaving this for now, with a TODO note to optimize it later. |
||
self.variable[key] = value | ||
|
||
def __delitem__(self, key): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -637,17 +637,22 @@ def __setitem__(self, key, value): | |
""" | ||
dims, index_tuple, new_order = self._broadcast_indexes(key) | ||
|
||
if isinstance(value, Variable): | ||
value = value.set_dims(dims).data | ||
|
||
if new_order: | ||
value = duck_array_ops.asarray(value) | ||
if not isinstance(value, Variable): | ||
value = as_compatible_data(value) | ||
if value.ndim > len(dims): | ||
raise ValueError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we still need this special case error message now that we call set_dims below? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||
'shape mismatch: value array of shape %s could not be' | ||
'broadcast to indexing result with %s dimensions' | ||
% (value.shape, len(dims))) | ||
if value.ndim == 0: | ||
value = Variable((), value) | ||
else: | ||
value = Variable(dims[-value.ndim:], value) | ||
# broadcast to become assignable | ||
value = value.set_dims(dims).data | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I decided to revert |
||
|
||
if new_order: | ||
value = duck_array_ops.asarray(value) | ||
value = value[(len(dims) - value.ndim) * (np.newaxis,) + | ||
(Ellipsis,)] | ||
value = np.moveaxis(value, new_order, range(len(new_order))) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -526,6 +526,91 @@ def test_setitem(self): | |
expected[t] = 1 | ||
self.assertArrayEqual(orig.values, expected) | ||
|
||
def test_setitem_fancy(self): | ||
# vectorized indexing | ||
da = DataArray(np.ones((3, 2)), dims=['x', 'y']) | ||
ind = Variable(['a'], [0, 1]) | ||
da[dict(x=ind, y=ind)] = 0 | ||
expected = DataArray([[0, 1], [1, 0], [1, 1]], dims=['x', 'y']) | ||
self.assertDataArrayIdentical(expected, da) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the future (no need to change this time), you can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I always forget this, maybe because I usually make tests based on the existing ones... |
||
# assign another 0d-variable | ||
da[dict(x=ind, y=ind)] = Variable((), 0) | ||
expected = DataArray([[0, 1], [1, 0], [1, 1]], dims=['x', 'y']) | ||
self.assertDataArrayIdentical(expected, da) | ||
# assign another 1d-variable | ||
da[dict(x=ind, y=ind)] = Variable(['a'], [2, 3]) | ||
expected = DataArray([[2, 1], [1, 3], [1, 1]], dims=['x', 'y']) | ||
self.assertVariableIdentical(expected, da) | ||
|
||
# 2d-vectorized indexing | ||
da = DataArray(np.ones((3, 2)), dims=['x', 'y']) | ||
ind_x = DataArray([[0, 1]], dims=['a', 'b']) | ||
ind_y = DataArray([[1, 0]], dims=['a', 'b']) | ||
da[dict(x=ind_x, y=ind_y)] = 0 | ||
expected = DataArray([[1, 0], [0, 1], [1, 1]], dims=['x', 'y']) | ||
self.assertVariableIdentical(expected, da) | ||
|
||
da = DataArray(np.ones((3, 2)), dims=['x', 'y']) | ||
ind = Variable(['a'], [0, 1]) | ||
da[ind] = 0 | ||
expected = DataArray([[0, 0], [0, 0], [1, 1]], dims=['x', 'y']) | ||
self.assertVariableIdentical(expected, da) | ||
|
||
def test_setitem_dataarray(self): | ||
def get_data(): | ||
return DataArray(np.ones((4, 3, 2)), dims=['x', 'y', 'z'], | ||
coords={'x': np.arange(4), 'y': ['a', 'b', 'c'], | ||
'non-dim': ('x', [1, 3, 4, 2])}) | ||
|
||
da = get_data() | ||
# indexer with inconsistent coordinates. | ||
ind = DataArray(np.arange(1, 4), dims=['x'], | ||
coords={'x': np.random.randn(3)}) | ||
with raises_regex(IndexError, "dimension coordinate 'x'"): | ||
da[dict(x=ind)] = 0 | ||
|
||
# indexer with consistent coordinates. | ||
ind = DataArray(np.arange(1, 4), dims=['x'], | ||
coords={'x': np.arange(1, 4)}) | ||
da[dict(x=ind)] = 0 # should not raise | ||
assert np.allclose(da[dict(x=ind)].values, 0) | ||
self.assertDataArrayIdentical(da['x'], get_data()['x']) | ||
self.assertDataArrayIdentical(da['non-dim'], get_data()['non-dim']) | ||
|
||
da = get_data() | ||
# conflict in the assigning values | ||
value = xr.DataArray(np.zeros((3, 3, 2)), dims=['x', 'y', 'z'], | ||
coords={'x': [0, 1, 2], | ||
'non-dim': ('x', [0, 2, 4])}) | ||
with raises_regex(IndexError, "dimension coordinate 'x'"): | ||
da[dict(x=ind)] = value | ||
|
||
# consistent coordinate in the assigning values | ||
value = xr.DataArray(np.zeros((3, 3, 2)), dims=['x', 'y', 'z'], | ||
coords={'x': [1, 2, 3], | ||
'non-dim': ('x', [0, 2, 4])}) | ||
da[dict(x=ind)] = value | ||
assert np.allclose(da[dict(x=ind)].values, 0) | ||
self.assertDataArrayIdentical(da['x'], get_data()['x']) | ||
self.assertDataArrayIdentical(da['non-dim'], get_data()['non-dim']) | ||
|
||
# Conflict in the non-dimension coordinate | ||
value = xr.DataArray(np.zeros((3, 3, 2)), dims=['x', 'y', 'z'], | ||
coords={'x': [1, 2, 3], | ||
'non-dim': ('x', [0, 2, 4])}) | ||
# conflict in the assigning values | ||
value = xr.DataArray(np.zeros((3, 3, 2)), dims=['x', 'y', 'z'], | ||
coords={'x': [0, 1, 2], | ||
'non-dim': ('x', [0, 2, 4])}) | ||
with raises_regex(IndexError, "dimension coordinate 'x'"): | ||
da[dict(x=ind)] = value | ||
|
||
# consistent coordinate in the assigning values | ||
value = xr.DataArray(np.zeros((3, 3, 2)), dims=['x', 'y', 'z'], | ||
coords={'x': [1, 2, 3], | ||
'non-dim': ('x', [0, 2, 4])}) | ||
da[dict(x=ind)] = value # should not raise | ||
|
||
def test_contains(self): | ||
data_array = DataArray(1, coords={'x': 2}) | ||
with pytest.warns(FutureWarning): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you can drop the extra parentheses here inside
if