You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# turn all tensors to scalars
scalar_metrics = convert_tensors_to_scalars(metrics) <-- 'step' casted to float tensor gets casted to float scalar
if step is None:
step = scalar_metrics.pop("step", None)
...
if step is None:
# added metrics for convenience
scalar_metrics.setdefault("epoch", self.trainer.current_epoch)
step = self.trainer.fit_loop.epoch_loop._batches_that_stepped
# log actual metrics
for logger in self.trainer.loggers:
logger.log_metrics(metrics=scalar_metrics, step=step) <-- pass down to mlflow here
logger.save()
if self.is_tensor:
value = cast(Tensor, value)
dtype = _get_default_dtype()
if not torch.is_floating_point(value):
warning_cache.warn(
# do not include the value to avoid cache misses
f"You called `self.log({self.meta.name!r}, ...)` in your `{self.meta.fx}` but the value needs to"
f" be floating to be reduced. Converting it to {dtype}."
" You can silence this warning by converting the value to floating point yourself."
" If you don't intend to reduce the value (for instance when logging the global step or epoch) then"
f" you can use `self.logger.log_metrics({{{self.meta.name!r}: ...}})` instead."
)
value = value.to(dtype)
if value.dtype not in (torch.float32, torch.float64):
value = value.to(dtype)
This can cause some issues, for example, in NeMo which uses this logic to pass the step to logger.log_metrics(metrics=scalar_metrics, step=step):
Aware that you're not exactly supposed to use lightning_module.log(value) for integers, but this is something done above in NeMo (#18739).
I'm not exactly clear on why this is done, but have seen in the past where in the fallback case, self.trainer.fit_loop.epoch_loop._batches_that_stepped does not always reflect trainer.global_step, which is why I'm assuming that logic is included:
if step is None:
# added metrics for convenience
scalar_metrics.setdefault("epoch", self.trainer.current_epoch)
step = self.trainer.fit_loop.epoch_loop._batches_that_stepped
What version are you seeing the problem on?
v2.5, v2.4
How to reproduce the bug
Thisdoesendupcausingissuesindownstreamloggerssuchasmlflow:
@staticmethoddef_get_metric_from_line(metric_name, metric_line, exp_id):
...
step=int(metric_parts[2]) iflen(metric_parts) ==3else0returnMetric(key=metric_name, value=val, timestamp=ts, step=step)
Butseemslikeitshouldbefixedinlightninginsteadsincethe [functionsignature](https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/loggers/mlflow.html#MLFlowLogger.log_metrics) of `log_metrics` is indeed an int.
Error messages and logs
No response
Environment
No response
More info
Seems like an easy fix to do a double check or cast to in in logger_connector:
Bug description
causes issue in mlflow: mlflow/mlflow#15180
There is some logic in
logger_connector
where whenstep
is part of the recorded metrics, it can be used to pass intologger.log_metrics
:ref
The metric always gets casted to a float here
This can cause some issues, for example, in NeMo which uses this logic to pass the step to
logger.log_metrics(metrics=scalar_metrics, step=step)
:Aware that you're not exactly supposed to use lightning_module.log(value) for integers, but this is something done above in NeMo (#18739).
I'm not exactly clear on why this is done, but have seen in the past where in the fallback case,
self.trainer.fit_loop.epoch_loop._batches_that_stepped
does not always reflect trainer.global_step, which is why I'm assuming that logic is included:What version are you seeing the problem on?
v2.5, v2.4
How to reproduce the bug
Error messages and logs
No response
Environment
No response
More info
Seems like an easy fix to do a double check or cast to in in logger_connector:
pytorch-lightning/src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py
Line 109 in ca13f77
The text was updated successfully, but these errors were encountered: