Skip to content

Commit 7952e31

Browse files
committed
test IgnoreErrors
(cherry picked from commit ac6763f)
1 parent e093299 commit 7952e31

File tree

1 file changed

+140
-0
lines changed

1 file changed

+140
-0
lines changed

tests/test_query.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
1616
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
1717

18+
import contextlib
1819
import socket
1920
import sys
2021
import time
@@ -32,6 +33,7 @@
3233
import dns.message
3334
import dns.name
3435
import dns.query
36+
import dns.rcode
3537
import dns.rdataclass
3638
import dns.rdatatype
3739
import dns.tsigkeyring
@@ -659,3 +661,141 @@ def test_matches_destination(self):
659661
dns.query._matches_destination(
660662
socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1235), False
661663
)
664+
665+
666+
@contextlib.contextmanager
667+
def mock_udp_recv(wire1, from1, wire2, from2):
668+
saved = dns.query._udp_recv
669+
first_time = True
670+
671+
def mock(sock, max_size, expiration):
672+
nonlocal first_time
673+
if first_time:
674+
first_time = False
675+
return wire1, from1
676+
else:
677+
return wire2, from2
678+
679+
try:
680+
dns.query._udp_recv = mock
681+
yield None
682+
finally:
683+
dns.query._udp_recv = saved
684+
685+
686+
class IgnoreErrors(unittest.TestCase):
687+
def setUp(self):
688+
self.q = dns.message.make_query("example.", "A")
689+
self.good_r = dns.message.make_response(self.q)
690+
self.good_r.set_rcode(dns.rcode.NXDOMAIN)
691+
self.good_r_wire = self.good_r.to_wire()
692+
693+
def mock_receive(
694+
self,
695+
wire1,
696+
from1,
697+
wire2,
698+
from2,
699+
ignore_unexpected=True,
700+
ignore_errors=True,
701+
):
702+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
703+
try:
704+
with mock_udp_recv(wire1, from1, wire2, from2):
705+
(r, when) = dns.query.receive_udp(
706+
s,
707+
("127.0.0.1", 53),
708+
time.time() + 2,
709+
ignore_unexpected=ignore_unexpected,
710+
ignore_errors=ignore_errors,
711+
query=self.q,
712+
)
713+
self.assertEqual(r, self.good_r)
714+
finally:
715+
s.close()
716+
717+
def test_good_mock(self):
718+
self.mock_receive(self.good_r_wire, ("127.0.0.1", 53), None, None)
719+
720+
def test_bad_address(self):
721+
self.mock_receive(
722+
self.good_r_wire, ("127.0.0.2", 53), self.good_r_wire, ("127.0.0.1", 53)
723+
)
724+
725+
def test_bad_address_not_ignored(self):
726+
def bad():
727+
self.mock_receive(
728+
self.good_r_wire,
729+
("127.0.0.2", 53),
730+
self.good_r_wire,
731+
("127.0.0.1", 53),
732+
ignore_unexpected=False,
733+
)
734+
735+
self.assertRaises(dns.query.UnexpectedSource, bad)
736+
737+
def test_bad_id(self):
738+
bad_r = dns.message.make_response(self.q)
739+
bad_r.id += 1
740+
bad_r_wire = bad_r.to_wire()
741+
self.mock_receive(
742+
bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
743+
)
744+
745+
def test_bad_id_not_ignored(self):
746+
bad_r = dns.message.make_response(self.q)
747+
bad_r.id += 1
748+
bad_r_wire = bad_r.to_wire()
749+
750+
def bad():
751+
(r, wire) = self.mock_receive(
752+
bad_r_wire,
753+
("127.0.0.1", 53),
754+
self.good_r_wire,
755+
("127.0.0.1", 53),
756+
ignore_errors=False,
757+
)
758+
759+
self.assertRaises(AssertionError, bad)
760+
761+
def test_bad_wire(self):
762+
bad_r = dns.message.make_response(self.q)
763+
bad_r.id += 1
764+
bad_r_wire = bad_r.to_wire()
765+
self.mock_receive(
766+
bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
767+
)
768+
769+
def test_bad_wire_not_ignored(self):
770+
bad_r = dns.message.make_response(self.q)
771+
bad_r.id += 1
772+
bad_r_wire = bad_r.to_wire()
773+
774+
def bad():
775+
self.mock_receive(
776+
bad_r_wire[:10],
777+
("127.0.0.1", 53),
778+
self.good_r_wire,
779+
("127.0.0.1", 53),
780+
ignore_errors=False,
781+
)
782+
783+
self.assertRaises(dns.message.ShortHeader, bad)
784+
785+
def test_trailing_wire(self):
786+
wire = self.good_r_wire + b"abcd"
787+
self.mock_receive(wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53))
788+
789+
def test_trailing_wire_not_ignored(self):
790+
wire = self.good_r_wire + b"abcd"
791+
792+
def bad():
793+
self.mock_receive(
794+
wire,
795+
("127.0.0.1", 53),
796+
self.good_r_wire,
797+
("127.0.0.1", 53),
798+
ignore_errors=False,
799+
)
800+
801+
self.assertRaises(dns.message.TrailingJunk, bad)

0 commit comments

Comments
 (0)