-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Unify checks #12069
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unify checks #12069
Changes from all commits
79183b6
e21d607
d37ff3e
2c75a7b
e00a14b
7018049
6e3168d
4f491a2
89d7d38
983af0b
b810bea
074742b
5748eb6
5d0457f
e0d6e3b
f26927e
d88b0b1
e9c4c1e
f08702f
b6895d2
80add63
83eef49
d4d0ddb
29c1b67
ea69ecc
2d0b3fa
c3eeafa
ba34815
7c0bf5f
aa0b208
b7965b4
9060cb6
c19a512
afedb53
18a4ee5
ee049f2
fdfba34
7e741c9
58af81f
2dcc7c7
5ddb327
3061171
1d88442
0265dfe
0451073
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,8 @@ | |
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase | ||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment | ||
from pytorch_lightning.plugins.precision import PrecisionPlugin | ||
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin | ||
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin | ||
from pytorch_lightning.strategies.ddp import DDPStrategy | ||
from pytorch_lightning.trainer.states import TrainerFn | ||
from pytorch_lightning.utilities import GradClipAlgorithmType | ||
|
@@ -39,7 +41,7 @@ | |
get_default_process_group_backend_for_device, | ||
log, | ||
) | ||
from pytorch_lightning.utilities.enums import AMPType, PrecisionType | ||
from pytorch_lightning.utilities.enums import PrecisionType | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE | ||
from pytorch_lightning.utilities.model_helpers import is_overridden | ||
|
@@ -651,7 +653,7 @@ def _auto_select_batch_size(self): | |
|
||
def _format_precision_config(self) -> None: | ||
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED): | ||
if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE: | ||
if "fp16" not in self.config and isinstance(self.precision_plugin, NativeMixedPrecisionPlugin): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note, the test (test_trainer_model_hook_system_fit) failed because the type was actually Since the DeepSpeedPrecisionPlugin is a weird basically bundling both native and apex together, we probably have to revert the changes here or have a more radical change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Deepspeed only supports native AMP, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It appears so yes. I too think it would be worth investigating #12323 |
||
# FP16 is a DeepSpeed standalone AMP implementation | ||
rank_zero_info("Enabling DeepSpeed FP16.") | ||
self.config["fp16"] = { | ||
|
@@ -662,7 +664,7 @@ def _format_precision_config(self) -> None: | |
"hysteresis": self.hysteresis, | ||
"min_loss_scale": self.min_loss_scale, | ||
} | ||
elif "amp" not in self.config and self.precision_plugin.amp_type == AMPType.APEX: | ||
elif "amp" not in self.config and isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): | ||
rank_zero_info("Enabling DeepSpeed APEX Implementation.") | ||
self.config["amp"] = {"enabled": True, "opt_level": self.precision_plugin.amp_level} | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.