Skip to content

Add LoRA Training Example to JAX Examples #186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
nikolasavic3 opened this issue Mar 21, 2025 · 0 comments · May be fixed by #187
Open

Add LoRA Training Example to JAX Examples #186

nikolasavic3 opened this issue Mar 21, 2025 · 0 comments · May be fixed by #187

Comments

@nikolasavic3
Copy link

nikolasavic3 commented Mar 21, 2025

I would like to contribute a new example that demonstrates how to implement Low-Rank Adaptation (LoRA) for fine-tuning language models using JAX and Flax.

Why this might be useful

LoRA is one of the most popular optimization techniques, and people utilizing JAX value optimization. I've noticed that implementing LoRA in JAX isn't quite straightforward. Someone new to JAX would have to search through documentation and GitHub issues to figure this out.

Implementation

The example will build on the JAX for LLM pretraining tutorial and will compare a model trained using that approach against one trained using LoRA.

I have a draft implementation ready to submit.

@nikolasavic3 nikolasavic3 linked a pull request Mar 22, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant