Skip to content

Commit 46f718d

Browse files
authored
Fix typing in pl.plugins.environments (#10943)
1 parent 6bfc0bb commit 46f718d

File tree

6 files changed

+33
-42
lines changed

6 files changed

+33
-42
lines changed

pyproject.toml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ module = [
4747
"pytorch_lightning.callbacks.finetuning",
4848
"pytorch_lightning.callbacks.lr_monitor",
4949
"pytorch_lightning.callbacks.model_checkpoint",
50-
"pytorch_lightning.callbacks.prediction_writer",
5150
"pytorch_lightning.callbacks.progress.base",
5251
"pytorch_lightning.callbacks.progress.progress",
5352
"pytorch_lightning.callbacks.progress.rich_progress",
@@ -70,10 +69,6 @@ module = [
7069
"pytorch_lightning.loggers.test_tube",
7170
"pytorch_lightning.loggers.wandb",
7271
"pytorch_lightning.loops.epoch.training_epoch_loop",
73-
"pytorch_lightning.plugins.environments.lightning_environment",
74-
"pytorch_lightning.plugins.environments.lsf_environment",
75-
"pytorch_lightning.plugins.environments.slurm_environment",
76-
"pytorch_lightning.plugins.environments.torchelastic_environment",
7772
"pytorch_lightning.plugins.training_type.ddp",
7873
"pytorch_lightning.plugins.training_type.ddp2",
7974
"pytorch_lightning.plugins.training_type.ddp_spawn",

pytorch_lightning/plugins/environments/lightning_environment.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ class LightningEnvironment(ClusterEnvironment):
3434
training as it provides a convenient way to launch the training script.
3535
"""
3636

37-
def __init__(self):
37+
def __init__(self) -> None:
3838
super().__init__()
39-
self._main_port = None
39+
self._main_port: int = -1
4040
self._global_rank: int = 0
4141
self._world_size: int = 1
4242

@@ -55,9 +55,9 @@ def main_address(self) -> str:
5555

5656
@property
5757
def main_port(self) -> int:
58-
if self._main_port is None:
59-
self._main_port = os.environ.get("MASTER_PORT", find_free_network_port())
60-
return int(self._main_port)
58+
if self._main_port == -1:
59+
self._main_port = int(os.environ.get("MASTER_PORT", find_free_network_port()))
60+
return self._main_port
6161

6262
@staticmethod
6363
def detect() -> bool:

pytorch_lightning/plugins/environments/lsf_environment.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import os
1616
import socket
17+
from typing import Dict, List
1718

1819
from pytorch_lightning import _logger as log
1920
from pytorch_lightning.plugins.environments import ClusterEnvironment
@@ -41,7 +42,7 @@ class LSFEnvironment(ClusterEnvironment):
4142
The world size for the task. This environment variable is set by jsrun
4243
"""
4344

44-
def __init__(self):
45+
def __init__(self) -> None:
4546
super().__init__()
4647
# TODO: remove in 1.7
4748
if hasattr(self, "is_using_lsf") and callable(self.is_using_lsf):
@@ -74,7 +75,7 @@ def detect() -> bool:
7475
required_env_vars = {"LSB_JOBID", "LSB_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE"}
7576
return required_env_vars.issubset(os.environ.keys())
7677

77-
def world_size(self):
78+
def world_size(self) -> int:
7879
"""The world size is read from the environment variable `JSM_NAMESPACE_SIZE`."""
7980
var = "JSM_NAMESPACE_SIZE"
8081
world_size = os.environ.get(var)
@@ -88,7 +89,7 @@ def world_size(self):
8889
def set_world_size(self, size: int) -> None:
8990
log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
9091

91-
def global_rank(self):
92+
def global_rank(self) -> int:
9293
"""The world size is read from the environment variable `JSM_NAMESPACE_RANK`."""
9394
var = "JSM_NAMESPACE_RANK"
9495
global_rank = os.environ.get(var)
@@ -102,7 +103,7 @@ def global_rank(self):
102103
def set_global_rank(self, rank: int) -> None:
103104
log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
104105

105-
def local_rank(self):
106+
def local_rank(self) -> int:
106107
"""The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`."""
107108
var = "JSM_NAMESPACE_LOCAL_RANK"
108109
local_rank = os.environ.get(var)
@@ -113,11 +114,11 @@ def local_rank(self):
113114
)
114115
return int(local_rank)
115116

116-
def node_rank(self):
117+
def node_rank(self) -> int:
117118
"""The node rank is determined by the position of the current hostname in the list of hosts stored in the
118119
environment variable `LSB_HOSTS`."""
119120
hosts = self._read_hosts()
120-
count = {}
121+
count: Dict[str, int] = {}
121122
for host in hosts:
122123
if "batch" in host or "login" in host:
123124
continue
@@ -126,7 +127,7 @@ def node_rank(self):
126127
return count[socket.gethostname()]
127128

128129
@staticmethod
129-
def _read_hosts():
130+
def _read_hosts() -> List[str]:
130131
hosts = os.environ.get("LSB_HOSTS")
131132
if not hosts:
132133
raise ValueError("Could not find hosts in environment variable LSB_HOSTS")
@@ -148,15 +149,13 @@ def _get_main_port() -> int:
148149
Uses the LSF job ID so all ranks can compute the main port.
149150
"""
150151
# check for user-specified main port
151-
port = os.environ.get("MASTER_PORT")
152-
if not port:
153-
jobid = os.environ.get("LSB_JOBID")
154-
if not jobid:
155-
raise ValueError("Could not find job id in environment variable LSB_JOBID")
156-
port = int(jobid)
152+
if "MASTER_PORT" in os.environ:
153+
log.debug(f"Using externally specified main port: {os.environ['MASTER_PORT']}")
154+
return int(os.environ["MASTER_PORT"])
155+
if "LSB_JOBID" in os.environ:
156+
port = int(os.environ["LSB_JOBID"])
157157
# all ports should be in the 10k+ range
158-
port = int(port) % 1000 + 10000
158+
port = port % 1000 + 10000
159159
log.debug(f"calculated LSF main port: {port}")
160-
else:
161-
log.debug(f"using externally specified main port: {port}")
162-
return int(port)
160+
return port
161+
raise ValueError("Could not find job id in environment variable LSB_JOBID")

pytorch_lightning/plugins/environments/slurm_environment.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ def main_port(self) -> int:
5858
# SLURM JOB = PORT number
5959
# -----------------------
6060
# this way every process knows what port to use
61-
default_port = os.environ.get("SLURM_JOB_ID")
62-
if default_port:
61+
job_id = os.environ.get("SLURM_JOB_ID")
62+
if job_id is not None:
6363
# use the last 4 numbers in the job id as the id
64-
default_port = default_port[-4:]
64+
default_port = job_id[-4:]
6565
# all ports should be in the 10k+ range
6666
default_port = int(default_port) + 15000
6767
else:
@@ -72,13 +72,12 @@ def main_port(self) -> int:
7272
# -----------------------
7373
# in case the user passed it in
7474
if "MASTER_PORT" in os.environ:
75-
default_port = os.environ["MASTER_PORT"]
75+
default_port = int(os.environ["MASTER_PORT"])
7676
else:
7777
os.environ["MASTER_PORT"] = str(default_port)
7878

7979
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
80-
81-
return int(default_port)
80+
return default_port
8281

8382
@staticmethod
8483
def detect() -> bool:

pytorch_lightning/plugins/environments/torchelastic_environment.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import logging
1616
import os
17-
from typing import Optional
1817

1918
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
2019
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
@@ -45,8 +44,7 @@ def main_address(self) -> str:
4544
rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost")
4645
os.environ["MASTER_ADDR"] = "127.0.0.1"
4746
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
48-
main_address = os.environ.get("MASTER_ADDR")
49-
return main_address
47+
return os.environ["MASTER_ADDR"]
5048

5149
@property
5250
def main_port(self) -> int:
@@ -55,18 +53,16 @@ def main_port(self) -> int:
5553
os.environ["MASTER_PORT"] = "12910"
5654
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
5755

58-
port = int(os.environ.get("MASTER_PORT"))
59-
return port
56+
return int(os.environ["MASTER_PORT"])
6057

6158
@staticmethod
6259
def detect() -> bool:
6360
"""Returns ``True`` if the current process was launched using the torchelastic command."""
6461
required_env_vars = {"RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE"}
6562
return required_env_vars.issubset(os.environ.keys())
6663

67-
def world_size(self) -> Optional[int]:
68-
world_size = os.environ.get("WORLD_SIZE")
69-
return int(world_size) if world_size is not None else world_size
64+
def world_size(self) -> int:
65+
return int(os.environ["WORLD_SIZE"])
7066

7167
def set_world_size(self, size: int) -> None:
7268
log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

tests/plugins/environments/test_torchelastic_environment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ def test_default_attributes():
2727
assert env.creates_processes_externally
2828
assert env.main_address == "127.0.0.1"
2929
assert env.main_port == 12910
30-
assert env.world_size() is None
30+
with pytest.raises(KeyError):
31+
# world size is required to be passed as env variable
32+
env.world_size()
3133
with pytest.raises(KeyError):
3234
# local rank is required to be passed as env variable
3335
env.local_rank()

0 commit comments

Comments
 (0)