Aligning TinyStories

cat

Ok, so tinystories is a fantastic paper that shows how a small transformer model can be trained to produce coherent stories. Their trick was to carefully curate training data by synthetically generating it (using GPT). It worked!

This got me thinking, is data still necessary when training LLMs?

Probably not.

All we really need are gradients - data is just a easy way to get them, but there are other ways.

Rather than training on data (labeled or unlabeled), we can train with another system that gives feedback. This could be a simple function that evaluates the model state or action, perhaps from a simulator, or it could be a deep-learning model specifically trained to give feedback - Reinforcement Learning (RL).

RL is making waves in AI research, especially around alignment.

OpenAI use RLHF to align GPT. Meta use it to fine-tune Llama2.

But, RLHF is pretty complex. We have to train a model, just to train a model. Huggingface have a nice library trl - but it’s high level, and I want to understand the mechanics.

So, let’s do something simpler to intuitively understand how to use RL on a language model.

Let’s try to align tinystories - to only tell stories about cats

https://github.com/pHaeusler/tinycatstories

We can always prompt

So tinystories is an autoregressive language model (GPTNeo). This means we can “influence” it (condition it) by simply prompting. If we start all stories with “Once upon a time there was a cat”. Then we’ll probably get cat stories.

This isn’t enough, we want to fundamentally change the behavior of the LLM regardless of prompting. It should always talk about cats, even if it’s told not to.

Embedding loss

The default approach would be to fine-tune the model with some cat-based stories. We could even do what the tinystories authors did, and generate these stories using GPT. That’s not very interesting though. We want RL.

We want a simple way to check if the story is about a cat, and use that for the loss.

Let’s use embeddings. We can embed the word “cat” and compare this to the embedding of a generated story. The difference (dot product similarity) is our loss - we can minimize that.

from sentence_transformers import SentenceTransformer, util

embedding_model = SentenceTransformer("all-MiniLM-L6-v2").to("cuda")
reference_embedding = embedding_model.encode("cat", convert_to_tensor=True)

def compute_rewards(sequences):
    sequence_embeddings = embedding_model.encode(sequences, convert_to_tensor=True)
    cosine_similarities = util.pytorch_cos_sim(
        reference_embedding.unsqueeze(0), sequence_embeddings
    ).squeeze()
    return cosine_similarities

Conceptually, we want to rollout stories, evaluate them with our reward function (the embedding similarity), compute a loss, back-propagate, and learn!

First we setup our model, tokenizer, and a starting prompt.

We could use any embedding model - let’s use EleutherAI/gpt-neo-125M

NUM_EPOCHS = 100
BATCH_SIZE = 10
NUM_TOKENS = 250

model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-33M").to("cuda")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
optimizer = AdamW(model.parameters(), lr=1e-5)

prompt = "Once upon a time there was"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")

Then let’s implement the training loop.

For efficiency we rollout stories in batches on the GPU. This is done auto-regressively. We start with the prompt padded with end of sequence tokens in output_ids. Then we run the model and write each generated token to the output_ids tensor NUM_TOKENS times.

Each generation involves sampling from the output logits of the model.

To build up the joint probability of the generated story we accumulated the log probability of each selected token.

Once complete, we compute the reward (the embedding similarity) for each sequence.

The final loss is the sum of the negative log probabilities for the story multiplied by the reward.

The log probabilities represent how likely a generated sequence is according to the model. The more probable a sequence, the higher its log probability. This will be negative number. Taking the negative turns this value into a cost. We weight the cost by the reward and use back-propagation to minimize it.

model.train()
for epoch in range(NUM_EPOCHS):
    output_ids = torch.full((BATCH_SIZE, NUM_TOKENS), tokenizer.eos_token_id, device="cuda")
    output_ids[:, : input_ids.shape[1]] = input_ids

    log_probs_accumulated = torch.zeros((BATCH_SIZE, 1), device="cuda")
    for i in range(input_ids.shape[1], NUM_TOKENS):
        prompt = output_ids[:, :i].clone()
        logits = model(prompt).logits[:, -1, :]
        probs = torch.nn.functional.softmax(logits, dim=-1)
        dist = Categorical(probs)
        next_tokens = dist.sample()
        log_probs_accumulated += dist.log_prob(next_tokens).unsqueeze(-1)
        output_ids[:, i] = next_tokens

    # Compute rewards for the entire batch
    sequences = [tokenizer.decode(input_id, skip_special_tokens=True) for input_id in output_ids]
    rewards = compute_rewards(sequences)

    # Compute loss for the entire batch
    normalized_log_probs = log_probs_accumulated / NUM_TOKENS
    loss = (-normalized_log_probs * rewards.unsqueeze(-1)).mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

For reference, this is a variant of the REINFORCE algorithm

Now let’s train

I trained with the following on 1x A100 80Gb. Took 1-2mins

NUM_EPOCHS = 100
BATCH_SIZE = 1000
NUM_TOKENS = 10
LR = 2e-6

loss reward

We have a steady 0.6 similarity between our generated stories and the word “cat” - excellent!

Let’s try it!

Once upon a time there was a cat. The cat was a very polite cat and always said "please" and "thank you".
One day, the cat looked up and saw a big bird. The bird was very big and it had a long, pointy beak.
The cat thought it looked like a big cat and started to tease the bird. The bird didn't like this and so he flew away. The cat felt very bad.
From then on, the cat was always polite and never bothered the big bird.
His nice cat learned that teasing others was not nice and it made everyone feel bad too.
Once upon a time there was a cat. The cat was very messy. The cat always wanted to be clean.
So one day, the cat went outside and saw a nice spot. It was the perfect place to bathe.
The cat jumped right into the cool water and swam around.
After a while, it got to the spot and shook itself dry.
Then, it started to look around. It saw a small rock and jumped out of the water to take a bath.
After the cat had bathed, it felt much better.
Then, the cat ran around looking for something else to do. And it never forgot to bathe again.

It works!

But some of the original tinystories ability has degraded - aligning is hard

For example - the formatting of stories is a bit broken. It’s not splitting paragraphs. It also says cat a lot, maybe too much.

KL-Divergence

To help prevent this we can tie the optimized model to a reference model (that we don’t touch). The intuition here is that we want to change the behavior of the model - make it tell cat stories - but not break the model.

We can easily break the model with only a reward function. Some of my first training runs produced models that only said “cat” - literally that’s all they would output. Which makes sense, given the reward function.

So to control the alignment, we add a reference model.

We want to measure how different the output probability distributions are between the models when generating stories. We do this by computing the KL-divergence.

This is similar to how PPO works - the algorithm behind RLHF. There is some clipping, and more optimization steps, but the intuition is the same.

We simply add this to the loss function and update the training loop.

meme

ref_model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-33M").to("cuda")

kl_div_accumulated = torch.zeros((BATCH_SIZE, 1), device="cuda")

for i in range(input_ids.shape[1], NUM_TOKENS):
    prompt = output_ids[:, :i].clone()
    logits = model(prompt).logits[:, -1, :]
    probs = torch.nn.functional.softmax(logits, dim=-1)
    dist = Categorical(probs)
    next_tokens = dist.sample()
    log_probs_accumulated[active_mask] += dist.log_prob(next_tokens).unsqueeze(-1)
    output_ids[:, i] = next_tokens

    # Compute reference model
    ref_logits = ref_model(prompt).logits[:, -1, :]

    # Compute KL Divergence
    kl_div = torch.nn.functional.kl_div(
        torch.nn.functional.log_softmax(logits, dim=-1),
        torch.nn.functional.log_softmax(ref_logits, dim=-1),
        reduction="none",
        log_target=True,
    )
    kl_div_accumulated += kl_div.mean(dim=-1).unsqueeze(-1)

...

normalized_log_probs = log_probs_accumulated / NUM_TOKENS
neg_advantage = (-normalized_log_probs * rewards.unsqueeze(-1)).mean()
loss = neg_advantage + KL_FACTOR * normalized_kl_div.mean()

KL_FACTOR is how much we weight the KL-divergence. If we weight it a lot, we won’t be able to optimize away from the reference model. If we don’t weight it much, it won’t help control the alignment, and the model might cat-astrophically break.

Let’s train!

NUM_EPOCHS = 100
BATCH_SIZE = 1000
NUM_TOKENS = 10
LR = 2e-6
KL_FACTOR = 6000

loss reward kl

It pretty much worked!

We don’t have the same beautiful decreasing loss curve, but it makes sense. We are now penalizing divergence from the reference model - it’s hard to talk only about cats, while not moving away from the reference model!

Let’s try

Once upon a time there was a small cat. He was very happy.
One day he was very tired from playing in the sun. He went to find a soft place to rest.

When he looked and looked, he could see lots of leaves on the ground.
They were colorful and made the cat look very cosy. The cat started to cover himself in the neat pile of leaves.
He curled up and slept until the sun was high in the sky.

The next morning, when the cat opened his eyes, he was not so tired anymore. He was ready for a new day.
He thanked the clouds for making the day so special and began to play in the leaves again.

The end.

It’s better!

Probably!

https://github.com/pHaeusler/tinycatstories


LLMAI

1495 Words

Aug 14, 2023