You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I would like to contribute a new example that demonstrates how to implement Low-Rank Adaptation (LoRA) for fine-tuning language models using JAX and Flax.
Why this might be useful
LoRA is one of the most popular optimization techniques, and people utilizing JAX value optimization. I've noticed that implementing LoRA in JAX isn't quite straightforward. Someone new to JAX would have to search through documentation and GitHub issues to figure this out.
Implementation
The example will build on the JAX for LLM pretraining tutorial and will compare a model trained using that approach against one trained using LoRA.
I have a draft implementation ready to submit.
The text was updated successfully, but these errors were encountered:
Uh oh!
There was an error while loading. Please reload this page.
I would like to contribute a new example that demonstrates how to implement Low-Rank Adaptation (LoRA) for fine-tuning language models using JAX and Flax.
Why this might be useful
LoRA is one of the most popular optimization techniques, and people utilizing JAX value optimization. I've noticed that implementing LoRA in JAX isn't quite straightforward. Someone new to JAX would have to search through documentation and GitHub issues to figure this out.
Implementation
The example will build on the JAX for LLM pretraining tutorial and will compare a model trained using that approach against one trained using LoRA.
I have a draft implementation ready to submit.
The text was updated successfully, but these errors were encountered: