diff --git a/model.py b/model.py index 302e111..ea9a54e 100644 --- a/model.py +++ b/model.py @@ -1491,6 +1491,8 @@ def set_trainable(self, layer_regex, model=None, indent=0, verbose=1): trainable = bool(re.fullmatch(layer_regex, layer_name)) if not trainable: param[1].requires_grad = False + else: # Fixed this bug. We need to set new parameters to train trainable. + param[1].requires_grad = True def set_log_dir(self, model_path=None): """Sets the model log directory and epoch counter.