Skip to content

Commit 3927427

Browse files
authored
Update accelerator.py (#7318)
1 parent badd0bb commit 3927427

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import contextlib
1515
from collections import defaultdict
16-
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
16+
from typing import Any, Callable, DefaultDict, Dict, Generator, Iterable, List, Optional, Union
1717

1818
import torch
1919
from torch import Tensor
@@ -114,7 +114,7 @@ def pre_dispatch(self, trainer: 'pl.Trainer') -> None:
114114
def _move_optimizer_state(self) -> None:
115115
""" Moves the state of the optimizers to the GPU if needed. """
116116
for opt in self.optimizers:
117-
state = defaultdict(dict)
117+
state: DefaultDict = defaultdict(dict)
118118
for p, v in opt.state.items():
119119
state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
120120
opt.state = state

0 commit comments

Comments
 (0)