Skip to content

Commit 2abb2f2

Browse files
authored
Immutable QueryParams (#1600)
* Tweak QueryParams implementation * Immutable QueryParams
1 parent 8fe32c5 commit 2abb2f2

File tree

4 files changed

+188
-67
lines changed

4 files changed

+188
-67
lines changed

httpx/_client.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def _merge_queryparams(
385385
"""
386386
if params or self.params:
387387
merged_queryparams = QueryParams(self.params)
388-
merged_queryparams.update(params)
388+
merged_queryparams = merged_queryparams.merge(params)
389389
return merged_queryparams
390390
return params
391391

httpx/_models.py

+144-28
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import urllib.request
77
from collections.abc import MutableMapping
88
from http.cookiejar import Cookie, CookieJar
9-
from urllib.parse import parse_qsl, quote, unquote, urlencode
9+
from urllib.parse import parse_qs, quote, unquote, urlencode
1010

1111
import idna
1212
import rfc3986
@@ -48,7 +48,6 @@
4848
URLTypes,
4949
)
5050
from ._utils import (
51-
flatten_queryparams,
5251
guess_json_utf,
5352
is_known_encoding,
5453
normalize_header_key,
@@ -148,8 +147,7 @@ def __init__(
148147
# Add any query parameters, merging with any in the URL if needed.
149148
if params:
150149
if self._uri_reference.query:
151-
url_params = QueryParams(self._uri_reference.query)
152-
url_params.update(params)
150+
url_params = QueryParams(self._uri_reference.query).merge(params)
153151
query_string = str(url_params)
154152
else:
155153
query_string = str(QueryParams(params))
@@ -450,7 +448,7 @@ def join(self, url: URLTypes) -> "URL":
450448
451449
url = httpx.URL("https://www.example.com/test")
452450
url = url.join("/new/path")
453-
assert url == "https://www.example.com/test/new/path"
451+
assert url == "https://www.example.com/new/path"
454452
"""
455453
if self.is_relative_url:
456454
# Workaround to handle relative URLs, which otherwise raise
@@ -504,38 +502,79 @@ def __init__(self, *args: QueryParamTypes, **kwargs: typing.Any) -> None:
504502
items: typing.Sequence[typing.Tuple[str, PrimitiveData]]
505503
if value is None or isinstance(value, (str, bytes)):
506504
value = value.decode("ascii") if isinstance(value, bytes) else value
507-
items = parse_qsl(value)
505+
self._dict = parse_qs(value)
508506
elif isinstance(value, QueryParams):
509-
items = value.multi_items()
510-
elif isinstance(value, (list, tuple)):
511-
items = value
507+
self._dict = {k: list(v) for k, v in value._dict.items()}
512508
else:
513-
items = flatten_queryparams(value)
514-
515-
self._dict: typing.Dict[str, typing.List[str]] = {}
516-
for item in items:
517-
k, v = item
518-
if str(k) not in self._dict:
519-
self._dict[str(k)] = [primitive_value_to_str(v)]
509+
dict_value: typing.Dict[typing.Any, typing.List[typing.Any]] = {}
510+
if isinstance(value, (list, tuple)):
511+
# Convert list inputs like:
512+
# [("a", "123"), ("a", "456"), ("b", "789")]
513+
# To a dict representation, like:
514+
# {"a": ["123", "456"], "b": ["789"]}
515+
for item in value:
516+
dict_value.setdefault(item[0], []).append(item[1])
520517
else:
521-
self._dict[str(k)].append(primitive_value_to_str(v))
518+
# Convert dict inputs like:
519+
# {"a": "123", "b": ["456", "789"]}
520+
# To dict inputs where values are always lists, like:
521+
# {"a": ["123"], "b": ["456", "789"]}
522+
dict_value = {
523+
k: list(v) if isinstance(v, (list, tuple)) else [v]
524+
for k, v in value.items()
525+
}
526+
527+
# Ensure that keys and values are neatly coerced to strings.
528+
# We coerce values `True` and `False` to JSON-like "true" and "false"
529+
# representations, and coerce `None` values to the empty string.
530+
self._dict = {
531+
str(k): [primitive_value_to_str(item) for item in v]
532+
for k, v in dict_value.items()
533+
}
522534

523535
def keys(self) -> typing.KeysView:
536+
"""
537+
Return all the keys in the query params.
538+
539+
Usage:
540+
541+
q = httpx.QueryParams("a=123&a=456&b=789")
542+
assert list(q.keys()) == ["a", "b"]
543+
"""
524544
return self._dict.keys()
525545

526546
def values(self) -> typing.ValuesView:
547+
"""
548+
Return all the values in the query params. If a key occurs more than once
549+
only the first item for that key is returned.
550+
551+
Usage:
552+
553+
q = httpx.QueryParams("a=123&a=456&b=789")
554+
assert list(q.values()) == ["123", "789"]
555+
"""
527556
return {k: v[0] for k, v in self._dict.items()}.values()
528557

529558
def items(self) -> typing.ItemsView:
530559
"""
531560
Return all items in the query params. If a key occurs more than once
532561
only the first item for that key is returned.
562+
563+
Usage:
564+
565+
q = httpx.QueryParams("a=123&a=456&b=789")
566+
assert list(q.items()) == [("a", "123"), ("b", "789")]
533567
"""
534568
return {k: v[0] for k, v in self._dict.items()}.items()
535569

536570
def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
537571
"""
538572
Return all items in the query params. Allow duplicate keys to occur.
573+
574+
Usage:
575+
576+
q = httpx.QueryParams("a=123&a=456&b=789")
577+
assert list(q.multi_items()) == [("a", "123"), ("a", "456"), ("b", "789")]
539578
"""
540579
multi_items: typing.List[typing.Tuple[str, str]] = []
541580
for k, v in self._dict.items():
@@ -546,31 +585,93 @@ def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
546585
"""
547586
Get a value from the query param for a given key. If the key occurs
548587
more than once, then only the first value is returned.
588+
589+
Usage:
590+
591+
q = httpx.QueryParams("a=123&a=456&b=789")
592+
assert q.get("a") == "123"
549593
"""
550594
if key in self._dict:
551-
return self._dict[key][0]
595+
return self._dict[str(key)][0]
552596
return default
553597

554598
def get_list(self, key: typing.Any) -> typing.List[str]:
555599
"""
556600
Get all values from the query param for a given key.
601+
602+
Usage:
603+
604+
q = httpx.QueryParams("a=123&a=456&b=789")
605+
assert q.get_list("a") == ["123", "456"]
557606
"""
558-
return list(self._dict.get(key, []))
607+
return list(self._dict.get(str(key), []))
559608

560-
def update(self, params: QueryParamTypes = None) -> None:
561-
if not params:
562-
return
609+
def set(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
610+
"""
611+
Return a new QueryParams instance, setting the value of a key.
612+
613+
Usage:
614+
615+
q = httpx.QueryParams("a=123")
616+
q = q.set("a", "456")
617+
assert q == httpx.QueryParams("a=456")
618+
"""
619+
q = QueryParams()
620+
q._dict = dict(self._dict)
621+
q._dict[str(key)] = [primitive_value_to_str(value)]
622+
return q
623+
624+
def add(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
625+
"""
626+
Return a new QueryParams instance, setting or appending the value of a key.
563627
564-
params = QueryParams(params)
565-
for k in params.keys():
566-
self._dict[k] = params.get_list(k)
628+
Usage:
629+
630+
q = httpx.QueryParams("a=123")
631+
q = q.add("a", "456")
632+
assert q == httpx.QueryParams("a=123&a=456")
633+
"""
634+
q = QueryParams()
635+
q._dict = dict(self._dict)
636+
q._dict[str(key)] = q.get_list(key) + [primitive_value_to_str(value)]
637+
return q
638+
639+
def remove(self, key: typing.Any) -> "QueryParams":
640+
"""
641+
Return a new QueryParams instance, removing the value of a key.
642+
643+
Usage:
644+
645+
q = httpx.QueryParams("a=123")
646+
q = q.remove("a")
647+
assert q == httpx.QueryParams("")
648+
"""
649+
q = QueryParams()
650+
q._dict = dict(self._dict)
651+
q._dict.pop(str(key), None)
652+
return q
653+
654+
def merge(self, params: QueryParamTypes = None) -> "QueryParams":
655+
"""
656+
Return a new QueryParams instance, updated with.
657+
658+
Usage:
659+
660+
q = httpx.QueryParams("a=123")
661+
q = q.merge({"b": "456"})
662+
assert q == httpx.QueryParams("a=123&b=456")
663+
664+
q = httpx.QueryParams("a=123")
665+
q = q.merge({"a": "456", "b": "789"})
666+
assert q == httpx.QueryParams("a=456&b=789")
667+
"""
668+
q = QueryParams(params)
669+
q._dict = {**self._dict, **q._dict}
670+
return q
567671

568672
def __getitem__(self, key: typing.Any) -> str:
569673
return self._dict[key][0]
570674

571-
def __setitem__(self, key: str, value: str) -> None:
572-
self._dict[key] = [value]
573-
574675
def __contains__(self, key: typing.Any) -> bool:
575676
return key in self._dict
576677

@@ -580,6 +681,9 @@ def __iter__(self) -> typing.Iterator[typing.Any]:
580681
def __len__(self) -> int:
581682
return len(self._dict)
582683

684+
def __hash__(self) -> int:
685+
return hash(str(self))
686+
583687
def __eq__(self, other: typing.Any) -> bool:
584688
if not isinstance(other, self.__class__):
585689
return False
@@ -593,6 +697,18 @@ def __repr__(self) -> str:
593697
query_string = str(self)
594698
return f"{class_name}({query_string!r})"
595699

700+
def update(self, params: QueryParamTypes = None) -> None:
701+
raise RuntimeError(
702+
"QueryParams are immutable since 0.18.0. "
703+
"Use `q = q.merge(...)` to create an updated copy."
704+
)
705+
706+
def __setitem__(self, key: str, value: str) -> None:
707+
raise RuntimeError(
708+
"QueryParams are immutable since 0.18.0. "
709+
"Use `q = q.set(key, value)` to create an updated copy."
710+
)
711+
596712

597713
class Headers(typing.MutableMapping[str, str]):
598714
"""

httpx/_utils.py

-26
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import codecs
2-
import collections
32
import logging
43
import mimetypes
54
import netrc
@@ -369,31 +368,6 @@ def peek_filelike_length(stream: typing.IO) -> int:
369368
return os.fstat(fd).st_size
370369

371370

372-
def flatten_queryparams(
373-
queryparams: typing.Mapping[
374-
str, typing.Union["PrimitiveData", typing.Sequence["PrimitiveData"]]
375-
]
376-
) -> typing.List[typing.Tuple[str, "PrimitiveData"]]:
377-
"""
378-
Convert a mapping of query params into a flat list of two-tuples
379-
representing each item.
380-
381-
Example:
382-
>>> flatten_queryparams_values({"q": "httpx", "tag": ["python", "dev"]})
383-
[("q", "httpx), ("tag", "python"), ("tag", "dev")]
384-
"""
385-
items = []
386-
387-
for k, v in queryparams.items():
388-
if isinstance(v, collections.abc.Sequence) and not isinstance(v, (str, bytes)):
389-
for u in v:
390-
items.append((k, u))
391-
else:
392-
items.append((k, typing.cast("PrimitiveData", v)))
393-
394-
return items
395-
396-
397371
class Timer:
398372
async def _get_time(self) -> float:
399373
library = sniffio.current_async_library()

tests/models/test_queryparams.py

+43-12
Original file line numberDiff line numberDiff line change
@@ -76,19 +76,50 @@ def test_queryparam_types():
7676
assert str(q) == "a=1&a=2"
7777

7878

79-
def test_queryparam_setters():
80-
q = httpx.QueryParams({"a": 1})
81-
q.update([])
79+
def test_queryparam_update_is_hard_deprecated():
80+
q = httpx.QueryParams("a=123")
81+
with pytest.raises(RuntimeError):
82+
q.update({"a": "456"})
8283

83-
assert str(q) == "a=1"
8484

85-
q = httpx.QueryParams([("a", 1), ("a", 2)])
86-
q["a"] = "3"
87-
assert str(q) == "a=3"
85+
def test_queryparam_setter_is_hard_deprecated():
86+
q = httpx.QueryParams("a=123")
87+
with pytest.raises(RuntimeError):
88+
q["a"] = "456"
8889

89-
q = httpx.QueryParams([("a", 1), ("b", 1)])
90-
u = httpx.QueryParams([("b", 2), ("b", 3)])
91-
q.update(u)
9290

93-
assert str(q) == "a=1&b=2&b=3"
94-
assert q["b"] == u["b"]
91+
def test_queryparam_set():
92+
q = httpx.QueryParams("a=123")
93+
q = q.set("a", "456")
94+
assert q == httpx.QueryParams("a=456")
95+
96+
97+
def test_queryparam_add():
98+
q = httpx.QueryParams("a=123")
99+
q = q.add("a", "456")
100+
assert q == httpx.QueryParams("a=123&a=456")
101+
102+
103+
def test_queryparam_remove():
104+
q = httpx.QueryParams("a=123")
105+
q = q.remove("a")
106+
assert q == httpx.QueryParams("")
107+
108+
109+
def test_queryparam_merge():
110+
q = httpx.QueryParams("a=123")
111+
q = q.merge({"b": "456"})
112+
assert q == httpx.QueryParams("a=123&b=456")
113+
q = q.merge({"a": "000", "c": "789"})
114+
assert q == httpx.QueryParams("a=000&b=456&c=789")
115+
116+
117+
def test_queryparams_are_hashable():
118+
params = (
119+
httpx.QueryParams("a=123"),
120+
httpx.QueryParams({"a": 123}),
121+
httpx.QueryParams("b=456"),
122+
httpx.QueryParams({"b": 456}),
123+
)
124+
125+
assert len(set(params)) == 2

0 commit comments

Comments
 (0)