In this blog, and in our Discord channel, we discuss training in detail. A topic that is often overlooked, is how to grow a model. Especially in incentivized, collaborative and distributed training, this is a key ingredient. This post explores the concept of model growth with concrete Python code examples and provides some tweaks to transformers that are instrumental to allow gradual model growth.

TL;DR

You should really read the whole article. But if you don’t have the time: we improved Llama and Phi3 to support model growth in the hidden size dimension without incurring loss increase. So we can now take a 7B model, turn it into a 14B model, train it, turn that into a 28B model, train it, and so on. This allows us to collectively train ever larger models without needing the massive compute to start from scratch.

Why grow? Isn’t training enough?

An LLM model basically is a collection of weights (a fancy word for numbers), organized in matrices and some vectors, plus some code to run the model, that is, do inference. When training the model on some dataset, the weights are adjusted so that the model predicts samples of the dataset ever more accurately. It is easy to see that given a finite amount of weights, a finite amount of knowledge is contained in the model, and the loss value that can be achieved, is bounded. Once some lower loss limit is slowly approached, enlarging the model is a good way to increase capacity and allow for training to an even lower loss value.

Why grow? Can’t we just start over, but bigger?

Well, yes, you can simply start over, but it is expensive and requires ever more resources to start from scratch every time you increase the model size. We have not yet found any research indicating that restarting from scratch would lead to a better model, than growing a big one, so economically speaking, the growth strategy seems best.

Is this about the 1B, 7B, 70B, 400B, … that is mentioned sometimes?

Yes, exactly. For example, a 1.6B model is a model containing (roughly) 1.6 billion parameters:

$ python -i
import transformers, torch, numpy, copy
torch.set_default_device('cuda:0')
model = transformers.AutoModelForCausalLM.from_pretrained(
   pretrained_model_name_or_path='stabilityai/stablelm-2-1_6b',
   attn_implementation="flash_attention_2",
   torch_dtype=torch.bfloat16
)
model.num_parameters()
1644515328

You can see how these weights are divided over various parts of the model:

model
StableLmForCausalLM(
  (model): StableLmModel(
    (embed_tokens): Embedding(100352, 2048)
    (layers): ModuleList(
      (0-23): 24 x StableLmDecoderLayer(
        (self_attn): StableLmFlashAttention2(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (attention_dropout): Dropout(p=0.0, inplace=False)
          (rotary_emb): StableLmRotaryEmbedding()
        )
        (mlp): StableLmMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
    )
    (norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=2048, out_features=100352, bias=False)
)

This shows two matrices of 2048*100352 (embed_tokens and lm_head) and 24 layers, each containing an attention block of four 2048*2048 matrices, an MLP block of three 2048*5632 matrices. This brings the total to:

2 * 2048*100352 + 24 * (4 * 2048*2048 + 3 * 2048*5632)
1644167168

So still 348160 short of the earlier number; this is due to normalization vectors called LayerNorm and bias vectors. This model has one final LayerNorm and two LayerNorm per layer, plus three bias vectors per layer (see bias=True on q/k/v_proj). The LayerNorm consists of a bias and a weight vector, each 2048 long. This all amounts to:

2 * 2048 + 24 * (2*2*2048 + 3*2048)
348160

and all weights are accounted for! A shorthand to list all parameters and their shape and size, is this:

for n,p in model.named_parameters():
   print(f'{n:45s} {str(list(p.shape)):15s} = {numpy.prod(p.shape)}')

model.embed_tokens.weight                     [100352, 2048]  = 205520896
model.layers.0.self_attn.q_proj.weight        [2048, 2048]    = 4194304
model.layers.0.self_attn.q_proj.bias          [2048]          = 2048
model.layers.0.self_attn.k_proj.weight        [2048, 2048]    = 4194304
model.layers.0.self_attn.k_proj.bias          [2048]          = 2048
model.layers.0.self_attn.v_proj.weight        [2048, 2048]    = 4194304

etc

So how do I grow a model?

Growing a model simply means making the model larger on one or more dimensions. The following dimensions are available on Llama and StableLM models; for every item the value of the model shown above is listed:

  • Data type, or dtype (BFloat16, so 16 bits, or 2 bytes per element)
  • Embedding size (100352)
  • Intermediate (MLP) size (5632)
  • Number of layers (24)
  • Hidden size (2048)
  • Number of attention / key-value heads (32)
  • Bias enabled True/False (2048*2 per item; note that this adds a marginal number of weights and cannot always be configured for every item)

In principle, each of these can be changed to reshape the model and increase or decrease model size. In the remainder of this post, both the number of key-value heads and bias enabled True/False are ignored.

Five dimensions to grow or reduce in

Dtype: change from 16 bit to 32 bit, bigger and slower

For some items, such as the dtype, the change is trivial – you can convert the model to 32 bit, doubling the size on disk, and allowing for more precision during training and inference. The 16 bit values can be converted to 32 bit and the performance of the model will be equal. Consensus seems to be that the speed gained by using 16 bit over 32 bit is by far worth the loss in precision, so we prefer to keep the model in BFloat16. Note that while 32 bit calculations are not only intrinsically slower than 16 bit calculations on most GPU platforms, they also prohibit the use of Flash Attention 2, which is fast, but currently not available for 32 bit float.

Embedding size: drop rare or unused tokens, or add more

Changing the embedding size would also involve changing the tokenizer, as the number of tokens directly determines the embedding size. For every token removed, you gain twice the hidden size. Removing unused tokens will not impact model loss. Changing the tokenizer to remove tokens that used to be in use, will increase loss on some samples; the model needs to be trained to learn that the content represented by these tokens, is now represented differently. The same holds for adding tokens to maybe better match the dataset and achieve lower loss values.

Intermediate (MLP) size: simply increase matrix sizes

Changing the intermediate (or MLP) size is trivial. The matrices shaped hidden_size*intermediate_size will simply increase in one direction. If you would look at the actual matrix calculations performed in the model (this line: transformers/models/llama/modeling_llama.py#L311), it is easy to see that there is no impact on the loss if the newly added space is filled with zeros.

Number of layers: just add more layers

Adding layers is equally trivial. Just add an empty layer (that is, a layer with all matrices zeroed out) and the model will perform identically to the original model. If you would look at the actual data being passed around it is clear that a layer that does nothing, just passes on the hidden state to the next layer. (starting from this line, see how residual passes on the state; if hidden_states is zero, nothing is done by the layer: transformers/models/llama/modeling_llama.py#L724)

Hidden size: things get complicated

Increasing the hidden state size is where things start to hurt. Badly. This is all due to this line (transformers/models/llama/modeling_llama.py#L125):

variance = hidden_states.pow(2).mean(-1, keepdim=True)

Although this line looks innocent, it ties the model weights as trained to the hidden state size. Simply increasing the hidden state size and keeping added weights at zero will change the variance value calculated here dramatically, as this line can be stated equivalently as:

variance = hidden_states.pow(2).sum(-1, keepdim=True)/hidden_size

Where hidden_size of course increases while the sum stays the same if the newly added weights are zero. There seems to be an escape, looking at the context of the quoted line:

def forward(self, hidden_states):
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
    return self.weight * hidden_states.to(input_dtype)

The escape here would be to adjust self.weight; that is, the contents of the LayerNorm items we saw before, to scale the value up by an appropriate factor. This kind-of works, but at every step a small amount of numerical error creeps in, eventually adding up to a possibly significant impact on loss. We have tried to do this adjustment and then train, to see whether the model would fall in line again, but that did not work out for all models.

Let’s grow a model!

As promised, growing a model is trivial:

config = copy.deepcopy(model.config)
config.intermediate_size = 10000
config._attn_implementation = 'flash_attention_2'
tp = type(model)
biggermodel = tp(config).to(dtype=torch.bfloat16)
biggermodel
StableLmForCausalLM(
  (model): StableLmModel(
    (embed_tokens): Embedding(100352, 2048)
    (layers): ModuleList(
      (0-23): 24 x StableLmDecoderLayer(
        (self_attn): StableLmFlashAttention2(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (attention_dropout): Dropout(p=0.0, inplace=False)
          (rotary_emb): StableLmRotaryEmbedding()
        )
        (mlp): StableLmMLP(
          (gate_proj): Linear(in_features=2048, out_features=10000, bias=False)
          (up_proj): Linear(in_features=2048, out_features=10000, bias=False)
          (down_proj): Linear(in_features=10000, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
    )
    (norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=2048, out_features=100352, bias=False)
)
biggermodel.num_parameters() - model.num_parameters()
644087808
24 * 3 * 2048*(10000-5632)
644087808

The last line shows how the model increased exactly the expected number of elements needed to grow the three MLP matrices for every layer. Now we can compare the performance of the models:

tokenizer = transformers.AutoTokenizer.from_pretrained('stabilityai/stablelm-2-1_6b')
inputs = tokenizer("The first letter of the alphabet is a.", return_tensors="pt")
labels = inputs['input_ids'].clone()
outputs = model(inputs['input_ids'], labels=labels)
outputs.loss.item()
3.40625

Doing the same on biggermodel as just created, yields a huge loss value, but that is because it contains random data. We need to initialize the weights, copying them from model like this:

for fr,to in zip(model.parameters(),biggermodel.parameters()):
   to.data[:] = 0
   idx = [slice(n) for n in fr.data.shape]
   to.data[idx] = fr.data[:]

biggermodel(inputs['input_ids'], labels=labels).loss.item()
3.40625

So we succeeded in creating a bigger model with identical loss!

Not the whole story

The model just created is bigger, but actually provides no benefit compared to the original. Looking at the process of training, it is easily seen that any zero valued element will generally have a zero gradient, and will therefore never change during training; the newly added values will forever stay zero. We enlarged a model but did not yet allow the newly added weights to be trained. This is where we leave the reader to develop their own ideas about what initialization strategy works best. The tradeoff generally is: either start with large random numbers, and you train faster (because the gradient is bigger), but you also incur a severe loss penalty, or start with small random numbers, and train slowly, but you don’t incur a large loss penalty. Of course there are an infinite number of strategies to grow in small steps, to freeze parts of the model, to amplify (subsets of) gradients, …

One more thing…

As explained the hidden size is the one problematic dimension where we cannot simply grow the model while keeping the loss equal. After careful consideration we believe that this is fixed by tweaking the offending line and adding a model parameter norm_divisor:

--- variance = hidden_states.pow(2).mean(-1, keepdim=True)
+++ variance = hidden_states.pow(2).sum(-1, keepdim=True)/self.norm_divisor

The idea here is that at some point norm_divisor is frozen at the value of hidden size. From that point on, the hidden size can be increased, while keeping the loss equal. This does however amplify the “normalized” vectors with a value that linearly scales with the hidden size. However, as the vectors are multiplied by a trained weight vector, immediately after normalization, we feel confident that this scaling is simply corrected in the training process. We prefer this solution over compensating the weights as it is cleaner and guarantees that loss is not increased.

Breaking compatibility?

A model thus trained, using the new config value norm_divisor, is forever incompatible with regular Llama models. As explained above, it is possible to scale the weights of the LayerNorm items so that the norm_divisor is equal to hidden size again, but this does incur additional loss. One could imagine that the LayerNorm can be changed to 32 bit, with a penalty added to the training process so that the weights of LayerNorm are steered toward values that nicely land on BFloat16 values, so that, eventually, a model can be converted back to being a regular Llama model, without significant loss increase.

What you will see in practice:

The validator code of SN29 now understands this added norm_divisor configuration value for Llama and Phi3 models. This means that miners can now grow these models, including in the hidden size dimension, without incurring loss due to model growth. Anyone using these models, for inference and training, should be aware that the line quoted above needs to be changed. If you run pip install -e on our validator code base, you can simply import transformers_llama, transformers_phi3 and enjoy this feature, as well as sliced Llama and Phi3 – about which a separate blog post will follow.

About the models we recently published in SN29

The current parameter limit of SN29 has been at 10.5B since August 21, so more than a month ago. As miners have still not yet uploaded successful models using the full parameter space available, we considered it was time to step in and publish good models using the full space. Note that these models don’t use the norm_divisor feature outlined above. Miners can take these models and train them to new lows. By publishing this article, we hope to enable miners to grow their models as soon as we increase the parameter limit again. As always, we are more than happy to help out and discuss code and strategies in our Discord channel.


0 Comments

Leave a Reply

Avatar placeholder

Your email address will not be published. Required fields are marked *