1
1
import nose
2
+ import time
2
3
3
- import pandas
4
- from pandas . util . testing import assert_frame_equal
4
+ import numpy as np
5
+ import pandas as pd
5
6
import pandas .util .testing as tm
7
+ import requests
6
8
7
- from pandas_datareader .wb import search , download , get_countries
8
-
9
- try :
10
- from pandas .compat import u
11
- except ImportError : # pragma: no cover
12
- try :
13
- unicode # python 2
14
- def u (s ):
15
- return unicode (s , "unicode_escape" )
16
- except NameError :
17
- def u (s ):
18
- return s
9
+ from pandas_datareader .wb import (search , download , get_countries ,
10
+ get_indicators , WorldBankReader )
11
+ from pandas_datareader ._utils import PANDAS_0170 , PANDAS_0160 , PANDAS_0140
19
12
20
13
21
14
class TestWB (tm .TestCase ):
@@ -29,6 +22,19 @@ def test_wdi_search(self):
29
22
result = search ('gdp.*capita.*constant' )
30
23
self .assertTrue (result .name .str .contains ('GDP' ).any ())
31
24
25
+ # check cache returns the results within 0.5 sec
26
+ current_time = time .time ()
27
+ result = search ('gdp.*capita.*constant' )
28
+ self .assertTrue (result .name .str .contains ('GDP' ).any ())
29
+ self .assertTrue (time .time () - current_time < 0.5 )
30
+
31
+ result2 = WorldBankReader ().search ('gdp.*capita.*constant' )
32
+ session = requests .Session ()
33
+ result3 = search ('gdp.*capita.*constant' , session = session )
34
+ result4 = WorldBankReader (session = session ).search ('gdp.*capita.*constant' )
35
+ for result in [result2 , result3 , result4 ]:
36
+ self .assertTrue (result .name .str .contains ('GDP' ).any ())
37
+
32
38
def test_wdi_download (self ):
33
39
34
40
# Test a bad indicator with double (US), triple (USA),
@@ -43,18 +49,120 @@ def test_wdi_download(self):
43
49
cntry_codes = ['CA' , 'MX' , 'USA' , 'US' , 'US' , 'KSV' , 'BLA' ]
44
50
inds = ['NY.GDP.PCAP.CD' ,'BAD.INDICATOR' ]
45
51
46
- expected = {'NY.GDP.PCAP.CD' : {('Canada' , '2003' ): 28026.006013044702 , ('Mexico' , '2003' ): 6601.0420648056606 , ('Canada' , '2004' ): 31829.522562759001 , ('Kosovo' , '2003' ): 1969.56271307405 , ('Mexico' , '2004' ): 7042.0247834044303 , ('United States' , '2004' ): 41928.886136479705 , ('United States' , '2003' ): 39682.472247320402 , ('Kosovo' , '2004' ): 2135.3328465238301 }}
47
- expected = pandas .DataFrame (expected )
52
+ expected = {'NY.GDP.PCAP.CD' : {('Canada' , '2004' ): 31829.522562759001 , ('Canada' , '2003' ): 28026.006013044702 ,
53
+ ('Kosovo' , '2004' ): 2135.3328465238301 , ('Kosovo' , '2003' ): 1969.56271307405 ,
54
+ ('Mexico' , '2004' ): 7042.0247834044303 , ('Mexico' , '2003' ): 6601.0420648056606 ,
55
+ ('United States' , '2004' ): 41928.886136479705 , ('United States' , '2003' ): 39682.472247320402 }}
56
+ expected = pd .DataFrame (expected )
48
57
# Round, to ignore revisions to data.
49
- expected = pandas .np .round (expected ,decimals = - 3 )
50
- expected .sort (inplace = True )
58
+ expected = np .round (expected ,decimals = - 3 )
59
+ if PANDAS_0170 :
60
+ expected = expected .sort_index ()
61
+ else :
62
+ expected = expected .sort ()
63
+
51
64
result = download (country = cntry_codes , indicator = inds ,
52
65
start = 2003 , end = 2004 , errors = 'ignore' )
53
- result .sort (inplace = True )
66
+ if PANDAS_0170 :
67
+ result = result .sort_index ()
68
+ else :
69
+ result = result .sort ()
70
+ # Round, to ignore revisions to data.
71
+ result = np .round (result , decimals = - 3 )
72
+
73
+
74
+ if PANDAS_0140 :
75
+ expected .index .names = ['country' , 'year' ]
76
+ else :
77
+ # prior versions doesn't allow to set multiple names to MultiIndex
78
+ # Thus overwrite it with the result
79
+ expected .index = result .index
80
+ tm .assert_frame_equal (result , expected )
81
+
82
+ # pass start and end as string
83
+ result = download (country = cntry_codes , indicator = inds ,
84
+ start = '2003' , end = '2004' , errors = 'ignore' )
85
+ if PANDAS_0170 :
86
+ result = result .sort_index ()
87
+ else :
88
+ result = result .sort ()
54
89
# Round, to ignore revisions to data.
55
- result = pandas .np .round (result ,decimals = - 3 )
56
- expected .index = result .index
57
- assert_frame_equal (result , pandas .DataFrame (expected ))
90
+ result = np .round (result , decimals = - 3 )
91
+ tm .assert_frame_equal (result , expected )
92
+
93
+ def test_wdi_download_str (self ):
94
+
95
+ expected = {'NY.GDP.PCAP.CD' : {('Japan' , '2004' ): 36441.50449394 ,
96
+ ('Japan' , '2003' ): 33690.93772972 ,
97
+ ('Japan' , '2002' ): 31235.58818439 ,
98
+ ('Japan' , '2001' ): 32716.41867489 ,
99
+ ('Japan' , '2000' ): 37299.64412913 }}
100
+ expected = pd .DataFrame (expected )
101
+ # Round, to ignore revisions to data.
102
+ expected = np .round (expected , decimals = - 3 )
103
+ if PANDAS_0170 :
104
+ expected = expected .sort_index ()
105
+ else :
106
+ expected = expected .sort ()
107
+
108
+ cntry_codes = 'JP'
109
+ inds = 'NY.GDP.PCAP.CD'
110
+ result = download (country = cntry_codes , indicator = inds ,
111
+ start = 2000 , end = 2004 , errors = 'ignore' )
112
+ if PANDAS_0170 :
113
+ result = result .sort_index ()
114
+ else :
115
+ result = result .sort ()
116
+ result = np .round (result , decimals = - 3 )
117
+
118
+ if PANDAS_0140 :
119
+ expected .index .names = ['country' , 'year' ]
120
+ else :
121
+ # prior versions doesn't allow to set multiple names to MultiIndex
122
+ # Thus overwrite it with the result
123
+ expected .index = result .index
124
+
125
+ tm .assert_frame_equal (result , expected )
126
+
127
+ result = WorldBankReader (inds , countries = cntry_codes ,
128
+ start = 2000 , end = 2004 , errors = 'ignore' ).read ()
129
+ if PANDAS_0170 :
130
+ result = result .sort_index ()
131
+ else :
132
+ result = result .sort ()
133
+ result = np .round (result , decimals = - 3 )
134
+ tm .assert_frame_equal (result , expected )
135
+
136
+ def test_wdi_download_error_handling (self ):
137
+ cntry_codes = ['USA' , 'XX' ]
138
+ inds = 'NY.GDP.PCAP.CD'
139
+
140
+ with tm .assertRaisesRegexp (ValueError , "Invalid Country Code\\ (s\\ ): XX" ):
141
+ result = download (country = cntry_codes , indicator = inds ,
142
+ start = 2003 , end = 2004 , errors = 'raise' )
143
+
144
+ if PANDAS_0160 :
145
+ # assert_produces_warning doesn't exists in prior versions
146
+ with self .assert_produces_warning ():
147
+ result = download (country = cntry_codes , indicator = inds ,
148
+ start = 2003 , end = 2004 , errors = 'warn' )
149
+ self .assertTrue (isinstance (result , pd .DataFrame ))
150
+ self .assertEqual (len (result ), 2 )
151
+
152
+ cntry_codes = ['USA' ]
153
+ inds = ['NY.GDP.PCAP.CD' , 'BAD_INDICATOR' ]
154
+
155
+ with tm .assertRaisesRegexp (ValueError , "The provided parameter value is not valid\\ . Indicator: BAD_INDICATOR" ):
156
+ result = download (country = cntry_codes , indicator = inds ,
157
+ start = 2003 , end = 2004 , errors = 'raise' )
158
+
159
+ if PANDAS_0160 :
160
+ with self .assert_produces_warning ():
161
+ result = download (country = cntry_codes , indicator = inds ,
162
+ start = 2003 , end = 2004 , errors = 'warn' )
163
+ self .assertTrue (isinstance (result , pd .DataFrame ))
164
+ self .assertEqual (len (result ), 2 )
165
+
58
166
59
167
def test_wdi_download_w_retired_indicator (self ):
60
168
@@ -101,11 +209,32 @@ def test_wdi_download_w_crash_inducing_countrycode(self):
101
209
raise nose .SkipTest ("Invalid results" )
102
210
103
211
def test_wdi_get_countries (self ):
104
- result = get_countries ()
105
- self .assertTrue ('Zimbabwe' in list (result ['name' ]))
106
- self .assertTrue (len (result ) > 100 )
107
- self .assertTrue (pandas .notnull (result .latitude .mean ()))
108
- self .assertTrue (pandas .notnull (result .longitude .mean ()))
212
+ result1 = get_countries ()
213
+ result2 = WorldBankReader ().get_countries ()
214
+
215
+ session = requests .Session ()
216
+ result3 = get_countries (session = session )
217
+ result4 = WorldBankReader (session = session ).get_countries ()
218
+
219
+ for result in [result1 , result2 ]:
220
+ self .assertTrue ('Zimbabwe' in list (result ['name' ]))
221
+ self .assertTrue (len (result ) > 100 )
222
+ self .assertTrue (pd .notnull (result .latitude .mean ()))
223
+ self .assertTrue (pd .notnull (result .longitude .mean ()))
224
+
225
+ def test_wdi_get_indicators (self ):
226
+ result1 = get_indicators ()
227
+ result2 = WorldBankReader ().get_indicators ()
228
+
229
+ session = requests .Session ()
230
+ result3 = get_indicators (session = session )
231
+ result4 = WorldBankReader (session = session ).get_indicators ()
232
+
233
+ for result in [result1 , result2 , result3 , result4 ]:
234
+ exp_col = pd .Index (['id' , 'name' , 'source' , 'sourceNote' , 'sourceOrganization' , 'topics' ])
235
+ # assert_index_equal doesn't exists
236
+ self .assertTrue (result .columns .equals (exp_col ))
237
+ self .assertTrue (len (result ) > 10000 )
109
238
110
239
111
240
if __name__ == '__main__' :
0 commit comments