|
20 | 20 | import hmac
|
21 | 21 | import logging
|
22 | 22 | import sys
|
23 |
| -from typing import Callable, Optional |
| 23 | +from typing import Any, Callable, Dict, Optional |
24 | 24 |
|
25 | 25 | import requests
|
26 | 26 | import yaml
|
27 | 27 |
|
| 28 | +_DEFAULT_SERVER_URL = "http://localhost:8008" |
| 29 | + |
28 | 30 |
|
29 | 31 | def request_registration(
|
30 | 32 | user: str,
|
@@ -203,31 +205,74 @@ def main() -> None:
|
203 | 205 |
|
204 | 206 | parser.add_argument(
|
205 | 207 | "server_url",
|
206 |
| - default="https://localhost:8448", |
207 | 208 | nargs="?",
|
208 |
| - help="URL to use to talk to the homeserver. Defaults to " |
209 |
| - " 'https://localhost:8448'.", |
| 209 | + help="URL to use to talk to the homeserver. By default, tries to find a " |
| 210 | + "suitable URL from the configuration file. Otherwise, defaults to " |
| 211 | + f"'{_DEFAULT_SERVER_URL}'.", |
210 | 212 | )
|
211 | 213 |
|
212 | 214 | args = parser.parse_args()
|
213 | 215 |
|
214 | 216 | if "config" in args and args.config:
|
215 | 217 | config = yaml.safe_load(args.config)
|
| 218 | + |
| 219 | + if args.shared_secret: |
| 220 | + secret = args.shared_secret |
| 221 | + else: |
| 222 | + # argparse should check that we have either config or shared secret |
| 223 | + assert config |
| 224 | + |
216 | 225 | secret = config.get("registration_shared_secret", None)
|
217 | 226 | if not secret:
|
218 | 227 | print("No 'registration_shared_secret' defined in config.")
|
219 | 228 | sys.exit(1)
|
| 229 | + |
| 230 | + if args.server_url: |
| 231 | + server_url = args.server_url |
| 232 | + elif config: |
| 233 | + server_url = _find_client_listener(config) |
| 234 | + if not server_url: |
| 235 | + server_url = _DEFAULT_SERVER_URL |
| 236 | + print( |
| 237 | + "Unable to find a suitable HTTP listener in the configuration file. " |
| 238 | + f"Trying {server_url} as a last resort.", |
| 239 | + file=sys.stderr, |
| 240 | + ) |
220 | 241 | else:
|
221 |
| - secret = args.shared_secret |
| 242 | + server_url = _DEFAULT_SERVER_URL |
| 243 | + print( |
| 244 | + f"No server url or configuration file given. Defaulting to {server_url}.", |
| 245 | + file=sys.stderr, |
| 246 | + ) |
222 | 247 |
|
223 | 248 | admin = None
|
224 | 249 | if args.admin or args.no_admin:
|
225 | 250 | admin = args.admin
|
226 | 251 |
|
227 | 252 | register_new_user(
|
228 |
| - args.user, args.password, args.server_url, secret, admin, args.user_type |
| 253 | + args.user, args.password, server_url, secret, admin, args.user_type |
229 | 254 | )
|
230 | 255 |
|
231 | 256 |
|
| 257 | +def _find_client_listener(config: Dict[str, Any]) -> Optional[str]: |
| 258 | + # try to find a listener in the config. Returns a host:port pair |
| 259 | + for listener in config.get("listeners", []): |
| 260 | + if listener.get("type") != "http" or listener.get("tls", False): |
| 261 | + continue |
| 262 | + |
| 263 | + if not any( |
| 264 | + name == "client" |
| 265 | + for resource in listener.get("resources", []) |
| 266 | + for name in resource.get("names", []) |
| 267 | + ): |
| 268 | + continue |
| 269 | + |
| 270 | + # TODO: consider bind_addresses |
| 271 | + return f"http://localhost:{listener['port']}" |
| 272 | + |
| 273 | + # no suitable listeners? |
| 274 | + return None |
| 275 | + |
| 276 | + |
232 | 277 | if __name__ == "__main__":
|
233 | 278 | main()
|
0 commit comments