diff --git a/pandas/core/internals.py b/pandas/core/internals.py index bb6702b50ad3d..a0e122d390240 100644 --- a/pandas/core/internals.py +++ b/pandas/core/internals.py @@ -5391,7 +5391,8 @@ def is_uniform_join_units(join_units): # all blocks need to have the same type all(type(ju.block) is type(join_units[0].block) for ju in join_units) and # noqa # no blocks that would get missing values (can lead to type upcasts) - all(not ju.is_na for ju in join_units) and + # unless we're an extension dtype. + all(not ju.is_na or ju.block.is_extension for ju in join_units) and # no blocks with indexers (as then the dimensions do not fit) all(not ju.indexers for ju in join_units) and # disregard Panels diff --git a/pandas/tests/extension/base/reshaping.py b/pandas/tests/extension/base/reshaping.py index cfb70f2291555..9b9a614889bef 100644 --- a/pandas/tests/extension/base/reshaping.py +++ b/pandas/tests/extension/base/reshaping.py @@ -25,6 +25,21 @@ def test_concat(self, data, in_frame): assert dtype == data.dtype assert isinstance(result._data.blocks[0], ExtensionBlock) + @pytest.mark.parametrize('in_frame', [True, False]) + def test_concat_all_na_block(self, data_missing, in_frame): + valid_block = pd.Series(data_missing.take([1, 1]), index=[0, 1]) + na_block = pd.Series(data_missing.take([0, 0]), index=[2, 3]) + if in_frame: + valid_block = pd.DataFrame({"a": valid_block}) + na_block = pd.DataFrame({"a": na_block}) + result = pd.concat([valid_block, na_block]) + if in_frame: + expected = pd.DataFrame({"a": data_missing.take([1, 1, 0, 0])}) + self.assert_frame_equal(result, expected) + else: + expected = pd.Series(data_missing.take([1, 1, 0, 0])) + self.assert_series_equal(result, expected) + def test_align(self, data, na_value): a = data[:3] b = data[2:5] diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 01ae092bc1521..4c6ef9b4d38c8 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -36,31 +36,22 @@ def na_value(): class BaseDecimal(object): - @staticmethod - def assert_series_equal(left, right, *args, **kwargs): - # tm.assert_series_equal doesn't handle Decimal('NaN'). - # We will ensure that the NA values match, and then - # drop those values before moving on. + + def assert_series_equal(self, left, right, *args, **kwargs): left_na = left.isna() right_na = right.isna() tm.assert_series_equal(left_na, right_na) - tm.assert_series_equal(left[~left_na], right[~right_na], - *args, **kwargs) - - @staticmethod - def assert_frame_equal(left, right, *args, **kwargs): - # TODO(EA): select_dtypes - decimals = (left.dtypes == 'decimal').index - - for col in decimals: - BaseDecimal.assert_series_equal(left[col], right[col], - *args, **kwargs) - - left = left.drop(columns=decimals) - right = right.drop(columns=decimals) - tm.assert_frame_equal(left, right, *args, **kwargs) + return tm.assert_series_equal(left[~left_na], + right[~right_na], + *args, **kwargs) + + def assert_frame_equal(self, left, right, *args, **kwargs): + self.assert_series_equal(left.dtypes, right.dtypes) + for col in left.columns: + self.assert_series_equal(left[col], right[col], + *args, **kwargs) class TestDtype(BaseDecimal, base.BaseDtypeTests):