Skip to content

Commit 0e22ac1

Browse files
Merge pull request #110 from sinhrks/cln_subclass
CLN: Cleanup subclass
2 parents 2aaf9ca + 71df260 commit 0e22ac1

File tree

14 files changed

+509
-429
lines changed

14 files changed

+509
-429
lines changed

pandas_datareader/_utils.py

-122
Original file line numberDiff line numberDiff line change
@@ -1,134 +1,12 @@
1-
import time
21
import warnings
3-
import numpy as np
4-
import datetime as dt
52

6-
from pandas import to_datetime
7-
import pandas.compat as compat
83
from pandas.core.common import PandasError
9-
from pandas import Panel, DataFrame
10-
from pandas.io.common import urlopen
11-
from pandas import read_csv
12-
from pandas.compat import StringIO, bytes_to_str
13-
from pandas.util.testing import _network_error_classes
144

15-
if compat.PY3:
16-
from urllib.parse import urlencode
17-
else:
18-
from urllib import urlencode
195

206
class SymbolWarning(UserWarning):
217
pass
228

239
class RemoteDataError(PandasError, IOError):
2410
pass
2511

26-
def _get_data_from(symbols, start, end, interval, retry_count, pause,
27-
chunksize, src_fn):
2812

29-
# If a single symbol, (e.g., 'GOOG')
30-
if isinstance(symbols, (compat.string_types, int)):
31-
hist_data = src_fn(symbols, start, end, interval, retry_count, pause)
32-
# Or multiple symbols, (e.g., ['GOOG', 'AAPL', 'MSFT'])
33-
elif isinstance(symbols, DataFrame):
34-
hist_data = _dl_mult_symbols(symbols.index, start, end, interval, chunksize,
35-
retry_count, pause, src_fn)
36-
else:
37-
hist_data = _dl_mult_symbols(symbols, start, end, interval, chunksize,
38-
retry_count, pause, src_fn)
39-
return hist_data
40-
41-
def _dl_mult_symbols(symbols, start, end, interval, chunksize, retry_count, pause,
42-
method):
43-
stocks = {}
44-
failed = []
45-
passed = []
46-
for sym_group in _in_chunks(symbols, chunksize):
47-
for sym in sym_group:
48-
try:
49-
stocks[sym] = method(sym, start, end, interval, retry_count, pause)
50-
passed.append(sym)
51-
except IOError:
52-
warnings.warn('Failed to read symbol: {0!r}, replacing with '
53-
'NaN.'.format(sym), SymbolWarning)
54-
failed.append(sym)
55-
56-
if len(passed) == 0:
57-
raise RemoteDataError("No data fetched using "
58-
"{0!r}".format(method.__name__))
59-
try:
60-
if len(stocks) > 0 and len(failed) > 0 and len(passed) > 0:
61-
df_na = stocks[passed[0]].copy()
62-
df_na[:] = np.nan
63-
for sym in failed:
64-
stocks[sym] = df_na
65-
return Panel(stocks).swapaxes('items', 'minor')
66-
except AttributeError:
67-
# cannot construct a panel with just 1D nans indicating no data
68-
raise RemoteDataError("No data fetched using "
69-
"{0!r}".format(method.__name__))
70-
71-
72-
def _sanitize_dates(start, end):
73-
"""
74-
Return (datetime_start, datetime_end) tuple
75-
if start is None - default is 2010/01/01
76-
if end is None - default is today
77-
"""
78-
start = to_datetime(start)
79-
end = to_datetime(end)
80-
if start is None:
81-
start = dt.datetime(2010, 1, 1)
82-
if end is None:
83-
end = dt.datetime.today()
84-
return start, end
85-
86-
def _in_chunks(seq, size):
87-
"""
88-
Return sequence in 'chunks' of size defined by size
89-
"""
90-
return (seq[pos:pos + size] for pos in range(0, len(seq), size))
91-
92-
def _encode_url(url, params):
93-
"""
94-
Return encoded url with parameters
95-
"""
96-
s_params = urlencode(params)
97-
if s_params:
98-
return url + '?' + s_params
99-
else:
100-
return url
101-
102-
def _retry_read_url(url, retry_count, pause, name):
103-
"""
104-
Open url (and retry)
105-
"""
106-
for _ in range(retry_count):
107-
108-
# kludge to close the socket ASAP
109-
try:
110-
with urlopen(url) as resp:
111-
lines = resp.read()
112-
except _network_error_classes:
113-
pass
114-
else:
115-
rs = read_csv(StringIO(bytes_to_str(lines)), index_col=0,
116-
parse_dates=True, na_values='-')[::-1]
117-
# Yahoo! Finance sometimes does this awesome thing where they
118-
# return 2 rows for the most recent business day
119-
if len(rs) > 2 and rs.index[-1] == rs.index[-2]: # pragma: no cover
120-
rs = rs[:-1]
121-
122-
#Get rid of unicode characters in index name.
123-
try:
124-
rs.index.name = rs.index.name.decode('unicode_escape').encode('ascii', 'ignore')
125-
except AttributeError:
126-
#Python 3 string has no decode method.
127-
rs.index.name = rs.index.name.encode('ascii', 'ignore').decode()
128-
129-
return rs
130-
131-
time.sleep(pause)
132-
133-
raise IOError("after %d tries, %s did not "
134-
"return a 200 for url %r" % (retry_count, name, url))

pandas_datareader/base.py

+200
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import time
2+
import warnings
3+
import numpy as np
4+
import datetime as dt
5+
6+
import requests
7+
8+
from pandas import to_datetime
9+
import pandas.compat as compat
10+
from pandas.core.common import PandasError
11+
from pandas import Panel, DataFrame
12+
from pandas import read_csv
13+
from pandas.compat import StringIO, bytes_to_str
14+
from pandas.util.testing import _network_error_classes
15+
16+
from pandas_datareader._utils import RemoteDataError, SymbolWarning
17+
18+
19+
class _BaseReader(object):
20+
21+
"""
22+
23+
Parameters
24+
----------
25+
sym : string with a single Single stock symbol (ticker).
26+
start : string, (defaults to '1/1/2010')
27+
Starting date, timestamp. Parses many different kind of date
28+
representations (e.g., 'JAN-01-2010', '1/1/10', 'Jan, 1, 1980')
29+
end : string, (defaults to today)
30+
Ending date, timestamp. Same format as starting date.
31+
retry_count : int, default 3
32+
Number of times to retry query request.
33+
pause : int, default 0
34+
Time, in seconds, of the pause between retries.
35+
session : Session, default None
36+
requests.sessions.Session instance to be used
37+
"""
38+
39+
_chunk_size = 1024 * 1024
40+
41+
def __init__(self, symbols, start=None, end=None,
42+
retry_count=3, pause=0.001, session=None):
43+
self.symbols = symbols
44+
45+
start, end = self._sanitize_dates(start, end)
46+
self.start = start
47+
self.end = end
48+
49+
if not isinstance(retry_count, int) or retry_count < 0:
50+
raise ValueError("'retry_count' must be integer larger than 0")
51+
self.retry_count = retry_count
52+
self.pause = pause
53+
self.session = self._init_session(session, retry_count)
54+
55+
def _init_session(self, session, retry_count):
56+
if session is None:
57+
session = requests.Session()
58+
# do not set requests max_retries here to support arbitrary pause
59+
return session
60+
61+
@property
62+
def url(self):
63+
# must be overridden in subclass
64+
raise NotImplementedError
65+
66+
@property
67+
def params(self):
68+
return None
69+
70+
def read(self):
71+
""" read data """
72+
return self._read_one_data(self.url, self.params)
73+
74+
def _read_one_data(self, url, params):
75+
""" read one data from specified URL """
76+
out = self._read_url_as_StringIO(self.url, params=params)
77+
return self._read_lines(out)
78+
79+
def _read_url_as_StringIO(self, url, params=None):
80+
"""
81+
Open url (and retry)
82+
"""
83+
response = self._get_response(url, params=params)
84+
out = StringIO()
85+
if isinstance(response.content, compat.binary_type):
86+
out.write(bytes_to_str(response.content))
87+
else:
88+
out.write(response.content)
89+
out.seek(0)
90+
return out
91+
92+
def _get_response(self, url, params=None):
93+
""" send raw HTTP request to get requests.Response from the specified url
94+
Parameters
95+
----------
96+
url : str
97+
target URL
98+
params : dict or None
99+
parameters passed to the URL
100+
"""
101+
102+
# initial attempt + retry
103+
for i in range(self.retry_count + 1):
104+
response = self.session.get(url, params=params)
105+
if response.status_code == requests.codes.ok:
106+
return response
107+
time.sleep(self.pause)
108+
109+
raise RemoteDataError('Unable to read URL: {0}'.format(url))
110+
111+
def _read_lines(self, out):
112+
rs = read_csv(out, index_col=0, parse_dates=True, na_values='-')[::-1]
113+
# Yahoo! Finance sometimes does this awesome thing where they
114+
# return 2 rows for the most recent business day
115+
if len(rs) > 2 and rs.index[-1] == rs.index[-2]: # pragma: no cover
116+
rs = rs[:-1]
117+
#Get rid of unicode characters in index name.
118+
try:
119+
rs.index.name = rs.index.name.decode('unicode_escape').encode('ascii', 'ignore')
120+
except AttributeError:
121+
#Python 3 string has no decode method.
122+
rs.index.name = rs.index.name.encode('ascii', 'ignore').decode()
123+
return rs
124+
125+
def _sanitize_dates(self, start, end):
126+
"""
127+
Return (datetime_start, datetime_end) tuple
128+
if start is None - default is 2010/01/01
129+
if end is None - default is today
130+
"""
131+
start = to_datetime(start)
132+
end = to_datetime(end)
133+
if start is None:
134+
start = dt.datetime(2010, 1, 1)
135+
if end is None:
136+
end = dt.datetime.today()
137+
return start, end
138+
139+
140+
class _DailyBaseReader(_BaseReader):
141+
""" Base class for Google / Yahoo daily reader """
142+
143+
def __init__(self, symbols=None, start=None, end=None, retry_count=3,
144+
pause=0.001, session=None, chunksize=25):
145+
super(_DailyBaseReader, self).__init__(symbols=symbols,
146+
start=start, end=end,
147+
retry_count=retry_count,
148+
pause=pause, session=session)
149+
self.chunksize = chunksize
150+
151+
def _get_params(self, *args, **kwargs):
152+
raise NotImplementedError
153+
154+
def read(self):
155+
""" read data """
156+
# If a single symbol, (e.g., 'GOOG')
157+
if isinstance(self.symbols, (compat.string_types, int)):
158+
df = self._read_one_data(self.url, params=self._get_params(self.symbols))
159+
# Or multiple symbols, (e.g., ['GOOG', 'AAPL', 'MSFT'])
160+
elif isinstance(self.symbols, DataFrame):
161+
df = self._dl_mult_symbols(self.symbols.index)
162+
else:
163+
df = self._dl_mult_symbols(self.symbols)
164+
return df
165+
166+
def _dl_mult_symbols(self, symbols):
167+
stocks = {}
168+
failed = []
169+
passed = []
170+
for sym_group in _in_chunks(symbols, self.chunksize):
171+
for sym in sym_group:
172+
try:
173+
stocks[sym] = self._read_one_data(self.url, self._get_params(sym))
174+
passed.append(sym)
175+
except IOError:
176+
msg = 'Failed to read symbol: {0!r}, replacing with NaN.'
177+
warnings.warn(msg.format(sym), SymbolWarning)
178+
failed.append(sym)
179+
180+
if len(passed) == 0:
181+
msg = "No data fetched using {0!r}"
182+
raise RemoteDataError(msg.format(self.__class__.__name__))
183+
try:
184+
if len(stocks) > 0 and len(failed) > 0 and len(passed) > 0:
185+
df_na = stocks[passed[0]].copy()
186+
df_na[:] = np.nan
187+
for sym in failed:
188+
stocks[sym] = df_na
189+
return Panel(stocks).swapaxes('items', 'minor')
190+
except AttributeError:
191+
# cannot construct a panel with just 1D nans indicating no data
192+
msg = "No data fetched using {0!r}"
193+
raise RemoteDataError(msg.format(self.__class__.__name__))
194+
195+
196+
def _in_chunks(seq, size):
197+
"""
198+
Return sequence in 'chunks' of size defined by size
199+
"""
200+
return (seq[pos:pos + size] for pos in range(0, len(seq), size))

0 commit comments

Comments
 (0)