From 71df260159105d4814672b32c5fa0b236db9dd73 Mon Sep 17 00:00:00 2001 From: sinhrks Date: Fri, 9 Oct 2015 22:35:06 +0900 Subject: [PATCH] CLN: Cleanup subclass --- pandas_datareader/_utils.py | 122 ---------------- pandas_datareader/base.py | 200 +++++++++++++++++++++++++++ pandas_datareader/data.py | 67 ++++++--- pandas_datareader/fred.py | 50 +++---- pandas_datareader/google/daily.py | 48 +++---- pandas_datareader/oecd.py | 50 +++---- pandas_datareader/tests/test_data.py | 61 +++++--- pandas_datareader/tests/test_oecd.py | 7 + pandas_datareader/tests/test_wb.py | 1 + pandas_datareader/wb.py | 40 +++--- pandas_datareader/yahoo/actions.py | 124 +++++++---------- pandas_datareader/yahoo/daily.py | 90 ++++++------ pandas_datareader/yahoo/quotes.py | 76 +++++----- setup.py | 2 +- 14 files changed, 509 insertions(+), 429 deletions(-) create mode 100644 pandas_datareader/base.py diff --git a/pandas_datareader/_utils.py b/pandas_datareader/_utils.py index 97f50ee2..bd665f70 100644 --- a/pandas_datareader/_utils.py +++ b/pandas_datareader/_utils.py @@ -1,21 +1,7 @@ -import time import warnings -import numpy as np -import datetime as dt -from pandas import to_datetime -import pandas.compat as compat from pandas.core.common import PandasError -from pandas import Panel, DataFrame -from pandas.io.common import urlopen -from pandas import read_csv -from pandas.compat import StringIO, bytes_to_str -from pandas.util.testing import _network_error_classes -if compat.PY3: - from urllib.parse import urlencode -else: - from urllib import urlencode class SymbolWarning(UserWarning): pass @@ -23,112 +9,4 @@ class SymbolWarning(UserWarning): class RemoteDataError(PandasError, IOError): pass -def _get_data_from(symbols, start, end, interval, retry_count, pause, - chunksize, src_fn): - # If a single symbol, (e.g., 'GOOG') - if isinstance(symbols, (compat.string_types, int)): - hist_data = src_fn(symbols, start, end, interval, retry_count, pause) - # Or multiple symbols, (e.g., ['GOOG', 'AAPL', 'MSFT']) - elif isinstance(symbols, DataFrame): - hist_data = _dl_mult_symbols(symbols.index, start, end, interval, chunksize, - retry_count, pause, src_fn) - else: - hist_data = _dl_mult_symbols(symbols, start, end, interval, chunksize, - retry_count, pause, src_fn) - return hist_data - -def _dl_mult_symbols(symbols, start, end, interval, chunksize, retry_count, pause, - method): - stocks = {} - failed = [] - passed = [] - for sym_group in _in_chunks(symbols, chunksize): - for sym in sym_group: - try: - stocks[sym] = method(sym, start, end, interval, retry_count, pause) - passed.append(sym) - except IOError: - warnings.warn('Failed to read symbol: {0!r}, replacing with ' - 'NaN.'.format(sym), SymbolWarning) - failed.append(sym) - - if len(passed) == 0: - raise RemoteDataError("No data fetched using " - "{0!r}".format(method.__name__)) - try: - if len(stocks) > 0 and len(failed) > 0 and len(passed) > 0: - df_na = stocks[passed[0]].copy() - df_na[:] = np.nan - for sym in failed: - stocks[sym] = df_na - return Panel(stocks).swapaxes('items', 'minor') - except AttributeError: - # cannot construct a panel with just 1D nans indicating no data - raise RemoteDataError("No data fetched using " - "{0!r}".format(method.__name__)) - - -def _sanitize_dates(start, end): - """ - Return (datetime_start, datetime_end) tuple - if start is None - default is 2010/01/01 - if end is None - default is today - """ - start = to_datetime(start) - end = to_datetime(end) - if start is None: - start = dt.datetime(2010, 1, 1) - if end is None: - end = dt.datetime.today() - return start, end - -def _in_chunks(seq, size): - """ - Return sequence in 'chunks' of size defined by size - """ - return (seq[pos:pos + size] for pos in range(0, len(seq), size)) - -def _encode_url(url, params): - """ - Return encoded url with parameters - """ - s_params = urlencode(params) - if s_params: - return url + '?' + s_params - else: - return url - -def _retry_read_url(url, retry_count, pause, name): - """ - Open url (and retry) - """ - for _ in range(retry_count): - - # kludge to close the socket ASAP - try: - with urlopen(url) as resp: - lines = resp.read() - except _network_error_classes: - pass - else: - rs = read_csv(StringIO(bytes_to_str(lines)), index_col=0, - parse_dates=True, na_values='-')[::-1] - # Yahoo! Finance sometimes does this awesome thing where they - # return 2 rows for the most recent business day - if len(rs) > 2 and rs.index[-1] == rs.index[-2]: # pragma: no cover - rs = rs[:-1] - - #Get rid of unicode characters in index name. - try: - rs.index.name = rs.index.name.decode('unicode_escape').encode('ascii', 'ignore') - except AttributeError: - #Python 3 string has no decode method. - rs.index.name = rs.index.name.encode('ascii', 'ignore').decode() - - return rs - - time.sleep(pause) - - raise IOError("after %d tries, %s did not " - "return a 200 for url %r" % (retry_count, name, url)) diff --git a/pandas_datareader/base.py b/pandas_datareader/base.py new file mode 100644 index 00000000..9a3ce0cd --- /dev/null +++ b/pandas_datareader/base.py @@ -0,0 +1,200 @@ +import time +import warnings +import numpy as np +import datetime as dt + +import requests + +from pandas import to_datetime +import pandas.compat as compat +from pandas.core.common import PandasError +from pandas import Panel, DataFrame +from pandas import read_csv +from pandas.compat import StringIO, bytes_to_str +from pandas.util.testing import _network_error_classes + +from pandas_datareader._utils import RemoteDataError, SymbolWarning + + +class _BaseReader(object): + + """ + + Parameters + ---------- + sym : string with a single Single stock symbol (ticker). + start : string, (defaults to '1/1/2010') + Starting date, timestamp. Parses many different kind of date + representations (e.g., 'JAN-01-2010', '1/1/10', 'Jan, 1, 1980') + end : string, (defaults to today) + Ending date, timestamp. Same format as starting date. + retry_count : int, default 3 + Number of times to retry query request. + pause : int, default 0 + Time, in seconds, of the pause between retries. + session : Session, default None + requests.sessions.Session instance to be used + """ + + _chunk_size = 1024 * 1024 + + def __init__(self, symbols, start=None, end=None, + retry_count=3, pause=0.001, session=None): + self.symbols = symbols + + start, end = self._sanitize_dates(start, end) + self.start = start + self.end = end + + if not isinstance(retry_count, int) or retry_count < 0: + raise ValueError("'retry_count' must be integer larger than 0") + self.retry_count = retry_count + self.pause = pause + self.session = self._init_session(session, retry_count) + + def _init_session(self, session, retry_count): + if session is None: + session = requests.Session() + # do not set requests max_retries here to support arbitrary pause + return session + + @property + def url(self): + # must be overridden in subclass + raise NotImplementedError + + @property + def params(self): + return None + + def read(self): + """ read data """ + return self._read_one_data(self.url, self.params) + + def _read_one_data(self, url, params): + """ read one data from specified URL """ + out = self._read_url_as_StringIO(self.url, params=params) + return self._read_lines(out) + + def _read_url_as_StringIO(self, url, params=None): + """ + Open url (and retry) + """ + response = self._get_response(url, params=params) + out = StringIO() + if isinstance(response.content, compat.binary_type): + out.write(bytes_to_str(response.content)) + else: + out.write(response.content) + out.seek(0) + return out + + def _get_response(self, url, params=None): + """ send raw HTTP request to get requests.Response from the specified url + Parameters + ---------- + url : str + target URL + params : dict or None + parameters passed to the URL + """ + + # initial attempt + retry + for i in range(self.retry_count + 1): + response = self.session.get(url, params=params) + if response.status_code == requests.codes.ok: + return response + time.sleep(self.pause) + + raise RemoteDataError('Unable to read URL: {0}'.format(url)) + + def _read_lines(self, out): + rs = read_csv(out, index_col=0, parse_dates=True, na_values='-')[::-1] + # Yahoo! Finance sometimes does this awesome thing where they + # return 2 rows for the most recent business day + if len(rs) > 2 and rs.index[-1] == rs.index[-2]: # pragma: no cover + rs = rs[:-1] + #Get rid of unicode characters in index name. + try: + rs.index.name = rs.index.name.decode('unicode_escape').encode('ascii', 'ignore') + except AttributeError: + #Python 3 string has no decode method. + rs.index.name = rs.index.name.encode('ascii', 'ignore').decode() + return rs + + def _sanitize_dates(self, start, end): + """ + Return (datetime_start, datetime_end) tuple + if start is None - default is 2010/01/01 + if end is None - default is today + """ + start = to_datetime(start) + end = to_datetime(end) + if start is None: + start = dt.datetime(2010, 1, 1) + if end is None: + end = dt.datetime.today() + return start, end + + +class _DailyBaseReader(_BaseReader): + """ Base class for Google / Yahoo daily reader """ + + def __init__(self, symbols=None, start=None, end=None, retry_count=3, + pause=0.001, session=None, chunksize=25): + super(_DailyBaseReader, self).__init__(symbols=symbols, + start=start, end=end, + retry_count=retry_count, + pause=pause, session=session) + self.chunksize = chunksize + + def _get_params(self, *args, **kwargs): + raise NotImplementedError + + def read(self): + """ read data """ + # If a single symbol, (e.g., 'GOOG') + if isinstance(self.symbols, (compat.string_types, int)): + df = self._read_one_data(self.url, params=self._get_params(self.symbols)) + # Or multiple symbols, (e.g., ['GOOG', 'AAPL', 'MSFT']) + elif isinstance(self.symbols, DataFrame): + df = self._dl_mult_symbols(self.symbols.index) + else: + df = self._dl_mult_symbols(self.symbols) + return df + + def _dl_mult_symbols(self, symbols): + stocks = {} + failed = [] + passed = [] + for sym_group in _in_chunks(symbols, self.chunksize): + for sym in sym_group: + try: + stocks[sym] = self._read_one_data(self.url, self._get_params(sym)) + passed.append(sym) + except IOError: + msg = 'Failed to read symbol: {0!r}, replacing with NaN.' + warnings.warn(msg.format(sym), SymbolWarning) + failed.append(sym) + + if len(passed) == 0: + msg = "No data fetched using {0!r}" + raise RemoteDataError(msg.format(self.__class__.__name__)) + try: + if len(stocks) > 0 and len(failed) > 0 and len(passed) > 0: + df_na = stocks[passed[0]].copy() + df_na[:] = np.nan + for sym in failed: + stocks[sym] = df_na + return Panel(stocks).swapaxes('items', 'minor') + except AttributeError: + # cannot construct a panel with just 1D nans indicating no data + msg = "No data fetched using {0!r}" + raise RemoteDataError(msg.format(self.__class__.__name__)) + + +def _in_chunks(seq, size): + """ + Return sequence in 'chunks' of size defined by size + """ + return (seq[pos:pos + size] for pos in range(0, len(seq), size)) \ No newline at end of file diff --git a/pandas_datareader/data.py b/pandas_datareader/data.py index fe2a3d14..2f8a99c8 100644 --- a/pandas_datareader/data.py +++ b/pandas_datareader/data.py @@ -1,29 +1,42 @@ """ Module contains tools for collecting data from various remote sources - - """ import warnings -from pandas_datareader._utils import _sanitize_dates - -from pandas_datareader.google.daily import _get_data as get_data_google +from pandas_datareader.google.daily import GoogleDailyReader from pandas_datareader.google.quotes import _get_data as get_quote_google -from pandas_datareader.yahoo.daily import _get_data as get_data_yahoo -from pandas_datareader.yahoo.quotes import _get_data as get_quote_yahoo -from pandas_datareader.yahoo.actions import _get_data as get_data_yahoo_actions +from pandas_datareader.yahoo.daily import YahooDailyReader +from pandas_datareader.yahoo.quotes import YahooQuotesReader +from pandas_datareader.yahoo.actions import YahooActionReader from pandas_datareader.yahoo.components import _get_data as get_components_yahoo from pandas_datareader.yahoo.options import Options as YahooOptions -from pandas_datareader.fred import _get_data as get_data_fred +from pandas_datareader.fred import FredReader from pandas_datareader.famafrench import _get_data as get_data_famafrench -from pandas_datareader.oecd import _get_data as get_data_oecd +from pandas_datareader.oecd import OECDReader + + +# ToDo: deprecate +def get_data_fred(*args, **kwargs): + return FredReader(*args, **kwargs).read() + +def get_data_google(*args, **kwargs): + return GoogleDailyReader(*args, **kwargs).read() + +def get_data_yahoo(*args, **kwargs): + return YahooDailyReader(*args, **kwargs).read() + +def get_data_yahoo_actions(*args, **kwargs): + return YahooActionReader(*args, **kwargs).read() + +def get_quote_yahoo(*args, **kwargs): + return YahooQuotesReader(*args, **kwargs).read() def DataReader(name, data_source=None, start=None, end=None, - retry_count=3, pause=0.001): + retry_count=3, pause=0.001, session=None): """ Imports data from a number of online sources. @@ -46,6 +59,8 @@ def DataReader(name, data_source=None, start=None, end=None, pause : {numeric, 0.001} Time, in seconds, to pause between consecutive queries of chunks. If single value given for symbol, represents the pause between retries. + session : Session, default None + requests.sessions.Session instance to be used Examples ---------- @@ -68,24 +83,30 @@ def DataReader(name, data_source=None, start=None, end=None, ff = DataReader("6_Portfolios_2x3", "famafrench") ff = DataReader("F-F_ST_Reversal_Factor", "famafrench") """ - start, end = _sanitize_dates(start, end) - if data_source == "yahoo": - return get_data_yahoo(symbols=name, start=start, end=end, - adjust_price=False, chunksize=25, - retry_count=retry_count, pause=pause) + return YahooDailyReader(symbols=name, start=start, end=end, + adjust_price=False, chunksize=25, + retry_count=retry_count, pause=pause, + session=session).read() elif data_source == "yahoo-actions": - return get_data_yahoo_actions(symbol=name, start=start, end=end, - retry_count=retry_count, pause=pause) + return YahooActionReader(symbol=name, start=start, end=end, + retry_count=retry_count, pause=pause, + session=session).read() elif data_source == "google": - return get_data_google(symbols=name, start=start, end=end, - chunksize=25, retry_count=retry_count, pause=pause) + return GoogleDailyReader(symbols=name, start=start, end=end, + chunksize=25, + retry_count=retry_count, pause=pause, + session=session).read() elif data_source == "fred": - return get_data_fred(name, start, end) + return FredReader(symbols=name, start=start, end=end, + retry_count=retry_count, pause=pause, + session=session).read() elif data_source == "famafrench": return get_data_famafrench(name) elif data_source == "oecd": - return get_data_oecd(name, start, end) + return OECDReader(symbols=name, start=start, end=end, + retry_count=retry_count, pause=pause, + session=session).read() else: raise NotImplementedError( "data_source=%r is not implemented" % data_source) @@ -95,7 +116,7 @@ def DataReader(name, data_source=None, start=None, end=None, def Options(symbol, data_source=None): if data_source is None: warnings.warn("Options(symbol) is deprecated, use Options(symbol," - " data_source) instead", FutureWarning, stacklevel=2) + " data_source) instead", FutureWarning, stacklevel=2) data_source = "yahoo" if data_source == "yahoo": return YahooOptions(symbol) diff --git a/pandas_datareader/fred.py b/pandas_datareader/fred.py index fbda60bb..23e3e561 100644 --- a/pandas_datareader/fred.py +++ b/pandas_datareader/fred.py @@ -1,15 +1,11 @@ -import datetime as dt from pandas.core.common import is_list_like -from pandas.io.common import urlopen from pandas import concat, read_csv -from pandas_datareader._utils import _sanitize_dates +from pandas_datareader.base import _BaseReader -_URL = "http://research.stlouisfed.org/fred2/series/" +class FredReader(_BaseReader): -def _get_data(name, start=dt.datetime(2010, 1, 1), - end=dt.datetime.today()): """ Get data for the given name from the St. Louis FED (FRED). Date format is datetime @@ -19,28 +15,32 @@ def _get_data(name, start=dt.datetime(2010, 1, 1), If multiple names are passed for "series" then the index of the DataFrame is the outer join of the indicies of each series. """ - start, end = _sanitize_dates(start, end) - if not is_list_like(name): - names = [name] - else: - names = name + @property + def url(self): + return "http://research.stlouisfed.org/fred2/series/" - urls = [_URL + '%s' % n + '/downloaddata/%s' % n + '.csv' for - n in names] + def read(self): + if not is_list_like(self.symbols): + names = [self.symbols] + else: + names = self.symbols - def fetch_data(url, name): - with urlopen(url) as resp: + urls = [self.url + '%s' % n + '/downloaddata/%s' % n + '.csv' for + n in names] + + def fetch_data(url, name): + resp = self._read_url_as_StringIO(url) data = read_csv(resp, index_col=0, parse_dates=True, header=None, skiprows=1, names=["DATE", name], na_values='.') - try: - return data.truncate(start, end) - except KeyError: # pragma: no cover - if data.ix[3].name[7:12] == 'Error': - raise IOError("Failed to get the data. Check that {0!r} is " - "a valid FRED series.".format(name)) - raise - df = concat([fetch_data(url, n) for url, n in zip(urls, names)], - axis=1, join='outer') - return df + try: + return data.truncate(self.start, self.end) + except KeyError: # pragma: no cover + if data.ix[3].name[7:12] == 'Error': + raise IOError("Failed to get the data. Check that {0!r} is " + "a valid FRED series.".format(name)) + raise + df = concat([fetch_data(url, n) for url, n in zip(urls, names)], + axis=1, join='outer') + return df diff --git a/pandas_datareader/google/daily.py b/pandas_datareader/google/daily.py index dd8c7daa..b232a072 100644 --- a/pandas_datareader/google/daily.py +++ b/pandas_datareader/google/daily.py @@ -1,13 +1,8 @@ -from pandas.io.common import urlencode -from pandas_datareader._utils import ( - _retry_read_url, _encode_url, _sanitize_dates, _get_data_from -) +from pandas_datareader.base import _DailyBaseReader -_URL = 'http://www.google.com/finance/historical' +class GoogleDailyReader(_DailyBaseReader): -def _get_data(symbols=None, start=None, end=None, retry_count=3, - pause=0.001, chunksize=25): """ Returns DataFrame/Panel of historical stock prices from symbols, over date range, start to end. To avoid being penalized by Google Finance servers, @@ -30,30 +25,19 @@ def _get_data(symbols=None, start=None, end=None, retry_count=3, single value given for symbol, represents the pause between retries. chunksize : int, default 25 Number of symbols to download consecutively before intiating pause. - - Returns - ------- - hist_data : DataFrame (str) or Panel (array-like object, DataFrame) - """ - return _get_data_from(symbols, start, end, None, retry_count, pause, - chunksize, _get_data_one) - - -def _get_data_one(sym, start, end, interval, retry_count, pause): - """ - Get historical data for the given name from google. - Date format is datetime - - Returns a DataFrame. + session : Session, default None + requests.sessions.Session instance to be used """ - start, end = _sanitize_dates(start, end) - # www.google.com/finance/historical?q=GOOG&startdate=Jun+9%2C+2011&enddate=Jun+8%2C+2013&output=csv - params = { - 'q': sym, - 'startdate': start.strftime('%b %d, %Y'), - 'enddate': end.strftime('%b %d, %Y'), - 'output': "csv" - } - url = _encode_url(_URL, params) - return _retry_read_url(url, retry_count, pause, 'Google') + @property + def url(self): + return 'http://www.google.com/finance/historical' + + def _get_params(self, symbol): + params = { + 'q': symbol, + 'startdate': self.start.strftime('%b %d, %Y'), + 'enddate': self.end.strftime('%b %d, %Y'), + 'output': "csv" + } + return params diff --git a/pandas_datareader/oecd.py b/pandas_datareader/oecd.py index 87905691..6b692115 100644 --- a/pandas_datareader/oecd.py +++ b/pandas_datareader/oecd.py @@ -3,45 +3,37 @@ import pandas as pd from pandas.core.common import is_list_like import pandas.compat as compat -from pandas.io.common import _urlopen from pandas import concat, read_csv -from pandas_datareader._utils import _sanitize_dates from pandas_datareader.io import read_jsdmx -_URL = 'http://stats.oecd.org/SDMX-JSON/data' +from pandas_datareader.base import _BaseReader -def _get_data(name, start=dt.datetime(2010, 1, 1), - end=dt.datetime.today()): - """ - Get data for the given name from OECD. - Date format is datetime +class OECDReader(_BaseReader): - Returns a DataFrame. - """ - start, end = _sanitize_dates(start, end) + """Get data for the given name from OECD.""" - if not isinstance(name, compat.string_types): - raise ValueError('data name must be string') + @property + def url(self): + url = 'http://stats.oecd.org/SDMX-JSON/data' + if not isinstance(self.symbols, compat.string_types): + raise ValueError('data name must be string') - # API: https://data.oecd.org/api/sdmx-json-documentation/ - url = '{0}/{1}/all/all?'.format(_URL, name) - def fetch_data(url, name): - resp = _urlopen(url) - resp = resp.read() - resp = resp.decode('utf-8') - data = read_jsdmx(resp) + # API: https://data.oecd.org/api/sdmx-json-documentation/ + return '{0}/{1}/all/all?'.format(url, self.symbols) + + def _read_one_data(self, url, params): + """ read one data from specified URL """ + resp = self._get_response(url) + df = read_jsdmx(resp.json()) try: - idx_name = data.index.name # hack for pandas 0.16.2 - data.index = pd.to_datetime(data.index) - data = data.sort_index() - data = data.truncate(start, end) - data.index.name = idx_name + idx_name = df.index.name # hack for pandas 0.16.2 + df.index = pd.to_datetime(df.index) + df = df.sort_index() + df = df.truncate(self.start, self.end) + df.index.name = idx_name except ValueError: pass - return data - df = fetch_data(url, name) - return df - + return df diff --git a/pandas_datareader/tests/test_data.py b/pandas_datareader/tests/test_data.py index 65e24a4f..00dd609d 100644 --- a/pandas_datareader/tests/test_data.py +++ b/pandas_datareader/tests/test_data.py @@ -1,4 +1,3 @@ -from __future__ import print_function from pandas import compat import warnings import nose @@ -6,6 +5,8 @@ from datetime import datetime import os +import requests + import numpy as np import pandas as pd from pandas import DataFrame, Timestamp @@ -18,16 +19,14 @@ import pandas.util.testing as tm from numpy.testing import assert_array_equal -try: - from urllib.error import HTTPError -except ImportError: # pragma: no cover - from urllib2 import HTTPError - import pandas_datareader.data as web -from pandas_datareader.data import DataReader +from pandas_datareader.data import (DataReader, GoogleDailyReader, YahooDailyReader, + YahooQuotesReader, YahooActionReader, + FredReader, OECDReader) from pandas_datareader._utils import SymbolWarning, RemoteDataError from pandas_datareader.yahoo.quotes import _yahoo_codes + def _skip_if_no_lxml(): try: import lxml @@ -47,6 +46,7 @@ def assert_n_failed_equals_n_null_columns(wngs, obj, cls=SymbolWarning): class TestGoogle(tm.TestCase): + @classmethod def setUpClass(cls): super(TestGoogle, cls).setUpClass() @@ -133,7 +133,21 @@ def test_dtypes(self): def test_unicode_date(self): #GH8967 data = web.get_data_google('F', start='JAN-01-10', end='JAN-27-13') - self.assertEquals(data.index.name, 'Date') + self.assertEqual(data.index.name, 'Date') + + def test_google_reader_class(self): + r = GoogleDailyReader('GOOG') + df = r.read() + self.assertEqual(df.Volume.ix['JAN-02-2015'], 1446662) + + session = requests.Session() + r = GoogleDailyReader('GOOG', session=session) + self.assertTrue(r.session is session) + + def test_bad_retry_count(self): + + with tm.assertRaises(ValueError): + web.get_data_google('F', retry_count = -1) class TestYahoo(tm.TestCase): @@ -296,6 +310,15 @@ def test_get_data_yahoo_actions_invalid_symbol(self): self.assertRaises(IOError, web.get_data_yahoo_actions, 'UNKNOWN TICKER', start, end) + def test_yahoo_reader_class(self): + r = YahooDailyReader('GOOG') + df = r.read() + self.assertEqual(df.Volume.loc['JAN-02-2015'], 1447600) + + session = requests.Session() + r = YahooDailyReader('GOOG', session=session) + self.assertTrue(r.session is session) + class TestYahooOptions(tm.TestCase): @classmethod @@ -396,37 +419,37 @@ def test_get_underlying_price(self): quote_price = options_object._underlying_price_from_root(root) except RemoteDataError as e: # pragma: no cover raise nose.SkipTest(e) - self.assert_(isinstance(quote_price, float)) + self.assertTrue(isinstance(quote_price, float)) def test_sample_page_price_quote_time1(self): #Tests the weekend quote time format price, quote_time = self.aapl._underlying_price_and_time_from_url(self.html1) - self.assert_(isinstance(price, (int, float, complex))) - self.assert_(isinstance(quote_time, (datetime, Timestamp))) + self.assertTrue(isinstance(price, (int, float, complex))) + self.assertTrue(isinstance(quote_time, (datetime, Timestamp))) def test_chop(self): #regression test for #7625 self.aapl.chop_data(self.data1, above_below=2, underlying_price=np.nan) chopped = self.aapl.chop_data(self.data1, above_below=2, underlying_price=100) - self.assert_(isinstance(chopped, DataFrame)) + self.assertTrue(isinstance(chopped, DataFrame)) self.assertTrue(len(chopped) > 1) chopped2 = self.aapl.chop_data(self.data1, above_below=2, underlying_price=None) - self.assert_(isinstance(chopped2, DataFrame)) + self.assertTrue(isinstance(chopped2, DataFrame)) self.assertTrue(len(chopped2) > 1) def test_chop_out_of_strike_range(self): #regression test for #7625 self.aapl.chop_data(self.data1, above_below=2, underlying_price=np.nan) chopped = self.aapl.chop_data(self.data1, above_below=2, underlying_price=100000) - self.assert_(isinstance(chopped, DataFrame)) + self.assertTrue(isinstance(chopped, DataFrame)) self.assertTrue(len(chopped) > 1) def test_sample_page_price_quote_time2(self): #Tests the EDT page format #regression test for #8741 price, quote_time = self.aapl._underlying_price_and_time_from_url(self.html2) - self.assert_(isinstance(price, (int, float, complex))) - self.assert_(isinstance(quote_time, (datetime, Timestamp))) + self.assertTrue(isinstance(price, (int, float, complex))) + self.assertTrue(isinstance(quote_time, (datetime, Timestamp))) def test_sample_page_chg_float(self): #Tests that numeric columns with comma's are appropriately dealt with @@ -501,8 +524,8 @@ def test_fred(self): #self.assertEqual(int(received), 16502) self.assertEqual(int(received), 16440) - self.assertRaises(Exception, web.DataReader, "NON EXISTENT SERIES", - 'fred', start, end) + with tm.assertRaises(RemoteDataError): + web.DataReader("NON EXISTENT SERIES", 'fred', start, end) def test_fred_nan(self): start = datetime(2010, 1, 1) @@ -550,7 +573,7 @@ def test_fred_multi(self): # pragma: no cover def test_fred_multi_bad_series(self): names = ['NOTAREALSERIES', 'CPIAUCSL', "ALSO FAKE"] - with tm.assertRaises(HTTPError): + with tm.assertRaises(RemoteDataError): DataReader(names, data_source="fred") diff --git a/pandas_datareader/tests/test_oecd.py b/pandas_datareader/tests/test_oecd.py index 971dbbd2..2f9d73f2 100644 --- a/pandas_datareader/tests/test_oecd.py +++ b/pandas_datareader/tests/test_oecd.py @@ -4,6 +4,7 @@ from pandas.compat import range import pandas.util.testing as tm import pandas_datareader.data as web +from pandas_datareader._utils import SymbolWarning, RemoteDataError class TestOECD(tm.TestCase): @@ -62,6 +63,12 @@ def test_get_tourism(self): expected = pd.Series(values, index=index, name='Total international arrivals') tm.assert_series_equal(df[label]['Total international arrivals'], expected) + def test_oecd_invalid_symbol(self): + with tm.assertRaises(RemoteDataError): + web.DataReader('INVALID_KEY', 'oecd') + + with tm.assertRaises(ValueError): + web.DataReader(1234, 'oecd') if __name__ == '__main__': import nose diff --git a/pandas_datareader/tests/test_wb.py b/pandas_datareader/tests/test_wb.py index a3f85fef..1515b0f8 100644 --- a/pandas_datareader/tests/test_wb.py +++ b/pandas_datareader/tests/test_wb.py @@ -107,6 +107,7 @@ def test_wdi_get_countries(self): self.assertTrue(pandas.notnull(result.latitude.mean())) self.assertTrue(pandas.notnull(result.longitude.mean())) + if __name__ == '__main__': nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'], exit=False) # pragma: no cover diff --git a/pandas_datareader/wb.py b/pandas_datareader/wb.py index c975334e..864363f0 100644 --- a/pandas_datareader/wb.py +++ b/pandas_datareader/wb.py @@ -1,7 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import print_function - import warnings from pandas.compat import reduce, lrange @@ -12,8 +10,8 @@ # This list of country codes was pulled from wikipedia during October 2014. # While some exceptions do exist, it is the best proxy for countries supported -# by World Bank. It is an aggregation of the 2-digit ISO 3166-1 alpha-2, and -# 3-digit ISO 3166-1 alpha-3, codes, with 'all', 'ALL', and 'All' appended ot +# by World Bank. It is an aggregation of the 2-digit ISO 3166-1 alpha-2, and +# 3-digit ISO 3166-1 alpha-3, codes, with 'all', 'ALL', and 'All' appended ot # the end. country_codes = ['AD', 'AE', 'AF', 'AG', 'AI', 'AL', 'AM', 'AO', 'AQ', 'AR', \ @@ -84,35 +82,35 @@ def download(country=['MX', 'CA', 'US'], indicator=['NY.GDP.MKTP.CD', 'NY.GNS.IC indicator: string or list of strings taken from the ``id`` field in ``WDIsearch()`` - + country: string or list of strings. ``all`` downloads data for all countries 2 or 3 character ISO country codes select individual countries (e.g.``US``,``CA``) or (e.g.``USA``,``CAN``). The codes can be mixed. - + The two ISO lists of countries, provided by wikipedia, are hardcoded into pandas as of 11/10/2014. - + start: int First year of the data series - + end: int Last year of the data series (inclusive) - + errors: str {'ignore', 'warn', 'raise'}, default 'warn' Country codes are validated against a hardcoded list. This controls the outcome of that validation, and attempts to also apply to the results from world bank. - + errors='raise', will raise a ValueError on a bad country code. - + Returns ------- - ``pandas`` DataFrame with columns: country, iso_code, year, + ``pandas`` DataFrame with columns: country, iso_code, year, indicator value. - + """ if type(country) == str: @@ -131,7 +129,7 @@ def download(country=['MX', 'CA', 'US'], indicator=['NY.GDP.MKTP.CD', 'NY.GNS.IC # Work with a list of indicators if type(indicator) == str: indicator = [indicator] - + # Download data = [] bad_indicators = {} @@ -166,12 +164,12 @@ def download(country=['MX', 'CA', 'US'], indicator=['NY.GDP.MKTP.CD', 'NY.GNS.IC def _get_data(indicator="NY.GNS.ICTR.GN.ZS", country='US', start=2002, end=2005): - + if type(country) == str: country = [country] - + countries = ';'.join(country) - + # Build URL for api call url = ("http://api.worldbank.org/countries/" + countries + "/indicators/" + indicator + "?date=" + str(start) + ":" + str(end) + @@ -196,11 +194,11 @@ def _get_data(indicator="NY.GNS.ICTR.GN.ZS", country='US', wb_err += msg['value'] error_msg = "Problem with a World Bank Query \n %s" return None, error_msg % wb_err - + if 'total' in possible_message.keys(): if possible_message['total'] == 0: return None, "No results from world bank." - + # Parse JSON file data = json.loads(data)[1] country = [x['country']['value'] for x in data] @@ -214,8 +212,8 @@ def _get_data(indicator="NY.GNS.ICTR.GN.ZS", country='US', def get_countries(): '''Query information about countries - - Provides information such as: + + Provides information such as: country code, region, income level, capital city, latitude and longitude ''' url = 'http://api.worldbank.org/countries/?per_page=1000&format=json' diff --git a/pandas_datareader/yahoo/actions.py b/pandas_datareader/yahoo/actions.py index b4ff03ca..2617c981 100644 --- a/pandas_datareader/yahoo/actions.py +++ b/pandas_datareader/yahoo/actions.py @@ -1,87 +1,63 @@ -import time import csv from pandas import to_datetime, DataFrame -from pandas.io.common import urlopen -from pandas.util.testing import _network_error_classes from pandas.compat import StringIO, bytes_to_str -from pandas_datareader._utils import _sanitize_dates, _encode_url +from pandas_datareader.base import _BaseReader -_URL = 'http://ichart.finance.yahoo.com/x' +class YahooActionReader(_BaseReader): -def _get_data(symbol, start=None, end=None, retry_count=3, pause=0.001): """ Returns DataFrame of historical corporate actions (dividends and stock splits) from symbols, over date range, start to end. All dates in the resulting DataFrame correspond with dividend and stock split ex-dates. - - Parameters - ---------- - sym : string with a single Single stock symbol (ticker). - start : string, (defaults to '1/1/2010') - Starting date, timestamp. Parses many different kind of date - representations (e.g., 'JAN-01-2010', '1/1/10', 'Jan, 1, 1980') - end : string, (defaults to today) - Ending date, timestamp. Same format as starting date. - retry_count : int, default 3 - Number of times to retry query request. - pause : int, default 0 - Time, in seconds, of the pause between retries. """ - start, end = _sanitize_dates(start, end) - params = { - 's': symbol, - 'a': start.month - 1, - 'b': start.day, - 'c': start.year, - 'd': end.month - 1, - 'e': end.day, - 'f': end.year, - 'g': 'v' - } - url = _encode_url(_URL, params) - - for _ in range(retry_count): - - try: - with urlopen(url) as resp: - lines = resp.read() - except _network_error_classes: - pass - else: - actions_index = [] - actions_entries = [] - - for line in csv.reader(StringIO(bytes_to_str(lines))): - # Ignore lines that aren't dividends or splits (Yahoo - # add a bunch of irrelevant fields.) - if len(line) != 3 or line[0] not in ('DIVIDEND', 'SPLIT'): - continue - - action, date, value = line - if action == 'DIVIDEND': - actions_index.append(to_datetime(date)) - actions_entries.append({ - 'action': action, - 'value': float(value) - }) - elif action == 'SPLIT' and ':' in value: - # Convert the split ratio to a fraction. For example a - # 4:1 split expressed as a fraction is 1/4 = 0.25. - denominator, numerator = value.split(':', 1) - split_fraction = float(numerator) / float(denominator) - - actions_index.append(to_datetime(date)) - actions_entries.append({ - 'action': action, - 'value': split_fraction - }) - - return DataFrame(actions_entries, index=actions_index) - - time.sleep(pause) - - raise IOError("after %d tries, Yahoo! did not " \ - "return a 200 for url %r" % (retry_count, url)) + @property + def url(self): + return 'http://ichart.finance.yahoo.com/x' + + @property + def params(self): + params = { + 's': self.symbols, + 'a': self.start.month - 1, + 'b': self.start.day, + 'c': self.start.year, + 'd': self.end.month - 1, + 'e': self.end.day, + 'f': self.end.year, + 'g': 'v' + } + return params + + def _read_lines(self, out): + actions_index = [] + actions_entries = [] + + for line in csv.reader(out.readlines()): + # Ignore lines that aren't dividends or splits (Yahoo + # add a bunch of irrelevant fields.) + if len(line) != 3 or line[0] not in ('DIVIDEND', 'SPLIT'): + continue + + action, date, value = line + if action == 'DIVIDEND': + actions_index.append(to_datetime(date)) + actions_entries.append({ + 'action': action, + 'value': float(value) + }) + elif action == 'SPLIT' and ':' in value: + # Convert the split ratio to a fraction. For example a + # 4:1 split expressed as a fraction is 1/4 = 0.25. + denominator, numerator = value.split(':', 1) + split_fraction = float(numerator) / float(denominator) + + actions_index.append(to_datetime(date)) + actions_entries.append({ + 'action': action, + 'value': split_fraction + }) + + return DataFrame(actions_entries, index=actions_index) diff --git a/pandas_datareader/yahoo/daily.py b/pandas_datareader/yahoo/daily.py index 4a9708e2..2fb3c23c 100644 --- a/pandas_datareader/yahoo/daily.py +++ b/pandas_datareader/yahoo/daily.py @@ -1,12 +1,8 @@ -from pandas_datareader._utils import ( - _retry_read_url, _encode_url, _sanitize_dates, _get_data_from -) +from pandas_datareader.base import _DailyBaseReader -_URL = 'http://ichart.finance.yahoo.com/table.csv' -def _get_data(symbols=None, start=None, end=None, retry_count=3, - pause=0.001, adjust_price=False, ret_index=False, - chunksize=25, interval='d'): +class YahooDailyReader(_DailyBaseReader): + """ Returns DataFrame/Panel of historical stock prices from symbols, over date range, start to end. To avoid being penalized by Yahoo! Finance servers, @@ -27,6 +23,8 @@ def _get_data(symbols=None, start=None, end=None, retry_count=3, pause : int, default 0 Time, in seconds, to pause between consecutive queries of chunks. If single value given for symbol, represents the pause between retries. + session : Session, default None + requests.sessions.Session instance to be used adjust_price : bool, default False If True, adjusts all prices in hist_data ('Open', 'High', 'Low', 'Close') based on 'Adj Close' price. Adds 'Adj_Ratio' column and drops @@ -38,20 +36,51 @@ def _get_data(symbols=None, start=None, end=None, retry_count=3, interval : string, default 'd' Time interval code, valid values are 'd' for daily, 'w' for weekly, 'm' for monthly and 'v' for dividend. - - Returns - ------- - hist_data : DataFrame (str) or Panel (array-like object, DataFrame) """ - if interval not in ['d', 'w', 'm', 'v']: - raise ValueError("Invalid interval: valid values are 'd', 'w', 'm' and 'v'") - hist_data = _get_data_from(symbols, start, end, interval, retry_count, pause, \ - chunksize, _get_data_one) - if ret_index: - hist_data['Ret_Index'] = _calc_return_index(hist_data['Adj Close']) - if adjust_price: - hist_data = _adjust_prices(hist_data) - return hist_data + + def __init__(self, symbols=None, start=None, end=None, retry_count=3, + pause=0.001, session=None, adjust_price=False, ret_index=False, + chunksize=25, interval='d'): + super(YahooDailyReader, self).__init__(symbols=symbols, + start=start, end=end, + retry_count=retry_count, + pause=pause, session=session, + chunksize=chunksize) + self.adjust_price = adjust_price + self.ret_index = ret_index + + if interval not in ['d', 'w', 'm', 'v']: + raise ValueError("Invalid interval: valid values are 'd', 'w', 'm' and 'v'") + self.interval = interval + + @property + def url(self): + return 'http://ichart.finance.yahoo.com/table.csv' + + def _get_params(self, symbol): + params = { + 's': symbol, + 'a': self.start.month - 1, + 'b': self.start.day, + 'c': self.start.year, + 'd': self.end.month - 1, + 'e': self.end.day, + 'f': self.end.year, + 'g': self.interval, + 'ignore': '.csv' + } + return params + + + def read(self): + """ read one data from specified URL """ + df = super(YahooDailyReader, self).read() + if self.ret_index: + df['Ret_Index'] = _calc_return_index(df['Adj Close']) + if self.adjust_price: + df = _adjust_prices(df) + return df + def _adjust_prices(hist_data, price_list=None): """ @@ -88,24 +117,3 @@ def _calc_return_index(price_df): return df -def _get_data_one(sym, start, end, interval, retry_count, pause): - """ - Get historical data for the given name from yahoo. - Date format is datetime - - Returns a DataFrame. - """ - start, end = _sanitize_dates(start, end) - params = { - 's': sym, - 'a': start.month - 1, - 'b': start.day, - 'c': start.year, - 'd': end.month - 1, - 'e': end.day, - 'f': end.year, - 'g': interval, - 'ignore': '.csv' - } - url = _encode_url(_URL, params) - return _retry_read_url(url, retry_count, pause, 'Yahoo!') diff --git a/pandas_datareader/yahoo/quotes.py b/pandas_datareader/yahoo/quotes.py index 9cf54b4f..6413dda8 100644 --- a/pandas_datareader/yahoo/quotes.py +++ b/pandas_datareader/yahoo/quotes.py @@ -2,59 +2,51 @@ import csv import pandas.compat as compat -from pandas.io.common import urlopen from pandas import DataFrame -from pandas_datareader._utils import _encode_url + _yahoo_codes = {'symbol': 's', 'last': 'l1', 'change_pct': 'p2', 'PE': 'r', 'time': 't1', 'short_ratio': 's7'} -_URL = 'http://finance.yahoo.com/d/quotes.csv' - - -def _get_data(symbols): - """ - Get current yahoo quote +from pandas_datareader.base import _BaseReader - Returns a DataFrame - """ - if isinstance(symbols, compat.string_types): - sym_list = symbols - else: - sym_list = '+'.join(symbols) - # for codes see: http://www.gummy-stuff.org/Yahoo-data.htm - request = ''.join(compat.itervalues(_yahoo_codes)) # code request string - header = list(_yahoo_codes.keys()) +class YahooQuotesReader(_BaseReader): - data = defaultdict(list) + """Get current yahoo quote""" - params = { - 's': sym_list, - 'f': request - } - url = _encode_url(_URL, params) + @property + def url(self): + return 'http://finance.yahoo.com/d/quotes.csv' - with urlopen(url) as response: - lines = response.readlines() + @property + def params(self): + if isinstance(self.symbols, compat.string_types): + sym_list = self.symbols + else: + sym_list = '+'.join(self.symbols) + # for codes see: http://www.gummy-stuff.org/Yahoo-data.htm + request = ''.join(compat.itervalues(_yahoo_codes)) # code request string + params = {'s': sym_list, 'f': request} + return params - def line_gen(lines): - for line in lines: - yield line.decode('utf-8').strip() + def _read_lines(self, out): + data = defaultdict(list) + header = list(_yahoo_codes.keys()) - for line in csv.reader(line_gen(lines)): - for i, field in enumerate(line): - if field[-2:] == '%"': - v = float(field.strip('"%')) - elif field[0] == '"': - v = field.strip('"') - else: - try: - v = float(field) - except ValueError: - v = field - data[header[i]].append(v) + for line in csv.reader(out.readlines()): + for i, field in enumerate(line): + if field[-2:] == '%"': + v = float(field.strip('"%')) + elif field[0] == '"': + v = field.strip('"') + else: + try: + v = float(field) + except ValueError: + v = field + data[header[i]].append(v) - idx = data.pop('symbol') - return DataFrame(data, index=idx) + idx = data.pop('symbol') + return DataFrame(data, index=idx) diff --git a/setup.py b/setup.py index 5eb6874f..642a46cb 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ def readme(): return f.read() INSTALL_REQUIRES = ( - ['pandas'] + ['pandas', 'requests'] ) setup(