Skip to content

Commit d533b60

Browse files
committed
CLN: Cleanup WorldBank reader
1 parent ccc5aa9 commit d533b60

File tree

5 files changed

+479
-186
lines changed

5 files changed

+479
-186
lines changed

pandas_datareader/_utils.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import warnings
2-
1+
import pandas as pd
32
from pandas.core.common import PandasError
43

54

@@ -10,3 +9,21 @@ class RemoteDataError(PandasError, IOError):
109
pass
1110

1211

12+
from distutils.version import LooseVersion
13+
14+
PANDAS_VERSION = LooseVersion(pd.__version__)
15+
16+
if PANDAS_VERSION >= LooseVersion('0.17.0'):
17+
PANDAS_0170 = True
18+
else:
19+
PANDAS_0170 = False
20+
21+
if PANDAS_VERSION >= LooseVersion('0.16.0'):
22+
PANDAS_0160 = True
23+
else:
24+
PANDAS_0160 = False
25+
26+
if PANDAS_VERSION >= LooseVersion('0.14.0'):
27+
PANDAS_0140 = True
28+
else:
29+
PANDAS_0140 = False

pandas_datareader/base.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from pandas import to_datetime
99
import pandas.compat as compat
10-
from pandas.core.common import PandasError
10+
from pandas.core.common import PandasError, is_number
1111
from pandas import Panel, DataFrame
1212
from pandas import read_csv
1313
from pandas.compat import StringIO, bytes_to_str
@@ -37,6 +37,7 @@ class _BaseReader(object):
3737
"""
3838

3939
_chunk_size = 1024 * 1024
40+
_format = 'string'
4041

4142
def __init__(self, symbols, start=None, end=None,
4243
retry_count=3, pause=0.001, session=None):
@@ -73,7 +74,12 @@ def read(self):
7374

7475
def _read_one_data(self, url, params):
7576
""" read one data from specified URL """
76-
out = self._read_url_as_StringIO(self.url, params=params)
77+
if self._format == 'string':
78+
out = self._read_url_as_StringIO(url, params=params)
79+
elif self._format == 'json':
80+
out = self._get_response(url, params=params).json()
81+
else:
82+
raise NotImplementedError(self._format)
7783
return self._read_lines(out)
7884

7985
def _read_url_as_StringIO(self, url, params=None):
@@ -128,8 +134,15 @@ def _sanitize_dates(self, start, end):
128134
if start is None - default is 2010/01/01
129135
if end is None - default is today
130136
"""
137+
if is_number(start):
138+
# regard int as year
139+
start = dt.datetime(start, 1, 1)
131140
start = to_datetime(start)
141+
142+
if is_number(end):
143+
end = dt.datetime(end, 1, 1)
132144
end = to_datetime(end)
145+
133146
if start is None:
134147
start = dt.datetime(2010, 1, 1)
135148
if end is None:

pandas_datareader/oecd.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ class OECDReader(_BaseReader):
99

1010
"""Get data for the given name from OECD."""
1111

12+
_format = 'json'
13+
1214
@property
1315
def url(self):
1416
url = 'http://stats.oecd.org/SDMX-JSON/data'
@@ -19,10 +21,9 @@ def url(self):
1921
# API: https://data.oecd.org/api/sdmx-json-documentation/
2022
return '{0}/{1}/all/all?'.format(url, self.symbols)
2123

22-
def _read_one_data(self, url, params):
24+
def _read_lines(self, out):
2325
""" read one data from specified URL """
24-
resp = self._get_response(url)
25-
df = read_jsdmx(resp.json())
26+
df = read_jsdmx(out)
2627
try:
2728
idx_name = df.index.name # hack for pandas 0.16.2
2829
df.index = pd.to_datetime(df.index)

pandas_datareader/tests/test_wb.py

+156-27
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,14 @@
11
import nose
2+
import time
23

3-
import pandas
4-
from pandas.util.testing import assert_frame_equal
4+
import numpy as np
5+
import pandas as pd
56
import pandas.util.testing as tm
7+
import requests
68

7-
from pandas_datareader.wb import search, download, get_countries
8-
9-
try:
10-
from pandas.compat import u
11-
except ImportError: # pragma: no cover
12-
try:
13-
unicode # python 2
14-
def u(s):
15-
return unicode(s, "unicode_escape")
16-
except NameError:
17-
def u(s):
18-
return s
9+
from pandas_datareader.wb import (search, download, get_countries,
10+
get_indicators, WorldBankReader)
11+
from pandas_datareader._utils import PANDAS_0170, PANDAS_0160, PANDAS_0140
1912

2013

2114
class TestWB(tm.TestCase):
@@ -29,6 +22,19 @@ def test_wdi_search(self):
2922
result = search('gdp.*capita.*constant')
3023
self.assertTrue(result.name.str.contains('GDP').any())
3124

25+
# check cache returns the results within 0.5 sec
26+
current_time = time.time()
27+
result = search('gdp.*capita.*constant')
28+
self.assertTrue(result.name.str.contains('GDP').any())
29+
self.assertTrue(time.time() - current_time < 0.5)
30+
31+
result2 = WorldBankReader().search('gdp.*capita.*constant')
32+
session = requests.Session()
33+
result3 = search('gdp.*capita.*constant', session=session)
34+
result4 = WorldBankReader(session=session).search('gdp.*capita.*constant')
35+
for result in [result2, result3, result4]:
36+
self.assertTrue(result.name.str.contains('GDP').any())
37+
3238
def test_wdi_download(self):
3339

3440
# Test a bad indicator with double (US), triple (USA),
@@ -43,18 +49,120 @@ def test_wdi_download(self):
4349
cntry_codes = ['CA', 'MX', 'USA', 'US', 'US', 'KSV', 'BLA']
4450
inds = ['NY.GDP.PCAP.CD','BAD.INDICATOR']
4551

46-
expected = {'NY.GDP.PCAP.CD': {('Canada', '2003'): 28026.006013044702, ('Mexico', '2003'): 6601.0420648056606, ('Canada', '2004'): 31829.522562759001, ('Kosovo', '2003'): 1969.56271307405, ('Mexico', '2004'): 7042.0247834044303, ('United States', '2004'): 41928.886136479705, ('United States', '2003'): 39682.472247320402, ('Kosovo', '2004'): 2135.3328465238301}}
47-
expected = pandas.DataFrame(expected)
52+
expected = {'NY.GDP.PCAP.CD': {('Canada', '2004'): 31829.522562759001, ('Canada', '2003'): 28026.006013044702,
53+
('Kosovo', '2004'): 2135.3328465238301, ('Kosovo', '2003'): 1969.56271307405,
54+
('Mexico', '2004'): 7042.0247834044303, ('Mexico', '2003'): 6601.0420648056606,
55+
('United States', '2004'): 41928.886136479705, ('United States', '2003'): 39682.472247320402}}
56+
expected = pd.DataFrame(expected)
4857
# Round, to ignore revisions to data.
49-
expected = pandas.np.round(expected,decimals=-3)
50-
expected.sort(inplace=True)
58+
expected = np.round(expected,decimals=-3)
59+
if PANDAS_0170:
60+
expected = expected.sort_index()
61+
else:
62+
expected = expected.sort()
63+
5164
result = download(country=cntry_codes, indicator=inds,
5265
start=2003, end=2004, errors='ignore')
53-
result.sort(inplace=True)
66+
if PANDAS_0170:
67+
result = result.sort_index()
68+
else:
69+
result = result.sort()
70+
# Round, to ignore revisions to data.
71+
result = np.round(result, decimals=-3)
72+
73+
74+
if PANDAS_0140:
75+
expected.index.names=['country', 'year']
76+
else:
77+
# prior versions doesn't allow to set multiple names to MultiIndex
78+
# Thus overwrite it with the result
79+
expected.index = result.index
80+
tm.assert_frame_equal(result, expected)
81+
82+
# pass start and end as string
83+
result = download(country=cntry_codes, indicator=inds,
84+
start='2003', end='2004', errors='ignore')
85+
if PANDAS_0170:
86+
result = result.sort_index()
87+
else:
88+
result = result.sort()
5489
# Round, to ignore revisions to data.
55-
result = pandas.np.round(result,decimals=-3)
56-
expected.index = result.index
57-
assert_frame_equal(result, pandas.DataFrame(expected))
90+
result = np.round(result, decimals=-3)
91+
tm.assert_frame_equal(result, expected)
92+
93+
def test_wdi_download_str(self):
94+
95+
expected = {'NY.GDP.PCAP.CD': {('Japan', '2004'): 36441.50449394,
96+
('Japan', '2003'): 33690.93772972,
97+
('Japan', '2002'): 31235.58818439,
98+
('Japan', '2001'): 32716.41867489,
99+
('Japan', '2000'): 37299.64412913}}
100+
expected = pd.DataFrame(expected)
101+
# Round, to ignore revisions to data.
102+
expected = np.round(expected, decimals=-3)
103+
if PANDAS_0170:
104+
expected = expected.sort_index()
105+
else:
106+
expected = expected.sort()
107+
108+
cntry_codes = 'JP'
109+
inds = 'NY.GDP.PCAP.CD'
110+
result = download(country=cntry_codes, indicator=inds,
111+
start=2000, end=2004, errors='ignore')
112+
if PANDAS_0170:
113+
result = result.sort_index()
114+
else:
115+
result = result.sort()
116+
result = np.round(result, decimals=-3)
117+
118+
if PANDAS_0140:
119+
expected.index.names=['country', 'year']
120+
else:
121+
# prior versions doesn't allow to set multiple names to MultiIndex
122+
# Thus overwrite it with the result
123+
expected.index = result.index
124+
125+
tm.assert_frame_equal(result, expected)
126+
127+
result = WorldBankReader(inds, countries=cntry_codes,
128+
start=2000, end=2004, errors='ignore').read()
129+
if PANDAS_0170:
130+
result = result.sort_index()
131+
else:
132+
result = result.sort()
133+
result = np.round(result, decimals=-3)
134+
tm.assert_frame_equal(result, expected)
135+
136+
def test_wdi_download_error_handling(self):
137+
cntry_codes = ['USA', 'XX']
138+
inds = 'NY.GDP.PCAP.CD'
139+
140+
with tm.assertRaisesRegexp(ValueError, "Invalid Country Code\\(s\\): XX"):
141+
result = download(country=cntry_codes, indicator=inds,
142+
start=2003, end=2004, errors='raise')
143+
144+
if PANDAS_0160:
145+
# assert_produces_warning doesn't exists in prior versions
146+
with self.assert_produces_warning():
147+
result = download(country=cntry_codes, indicator=inds,
148+
start=2003, end=2004, errors='warn')
149+
self.assertTrue(isinstance(result, pd.DataFrame))
150+
self.assertEqual(len(result), 2)
151+
152+
cntry_codes = ['USA']
153+
inds = ['NY.GDP.PCAP.CD', 'BAD_INDICATOR']
154+
155+
with tm.assertRaisesRegexp(ValueError, "The provided parameter value is not valid\\. Indicator: BAD_INDICATOR"):
156+
result = download(country=cntry_codes, indicator=inds,
157+
start=2003, end=2004, errors='raise')
158+
159+
if PANDAS_0160:
160+
with self.assert_produces_warning():
161+
result = download(country=cntry_codes, indicator=inds,
162+
start=2003, end=2004, errors='warn')
163+
self.assertTrue(isinstance(result, pd.DataFrame))
164+
self.assertEqual(len(result), 2)
165+
58166

59167
def test_wdi_download_w_retired_indicator(self):
60168

@@ -101,11 +209,32 @@ def test_wdi_download_w_crash_inducing_countrycode(self):
101209
raise nose.SkipTest("Invalid results")
102210

103211
def test_wdi_get_countries(self):
104-
result = get_countries()
105-
self.assertTrue('Zimbabwe' in list(result['name']))
106-
self.assertTrue(len(result) > 100)
107-
self.assertTrue(pandas.notnull(result.latitude.mean()))
108-
self.assertTrue(pandas.notnull(result.longitude.mean()))
212+
result1 = get_countries()
213+
result2 = WorldBankReader().get_countries()
214+
215+
session = requests.Session()
216+
result3 = get_countries(session=session)
217+
result4 = WorldBankReader(session=session).get_countries()
218+
219+
for result in [result1, result2]:
220+
self.assertTrue('Zimbabwe' in list(result['name']))
221+
self.assertTrue(len(result) > 100)
222+
self.assertTrue(pd.notnull(result.latitude.mean()))
223+
self.assertTrue(pd.notnull(result.longitude.mean()))
224+
225+
def test_wdi_get_indicators(self):
226+
result1 = get_indicators()
227+
result2 = WorldBankReader().get_indicators()
228+
229+
session = requests.Session()
230+
result3 = get_indicators(session=session)
231+
result4 = WorldBankReader(session=session).get_indicators()
232+
233+
for result in [result1, result2, result3, result4]:
234+
exp_col = pd.Index(['id', 'name', 'source', 'sourceNote', 'sourceOrganization', 'topics'])
235+
# assert_index_equal doesn't exists
236+
self.assertTrue(result.columns.equals(exp_col))
237+
self.assertTrue(len(result) > 10000)
109238

110239

111240
if __name__ == '__main__':

0 commit comments

Comments
 (0)