In a nutshell
-
Alignment requires making the preferred answer more likely and making all unwanted answers less likely. Finetuning only boosts the likelihood of the preferred response.
-
3 Step process for RHLF
- Instruction Tuning
- Start with Base Pretrained LM
- Create (instruction, input, output) dataset
- Use instruction tuning to finetune the base LM to get a SFT LM
- Reward Modeling
- Start with SFT LM
- Model generates (x, y_1, y_2) dataset and have experts label y_1, y_2 as “better” or “worse”
- Use the Bradley-Terry loss function to finetune a copy of SFT LM to get a Reward Model
- RL w/ Policy Grading
- Start with SFT LM and the Reward Model
- Use PPO to finetune a copy of SFT LM by maximising expected rewards (rewards determined by the Reward Model) to get a RLHF aligned model, which requires rollouts from the policy
- Caveat: We don’t want the aligned model to deviate too much from the base model (and forget useful things like instructions) so we use a KL Divergence to keep overall changes under check
- 2 Step process for DPO
- Instruction Tuning
- Start with Base Pretrained LM
- Create (instruction, input, output) dataset
- Use instruction tuning to finetune the base LM to get a SFT LM
- Direct Preference Optimization
- Start with SFT LM
- Model generates (x, y_1, y_2) dataset and have experts label y_1, y_2 as “better” or “worse”
- Use the DPO loss function to finetune a copy of the SFT LM to get the DPO LM
Key Ideas
Alignment
- This isn’t just limited to morality, but also utility and format of generated text
Reward Modeling
- No oracle exists to determine what kind of output is more favoured, so a model needs to be used to generalize specific human feedback
Policy Grading
- RL can be used to nudge a model to always generate outputs which are highly rated (highly rewarded).
Loss function selection in Fine Tuning
- The loss function can be (maybe even should be) different between learning, fine tuning and alignment to allow for specific knowledge or behaviours to be learnt
- The factors deriving value (or determining loss from ideal behaviour) are key factors which should be accounted for in loss fn selection
Notes
LLM Alignment
- Finetuning: A way is to create a large dataset of illegal instructions + valid instructions and desired outputs
- Not most effective, is not comprehensive enough to prevent ‘jailbreak attacks’
RLHF
- Why RLHF / Limitations of SFT
- Trainer always gives postive feedback, never negative feedback for mistakes
- Multiple correct outputs (esp. creative tasks), but dataset only has one provided response
- Model does not have enough data to know what it doesn’t know. / Hard to encourage abstaining when model does not know something
- Models do not directly involve human preferences, so SFT may not impart preference information to the LLM
- With RLHF we can get a flywheel of human queries and human preferences for future fine-tuning
Why not RLHF:
- Expensive to collect human annotations
- Model loses out on creativity
Caveats:
- Reward model needs to be very carefully trained, to prevent it from learning shortcuts
- e.g. longer answers are better
- Detailed Steps
- Instruction Tuning (SFT): Start with a large pretrained instruction tuned LLM
- Have a human rank a set of sampled results for a query from the LLM i. Needs knowledgable rankers and domain experts
- Develop a reward model (Bradley - Terry pairwise preference model) with inputs:
{prompt x, output y_i}
and outputs:{scalar score}
i. Learnr(x, y)
: reward for outputy
givenx
ii. Compare pairs of items and calculateP(y_W > y_L | x)
y_W
: preferred by humans (the winner)y_L
: not preferred by humans (the loser)P(y_w > y_L | X) = exp(r(x, y_W)) / (exp(r(x, y_W)) + exp(r(x, y_L)))
(Looks like softmax!)
- Calculate a Loss function for this reward model! This then nudges the model to prefer
y_W
overy_L
Loss = -log(P(y_W > y_L))
- Which simplifies to
Loss = -log sigmoid(r(x, y_W) - r(x, y_L))
wheresigmoid(x) = 1/(1+ e^{-x})
- i.e. Reward (
R
) ofy_W
needs to be higher thany_L
- Take the copy of the fine tuned LLM, remove the last linear layers (i.e. we’re interested in the token representations from the final layer), then apply a new linear layer of
(num_samples, x) => (num_samples, 1)
on some/all of these encodings to get a score for eachx, y_W
pair. This becomes theR_pred
, which we can then attempt to make equal toR_annotated
(from the rewards given by human feedback) via the loss function defined above. This will tweak the model’s weights. - Create a pipeline where,
x -> LLM(x) -> [y_1, y_2, y_3, ...] -> Reward Model -> [r_1, r_2, r_3, ...]
!
Strategies
- “best-of-n” sampling (rejection sampling)
- a pipeline where,
x -> LLM(x) -> [y_1, y_2, y_3, ...] -> Reward Model -> [r_1, r_2, r_3, ...] -> argmax_y([r_1, r_2, r_3]) -> y_{max_r}
- pros: Simple!
- cons: Slow/Expensive :(
- a pipeline where,
- Finetune original LM on
y_{max_r}
- i.e. maximise
p(y_W | x)
- pros: Simple!
- cons: Not learning from the negative/lower reward options
- i.e. maximise
- Use RL to increase
p(y_W | x)
and decreasep(y_L | x)
by small amounts- Needed when there are some non-differentiable aspects of the loss function. In this case, it the fact that there are multiple
y
s for a singlex
and sampling is non-differentiable - pros: Learning from all types of responses
- cons: Difficult for long outputs, so cannot determine what part is incorrect + Reward model can learn incorrect behaviours.
- Needed when there are some non-differentiable aspects of the loss function. In this case, it the fact that there are multiple
- Steps (contd.)
- Create a copy of the SFT model to be a policy model -
Pi
. The base SFT model can be referred to asPi_ref
and it will be a reference model. - Create a loss function which aims to find the policy that for any given
x
gives they
which maximizes expected reward. It should be done without deviating too much from the reference model
loss = max_pi E_{x, y}[r(x,y) - beta * D_KL(pi(y|x) || pi_ref(y|x))]
-
D_KL(pi(y|x) || pi_ref(y|x))
(KL Divergence) = log( pi(w_iw_{1, …, i-1}, x)/pi_ref(w_i w_{1, …, i-1}, x))` - In practice, only a subsequence is taken (of length 0, i.e.
p(w_i|x)
) 9. Optimize with the PPO/REINFORCE/any policy grading algorithm
- Create a copy of the SFT model to be a policy model -
Direct Preference Optimization (DPO)
- Removes the need for an explict reward model
-
Does not sample outputs
y|x
from the SFT model for the RL step loss_DPO(pi_theta, pi_ref) = - E_{x, y_W, y_L} log sigmoid(beta*(log(pi_theta*(y_W|x)/pi_ref(y_W|x)) - log(pi_theta*(y_L|x)/pi_ref(y_L|x))))
Steps
- Introduce a policy
pi*
which mergespi_ref
with rewards.pi*(y|x) = (pi_ref(y|x)exp(r(x, y)/beta))/sum_y(pi_ref(y|x)exp(r(x, y)/beta))
sum_y(pi_ref(y|x)exp(r(x, y)/beta)
becomesZ(x)
and is termed as the normalizer/partition function and is over ally
forx
- Rewrite the Bradley-Terry loss fn to become a min version and substitute
Z(x)
in itloss = min_pi E_{x, y} log(Z(x)(pi(y|x)/pi_ref(y|x)exp(r(x,y)/beta)) - log Z)
- Which simplifies (with
pi*
) to becomeloss = min_pi E_{x, y} log(pi(y|x)/pi*(y|x)) - log Z
- The first term is essentially a KL Divergence, and it can be minimized (i.e
loss = -log Z
) whenpi(y|x) = pi*(y|x) = pi_ref(y|x)exp(r(x, y)/beta)/Z(x)
, which gives usr(x,y) = beta * (log(pi*(y|x)/pi_ref(y|x)) + log z)
- Since
log(z)
is intractable, we putr(y|x)
back into the Bradley Terry preference model and make a loss function out of itp(y_W > y_L |x) = sigmoid(beta*(log(pi*(y_W|x)/pi_ref(y_W|x)) - log(pi*(y_L|x)/pi_ref(y_L|x))))
loss_DPO(pi_theta, pi_ref) = - E_{x, y_W, y_L} log sigmoid(beta*(log(pi_theta*(y_W|x)/pi_ref(y_W|x)) - log(pi_theta*(y_L|x)/pi_ref(y_L|x))))
- It has no reward function in it!
- It has no need for sampling! Can just work off human annotations
Misc
- Meta’s LlaMA-2 (Language) is unaligned! The Chat model has guardrails in place.
- MT Bench performs automatic evaluation of generated text for several standardized prompts and provides an score for each response, judged by a strong LLM (LLM-as-a-judge paradigm)
Resources
- Dr. Mohit Iyyer’s UMass CS685 S24 Lecture 11
- Dr. Mohit Iyyer’s UMass CS685 S24 Lecture 12
- Direct Preference Optimization paper