1
1
"""Utility functions"""
2
2
3
3
__all__ = ['make_array' , 'percentile' , 'plot_cdf_area' , 'plot_normal_cdf' ,
4
- 'table_apply' , 'proportions_from_distribution' , 'minimize' ]
4
+ 'table_apply' , 'proportions_from_distribution' ,
5
+ 'sample_proportions' , 'minimize' ]
5
6
6
7
import numpy as np
7
8
import pandas as pd
@@ -105,6 +106,23 @@ def plot_normal_cdf(rbound=None, lbound=None, mean=0, sd=1):
105
106
plot_cdf_area = plot_normal_cdf
106
107
107
108
109
+ def sample_proportions (sample_size , probabilities ):
110
+ """Return the proportion of random draws for each outcome in a distribution.
111
+
112
+ This function is similar to np.random.multinomial, but returns proportions
113
+ instead of counts.
114
+
115
+ Args:
116
+ ``sample_size``: The size of the sample to draw from the distribution.
117
+
118
+ ``probabilities``: An array of probabilities that forms a distribution.
119
+
120
+ Returns:
121
+ An array with the same length as ``probability`` that sums to 1.
122
+ """
123
+ return np .random .multinomial (sample_size , probabilities ) / sample_size
124
+
125
+
108
126
def proportions_from_distribution (table , label , sample_size ,
109
127
column_name = 'Random Sample' ):
110
128
"""
@@ -115,8 +133,6 @@ def proportions_from_distribution(table, label, sample_size,
115
133
from the distribution in ``table.column(label)``, then divides by
116
134
``sample_size`` to create the resulting column of proportions.
117
135
118
- Returns a new ``Table`` and does not modify ``table``.
119
-
120
136
Args:
121
137
``table``: An instance of ``Table``.
122
138
@@ -136,8 +152,7 @@ def proportions_from_distribution(table, label, sample_size,
136
152
``ValueError``: If the ``label`` is not in the table, or if
137
153
``table.column(label)`` does not sum to 1.
138
154
"""
139
- proportions = (np .random .multinomial (sample_size , table .column (label )) /
140
- sample_size )
155
+ proportions = sample_proportions (sample_size , table .column (label ))
141
156
return table .with_column ('Random Sample' , proportions )
142
157
143
158
@@ -225,4 +240,3 @@ def objective(args):
225
240
return result .x .item (0 )
226
241
else :
227
242
return result .x
228
-
0 commit comments