Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

logger_connector has edge case where step can be a float #20692

Open
ryxli opened this issue Apr 2, 2025 · 0 comments
Open

logger_connector has edge case where step can be a float #20692

ryxli opened this issue Apr 2, 2025 · 0 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x ver: 2.5.x

Comments

@ryxli
Copy link

ryxli commented Apr 2, 2025

Bug description

causes issue in mlflow: mlflow/mlflow#15180

There is some logic in logger_connector where when step is part of the recorded metrics, it can be used to pass into logger.log_metrics:

ref

        # turn all tensors to scalars
        scalar_metrics = convert_tensors_to_scalars(metrics)  <-- 'step' casted to float tensor gets casted to float scalar

        if step is None:
            step = scalar_metrics.pop("step", None)
        ...

        if step is None:
            # added metrics for convenience
            scalar_metrics.setdefault("epoch", self.trainer.current_epoch)
            step = self.trainer.fit_loop.epoch_loop._batches_that_stepped

        # log actual metrics
        for logger in self.trainer.loggers:
            logger.log_metrics(metrics=scalar_metrics, step=step)  <-- pass down to mlflow here
            logger.save()

The metric always gets casted to a float here

        if self.is_tensor:
            value = cast(Tensor, value)
            dtype = _get_default_dtype()
            if not torch.is_floating_point(value):
                warning_cache.warn(
                    # do not include the value to avoid cache misses
                    f"You called `self.log({self.meta.name!r}, ...)` in your `{self.meta.fx}` but the value needs to"
                    f" be floating to be reduced. Converting it to {dtype}."
                    " You can silence this warning by converting the value to floating point yourself."
                    " If you don't intend to reduce the value (for instance when logging the global step or epoch) then"
                    f" you can use `self.logger.log_metrics({{{self.meta.name!r}: ...}})` instead."
                )
                value = value.to(dtype)
            if value.dtype not in (torch.float32, torch.float64):
                value = value.to(dtype)

This can cause some issues, for example, in NeMo which uses this logic to pass the step to logger.log_metrics(metrics=scalar_metrics, step=step):

    @override
    def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
            ...
            self.lightning_module.log(
                "step",
                self.trainer.global_step,
            )
            ...

Aware that you're not exactly supposed to use lightning_module.log(value) for integers, but this is something done above in NeMo (#18739).

I'm not exactly clear on why this is done, but have seen in the past where in the fallback case, self.trainer.fit_loop.epoch_loop._batches_that_stepped does not always reflect trainer.global_step, which is why I'm assuming that logic is included:

        if step is None:
            # added metrics for convenience
            scalar_metrics.setdefault("epoch", self.trainer.current_epoch)
            step = self.trainer.fit_loop.epoch_loop._batches_that_stepped

What version are you seeing the problem on?

v2.5, v2.4

How to reproduce the bug

This does end up causing issues in downstream loggers such as mlflow:

    @staticmethod
    def _get_metric_from_line(metric_name, metric_line, exp_id):
        ...
        step = int(metric_parts[2]) if len(metric_parts) == 3 else 0
        return Metric(key=metric_name, value=val, timestamp=ts, step=step)


But seems like it should be fixed in lightning instead since the [function signature](https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/loggers/mlflow.html#MLFlowLogger.log_metrics) of `log_metrics` is indeed an int.

Error messages and logs

No response

Environment

No response

More info

Seems like an easy fix to do a double check or cast to in in logger_connector:

        if step is None:
            step = int(scalar_metrics.pop("step", None))
@ryxli ryxli added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x ver: 2.5.x
Projects
None yet
Development

No branches or pull requests

1 participant