Skip to content

Commit cecb853

Browse files
committed
Further improve CVE fix coverage to 100% for sync and async.
(cherry picked from commit a1a9989)
1 parent 7952e31 commit cecb853

File tree

2 files changed

+204
-1
lines changed

2 files changed

+204
-1
lines changed

tests/test_async.py

Lines changed: 183 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import asyncio
1919
import random
2020
import socket
21-
import sys
2221
import time
2322
import unittest
2423

@@ -28,6 +27,7 @@
2827
import dns.message
2928
import dns.name
3029
import dns.query
30+
import dns.rcode
3131
import dns.rdataclass
3232
import dns.rdatatype
3333
import dns.resolver
@@ -664,3 +664,185 @@ def async_run(self, afunc):
664664

665665
except ImportError:
666666
pass
667+
668+
669+
class MockSock:
670+
def __init__(self, wire1, from1, wire2, from2):
671+
self.family = socket.AF_INET
672+
self.first_time = True
673+
self.wire1 = wire1
674+
self.from1 = from1
675+
self.wire2 = wire2
676+
self.from2 = from2
677+
678+
async def sendto(self, data, where, timeout):
679+
return len(data)
680+
681+
async def recvfrom(self, bufsize, expiration):
682+
if self.first_time:
683+
self.first_time = False
684+
return self.wire1, self.from1
685+
else:
686+
return self.wire2, self.from2
687+
688+
689+
class IgnoreErrors(unittest.TestCase):
690+
def setUp(self):
691+
self.q = dns.message.make_query("example.", "A")
692+
self.good_r = dns.message.make_response(self.q)
693+
self.good_r.set_rcode(dns.rcode.NXDOMAIN)
694+
self.good_r_wire = self.good_r.to_wire()
695+
dns.asyncbackend.set_default_backend("asyncio")
696+
697+
def async_run(self, afunc):
698+
return asyncio.run(afunc())
699+
700+
async def mock_receive(
701+
self,
702+
wire1,
703+
from1,
704+
wire2,
705+
from2,
706+
ignore_unexpected=True,
707+
ignore_errors=True,
708+
):
709+
s = MockSock(wire1, from1, wire2, from2)
710+
(r, when, _) = await dns.asyncquery.receive_udp(
711+
s,
712+
("127.0.0.1", 53),
713+
time.time() + 2,
714+
ignore_unexpected=ignore_unexpected,
715+
ignore_errors=ignore_errors,
716+
query=self.q,
717+
)
718+
self.assertEqual(r, self.good_r)
719+
720+
def test_good_mock(self):
721+
async def run():
722+
await self.mock_receive(self.good_r_wire, ("127.0.0.1", 53), None, None)
723+
724+
self.async_run(run)
725+
726+
def test_bad_address(self):
727+
async def run():
728+
await self.mock_receive(
729+
self.good_r_wire, ("127.0.0.2", 53), self.good_r_wire, ("127.0.0.1", 53)
730+
)
731+
732+
self.async_run(run)
733+
734+
def test_bad_address_not_ignored(self):
735+
async def abad():
736+
await self.mock_receive(
737+
self.good_r_wire,
738+
("127.0.0.2", 53),
739+
self.good_r_wire,
740+
("127.0.0.1", 53),
741+
ignore_unexpected=False,
742+
)
743+
744+
def bad():
745+
self.async_run(abad)
746+
747+
self.assertRaises(dns.query.UnexpectedSource, bad)
748+
749+
def test_not_response_not_ignored_udp_level(self):
750+
async def abad():
751+
bad_r = dns.message.make_response(self.q)
752+
bad_r.id += 1
753+
bad_r_wire = bad_r.to_wire()
754+
s = MockSock(
755+
bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
756+
)
757+
await dns.asyncquery.udp(self.good_r, "127.0.0.1", sock=s)
758+
759+
def bad():
760+
self.async_run(abad)
761+
762+
self.assertRaises(dns.query.BadResponse, bad)
763+
764+
def test_bad_id(self):
765+
async def run():
766+
bad_r = dns.message.make_response(self.q)
767+
bad_r.id += 1
768+
bad_r_wire = bad_r.to_wire()
769+
await self.mock_receive(
770+
bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
771+
)
772+
773+
self.async_run(run)
774+
775+
def test_bad_id_not_ignored(self):
776+
bad_r = dns.message.make_response(self.q)
777+
bad_r.id += 1
778+
bad_r_wire = bad_r.to_wire()
779+
780+
async def abad():
781+
(r, wire) = await self.mock_receive(
782+
bad_r_wire,
783+
("127.0.0.1", 53),
784+
self.good_r_wire,
785+
("127.0.0.1", 53),
786+
ignore_errors=False,
787+
)
788+
789+
def bad():
790+
self.async_run(abad)
791+
792+
self.assertRaises(AssertionError, bad)
793+
794+
def test_bad_wire(self):
795+
async def run():
796+
bad_r = dns.message.make_response(self.q)
797+
bad_r.id += 1
798+
bad_r_wire = bad_r.to_wire()
799+
await self.mock_receive(
800+
bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
801+
)
802+
803+
self.async_run(run)
804+
805+
def test_bad_wire_not_ignored(self):
806+
bad_r = dns.message.make_response(self.q)
807+
bad_r.id += 1
808+
bad_r_wire = bad_r.to_wire()
809+
810+
async def abad():
811+
await self.mock_receive(
812+
bad_r_wire[:10],
813+
("127.0.0.1", 53),
814+
self.good_r_wire,
815+
("127.0.0.1", 53),
816+
ignore_errors=False,
817+
)
818+
819+
def bad():
820+
self.async_run(abad)
821+
822+
self.assertRaises(dns.message.ShortHeader, bad)
823+
824+
def test_trailing_wire(self):
825+
async def run():
826+
wire = self.good_r_wire + b"abcd"
827+
await self.mock_receive(
828+
wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
829+
)
830+
831+
self.async_run(run)
832+
833+
def test_trailing_wire_not_ignored(self):
834+
wire = self.good_r_wire + b"abcd"
835+
836+
async def abad():
837+
await self.mock_receive(
838+
wire,
839+
("127.0.0.1", 53),
840+
self.good_r_wire,
841+
("127.0.0.1", 53),
842+
ignore_errors=False,
843+
)
844+
845+
def bad():
846+
self.async_run(abad)
847+
848+
self.assertRaises(dns.message.TrailingJunk, bad)

tests/test_query.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,14 @@ def mock(sock, max_size, expiration):
683683
dns.query._udp_recv = saved
684684

685685

686+
class MockSock:
687+
def __init__(self):
688+
self.family = socket.AF_INET
689+
690+
def sendto(self, data, where):
691+
return len(data)
692+
693+
686694
class IgnoreErrors(unittest.TestCase):
687695
def setUp(self):
688696
self.q = dns.message.make_query("example.", "A")
@@ -758,6 +766,19 @@ def bad():
758766

759767
self.assertRaises(AssertionError, bad)
760768

769+
def test_not_response_not_ignored_udp_level(self):
770+
def bad():
771+
bad_r = dns.message.make_response(self.q)
772+
bad_r.id += 1
773+
bad_r_wire = bad_r.to_wire()
774+
with mock_udp_recv(
775+
bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
776+
):
777+
s = MockSock()
778+
dns.query.udp(self.good_r, "127.0.0.1", sock=s)
779+
780+
self.assertRaises(dns.query.BadResponse, bad)
781+
761782
def test_bad_wire(self):
762783
bad_r = dns.message.make_response(self.q)
763784
bad_r.id += 1

0 commit comments

Comments
 (0)