Skip to content

Commit 0c3d2bd

Browse files
committed
BUG: Make DataFrame not hardcode it's own constructor in the code. Issue
Changed all occurrences of `return DataFrame` to `return self._constructor` and changed the latter to return `self.__class__` rather than `DataFrame` by default. Also made static methods like `DataFrame.from_dict` work on the `cls` passed to them rather than on hard-coded `DataFrame`. This is a minimal quick win on the cases described in Issue pandas-dev#2859. A formal approach to make classes friendly to subclassing might still be needed, see discussion in the case.
1 parent fc8de6d commit 0c3d2bd

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

pandas/core/frame.py

+20-20
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ def axes(self):
565565

566566
@property
567567
def _constructor(self):
568-
return DataFrame
568+
return self.__class__
569569

570570
# Fancy indexing
571571
_ix = None
@@ -855,15 +855,15 @@ def dot(self, other):
855855
(lvals.shape, rvals.shape))
856856

857857
if isinstance(other, DataFrame):
858-
return DataFrame(np.dot(lvals, rvals),
858+
return self._constructor(np.dot(lvals, rvals),
859859
index=left.index,
860860
columns=other.columns)
861861
elif isinstance(other, Series):
862862
return Series(np.dot(lvals, rvals), index=left.index)
863863
elif isinstance(rvals, np.ndarray):
864864
result = np.dot(lvals, rvals)
865865
if result.ndim == 2:
866-
return DataFrame(result, index=left.index)
866+
return self._constructor(result, index=left.index)
867867
else:
868868
return Series(result, index=left.index)
869869
else: # pragma: no cover
@@ -902,7 +902,7 @@ def from_dict(cls, data, orient='columns', dtype=None):
902902
elif orient != 'columns': # pragma: no cover
903903
raise ValueError('only recognize index or columns for orient')
904904

905-
return DataFrame(data, index=index, columns=columns, dtype=dtype)
905+
return cls(data, index=index, columns=columns, dtype=dtype)
906906

907907
def to_dict(self, outtype='dict'):
908908
"""
@@ -969,15 +969,15 @@ def from_records(cls, data, index=None, exclude=None, columns=None,
969969

970970
if com.is_iterator(data):
971971
if nrows == 0:
972-
return DataFrame()
972+
return cls()
973973

974974
try:
975975
if py3compat.PY3:
976976
first_row = next(data)
977977
else:
978978
first_row = data.next()
979979
except StopIteration:
980-
return DataFrame(index=index, columns=columns)
980+
return cls(index=index, columns=columns)
981981

982982
dtype = None
983983
if hasattr(first_row, 'dtype') and first_row.dtype.names:
@@ -1067,7 +1067,7 @@ def from_records(cls, data, index=None, exclude=None, columns=None,
10671067
mgr = _arrays_to_mgr(arrays, arr_columns, result_index,
10681068
columns)
10691069

1070-
return DataFrame(mgr)
1070+
return cls(mgr)
10711071

10721072
def to_records(self, index=True, convert_datetime64=True):
10731073
"""
@@ -2061,7 +2061,7 @@ def _slice(self, slobj, axis=0):
20612061
def _box_item_values(self, key, values):
20622062
items = self.columns[self.columns.get_loc(key)]
20632063
if values.ndim == 2:
2064-
return DataFrame(values.T, columns=items, index=self.index)
2064+
return self._constructor(values.T, columns=items, index=self.index)
20652065
else:
20662066
return Series.from_array(values, index=self.index, name=items)
20672067

@@ -2647,7 +2647,7 @@ def _reindex_multi(self, new_index, new_columns, copy, fill_value):
26472647
if row_indexer is not None and col_indexer is not None:
26482648
new_values = com.take_2d_multi(self.values, row_indexer,
26492649
col_indexer, fill_value=fill_value)
2650-
return DataFrame(new_values, index=new_index, columns=new_columns)
2650+
return self._constructor(new_values, index=new_index, columns=new_columns)
26512651
elif row_indexer is not None:
26522652
return self._reindex_with_indexers(new_index, row_indexer,
26532653
None, None, copy, fill_value)
@@ -2695,7 +2695,7 @@ def _reindex_with_indexers(self, index, row_indexer, columns, col_indexer,
26952695
if copy and new_data is self._data:
26962696
new_data = new_data.copy()
26972697

2698-
return DataFrame(new_data)
2698+
return self._constructor(new_data)
26992699

27002700
def reindex_like(self, other, method=None, copy=True, limit=None):
27012701
"""
@@ -2938,7 +2938,7 @@ def take(self, indices, axis=0):
29382938
if self._is_mixed_type:
29392939
if axis == 0:
29402940
new_data = self._data.take(indices, axis=1)
2941-
return DataFrame(new_data)
2941+
return self._constructor(new_data)
29422942
else:
29432943
new_columns = self.columns.take(indices)
29442944
return self.reindex(columns=new_columns)
@@ -2952,7 +2952,7 @@ def take(self, indices, axis=0):
29522952
else:
29532953
new_columns = self.columns.take(indices)
29542954
new_index = self.index
2955-
return DataFrame(new_values, index=new_index,
2955+
return self._constructor(new_values, index=new_index,
29562956
columns=new_columns)
29572957

29582958
#----------------------------------------------------------------------
@@ -4191,7 +4191,7 @@ def _apply_raw(self, func, axis):
41914191

41924192
# TODO: mixed type case
41934193
if result.ndim == 2:
4194-
return DataFrame(result, index=self.index,
4194+
return self._constructor(result, index=self.index,
41954195
columns=self.columns)
41964196
else:
41974197
return Series(result, index=self._get_agg_axis(axis))
@@ -4622,7 +4622,7 @@ def describe(self, percentile_width=50):
46224622
numdata = self._get_numeric_data()
46234623

46244624
if len(numdata.columns) == 0:
4625-
return DataFrame(dict((k, v.describe())
4625+
return self._constructor(dict((k, v.describe())
46264626
for k, v in self.iteritems()),
46274627
columns=self.columns)
46284628

@@ -5006,7 +5006,7 @@ def _get_agg_axis(self, axis_num):
50065006
def _get_numeric_data(self):
50075007
if self._is_mixed_type:
50085008
num_data = self._data.get_numeric_data()
5009-
return DataFrame(num_data, index=self.index, copy=False)
5009+
return self._constructor(num_data, index=self.index, copy=False)
50105010
else:
50115011
if (self.values.dtype != np.object_ and
50125012
not issubclass(self.values.dtype.type, np.datetime64)):
@@ -5017,7 +5017,7 @@ def _get_numeric_data(self):
50175017
def _get_bool_data(self):
50185018
if self._is_mixed_type:
50195019
bool_data = self._data.get_bool_data()
5020-
return DataFrame(bool_data, index=self.index, copy=False)
5020+
return self._constructor(bool_data, index=self.index, copy=False)
50215021
else: # pragma: no cover
50225022
if self.values.dtype == np.bool_:
50235023
return self
@@ -5127,7 +5127,7 @@ def rank(self, axis=0, numeric_only=None, method='average',
51275127
try:
51285128
ranks = algos.rank(self.values, axis=axis, method=method,
51295129
ascending=ascending, na_option=na_option)
5130-
return DataFrame(ranks, index=self.index, columns=self.columns)
5130+
return self._constructor(ranks, index=self.index, columns=self.columns)
51315131
except TypeError:
51325132
numeric_only = True
51335133

@@ -5137,7 +5137,7 @@ def rank(self, axis=0, numeric_only=None, method='average',
51375137
data = self
51385138
ranks = algos.rank(data.values, axis=axis, method=method,
51395139
ascending=ascending, na_option=na_option)
5140-
return DataFrame(ranks, index=data.index, columns=data.columns)
5140+
return self._constructor(ranks, index=data.index, columns=data.columns)
51415141

51425142
def to_timestamp(self, freq=None, how='start', axis=0, copy=True):
51435143
"""
@@ -5170,7 +5170,7 @@ def to_timestamp(self, freq=None, how='start', axis=0, copy=True):
51705170
else:
51715171
raise ValueError('Axis must be 0 or 1. Got %s' % str(axis))
51725172

5173-
return DataFrame(new_data)
5173+
return self._constructor(new_data)
51745174

51755175
def to_period(self, freq=None, axis=0, copy=True):
51765176
"""
@@ -5204,7 +5204,7 @@ def to_period(self, freq=None, axis=0, copy=True):
52045204
else:
52055205
raise ValueError('Axis must be 0 or 1. Got %s' % str(axis))
52065206

5207-
return DataFrame(new_data)
5207+
return self._constructor(new_data)
52085208

52095209
#----------------------------------------------------------------------
52105210
# Deprecated stuff

0 commit comments

Comments
 (0)