Skip to content

[Transform] Factory classes with shared memory and offloading #316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: kylesayrs/transform_utils
Choose a base branch
from

Conversation

kylesayrs
Copy link
Contributor

@kylesayrs kylesayrs commented May 15, 2025

Purpose

  • Support applying transforms to models while maintaining maximal shared memory across transform applications

Prerequisites

Transform Factory and Submodules

  • Transform weights are cached using the ParameterizedDefaultDict, typically keyed by matrix size, dtype, and device
  • Transforms are defined as submodules which apply their weights to a module
    • Transforms are applied to modules either through hooks (activations) or updating the module’s weight
  • Transforms are added as submodules to the modules they are applied to
    • This enables easy serialization with a lora-style checkpoint structure
    • This allows easy utilization of pytorch's parameterization feature for learning transform during training
    • If the parent module is offloaded, then the transform submodule also becomes offloaded

Transform Types

TransformFactory

  • Base transform factory class
  • Each transform factory corresponds to one transform scheme in the transform config
  • Subclasses must implement create_transform and are responsible for caching/sharing memory
  • Implements apply_to_model, which applies transform created via create_transform to target modules in the model

HadamardFactory

RandomMatrixFactory

  • Creates random matrix transforms
  • Weights are cached according to matrix size, dtype, and device
  • In order to save runtime, inverse are also cached

RandomHadamardFactory

  • Subclasses HadamardFactory, but uses random_hadamard_matrix to create matrix weights

Testing

  • Add tests for correctness when applying transforms across models
  • Add tests for transform weight memory sharing (both with and without offloading)

@kylesayrs kylesayrs changed the base branch from main to kylesayrs/transform-config May 21, 2025 14:15
Base automatically changed from kylesayrs/transform-config to main May 28, 2025 19:42
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
@kylesayrs kylesayrs changed the base branch from main to kylesayrs/transform-accelerate-utilities May 30, 2025 19:23
Signed-off-by: Kyle Sayers <[email protected]>
@kylesayrs kylesayrs changed the base branch from kylesayrs/transform-accelerate-utilities to kylesayrs/transform_utils May 30, 2025 19:56
@kylesayrs kylesayrs changed the title [WIP] Transform Factory [Transform] Factory classes with offloading support May 30, 2025
@kylesayrs kylesayrs force-pushed the kylesayrs/transform_factory branch from 7b36d2e to 809e367 Compare May 30, 2025 20:06
@kylesayrs kylesayrs changed the title [Transform] Factory classes with offloading support [Transform] Factory classes with shared memory and offloading May 30, 2025
@kylesayrs kylesayrs marked this pull request as ready for review May 30, 2025 20:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant