|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 |
| - |
| 14 | +import hashlib |
15 | 15 | import json
|
16 | 16 | import logging
|
| 17 | +import time |
| 18 | +import uuid |
| 19 | +import warnings |
17 | 20 | from collections import deque
|
18 | 21 | from io import SEEK_END, BytesIO
|
19 | 22 | from typing import (
|
|
27 | 30 | Type,
|
28 | 31 | Union,
|
29 | 32 | )
|
| 33 | +from unittest.mock import Mock |
30 | 34 |
|
31 | 35 | import attr
|
32 | 36 | from typing_extensions import Deque
|
|
53 | 57 | from twisted.web.resource import IResource
|
54 | 58 | from twisted.web.server import Request, Site
|
55 | 59 |
|
| 60 | +from synapse.config.database import DatabaseConnectionConfig |
56 | 61 | from synapse.http.site import SynapseRequest
|
| 62 | +from synapse.server import HomeServer |
| 63 | +from synapse.storage import DataStore |
| 64 | +from synapse.storage.engines import PostgresEngine, create_engine |
57 | 65 | from synapse.types import JsonDict
|
58 | 66 | from synapse.util import Clock
|
59 | 67 |
|
| 68 | +from tests.utils import ( |
| 69 | + LEAVE_DB, |
| 70 | + POSTGRES_BASE_DB, |
| 71 | + POSTGRES_HOST, |
| 72 | + POSTGRES_PASSWORD, |
| 73 | + POSTGRES_USER, |
| 74 | + USE_POSTGRES_FOR_TESTS, |
| 75 | + MockClock, |
| 76 | + default_config, |
| 77 | +) |
60 | 78 |
|
61 | 79 | logger = logging.getLogger(__name__)
|
62 | 80 |
|
@@ -668,3 +686,168 @@ def connect_client(
|
668 | 686 | client.makeConnection(FakeTransport(server, reactor))
|
669 | 687 |
|
670 | 688 | return client, server
|
| 689 | + |
| 690 | + |
| 691 | +class TestHomeServer(HomeServer): |
| 692 | + DATASTORE_CLASS = DataStore |
| 693 | + |
| 694 | + |
| 695 | +def setup_test_homeserver( |
| 696 | + cleanup_func, |
| 697 | + name="test", |
| 698 | + config=None, |
| 699 | + reactor=None, |
| 700 | + homeserver_to_use: Type[HomeServer] = TestHomeServer, |
| 701 | + **kwargs, |
| 702 | +): |
| 703 | + """ |
| 704 | + Setup a homeserver suitable for running tests against. Keyword arguments |
| 705 | + are passed to the Homeserver constructor. |
| 706 | +
|
| 707 | + If no datastore is supplied, one is created and given to the homeserver. |
| 708 | +
|
| 709 | + Args: |
| 710 | + cleanup_func : The function used to register a cleanup routine for |
| 711 | + after the test. |
| 712 | +
|
| 713 | + Calling this method directly is deprecated: you should instead derive from |
| 714 | + HomeserverTestCase. |
| 715 | + """ |
| 716 | + if reactor is None: |
| 717 | + from twisted.internet import reactor |
| 718 | + |
| 719 | + if config is None: |
| 720 | + config = default_config(name, parse=True) |
| 721 | + |
| 722 | + config.ldap_enabled = False |
| 723 | + |
| 724 | + if "clock" not in kwargs: |
| 725 | + kwargs["clock"] = MockClock() |
| 726 | + |
| 727 | + if USE_POSTGRES_FOR_TESTS: |
| 728 | + test_db = "synapse_test_%s" % uuid.uuid4().hex |
| 729 | + |
| 730 | + database_config = { |
| 731 | + "name": "psycopg2", |
| 732 | + "args": { |
| 733 | + "database": test_db, |
| 734 | + "host": POSTGRES_HOST, |
| 735 | + "password": POSTGRES_PASSWORD, |
| 736 | + "user": POSTGRES_USER, |
| 737 | + "cp_min": 1, |
| 738 | + "cp_max": 5, |
| 739 | + }, |
| 740 | + } |
| 741 | + else: |
| 742 | + database_config = { |
| 743 | + "name": "sqlite3", |
| 744 | + "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, |
| 745 | + } |
| 746 | + |
| 747 | + if "db_txn_limit" in kwargs: |
| 748 | + database_config["txn_limit"] = kwargs["db_txn_limit"] |
| 749 | + |
| 750 | + database = DatabaseConnectionConfig("master", database_config) |
| 751 | + config.database.databases = [database] |
| 752 | + |
| 753 | + db_engine = create_engine(database.config) |
| 754 | + |
| 755 | + # Create the database before we actually try and connect to it, based off |
| 756 | + # the template database we generate in setupdb() |
| 757 | + if isinstance(db_engine, PostgresEngine): |
| 758 | + db_conn = db_engine.module.connect( |
| 759 | + database=POSTGRES_BASE_DB, |
| 760 | + user=POSTGRES_USER, |
| 761 | + host=POSTGRES_HOST, |
| 762 | + password=POSTGRES_PASSWORD, |
| 763 | + ) |
| 764 | + db_conn.autocommit = True |
| 765 | + cur = db_conn.cursor() |
| 766 | + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) |
| 767 | + cur.execute( |
| 768 | + "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB) |
| 769 | + ) |
| 770 | + cur.close() |
| 771 | + db_conn.close() |
| 772 | + |
| 773 | + hs = homeserver_to_use( |
| 774 | + name, |
| 775 | + config=config, |
| 776 | + version_string="Synapse/tests", |
| 777 | + reactor=reactor, |
| 778 | + ) |
| 779 | + |
| 780 | + # Install @cache_in_self attributes |
| 781 | + for key, val in kwargs.items(): |
| 782 | + setattr(hs, "_" + key, val) |
| 783 | + |
| 784 | + # Mock TLS |
| 785 | + hs.tls_server_context_factory = Mock() |
| 786 | + hs.tls_client_options_factory = Mock() |
| 787 | + |
| 788 | + hs.setup() |
| 789 | + if homeserver_to_use == TestHomeServer: |
| 790 | + hs.setup_background_tasks() |
| 791 | + |
| 792 | + if isinstance(db_engine, PostgresEngine): |
| 793 | + database = hs.get_datastores().databases[0] |
| 794 | + |
| 795 | + # We need to do cleanup on PostgreSQL |
| 796 | + def cleanup(): |
| 797 | + import psycopg2 |
| 798 | + |
| 799 | + # Close all the db pools |
| 800 | + database._db_pool.close() |
| 801 | + |
| 802 | + dropped = False |
| 803 | + |
| 804 | + # Drop the test database |
| 805 | + db_conn = db_engine.module.connect( |
| 806 | + database=POSTGRES_BASE_DB, |
| 807 | + user=POSTGRES_USER, |
| 808 | + host=POSTGRES_HOST, |
| 809 | + password=POSTGRES_PASSWORD, |
| 810 | + ) |
| 811 | + db_conn.autocommit = True |
| 812 | + cur = db_conn.cursor() |
| 813 | + |
| 814 | + # Try a few times to drop the DB. Some things may hold on to the |
| 815 | + # database for a few more seconds due to flakiness, preventing |
| 816 | + # us from dropping it when the test is over. If we can't drop |
| 817 | + # it, warn and move on. |
| 818 | + for _ in range(5): |
| 819 | + try: |
| 820 | + cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,)) |
| 821 | + db_conn.commit() |
| 822 | + dropped = True |
| 823 | + except psycopg2.OperationalError as e: |
| 824 | + warnings.warn( |
| 825 | + "Couldn't drop old db: " + str(e), category=UserWarning |
| 826 | + ) |
| 827 | + time.sleep(0.5) |
| 828 | + |
| 829 | + cur.close() |
| 830 | + db_conn.close() |
| 831 | + |
| 832 | + if not dropped: |
| 833 | + warnings.warn("Failed to drop old DB.", category=UserWarning) |
| 834 | + |
| 835 | + if not LEAVE_DB: |
| 836 | + # Register the cleanup hook |
| 837 | + cleanup_func(cleanup) |
| 838 | + |
| 839 | + # bcrypt is far too slow to be doing in unit tests |
| 840 | + # Need to let the HS build an auth handler and then mess with it |
| 841 | + # because AuthHandler's constructor requires the HS, so we can't make one |
| 842 | + # beforehand and pass it in to the HS's constructor (chicken / egg) |
| 843 | + async def hash(p): |
| 844 | + return hashlib.md5(p.encode("utf8")).hexdigest() |
| 845 | + |
| 846 | + hs.get_auth_handler().hash = hash |
| 847 | + |
| 848 | + async def validate_hash(p, h): |
| 849 | + return hashlib.md5(p.encode("utf8")).hexdigest() == h |
| 850 | + |
| 851 | + hs.get_auth_handler().validate_hash = validate_hash |
| 852 | + |
| 853 | + return hs |
0 commit comments