Skip to content

CLN: Cleanup WorldBank reader #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 25, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions pandas_datareader/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import warnings

import pandas as pd
from pandas.core.common import PandasError


Expand All @@ -10,3 +9,21 @@ class RemoteDataError(PandasError, IOError):
pass


from distutils.version import LooseVersion

PANDAS_VERSION = LooseVersion(pd.__version__)

if PANDAS_VERSION >= LooseVersion('0.17.0'):
PANDAS_0170 = True
else:
PANDAS_0170 = False

if PANDAS_VERSION >= LooseVersion('0.16.0'):
PANDAS_0160 = True
else:
PANDAS_0160 = False

if PANDAS_VERSION >= LooseVersion('0.14.0'):
PANDAS_0140 = True
else:
PANDAS_0140 = False
17 changes: 15 additions & 2 deletions pandas_datareader/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pandas import to_datetime
import pandas.compat as compat
from pandas.core.common import PandasError
from pandas.core.common import PandasError, is_number
from pandas import Panel, DataFrame
from pandas import read_csv
from pandas.compat import StringIO, bytes_to_str
Expand Down Expand Up @@ -37,6 +37,7 @@ class _BaseReader(object):
"""

_chunk_size = 1024 * 1024
_format = 'string'

def __init__(self, symbols, start=None, end=None,
retry_count=3, pause=0.001, session=None):
Expand Down Expand Up @@ -73,7 +74,12 @@ def read(self):

def _read_one_data(self, url, params):
""" read one data from specified URL """
out = self._read_url_as_StringIO(self.url, params=params)
if self._format == 'string':
out = self._read_url_as_StringIO(url, params=params)
elif self._format == 'json':
out = self._get_response(url, params=params).json()
else:
raise NotImplementedError(self._format)
return self._read_lines(out)

def _read_url_as_StringIO(self, url, params=None):
Expand Down Expand Up @@ -128,8 +134,15 @@ def _sanitize_dates(self, start, end):
if start is None - default is 2010/01/01
if end is None - default is today
"""
if is_number(start):
# regard int as year
start = dt.datetime(start, 1, 1)
start = to_datetime(start)

if is_number(end):
end = dt.datetime(end, 1, 1)
end = to_datetime(end)

if start is None:
start = dt.datetime(2010, 1, 1)
if end is None:
Expand Down
7 changes: 4 additions & 3 deletions pandas_datareader/oecd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ class OECDReader(_BaseReader):

"""Get data for the given name from OECD."""

_format = 'json'

@property
def url(self):
url = 'http://stats.oecd.org/SDMX-JSON/data'
Expand All @@ -19,10 +21,9 @@ def url(self):
# API: https://data.oecd.org/api/sdmx-json-documentation/
return '{0}/{1}/all/all?'.format(url, self.symbols)

def _read_one_data(self, url, params):
def _read_lines(self, out):
""" read one data from specified URL """
resp = self._get_response(url)
df = read_jsdmx(resp.json())
df = read_jsdmx(out)
try:
idx_name = df.index.name # hack for pandas 0.16.2
df.index = pd.to_datetime(df.index)
Expand Down
183 changes: 156 additions & 27 deletions pandas_datareader/tests/test_wb.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
import nose
import time

import pandas
from pandas.util.testing import assert_frame_equal
import numpy as np
import pandas as pd
import pandas.util.testing as tm
import requests

from pandas_datareader.wb import search, download, get_countries

try:
from pandas.compat import u
except ImportError: # pragma: no cover
try:
unicode # python 2
def u(s):
return unicode(s, "unicode_escape")
except NameError:
def u(s):
return s
from pandas_datareader.wb import (search, download, get_countries,
get_indicators, WorldBankReader)
from pandas_datareader._utils import PANDAS_0170, PANDAS_0160, PANDAS_0140


class TestWB(tm.TestCase):
Expand All @@ -29,6 +22,19 @@ def test_wdi_search(self):
result = search('gdp.*capita.*constant')
self.assertTrue(result.name.str.contains('GDP').any())

# check cache returns the results within 0.5 sec
current_time = time.time()
result = search('gdp.*capita.*constant')
self.assertTrue(result.name.str.contains('GDP').any())
self.assertTrue(time.time() - current_time < 0.5)

result2 = WorldBankReader().search('gdp.*capita.*constant')
session = requests.Session()
result3 = search('gdp.*capita.*constant', session=session)
result4 = WorldBankReader(session=session).search('gdp.*capita.*constant')
for result in [result2, result3, result4]:
self.assertTrue(result.name.str.contains('GDP').any())

def test_wdi_download(self):

# Test a bad indicator with double (US), triple (USA),
Expand All @@ -43,18 +49,120 @@ def test_wdi_download(self):
cntry_codes = ['CA', 'MX', 'USA', 'US', 'US', 'KSV', 'BLA']
inds = ['NY.GDP.PCAP.CD','BAD.INDICATOR']

expected = {'NY.GDP.PCAP.CD': {('Canada', '2003'): 28026.006013044702, ('Mexico', '2003'): 6601.0420648056606, ('Canada', '2004'): 31829.522562759001, ('Kosovo', '2003'): 1969.56271307405, ('Mexico', '2004'): 7042.0247834044303, ('United States', '2004'): 41928.886136479705, ('United States', '2003'): 39682.472247320402, ('Kosovo', '2004'): 2135.3328465238301}}
expected = pandas.DataFrame(expected)
expected = {'NY.GDP.PCAP.CD': {('Canada', '2004'): 31829.522562759001, ('Canada', '2003'): 28026.006013044702,
('Kosovo', '2004'): 2135.3328465238301, ('Kosovo', '2003'): 1969.56271307405,
('Mexico', '2004'): 7042.0247834044303, ('Mexico', '2003'): 6601.0420648056606,
('United States', '2004'): 41928.886136479705, ('United States', '2003'): 39682.472247320402}}
expected = pd.DataFrame(expected)
# Round, to ignore revisions to data.
expected = pandas.np.round(expected,decimals=-3)
expected.sort(inplace=True)
expected = np.round(expected,decimals=-3)
if PANDAS_0170:
expected = expected.sort_index()
else:
expected = expected.sort()

result = download(country=cntry_codes, indicator=inds,
start=2003, end=2004, errors='ignore')
result.sort(inplace=True)
if PANDAS_0170:
result = result.sort_index()
else:
result = result.sort()
# Round, to ignore revisions to data.
result = np.round(result, decimals=-3)


if PANDAS_0140:
expected.index.names=['country', 'year']
else:
# prior versions doesn't allow to set multiple names to MultiIndex
# Thus overwrite it with the result
expected.index = result.index
tm.assert_frame_equal(result, expected)

# pass start and end as string
result = download(country=cntry_codes, indicator=inds,
start='2003', end='2004', errors='ignore')
if PANDAS_0170:
result = result.sort_index()
else:
result = result.sort()
# Round, to ignore revisions to data.
result = pandas.np.round(result,decimals=-3)
expected.index = result.index
assert_frame_equal(result, pandas.DataFrame(expected))
result = np.round(result, decimals=-3)
tm.assert_frame_equal(result, expected)

def test_wdi_download_str(self):

expected = {'NY.GDP.PCAP.CD': {('Japan', '2004'): 36441.50449394,
('Japan', '2003'): 33690.93772972,
('Japan', '2002'): 31235.58818439,
('Japan', '2001'): 32716.41867489,
('Japan', '2000'): 37299.64412913}}
expected = pd.DataFrame(expected)
# Round, to ignore revisions to data.
expected = np.round(expected, decimals=-3)
if PANDAS_0170:
expected = expected.sort_index()
else:
expected = expected.sort()

cntry_codes = 'JP'
inds = 'NY.GDP.PCAP.CD'
result = download(country=cntry_codes, indicator=inds,
start=2000, end=2004, errors='ignore')
if PANDAS_0170:
result = result.sort_index()
else:
result = result.sort()
result = np.round(result, decimals=-3)

if PANDAS_0140:
expected.index.names=['country', 'year']
else:
# prior versions doesn't allow to set multiple names to MultiIndex
# Thus overwrite it with the result
expected.index = result.index

tm.assert_frame_equal(result, expected)

result = WorldBankReader(inds, countries=cntry_codes,
start=2000, end=2004, errors='ignore').read()
if PANDAS_0170:
result = result.sort_index()
else:
result = result.sort()
result = np.round(result, decimals=-3)
tm.assert_frame_equal(result, expected)

def test_wdi_download_error_handling(self):
cntry_codes = ['USA', 'XX']
inds = 'NY.GDP.PCAP.CD'

with tm.assertRaisesRegexp(ValueError, "Invalid Country Code\\(s\\): XX"):
result = download(country=cntry_codes, indicator=inds,
start=2003, end=2004, errors='raise')

if PANDAS_0160:
# assert_produces_warning doesn't exists in prior versions
with self.assert_produces_warning():
result = download(country=cntry_codes, indicator=inds,
start=2003, end=2004, errors='warn')
self.assertTrue(isinstance(result, pd.DataFrame))
self.assertEqual(len(result), 2)

cntry_codes = ['USA']
inds = ['NY.GDP.PCAP.CD', 'BAD_INDICATOR']

with tm.assertRaisesRegexp(ValueError, "The provided parameter value is not valid\\. Indicator: BAD_INDICATOR"):
result = download(country=cntry_codes, indicator=inds,
start=2003, end=2004, errors='raise')

if PANDAS_0160:
with self.assert_produces_warning():
result = download(country=cntry_codes, indicator=inds,
start=2003, end=2004, errors='warn')
self.assertTrue(isinstance(result, pd.DataFrame))
self.assertEqual(len(result), 2)


def test_wdi_download_w_retired_indicator(self):

Expand Down Expand Up @@ -101,11 +209,32 @@ def test_wdi_download_w_crash_inducing_countrycode(self):
raise nose.SkipTest("Invalid results")

def test_wdi_get_countries(self):
result = get_countries()
self.assertTrue('Zimbabwe' in list(result['name']))
self.assertTrue(len(result) > 100)
self.assertTrue(pandas.notnull(result.latitude.mean()))
self.assertTrue(pandas.notnull(result.longitude.mean()))
result1 = get_countries()
result2 = WorldBankReader().get_countries()

session = requests.Session()
result3 = get_countries(session=session)
result4 = WorldBankReader(session=session).get_countries()

for result in [result1, result2, result3, result4]:
self.assertTrue('Zimbabwe' in list(result['name']))
self.assertTrue(len(result) > 100)
self.assertTrue(pd.notnull(result.latitude.mean()))
self.assertTrue(pd.notnull(result.longitude.mean()))

def test_wdi_get_indicators(self):
result1 = get_indicators()
result2 = WorldBankReader().get_indicators()

session = requests.Session()
result3 = get_indicators(session=session)
result4 = WorldBankReader(session=session).get_indicators()

for result in [result1, result2, result3, result4]:
exp_col = pd.Index(['id', 'name', 'source', 'sourceNote', 'sourceOrganization', 'topics'])
# assert_index_equal doesn't exists
self.assertTrue(result.columns.equals(exp_col))
self.assertTrue(len(result) > 10000)


if __name__ == '__main__':
Expand Down
Loading