Skip to content

Commit cf9fda4

Browse files
committed
Fix mypy type errors across service-related files in Selenium WebDriver
1 parent 0b19300 commit cf9fda4

File tree

15 files changed

+117
-91
lines changed

15 files changed

+117
-91
lines changed

py/selenium/webdriver/chromium/service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949

5050
if isinstance(log_output, str):
5151
self.service_args.append(f"--log-path={log_output}")
52-
self.log_output: Optional[IOBase] = None
52+
self.log_output: cast(IOBase, None)
5353
elif isinstance(log_output, IOBase):
5454
self.log_output = log_output
5555
else:

py/selenium/webdriver/chromium/webdriver.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,12 @@ def __init__(
3333
self,
3434
browser_name: Optional[str] = None,
3535
vendor_prefix: Optional[str] = None,
36-
options: ArgOptions = ArgOptions(),
36+
options: ArgOptions = None,
3737
service: Optional[Service] = None,
3838
keep_alive: bool = True,
3939
) -> None:
40+
if options is None:
41+
options = ArgOptions()
4042
"""Creates a new WebDriver instance of the ChromiumDriver. Starts the
4143
service and then creates new WebDriver instance of ChromiumDriver.
4244
@@ -49,6 +51,9 @@ def __init__(
4951
"""
5052
self.service = service
5153

54+
if self.service is None:
55+
raise ValueError("Service must be provided and cannot be None")
56+
5257
finder = DriverFinder(self.service, options)
5358
if finder.get_browser_path():
5459
options.binary_location = finder.get_browser_path()
@@ -59,8 +64,8 @@ def __init__(
5964

6065
executor = ChromiumRemoteConnection(
6166
remote_server_addr=self.service.service_url,
62-
browser_name=browser_name,
63-
vendor_prefix=vendor_prefix,
67+
browser_name=browser_name or "",
68+
vendor_prefix=vendor_prefix or "",
6469
keep_alive=keep_alive,
6570
ignore_proxy=options._ignore_local_proxy,
6671
)
@@ -221,7 +226,8 @@ def quit(self) -> None:
221226
# We don't care about the message because something probably has gone wrong
222227
pass
223228
finally:
224-
self.service.stop()
229+
if self.service is not None:
230+
self.service.stop()
225231

226232
def download_file(self, *args, **kwargs):
227233
raise NotImplementedError

py/selenium/webdriver/common/bidi/browser.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,13 @@ def from_dict(cls, data: dict) -> "ClientWindowInfo":
125125
ClientWindowInfo: A new instance of ClientWindowInfo.
126126
"""
127127
return cls(
128-
client_window=data.get("clientWindow"),
129-
state=data.get("state"),
130-
width=data.get("width"),
131-
height=data.get("height"),
132-
x=data.get("x"),
133-
y=data.get("y"),
134-
active=data.get("active"),
128+
client_window=str(data.get("clientWindow")),
129+
state=str(data.get("state")),
130+
width=int(data.get("width") or 0),
131+
height=int(data.get("height") or 0),
132+
x=int(data.get("x") or 0),
133+
y=int(data.get("y") or 0),
134+
active=bool(data.get("active")),
135135
)
136136

137137

py/selenium/webdriver/common/bidi/browsing_context.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from typing import Optional, Union
18+
from typing import Optional, Union, Callable
1919

2020
from selenium.webdriver.common.bidi.common import command_builder
2121

@@ -67,10 +67,10 @@ def from_json(cls, json: dict) -> "NavigationInfo":
6767
NavigationInfo: A new instance of NavigationInfo.
6868
"""
6969
return cls(
70-
context=json.get("context"),
70+
context=str(json.get("context")),
7171
navigation=json.get("navigation"),
72-
timestamp=json.get("timestamp"),
73-
url=json.get("url"),
72+
timestamp=int(json.get("timestamp") or 0),
73+
url=str(json.get("url")),
7474
)
7575

7676

@@ -109,11 +109,11 @@ def from_json(cls, json: dict) -> "BrowsingContextInfo":
109109
"""
110110
children = None
111111
if json.get("children") is not None:
112-
children = [BrowsingContextInfo.from_json(child) for child in json.get("children")]
112+
children = [BrowsingContextInfo.from_json(child) for child in json.get("children", [])]
113113

114114
return cls(
115-
context=json.get("context"),
116-
url=json.get("url"),
115+
context=str(json.get("context")),
116+
url=str(json.get("url")),
117117
children=children,
118118
parent=json.get("parent"),
119119
user_context=json.get("userContext"),
@@ -149,11 +149,11 @@ def from_json(cls, json: dict) -> "DownloadWillBeginParams":
149149
DownloadWillBeginParams: A new instance of DownloadWillBeginParams.
150150
"""
151151
return cls(
152-
context=json.get("context"),
152+
context=str(json.get("context")),
153153
navigation=json.get("navigation"),
154-
timestamp=json.get("timestamp"),
155-
url=json.get("url"),
156-
suggested_filename=json.get("suggestedFilename"),
154+
timestamp=int(json.get("timestamp") or 0),
155+
url=str(json.get("url")),
156+
suggested_filename=str(json.get("suggestedFilename")),
157157
)
158158

159159

@@ -187,10 +187,10 @@ def from_json(cls, json: dict) -> "UserPromptOpenedParams":
187187
UserPromptOpenedParams: A new instance of UserPromptOpenedParams.
188188
"""
189189
return cls(
190-
context=json.get("context"),
191-
handler=json.get("handler"),
192-
message=json.get("message"),
193-
type=json.get("type"),
190+
context=str(json.get("context")),
191+
handler=str(json.get("handler")),
192+
message=str(json.get("message")),
193+
type=str(json.get("type")),
194194
default_value=json.get("defaultValue"),
195195
)
196196

@@ -223,9 +223,9 @@ def from_json(cls, json: dict) -> "UserPromptClosedParams":
223223
UserPromptClosedParams: A new instance of UserPromptClosedParams.
224224
"""
225225
return cls(
226-
context=json.get("context"),
227-
accepted=json.get("accepted"),
228-
type=json.get("type"),
226+
context=str(json.get("context")),
227+
accepted=bool(json.get("accepted")),
228+
type=str(json.get("type")),
229229
user_text=json.get("userText"),
230230
)
231231

@@ -254,8 +254,8 @@ def from_json(cls, json: dict) -> "HistoryUpdatedParams":
254254
HistoryUpdatedParams: A new instance of HistoryUpdatedParams.
255255
"""
256256
return cls(
257-
context=json.get("context"),
258-
url=json.get("url"),
257+
context=str(json.get("context")),
258+
url=str(json.get("url")),
259259
)
260260

261261

@@ -278,7 +278,7 @@ def from_json(cls, json: dict) -> "BrowsingContextEvent":
278278
-------
279279
BrowsingContextEvent: A new instance of BrowsingContextEvent.
280280
"""
281-
return cls(event_class=json.get("event_class"), **json)
281+
return cls(event_class=str(json.get("event_class")), **json)
282282

283283

284284
class BrowsingContext:
@@ -341,9 +341,9 @@ def capture_screenshot(
341341
"""
342342
params = {"context": context, "origin": origin}
343343
if format is not None:
344-
params["format"] = format
344+
params["format"] = str(format)
345345
if clip is not None:
346-
params["clip"] = clip
346+
params["clip"] = str(clip)
347347

348348
result = self.conn.execute(command_builder("browsingContext.captureScreenshot", params))
349349
return result["data"]
@@ -387,7 +387,7 @@ def create(
387387
if reference_context is not None:
388388
params["referenceContext"] = reference_context
389389
if background is not None:
390-
params["background"] = background
390+
params["background"] = str(background)
391391
if user_context is not None:
392392
params["userContext"] = user_context
393393

@@ -415,7 +415,7 @@ def get_tree(
415415
if max_depth is not None:
416416
params["maxDepth"] = max_depth
417417
if root is not None:
418-
params["root"] = root
418+
params["root"] = int(root or 0)
419419

420420
result = self.conn.execute(command_builder("browsingContext.getTree", params))
421421
return [BrowsingContextInfo.from_json(context) for context in result["contexts"]]
@@ -436,7 +436,7 @@ def handle_user_prompt(
436436
"""
437437
params = {"context": context}
438438
if accept is not None:
439-
params["accept"] = accept
439+
params["accept"] = str(accept)
440440
if user_text is not None:
441441
params["userText"] = user_text
442442

@@ -466,7 +466,7 @@ def locate_nodes(
466466
"""
467467
params = {"context": context, "locator": locator}
468468
if max_node_count is not None:
469-
params["maxNodeCount"] = max_node_count
469+
params["maxNodeCount"] = [int(max_node_count or 0)]
470470
if serialization_options is not None:
471471
params["serializationOptions"] = serialization_options
472472
if start_nodes is not None:
@@ -566,7 +566,7 @@ def reload(
566566
"""
567567
params = {"context": context}
568568
if ignore_cache is not None:
569-
params["ignoreCache"] = ignore_cache
569+
params["ignoreCache"] = str(ignore_cache)
570570
if wait is not None:
571571
params["wait"] = wait
572572

@@ -597,11 +597,11 @@ def set_viewport(
597597
if context is not None:
598598
params["context"] = context
599599
if viewport is not None:
600-
params["viewport"] = viewport
600+
params["viewport"] = str(viewport)
601601
if device_pixel_ratio is not None:
602-
params["devicePixelRatio"] = device_pixel_ratio
602+
params["devicePixelRatio"] = str(device_pixel_ratio)
603603
if user_contexts is not None:
604-
params["userContexts"] = user_contexts
604+
params["userContexts"] = str(user_contexts)
605605

606606
self.conn.execute(command_builder("browsingContext.setViewport", params))
607607

@@ -621,7 +621,7 @@ def traverse_history(self, context: str, delta: int) -> dict:
621621
result = self.conn.execute(command_builder("browsingContext.traverseHistory", params))
622622
return result
623623

624-
def _on_event(self, event_name: str, callback: callable) -> int:
624+
def _on_event(self, event_name: str, callback: Callable) -> int:
625625
"""Set a callback function to subscribe to a browsing context event.
626626
627627
Parameters:
@@ -665,7 +665,7 @@ def _callback(event_data):
665665

666666
return callback_id
667667

668-
def add_event_handler(self, event: str, callback: callable, contexts: Optional[list[str]] = None) -> int:
668+
def add_event_handler(self, event: str, callback: Callable, contexts: Optional[list[str]] = None) -> int:
669669
"""Add an event handler to the browsing context.
670670
671671
Parameters:
@@ -710,7 +710,7 @@ def remove_event_handler(self, event: str, callback_id: int) -> None:
710710
except KeyError:
711711
raise Exception(f"Event {event} not found")
712712

713-
event = BrowsingContextEvent(event_name)
713+
event = str(BrowsingContextEvent(event_name))
714714

715715
self.conn.remove_callback(event, callback_id)
716716
self.subscriptions[event_name].remove(callback_id)

py/selenium/webdriver/common/bidi/cdp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,10 @@ async def connect_session(self, target_id) -> "CdpSession":
427427
"""Returns a new :class:`CdpSession` connected to the specified
428428
target."""
429429
global devtools
430+
if devtools and devtools.target:
431+
session_id = await self.execute(devtools.target.attach_to_target(target_id, True))
432+
else:
433+
raise RuntimeError("devtools.target is not available.")
430434
session_id = await self.execute(devtools.target.attach_to_target(target_id, True))
431435
session = CdpSession(self.ws, session_id, target_id)
432436
self.sessions[session_id] = session

py/selenium/webdriver/common/bidi/storage.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ def from_dict(cls, data: dict) -> "Cookie":
8888
value = BytesValue(data.get("value", {}).get("type"), data.get("value", {}).get("value"))
8989

9090
return cls(
91-
name=data.get("name"),
91+
name=str(data.get("name")),
9292
value=value,
93-
domain=data.get("domain"),
93+
domain=str(data.get("domain")),
9494
path=data.get("path"),
9595
size=data.get("size"),
9696
http_only=data.get("httpOnly"),
@@ -136,21 +136,21 @@ def to_dict(self) -> dict:
136136
if self.name is not None:
137137
result["name"] = self.name
138138
if self.value is not None:
139-
result["value"] = self.value.to_dict()
139+
result["value"] = str(self.value.to_dict())
140140
if self.domain is not None:
141141
result["domain"] = self.domain
142142
if self.path is not None:
143143
result["path"] = self.path
144144
if self.size is not None:
145-
result["size"] = self.size
145+
result["size"] = str(self.size)
146146
if self.http_only is not None:
147-
result["httpOnly"] = self.http_only
147+
result["httpOnly"] = str(self.http_only)
148148
if self.secure is not None:
149-
result["secure"] = self.secure
149+
result["secure"] = str(self.secure)
150150
if self.same_site is not None:
151151
result["sameSite"] = self.same_site
152152
if self.expiry is not None:
153-
result["expiry"] = self.expiry
153+
result["expiry"] = str(self.expiry)
154154
return result
155155

156156

@@ -257,13 +257,13 @@ def to_dict(self) -> dict:
257257
if self.path is not None:
258258
result["path"] = self.path
259259
if self.http_only is not None:
260-
result["httpOnly"] = self.http_only
260+
result["httpOnly"] = [self.http_only]
261261
if self.secure is not None:
262-
result["secure"] = self.secure
262+
result["secure"] = [self.secure]
263263
if self.same_site is not None:
264264
result["sameSite"] = self.same_site
265265
if self.expiry is not None:
266-
result["expiry"] = self.expiry
266+
result["expiry"] = [self.expiry]
267267
return result
268268

269269

py/selenium/webdriver/common/options.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def __init__(self) -> None:
422422
self._caps = self.default_capabilities
423423
self._proxy = None
424424
self.set_capability("pageLoadStrategy", PageLoadStrategy.normal)
425-
self.mobile_options = None
425+
self.mobile_options: Optional[dict[str, str]] = None
426426
self._ignore_local_proxy = False
427427

428428
@property
@@ -475,6 +475,7 @@ class ArgOptions(BaseOptions):
475475
def __init__(self) -> None:
476476
super().__init__()
477477
self._arguments: list[str] = []
478+
self.binary_location: Optional[str] = None
478479

479480
@property
480481
def arguments(self):

py/selenium/webdriver/common/service.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ def __init__(
5858
) -> None:
5959
if isinstance(log_output, str):
6060
self.log_output = cast(IOBase, open(log_output, "a+", encoding="utf-8"))
61-
elif log_output == subprocess.STDOUT:
62-
self.log_output = cast(Optional[Union[int, IOBase]], None)
63-
elif log_output is None or log_output == subprocess.DEVNULL:
64-
self.log_output = cast(Optional[Union[int, IOBase]], subprocess.DEVNULL)
65-
else:
61+
elif log_output in {subprocess.STDOUT, subprocess.DEVNULL, None}:
62+
self.log_output = cast(IOBase, subprocess.DEVNULL)
63+
elif isinstance(log_output, IOBase):
6664
self.log_output = log_output
65+
else:
66+
raise TypeError("log_output must be a string, IOBase, or a valid subprocess constant")
6767

6868
self.port = port or utils.free_port()
6969
# Default value for every python subprocess: subprocess.Popen(..., creationflags=0)

0 commit comments

Comments
 (0)