|
15 | 15 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
|
16 | 16 | # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
17 | 17 |
|
| 18 | +import contextlib |
18 | 19 | import socket
|
19 | 20 | import sys
|
20 | 21 | import time
|
|
32 | 33 | import dns.message
|
33 | 34 | import dns.name
|
34 | 35 | import dns.query
|
| 36 | +import dns.rcode |
35 | 37 | import dns.rdataclass
|
36 | 38 | import dns.rdatatype
|
37 | 39 | import dns.tsigkeyring
|
@@ -659,3 +661,141 @@ def test_matches_destination(self):
|
659 | 661 | dns.query._matches_destination(
|
660 | 662 | socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1235), False
|
661 | 663 | )
|
| 664 | + |
| 665 | + |
| 666 | +@contextlib.contextmanager |
| 667 | +def mock_udp_recv(wire1, from1, wire2, from2): |
| 668 | + saved = dns.query._udp_recv |
| 669 | + first_time = True |
| 670 | + |
| 671 | + def mock(sock, max_size, expiration): |
| 672 | + nonlocal first_time |
| 673 | + if first_time: |
| 674 | + first_time = False |
| 675 | + return wire1, from1 |
| 676 | + else: |
| 677 | + return wire2, from2 |
| 678 | + |
| 679 | + try: |
| 680 | + dns.query._udp_recv = mock |
| 681 | + yield None |
| 682 | + finally: |
| 683 | + dns.query._udp_recv = saved |
| 684 | + |
| 685 | + |
| 686 | +class IgnoreErrors(unittest.TestCase): |
| 687 | + def setUp(self): |
| 688 | + self.q = dns.message.make_query("example.", "A") |
| 689 | + self.good_r = dns.message.make_response(self.q) |
| 690 | + self.good_r.set_rcode(dns.rcode.NXDOMAIN) |
| 691 | + self.good_r_wire = self.good_r.to_wire() |
| 692 | + |
| 693 | + def mock_receive( |
| 694 | + self, |
| 695 | + wire1, |
| 696 | + from1, |
| 697 | + wire2, |
| 698 | + from2, |
| 699 | + ignore_unexpected=True, |
| 700 | + ignore_errors=True, |
| 701 | + ): |
| 702 | + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) |
| 703 | + try: |
| 704 | + with mock_udp_recv(wire1, from1, wire2, from2): |
| 705 | + (r, when) = dns.query.receive_udp( |
| 706 | + s, |
| 707 | + ("127.0.0.1", 53), |
| 708 | + time.time() + 2, |
| 709 | + ignore_unexpected=ignore_unexpected, |
| 710 | + ignore_errors=ignore_errors, |
| 711 | + query=self.q, |
| 712 | + ) |
| 713 | + self.assertEqual(r, self.good_r) |
| 714 | + finally: |
| 715 | + s.close() |
| 716 | + |
| 717 | + def test_good_mock(self): |
| 718 | + self.mock_receive(self.good_r_wire, ("127.0.0.1", 53), None, None) |
| 719 | + |
| 720 | + def test_bad_address(self): |
| 721 | + self.mock_receive( |
| 722 | + self.good_r_wire, ("127.0.0.2", 53), self.good_r_wire, ("127.0.0.1", 53) |
| 723 | + ) |
| 724 | + |
| 725 | + def test_bad_address_not_ignored(self): |
| 726 | + def bad(): |
| 727 | + self.mock_receive( |
| 728 | + self.good_r_wire, |
| 729 | + ("127.0.0.2", 53), |
| 730 | + self.good_r_wire, |
| 731 | + ("127.0.0.1", 53), |
| 732 | + ignore_unexpected=False, |
| 733 | + ) |
| 734 | + |
| 735 | + self.assertRaises(dns.query.UnexpectedSource, bad) |
| 736 | + |
| 737 | + def test_bad_id(self): |
| 738 | + bad_r = dns.message.make_response(self.q) |
| 739 | + bad_r.id += 1 |
| 740 | + bad_r_wire = bad_r.to_wire() |
| 741 | + self.mock_receive( |
| 742 | + bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53) |
| 743 | + ) |
| 744 | + |
| 745 | + def test_bad_id_not_ignored(self): |
| 746 | + bad_r = dns.message.make_response(self.q) |
| 747 | + bad_r.id += 1 |
| 748 | + bad_r_wire = bad_r.to_wire() |
| 749 | + |
| 750 | + def bad(): |
| 751 | + (r, wire) = self.mock_receive( |
| 752 | + bad_r_wire, |
| 753 | + ("127.0.0.1", 53), |
| 754 | + self.good_r_wire, |
| 755 | + ("127.0.0.1", 53), |
| 756 | + ignore_errors=False, |
| 757 | + ) |
| 758 | + |
| 759 | + self.assertRaises(AssertionError, bad) |
| 760 | + |
| 761 | + def test_bad_wire(self): |
| 762 | + bad_r = dns.message.make_response(self.q) |
| 763 | + bad_r.id += 1 |
| 764 | + bad_r_wire = bad_r.to_wire() |
| 765 | + self.mock_receive( |
| 766 | + bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53) |
| 767 | + ) |
| 768 | + |
| 769 | + def test_bad_wire_not_ignored(self): |
| 770 | + bad_r = dns.message.make_response(self.q) |
| 771 | + bad_r.id += 1 |
| 772 | + bad_r_wire = bad_r.to_wire() |
| 773 | + |
| 774 | + def bad(): |
| 775 | + self.mock_receive( |
| 776 | + bad_r_wire[:10], |
| 777 | + ("127.0.0.1", 53), |
| 778 | + self.good_r_wire, |
| 779 | + ("127.0.0.1", 53), |
| 780 | + ignore_errors=False, |
| 781 | + ) |
| 782 | + |
| 783 | + self.assertRaises(dns.message.ShortHeader, bad) |
| 784 | + |
| 785 | + def test_trailing_wire(self): |
| 786 | + wire = self.good_r_wire + b"abcd" |
| 787 | + self.mock_receive(wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)) |
| 788 | + |
| 789 | + def test_trailing_wire_not_ignored(self): |
| 790 | + wire = self.good_r_wire + b"abcd" |
| 791 | + |
| 792 | + def bad(): |
| 793 | + self.mock_receive( |
| 794 | + wire, |
| 795 | + ("127.0.0.1", 53), |
| 796 | + self.good_r_wire, |
| 797 | + ("127.0.0.1", 53), |
| 798 | + ignore_errors=False, |
| 799 | + ) |
| 800 | + |
| 801 | + self.assertRaises(dns.message.TrailingJunk, bad) |
0 commit comments