Skip to content

Commit 05462f2

Browse files
committed
Initial changes to allow pymc3.Data() to support both int and float input data (previously all input data was coerced to float)
WIP for pymc-devs#3813
1 parent 433c693 commit 05462f2

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

pymc3/data.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -478,10 +478,20 @@ class Data:
478478
For more information, take a look at this example notebook
479479
https://docs.pymc.io/notebooks/data_container.html
480480
"""
481-
def __new__(self, name, value):
481+
def __new__(self, name, value, dtype = None):
482+
if dtype is None:
483+
if hasattr(value, 'dtype'):
484+
# if no dtype given, but available as attr of value, use that as dtype
485+
dtype = value.dtype
486+
elif isinstance(value, int):
487+
dtype = int
488+
else:
489+
# otherwise, assume float
490+
dtype = float
491+
482492
# `pm.model.pandas_to_array` takes care of parameter `value` and
483493
# transforms it to something digestible for pymc3
484-
shared_object = theano.shared(pm.model.pandas_to_array(value), name)
494+
shared_object = theano.shared(pm.model.pandas_to_array(value, dtype = dtype), name)
485495

486496
# To draw the node for this variable in the graphviz Digraph we need
487497
# its shape.

pymc3/model.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1473,7 +1473,7 @@ def init_value(self):
14731473
return self.tag.test_value
14741474

14751475

1476-
def pandas_to_array(data):
1476+
def pandas_to_array(data, dtype = float):
14771477
if hasattr(data, 'values'): # pandas
14781478
if data.isnull().any().any(): # missing values
14791479
ret = np.ma.MaskedArray(data.values, data.isnull().values)
@@ -1492,8 +1492,10 @@ def pandas_to_array(data):
14921492
ret = generator(data)
14931493
else:
14941494
ret = np.asarray(data)
1495-
return pm.floatX(ret)
1496-
1495+
if dtype in [float, np.float32, np.float64]:
1496+
return pm.floatX(ret)
1497+
elif dtype in [int, np.int32, np.int64]:
1498+
return pm.intX(ret)
14971499

14981500
def as_tensor(data, name, model, distribution):
14991501
dtype = distribution.dtype

0 commit comments

Comments
 (0)