Hu et al., 2021
paper: https://arxiv.org/pdf/2106.09685.pdf
code: https://github.com/microsoft/LoRA.
unstructured notes
- lora is a method for fine tuning llms without explicitly varying model parameters
- fine tuning is a method that is used to get a pre-trained base mode to perfrom well at a more specific task/domain
- lora is efficient and can be modulated
- the same model can be used to train multiple lora modules (matrices) which change model functionality
- lora hopes to find an optimal $\Theta$ where $\Theta$ is applied to the initial params of the model ($\Phi_0$) in order to achive maximum performance on the downstream task
- a downstram task is a specialization that the model is being tuned to perform
- existing methods for adaptive fine-tuning (before lora) are non-optimal and either introduce latency at inference time or are difficult to optimize and limit the effectiveness of the tuning process
- past research has shown that the number of params needed to adapt a pre trained lm to a new task is significantly lower than the total num of params (lms have low intrinsic dimension)
- the authors speculate that the rank the matrix that updates the weights of the lm’s params is also low (low intrinsic rank), this allows them to store the components of this matrix ($A$ and $B$) instead of the matrix itself ($\Delta W=BA$)
- when applying lora to all weight matrices in the pretrained lm and setting the rank appropriately, training lora (roughly) converges to training the llm
- deploying a lora adaptation is a simple as deploying $W=W_0+BA$ where $W_0$ are the params of the pre-trained model
- switching an adaption is as simple as subtracting one lora modula and adding another $(W^’=W_0-BA+B^’A^’)$
- no additional inference latency is added
- in the context of transformer architecture, lora is only applied to the attension matrices in the MHA layers. explicitly $(W_q,W_k,W_v,W_o)$
- lora can reduce VRAM usage by up to $\frac{2}{3}$
- lora also drastically reduces the amount of data needed to be stored at ‘checkpoints’ while tuning the model, sometimes up $10000\times$
- lora training can be $25\%$ faster than full fine tuning
- when performing batch inference with different inputs requiring different lora modules it is non-trivial to compute results in a single forward pass without introducing extra latency