|
18 | 18 | import asyncio
|
19 | 19 | import random
|
20 | 20 | import socket
|
21 |
| -import sys |
22 | 21 | import time
|
23 | 22 | import unittest
|
24 | 23 |
|
|
28 | 27 | import dns.message
|
29 | 28 | import dns.name
|
30 | 29 | import dns.query
|
| 30 | +import dns.rcode |
31 | 31 | import dns.rdataclass
|
32 | 32 | import dns.rdatatype
|
33 | 33 | import dns.resolver
|
@@ -664,3 +664,185 @@ def async_run(self, afunc):
|
664 | 664 |
|
665 | 665 | except ImportError:
|
666 | 666 | pass
|
| 667 | + |
| 668 | + |
| 669 | +class MockSock: |
| 670 | + def __init__(self, wire1, from1, wire2, from2): |
| 671 | + self.family = socket.AF_INET |
| 672 | + self.first_time = True |
| 673 | + self.wire1 = wire1 |
| 674 | + self.from1 = from1 |
| 675 | + self.wire2 = wire2 |
| 676 | + self.from2 = from2 |
| 677 | + |
| 678 | + async def sendto(self, data, where, timeout): |
| 679 | + return len(data) |
| 680 | + |
| 681 | + async def recvfrom(self, bufsize, expiration): |
| 682 | + if self.first_time: |
| 683 | + self.first_time = False |
| 684 | + return self.wire1, self.from1 |
| 685 | + else: |
| 686 | + return self.wire2, self.from2 |
| 687 | + |
| 688 | + |
| 689 | +class IgnoreErrors(unittest.TestCase): |
| 690 | + def setUp(self): |
| 691 | + self.q = dns.message.make_query("example.", "A") |
| 692 | + self.good_r = dns.message.make_response(self.q) |
| 693 | + self.good_r.set_rcode(dns.rcode.NXDOMAIN) |
| 694 | + self.good_r_wire = self.good_r.to_wire() |
| 695 | + dns.asyncbackend.set_default_backend("asyncio") |
| 696 | + |
| 697 | + def async_run(self, afunc): |
| 698 | + return asyncio.run(afunc()) |
| 699 | + |
| 700 | + async def mock_receive( |
| 701 | + self, |
| 702 | + wire1, |
| 703 | + from1, |
| 704 | + wire2, |
| 705 | + from2, |
| 706 | + ignore_unexpected=True, |
| 707 | + ignore_errors=True, |
| 708 | + ): |
| 709 | + s = MockSock(wire1, from1, wire2, from2) |
| 710 | + (r, when, _) = await dns.asyncquery.receive_udp( |
| 711 | + s, |
| 712 | + ("127.0.0.1", 53), |
| 713 | + time.time() + 2, |
| 714 | + ignore_unexpected=ignore_unexpected, |
| 715 | + ignore_errors=ignore_errors, |
| 716 | + query=self.q, |
| 717 | + ) |
| 718 | + self.assertEqual(r, self.good_r) |
| 719 | + |
| 720 | + def test_good_mock(self): |
| 721 | + async def run(): |
| 722 | + await self.mock_receive(self.good_r_wire, ("127.0.0.1", 53), None, None) |
| 723 | + |
| 724 | + self.async_run(run) |
| 725 | + |
| 726 | + def test_bad_address(self): |
| 727 | + async def run(): |
| 728 | + await self.mock_receive( |
| 729 | + self.good_r_wire, ("127.0.0.2", 53), self.good_r_wire, ("127.0.0.1", 53) |
| 730 | + ) |
| 731 | + |
| 732 | + self.async_run(run) |
| 733 | + |
| 734 | + def test_bad_address_not_ignored(self): |
| 735 | + async def abad(): |
| 736 | + await self.mock_receive( |
| 737 | + self.good_r_wire, |
| 738 | + ("127.0.0.2", 53), |
| 739 | + self.good_r_wire, |
| 740 | + ("127.0.0.1", 53), |
| 741 | + ignore_unexpected=False, |
| 742 | + ) |
| 743 | + |
| 744 | + def bad(): |
| 745 | + self.async_run(abad) |
| 746 | + |
| 747 | + self.assertRaises(dns.query.UnexpectedSource, bad) |
| 748 | + |
| 749 | + def test_not_response_not_ignored_udp_level(self): |
| 750 | + async def abad(): |
| 751 | + bad_r = dns.message.make_response(self.q) |
| 752 | + bad_r.id += 1 |
| 753 | + bad_r_wire = bad_r.to_wire() |
| 754 | + s = MockSock( |
| 755 | + bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53) |
| 756 | + ) |
| 757 | + await dns.asyncquery.udp(self.good_r, "127.0.0.1", sock=s) |
| 758 | + |
| 759 | + def bad(): |
| 760 | + self.async_run(abad) |
| 761 | + |
| 762 | + self.assertRaises(dns.query.BadResponse, bad) |
| 763 | + |
| 764 | + def test_bad_id(self): |
| 765 | + async def run(): |
| 766 | + bad_r = dns.message.make_response(self.q) |
| 767 | + bad_r.id += 1 |
| 768 | + bad_r_wire = bad_r.to_wire() |
| 769 | + await self.mock_receive( |
| 770 | + bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53) |
| 771 | + ) |
| 772 | + |
| 773 | + self.async_run(run) |
| 774 | + |
| 775 | + def test_bad_id_not_ignored(self): |
| 776 | + bad_r = dns.message.make_response(self.q) |
| 777 | + bad_r.id += 1 |
| 778 | + bad_r_wire = bad_r.to_wire() |
| 779 | + |
| 780 | + async def abad(): |
| 781 | + (r, wire) = await self.mock_receive( |
| 782 | + bad_r_wire, |
| 783 | + ("127.0.0.1", 53), |
| 784 | + self.good_r_wire, |
| 785 | + ("127.0.0.1", 53), |
| 786 | + ignore_errors=False, |
| 787 | + ) |
| 788 | + |
| 789 | + def bad(): |
| 790 | + self.async_run(abad) |
| 791 | + |
| 792 | + self.assertRaises(AssertionError, bad) |
| 793 | + |
| 794 | + def test_bad_wire(self): |
| 795 | + async def run(): |
| 796 | + bad_r = dns.message.make_response(self.q) |
| 797 | + bad_r.id += 1 |
| 798 | + bad_r_wire = bad_r.to_wire() |
| 799 | + await self.mock_receive( |
| 800 | + bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53) |
| 801 | + ) |
| 802 | + |
| 803 | + self.async_run(run) |
| 804 | + |
| 805 | + def test_bad_wire_not_ignored(self): |
| 806 | + bad_r = dns.message.make_response(self.q) |
| 807 | + bad_r.id += 1 |
| 808 | + bad_r_wire = bad_r.to_wire() |
| 809 | + |
| 810 | + async def abad(): |
| 811 | + await self.mock_receive( |
| 812 | + bad_r_wire[:10], |
| 813 | + ("127.0.0.1", 53), |
| 814 | + self.good_r_wire, |
| 815 | + ("127.0.0.1", 53), |
| 816 | + ignore_errors=False, |
| 817 | + ) |
| 818 | + |
| 819 | + def bad(): |
| 820 | + self.async_run(abad) |
| 821 | + |
| 822 | + self.assertRaises(dns.message.ShortHeader, bad) |
| 823 | + |
| 824 | + def test_trailing_wire(self): |
| 825 | + async def run(): |
| 826 | + wire = self.good_r_wire + b"abcd" |
| 827 | + await self.mock_receive( |
| 828 | + wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53) |
| 829 | + ) |
| 830 | + |
| 831 | + self.async_run(run) |
| 832 | + |
| 833 | + def test_trailing_wire_not_ignored(self): |
| 834 | + wire = self.good_r_wire + b"abcd" |
| 835 | + |
| 836 | + async def abad(): |
| 837 | + await self.mock_receive( |
| 838 | + wire, |
| 839 | + ("127.0.0.1", 53), |
| 840 | + self.good_r_wire, |
| 841 | + ("127.0.0.1", 53), |
| 842 | + ignore_errors=False, |
| 843 | + ) |
| 844 | + |
| 845 | + def bad(): |
| 846 | + self.async_run(abad) |
| 847 | + |
| 848 | + self.assertRaises(dns.message.TrailingJunk, bad) |
0 commit comments