Skip to content

LoggerConnector Refactor #7183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
4 tasks
tchaton opened this issue Apr 23, 2021 · 5 comments · Fixed by #7631
Closed
4 tasks

LoggerConnector Refactor #7183

tchaton opened this issue Apr 23, 2021 · 5 comments · Fixed by #7631
Assignees
Labels
design Includes a design discussion feature Is an improvement or enhancement help wanted Open to be worked on refactor
Milestone

Comments

@tchaton
Copy link
Contributor

tchaton commented Apr 23, 2021

🚀 Feature

Motivation

Pitch

The LoggerConnector Logic is pretty opaque and hard to follow.
The EpochResultStore and HookResult add an extra layer of complexity and the tests are possibly too sparse to catch wrong behaviours.

One of the reason of the complexity is the non-uniformity of the stored logged data.

Description of internal functionalities:

  1. An EpochResultStore is create for each Trainer RUNNING_STAGE.
    2 . A new Result Object is created when running a new hook.
    Result Object are enhanced dictionary containing a mapping key - value with the extra logged meta data and inferred batch_size.
  2. Store this Result Object in the associated EpochResultStore.
    How Result Object are stored is different between TRAIN and TEST/VALIDATION making the code complex and hard to follow.
  3. On batch_end: Get the latest stored Result Object and provide its values to logger and progress_bar based on meta.
  4. On epoch_end: Reduce the values and provide them to logger and progress_bar based on meta.
    As Logged value can either be a Metric or a float/tensor creating extra internal check for properly reduce on EpochEnd.

Proposition: Uniformize Logged Values to simplify storing them and reduction.

TODOs:

  • Simplify Result Object internally
  • Create 1 Result Object for the entire loop.
  • Storage: RunningStage -> hook_name -> dataloader_idx -> LoggedMetric
  • Create a LoggedTensorMetric

Here is the pseudo code for the LoggedMetric. It will wrap both Metric + tensors and greatly simplify the internal way to store information.
It would also make fault tolerant training simpler as the state could be reduced and stored/reloaded as 'weighted_mean, sum(batch_sizes)'

import torchmetrics

class LoggedMetric(torchmetrics.Metrics):

      def __init__(self, key, meta, wrap_metric: bool = False):
             self.key = key
             self.meta = meta
             self.wrap_metric = wrap_metric
             if not self.wrap_metric:
                self.add_state("value", default=torch.tensor(0))
                self.add_state("batch_sizes", defauft=[])

      def update(self, value, batch_size):
        if not self.wrap_metric:
            self.value += value
            self.batch_sizes.append(batch_size)
        else:
            if not isinstance(value, torchmetrics.Metrics):
                raise Mis...

            if not hasattr(self, "value"):
                self.value = value
            else:
                if not self.value != value:
                    raise Mis...

    def compute(self):
        if not self.wrap_metric:
            return weighted_mean(self.value, self.batch_sizes)
        else:
            return self.value.compute()

    @property
    def on_epoch(self) -> bool:
        return self.meta["on_epoch"]

    @property
    def on_step(self) -> bool:
        return self.meta["on_step"]

    ...
@tchaton tchaton added feature Is an improvement or enhancement help wanted Open to be worked on refactor labels Apr 23, 2021
@tchaton tchaton added this to the v1.4 milestone Apr 23, 2021
@Borda Borda added the design Includes a design discussion label Apr 23, 2021
@ananthsub
Copy link
Contributor

Thanks for starting this!

Simplify Result Object internally

Any more detail around what would be simplified?

@ananthsub
Copy link
Contributor

ananthsub commented May 6, 2021

@tchaton I think the main tension is that historically the Results object (and then self.log on top) are trying to do too many things in one API:

  • it accepts what to log, and we have different conditions based on whether we're logging numbers/single item tensors vs Metric objects
  • It accepts where to send results to (progress bar and/or logger => this is pointing to another spot where loggers and callbacks seem to have some overlap @awaelchli, related to [RFC] Let Logger base class inherit from Callback #6606)
  • It accepts a temporal dimension (on step vs on epoch), which includes more branching and complexity to reduce across time
  • It accepts a spatial aspect (synchronizing state across ranks per step or per epoch) which again has different behavior based on what's logged

If we don't address how to split these responsibilities up inside of the Results object/LightningModule, I think the simplification on the trainer side will be limited. Here's a writeup @maximsch2 @SkafteNicki and @Borda have

https://docs.google.com/document/d/16HwB8QGg3khnJWmpt4UOZlTi1kG8X9EmxS5aHyC8sYo/edit

@ananthsub
Copy link
Contributor

ananthsub commented May 21, 2021

@carmocca @tchaton @awaelchli what do you think of adding properties to the lightning module for train/val/test/predict(?) metrics.

I think elevating metrics to a top-level API inside the lightning module would bring a lot of benefit. Some of the pros:

  • The trainer then has direct visibility & access to these objects, which means users wouldn't need to go through self.log + the Result objects + aggregation inside of the Trainer for this. The Trainer could log these directly based on the metric state.
  • Clarifies what arguments should actually go through self.log - many of the arguments are ignored for self.log like sync_dist when a Metric instance is directly loggd.
  • This removes a dependency on self.trainer inside the LightningModule.
  • Over time, we could make calling self.log("val", <Metric>) an error.
  • Easier to reset the metrics at the start of each epoch because we have direct access
  • Strongly foundation via the API for users to create separate metric module objects per running stage, so that users don't mix up metric states across train/val/test

@carmocca
Copy link
Contributor

Answering a few of the points...

Would revamping self.log also play a role? Do we want to rethink how callbacks also log here?

Plan is for it to stay the same.

what do you think of adding properties to the lightning module for train/val/test/predict(?) metrics.

I don't like giving metrics a different treatment to tensors/numbers

Strongly foundation via the API for users to create separate metric module objects per running stage, so that users don't mix up metric states across train/val/test

We should flush any existing logging on running stage change.

@ananthsub
Copy link
Contributor

I don't like giving metrics a different treatment to tensors/numbers

self.log() already treats these differently because of differences for sync_dist. when logging a metric object, we ignore this field, don't we?

We should flush any existing logging change on running stage change.

I meant Metric state contained inside the torchmetrics.Metric module inside the lightning module. This is an example: #7520

If the metric is logged with self.log("name", metric.compute()) - we won't do the flush for them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion feature Is an improvement or enhancement help wanted Open to be worked on refactor
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants