Skip to content

Commit e08e606

Browse files
Fix mapping for choice values (#8968)
* fix mapping for choice values * fix tests for ChoiceField IntegerChoices * fix imports * fix imports in tests * Check for multiple choice enums * fix formatting * add tests for text choices class
1 parent d14eb75 commit e08e606

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

rest_framework/fields.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
MinValueValidator, ProhibitNullCharactersValidator, RegexValidator,
1717
URLValidator, ip_address_validators
1818
)
19+
from django.db.models import IntegerChoices, TextChoices
1920
from django.forms import FilePathField as DjangoFilePathField
2021
from django.forms import ImageField as DjangoImageField
2122
from django.utils import timezone
@@ -1397,6 +1398,10 @@ def to_internal_value(self, data):
13971398
if data == '' and self.allow_blank:
13981399
return ''
13991400

1401+
if isinstance(data, (IntegerChoices, TextChoices)) and str(data) != \
1402+
str(data.value):
1403+
data = data.value
1404+
14001405
try:
14011406
return self.choice_strings_to_values[str(data)]
14021407
except KeyError:
@@ -1405,6 +1410,11 @@ def to_internal_value(self, data):
14051410
def to_representation(self, value):
14061411
if value in ('', None):
14071412
return value
1413+
1414+
if isinstance(value, (IntegerChoices, TextChoices)) and str(value) != \
1415+
str(value.value):
1416+
value = value.value
1417+
14081418
return self.choice_strings_to_values.get(str(value), value)
14091419

14101420
def iter_options(self):
@@ -1428,7 +1438,8 @@ def _set_choices(self, choices):
14281438
# Allows us to deal with eg. integer choices while supporting either
14291439
# integer or string input, but still get the correct datatype out.
14301440
self.choice_strings_to_values = {
1431-
str(key): key for key in self.choices
1441+
str(key.value) if isinstance(key, (IntegerChoices, TextChoices))
1442+
and str(key) != str(key.value) else str(key): key for key in self.choices
14321443
}
14331444

14341445
choices = property(_get_choices, _set_choices)

tests/test_fields.py

+50
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
import sys
66
import uuid
77
from decimal import ROUND_DOWN, ROUND_UP, Decimal
8+
from enum import auto
89
from unittest.mock import patch
910

1011
import pytest
1112
import pytz
1213
from django.core.exceptions import ValidationError as DjangoValidationError
14+
from django.db.models import IntegerChoices, TextChoices
1315
from django.http import QueryDict
1416
from django.test import TestCase, override_settings
1517
from django.utils.timezone import activate, deactivate, override
@@ -1824,6 +1826,54 @@ def test_edit_choices(self):
18241826
field.run_validation(2)
18251827
assert exc_info.value.detail == ['"2" is not a valid choice.']
18261828

1829+
def test_integer_choices(self):
1830+
class ChoiceCase(IntegerChoices):
1831+
first = auto()
1832+
second = auto()
1833+
# Enum validate
1834+
choices = [
1835+
(ChoiceCase.first, "1"),
1836+
(ChoiceCase.second, "2")
1837+
]
1838+
1839+
field = serializers.ChoiceField(choices=choices)
1840+
assert field.run_validation(1) == 1
1841+
assert field.run_validation(ChoiceCase.first) == 1
1842+
assert field.run_validation("1") == 1
1843+
1844+
choices = [
1845+
(ChoiceCase.first.value, "1"),
1846+
(ChoiceCase.second.value, "2")
1847+
]
1848+
1849+
field = serializers.ChoiceField(choices=choices)
1850+
assert field.run_validation(1) == 1
1851+
assert field.run_validation(ChoiceCase.first) == 1
1852+
assert field.run_validation("1") == 1
1853+
1854+
def test_text_choices(self):
1855+
class ChoiceCase(TextChoices):
1856+
first = auto()
1857+
second = auto()
1858+
# Enum validate
1859+
choices = [
1860+
(ChoiceCase.first, "first"),
1861+
(ChoiceCase.second, "second")
1862+
]
1863+
1864+
field = serializers.ChoiceField(choices=choices)
1865+
assert field.run_validation(ChoiceCase.first) == "first"
1866+
assert field.run_validation("first") == "first"
1867+
1868+
choices = [
1869+
(ChoiceCase.first.value, "first"),
1870+
(ChoiceCase.second.value, "second")
1871+
]
1872+
1873+
field = serializers.ChoiceField(choices=choices)
1874+
assert field.run_validation(ChoiceCase.first) == "first"
1875+
assert field.run_validation("first") == "first"
1876+
18271877

18281878
class TestChoiceFieldWithType(FieldValues):
18291879
"""

0 commit comments

Comments
 (0)