Skip to content

Commit e517194

Browse files
committed
CLN: Cleanup subclass
1 parent 2aaf9ca commit e517194

File tree

14 files changed

+499
-429
lines changed

14 files changed

+499
-429
lines changed

pandas_datareader/_utils.py

-122
Original file line numberDiff line numberDiff line change
@@ -1,134 +1,12 @@
1-
import time
21
import warnings
3-
import numpy as np
4-
import datetime as dt
52

6-
from pandas import to_datetime
7-
import pandas.compat as compat
83
from pandas.core.common import PandasError
9-
from pandas import Panel, DataFrame
10-
from pandas.io.common import urlopen
11-
from pandas import read_csv
12-
from pandas.compat import StringIO, bytes_to_str
13-
from pandas.util.testing import _network_error_classes
144

15-
if compat.PY3:
16-
from urllib.parse import urlencode
17-
else:
18-
from urllib import urlencode
195

206
class SymbolWarning(UserWarning):
217
pass
228

239
class RemoteDataError(PandasError, IOError):
2410
pass
2511

26-
def _get_data_from(symbols, start, end, interval, retry_count, pause,
27-
chunksize, src_fn):
2812

29-
# If a single symbol, (e.g., 'GOOG')
30-
if isinstance(symbols, (compat.string_types, int)):
31-
hist_data = src_fn(symbols, start, end, interval, retry_count, pause)
32-
# Or multiple symbols, (e.g., ['GOOG', 'AAPL', 'MSFT'])
33-
elif isinstance(symbols, DataFrame):
34-
hist_data = _dl_mult_symbols(symbols.index, start, end, interval, chunksize,
35-
retry_count, pause, src_fn)
36-
else:
37-
hist_data = _dl_mult_symbols(symbols, start, end, interval, chunksize,
38-
retry_count, pause, src_fn)
39-
return hist_data
40-
41-
def _dl_mult_symbols(symbols, start, end, interval, chunksize, retry_count, pause,
42-
method):
43-
stocks = {}
44-
failed = []
45-
passed = []
46-
for sym_group in _in_chunks(symbols, chunksize):
47-
for sym in sym_group:
48-
try:
49-
stocks[sym] = method(sym, start, end, interval, retry_count, pause)
50-
passed.append(sym)
51-
except IOError:
52-
warnings.warn('Failed to read symbol: {0!r}, replacing with '
53-
'NaN.'.format(sym), SymbolWarning)
54-
failed.append(sym)
55-
56-
if len(passed) == 0:
57-
raise RemoteDataError("No data fetched using "
58-
"{0!r}".format(method.__name__))
59-
try:
60-
if len(stocks) > 0 and len(failed) > 0 and len(passed) > 0:
61-
df_na = stocks[passed[0]].copy()
62-
df_na[:] = np.nan
63-
for sym in failed:
64-
stocks[sym] = df_na
65-
return Panel(stocks).swapaxes('items', 'minor')
66-
except AttributeError:
67-
# cannot construct a panel with just 1D nans indicating no data
68-
raise RemoteDataError("No data fetched using "
69-
"{0!r}".format(method.__name__))
70-
71-
72-
def _sanitize_dates(start, end):
73-
"""
74-
Return (datetime_start, datetime_end) tuple
75-
if start is None - default is 2010/01/01
76-
if end is None - default is today
77-
"""
78-
start = to_datetime(start)
79-
end = to_datetime(end)
80-
if start is None:
81-
start = dt.datetime(2010, 1, 1)
82-
if end is None:
83-
end = dt.datetime.today()
84-
return start, end
85-
86-
def _in_chunks(seq, size):
87-
"""
88-
Return sequence in 'chunks' of size defined by size
89-
"""
90-
return (seq[pos:pos + size] for pos in range(0, len(seq), size))
91-
92-
def _encode_url(url, params):
93-
"""
94-
Return encoded url with parameters
95-
"""
96-
s_params = urlencode(params)
97-
if s_params:
98-
return url + '?' + s_params
99-
else:
100-
return url
101-
102-
def _retry_read_url(url, retry_count, pause, name):
103-
"""
104-
Open url (and retry)
105-
"""
106-
for _ in range(retry_count):
107-
108-
# kludge to close the socket ASAP
109-
try:
110-
with urlopen(url) as resp:
111-
lines = resp.read()
112-
except _network_error_classes:
113-
pass
114-
else:
115-
rs = read_csv(StringIO(bytes_to_str(lines)), index_col=0,
116-
parse_dates=True, na_values='-')[::-1]
117-
# Yahoo! Finance sometimes does this awesome thing where they
118-
# return 2 rows for the most recent business day
119-
if len(rs) > 2 and rs.index[-1] == rs.index[-2]: # pragma: no cover
120-
rs = rs[:-1]
121-
122-
#Get rid of unicode characters in index name.
123-
try:
124-
rs.index.name = rs.index.name.decode('unicode_escape').encode('ascii', 'ignore')
125-
except AttributeError:
126-
#Python 3 string has no decode method.
127-
rs.index.name = rs.index.name.encode('ascii', 'ignore').decode()
128-
129-
return rs
130-
131-
time.sleep(pause)
132-
133-
raise IOError("after %d tries, %s did not "
134-
"return a 200 for url %r" % (retry_count, name, url))

pandas_datareader/base.py

+195
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import time
2+
import warnings
3+
import numpy as np
4+
import datetime as dt
5+
6+
import requests
7+
8+
from pandas import to_datetime
9+
import pandas.compat as compat
10+
from pandas.core.common import PandasError
11+
from pandas import Panel, DataFrame
12+
from pandas import read_csv
13+
from pandas.compat import StringIO, bytes_to_str
14+
from pandas.util.testing import _network_error_classes
15+
16+
from pandas_datareader._utils import RemoteDataError, SymbolWarning
17+
18+
19+
class BaseReader(object):
20+
21+
"""
22+
23+
Parameters
24+
----------
25+
sym : string with a single Single stock symbol (ticker).
26+
start : string, (defaults to '1/1/2010')
27+
Starting date, timestamp. Parses many different kind of date
28+
representations (e.g., 'JAN-01-2010', '1/1/10', 'Jan, 1, 1980')
29+
end : string, (defaults to today)
30+
Ending date, timestamp. Same format as starting date.
31+
retry_count : int, default 3
32+
Number of times to retry query request.
33+
pause : int, default 0
34+
Time, in seconds, of the pause between retries.
35+
session : Session, default None
36+
requests.sessions.Session instance to be used
37+
"""
38+
39+
_chunk_size = 1024 * 1024
40+
41+
def __init__(self, symbols, start=None, end=None,
42+
retry_count=3, pause=0.001, session=None):
43+
self.symbols = symbols
44+
45+
start, end = self._sanitize_dates(start, end)
46+
self.start = start
47+
self.end = end
48+
49+
self.pause = pause
50+
self.session = self._init_session(session, retry_count)
51+
52+
def _init_session(self, session, retry_count):
53+
if session is None:
54+
session = requests.Session()
55+
session.mount("http://", requests.adapters.HTTPAdapter(max_retries=retry_count))
56+
session.mount("https://", requests.adapters.HTTPAdapter(max_retries=retry_count))
57+
return session
58+
59+
@property
60+
def url(self):
61+
# must be overridden in subclass
62+
raise NotImplementedError
63+
64+
@property
65+
def params(self):
66+
return None
67+
68+
def read(self):
69+
""" read data """
70+
return self._read_one_data(self.url, self.params)
71+
72+
def _read_one_data(self, url, params):
73+
""" read one data from specified URL """
74+
out = self._read_url_as_StringIO(self.url, params=params)
75+
return self._read_lines(out)
76+
77+
def _read_url_as_StringIO(self, url, params=None):
78+
"""
79+
Open url (and retry)
80+
"""
81+
response = self._get_response(url, params=params)
82+
out = StringIO()
83+
if isinstance(response.content, compat.binary_type):
84+
out.write(bytes_to_str(response.content))
85+
else:
86+
out.write(response.content)
87+
out.seek(0)
88+
return out
89+
90+
def _get_response(self, url, params=None):
91+
""" send raw HTTP request to get requests.Response from the specified url
92+
Parameters
93+
----------
94+
url : str
95+
target URL
96+
params : dict or None
97+
parameters passed to the URL
98+
"""
99+
response = self.session.get(url, params=params)
100+
if response.status_code == requests.codes.ok:
101+
return response
102+
else:
103+
from pandas_datareader._utils import RemoteDataError
104+
raise RemoteDataError('Unable to read URL: {0}'.format(url))
105+
106+
def _read_lines(self, out):
107+
rs = read_csv(out, index_col=0, parse_dates=True, na_values='-')[::-1]
108+
# Yahoo! Finance sometimes does this awesome thing where they
109+
# return 2 rows for the most recent business day
110+
if len(rs) > 2 and rs.index[-1] == rs.index[-2]: # pragma: no cover
111+
rs = rs[:-1]
112+
#Get rid of unicode characters in index name.
113+
try:
114+
rs.index.name = rs.index.name.decode('unicode_escape').encode('ascii', 'ignore')
115+
except AttributeError:
116+
#Python 3 string has no decode method.
117+
rs.index.name = rs.index.name.encode('ascii', 'ignore').decode()
118+
return rs
119+
120+
def _sanitize_dates(self, start, end):
121+
"""
122+
Return (datetime_start, datetime_end) tuple
123+
if start is None - default is 2010/01/01
124+
if end is None - default is today
125+
"""
126+
start = to_datetime(start)
127+
end = to_datetime(end)
128+
if start is None:
129+
start = dt.datetime(2010, 1, 1)
130+
if end is None:
131+
end = dt.datetime.today()
132+
return start, end
133+
134+
135+
class DailyBaseReader(BaseReader):
136+
""" Base class for Google / Yahoo daily reader """
137+
138+
def __init__(self, symbols=None, start=None, end=None, retry_count=3,
139+
pause=0.001, session=None, chunksize=25):
140+
super(DailyBaseReader, self).__init__(symbols=symbols,
141+
start=start, end=end,
142+
retry_count=retry_count,
143+
pause=pause, session=session)
144+
self.chunksize = chunksize
145+
146+
def _get_params(self, *args, **kwargs):
147+
raise NotImplementedError
148+
149+
def read(self):
150+
""" read data """
151+
# If a single symbol, (e.g., 'GOOG')
152+
if isinstance(self.symbols, (compat.string_types, int)):
153+
df = self._read_one_data(self.url, params=self._get_params(self.symbols))
154+
# Or multiple symbols, (e.g., ['GOOG', 'AAPL', 'MSFT'])
155+
elif isinstance(self.symbols, DataFrame):
156+
df = self._dl_mult_symbols(self.symbols.index)
157+
else:
158+
df = self._dl_mult_symbols(self.symbols)
159+
return df
160+
161+
def _dl_mult_symbols(self, symbols):
162+
stocks = {}
163+
failed = []
164+
passed = []
165+
for sym_group in _in_chunks(symbols, self.chunksize):
166+
for sym in sym_group:
167+
try:
168+
stocks[sym] = self._read_one_data(self.url, self._get_params(sym))
169+
passed.append(sym)
170+
except IOError:
171+
msg = 'Failed to read symbol: {0!r}, replacing with NaN.'
172+
warnings.warn(msg.format(sym), SymbolWarning)
173+
failed.append(sym)
174+
175+
if len(passed) == 0:
176+
msg = "No data fetched using {0!r}"
177+
raise RemoteDataError(msg.format(self.__class__.__name__))
178+
try:
179+
if len(stocks) > 0 and len(failed) > 0 and len(passed) > 0:
180+
df_na = stocks[passed[0]].copy()
181+
df_na[:] = np.nan
182+
for sym in failed:
183+
stocks[sym] = df_na
184+
return Panel(stocks).swapaxes('items', 'minor')
185+
except AttributeError:
186+
# cannot construct a panel with just 1D nans indicating no data
187+
msg = "No data fetched using {0!r}"
188+
raise RemoteDataError(msg.format(self.__class__.__name__))
189+
190+
191+
def _in_chunks(seq, size):
192+
"""
193+
Return sequence in 'chunks' of size defined by size
194+
"""
195+
return (seq[pos:pos + size] for pos in range(0, len(seq), size))

0 commit comments

Comments
 (0)