As explained in our previous article, growing a model in most dimensions in quite simple, but increasing the hidden size comes with a few problems. This article dives deep and shows how it can be done.

## Why, again?

We can simply grow a model’s MLP (intermediate size) or number of layers, without altering loss. This can be done by increasing matrix sizes and padding with zeros (for MLP) or adding all-zero matrices even (for additional layers). This will, however, lead to a model that is not in balance, if we don’t also grow in the hidden size dimension. We assume this is not optimal for training. The underlying mathematical implementation of a model makes increasing the hidden size dimension, while keeping loss equal, far more complicated than increasing size in other dimensions. Some more things need to change along with the hidden size.

*Sidenotes:*

*The Python code in this article can be easily selected for copy&paste; the prompt, output and comments are excluded from selection.**This article uses a StableLm model as an example, but applies,*mutatis mutandis*, to similar models.**Growing a model with zeros is just step one; the new parameters need non-zero initialization to be trainable.**This article doesn’t explain how to handle models with*`num_key_value_heads != num_attention_heads`

.

## A quick recap

Our previous article already explained that one problem lies in the line doing `.mean()`

here, in `LayerNorm`

code taken from Llama:

```
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
```** variance = hidden_states.pow(2).mean(-1, keepdim=True)
** hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

This is because this line can be written as:

`variance = hidden_states.pow(2).sum(-1, keepdim=True)/hidden_size`

and it is easy to see that adding zeros keeps the sum equal, but increasing hidden size reduces the variance value calculated, directly impacting the values returned by this `forward()`

, by (implicitly) scaling them with a factor of `sqrt(hidden_new/hidden_old)`

.

## Solving the problem within existing model code

We think it is best to solve this by modifying the code, but that breaks compatibility. To solve it, approximately, within existing model code, the values of `self.weight`

of `LayerNorm`

objects may be scaled appropriately. Another thing to watch out for is that the `q_proj`

, `k_proj`

and `v_proj`

matrices are actually used (indirectly) as 3D matrices, with one dimension of `hidden_size`

reshaped in two dimensions of sizes `head_dim`

and `num_key_value_heads`

, with `head_dim = hidden_size / num_key_value_heads`

. We will get back to that.

Let’s first naively create a bigger model and see what happens:

```
$ python -i
``` import transformers, torch, numpy, copy, math, subprocess
torch.set_default_device('cuda:0')
torch.set_grad_enabled(False)
model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path='stabilityai/stablelm-2-1_6b',
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16
)
model.num_parameters()
1644515328
config = copy.deepcopy(model.config)
config.hidden_size = 2560
config._attn_implementation = 'flash_attention_2'
tp = type(model)
biggermodel = tp(config).to(dtype=torch.bfloat16)
biggermodel
StableLmForCausalLM(
(model): StableLmModel(
(embed_tokens): Embedding(100352, 2560)
(layers): ModuleList(
(0-23): 24 x StableLmDecoderLayer(
(self_attn): StableLmFlashAttention2(
(q_proj): Linear(in_features=2560, out_features=2560, bias=True)
(k_proj): Linear(in_features=2560, out_features=2560, bias=True)
(v_proj): Linear(in_features=2560, out_features=2560, bias=True)
(o_proj): Linear(in_features=2560, out_features=2560, bias=False)
(attention_dropout): Dropout(p=0.0, inplace=False)
(rotary_emb): StableLmRotaryEmbedding()
)
(mlp): StableLmMLP(
(gate_proj): Linear(in_features=2560, out_features=5632, bias=False)
(up_proj): Linear(in_features=2560, out_features=5632, bias=False)
(down_proj): Linear(in_features=5632, out_features=2560, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
(post_attention_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=2560, out_features=100352, bias=False)
)
biggermodel.num_parameters() - model.num_parameters()
536957952

So we succeeded in adding 537M parameters by increasing `hidden_size`

. These 537M parameters are all accounted for by simple arithmetic:

inc_lm_head = inc_embed_tokens = 100352 * (2560-2048)
inc_layer_attn = 4 * (2560**2 - 2048**2) + 3 * (2560-2048)
inc_layer_mlp = 3 * 5632 * (2560-2048)
inc_layer_norms = 2 * 2 * (2560-2048)
inc_model_norm = 2 * (2560-2048)
inc_embed_tokens + 24 * (inc_layer_attn + inc_layer_mlp + inc_layer_norms) + inc_model_norm + inc_lm_head
536957952

## Transferring parameters

The new model is initialized with random data, which clearly shows in the loss. We use the Python license text as a sample of English text for loss calculations:

tokenizer = transformers.AutoTokenizer.from_pretrained('stabilityai/stablelm-2-1_6b')
cmd = 'python -c license.MAXLINES=1<<30;license()'
license_bytes = subprocess.check_output(cmd.split(' '))
license_text = license_bytes.decode('utf-8')
inputs = tokenizer(license_text, return_tensors="pt")
labels = inputs['input_ids'].clone()
model(inputs['input_ids'], labels=labels).loss.item()
0.7734375
biggermodel(inputs['input_ids'], labels=labels).loss.item()
12.125

The low loss of the original model shows that the license text is highly predictable. The high loss of the grown model is expected, for random initialization. Now let’s plug in the parameters of the original model, padded with zeros, and see what happens to the loss:

for fr,to in zip(model.parameters(),biggermodel.parameters()):
to.data[:] = 0
idx = [slice(n) for n in fr.data.shape]
to.data[idx] = fr.data[:]
biggermodel(inputs['input_ids'], labels=labels).loss.item()
8.8125

That’s a huge loss increase; maybe better than random initialization, but nowhere near the original score. The problem lies in the way the attention blocks use their projection matrices.

## Fixing attention

The issue is that while the `q_proj`

, `k_proj`

and `v_proj`

matrices are 2D, and after projecting `hidden_states`

, result in 2D matrices `query_states`

, `key_states`

and `value_states`

, these states are subsequently used as 4D matrices.

Their shape is `(batch_size, sequence_length, num_key_value_heads, head_dim)`

, where `num_key_value_heads`

follows from config and `head_dim = hidden_size / num_key_value_heads`

. To properly grow these projection matrices without changing the resulting 4D state matrices, aside from padding some zeros, they would need to be reshaped to 3D, then grown, and shaped back.

## The equivalent attention code

To understand what transformations need to be done, the attention calculations can be analyzed. The easiest way to understand them (for StableLm) is by looking at the Python un-optimized attention implementation in class StableLmAttention, which is what is run with `attn_implementation="eager"`

. Only what happens in `forward()`

is relevant, and of that, only those lines matter that stand between incoming and outgoing `hidden_states`

, with dependence on (the shape of) `q_proj`

, `k_proj`

and `v_proj`

. Trimmed down and simplified, the equivalent code looks like this:

```
hidden_size = 2048 # becomes 2560
num_heads = 32 # from config
batch_size = 1 # to keep it simple
seq_len = ... # arbitrary value
head_dim = hidden_size//num_heads
hidden_states.shape = (batch_size,seq_len,hidden_size)
self.[qkvo]_proj.shape = (hidden_size,hidden_size)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, seq_len, hidden_size)
hidden_states = self.o_proj(attn_output)
```

The goal is to expand the projection matrices with zeros at the right places, so that `hidden_states`

is effectively expanded with zeros, both on input and output of the attention block. That way, the entire model will give the same loss for the same input (plus, of course, the numerical error originating in necessary scaling), as the other parts of the model are easily adapted to a larger `hidden_size`

by simply extending matrices with zeros.

Note that `self.*_proj`

are linear projections; they are torch objects of type torch.nn.Linear performing a matrix multiplication (with 2D *weights*) and an addition (with 1D *biases*). In the following code we will simply perform the matrix multiplications and additions, instead of instantiating these objects, for clarity.

## Cracking the transformation problem

Note that not only `hidden_size`

is increased, but also the product `num_key_value_heads * head_dim`

. As we did not modify `num_key_value_heads`

in the config, `head_dim`

increased from `2048/32=64`

to `2560/32=80`

. It is easily seen that this doesn’t play out well in the code shown above. Using smaller numbers to more clearly demonstrate the point, in a toy model:

torch.set_default_device('cpu')
hidden_size = 8
num_heads = 4
head_dim = hidden_size//num_heads
seq_len = 1
batch_size = 1
value_states = torch.arange(1,hidden_size+1)
dt = value_states.dtype
value_states.view(batch_size,seq_len,num_heads,head_dim)
tensor([[[[1, 2],
[3, 4],
[5, 6],
[7, 8]]]])
hidden_size_grown = 12
head_dim_grown = hidden_size_grown//num_heads
value_states_grown = torch.zeros(hidden_size_grown,dtype=dt)
value_states_grown[:hidden_size] = value_states
value_states_grown.view(batch_size,seq_len,num_heads,head_dim_grown)
tensor([[[[1, 2, 3],
[4, 5, 6],
[7, 8, 0],
[0, 0, 0]]]])

Comparing the two tensors shown, it is obvious that the contents of queries, keys and values are shuffled beyond recognition here, essentially because `head_dim`

has grown from `2`

to `3`

, without inserting zeros at the right place. However, if we choose to keep `head_dim`

equal, and increase `num_key_value_heads`

, things are better:

num_heads_grown = 6
head_dim_grown = hidden_size_grown//num_heads_grown
value_states_grown.view(batch_size,seq_len,num_heads_grown,head_dim_grown)
tensor([[[[1, 2],
[3, 4],
[5, 6],
[7, 8],
[0, 0],
[0, 0]]]])

This looks like it may work out, as the value (and by extension, query and key) state matrices are padded with zeros in one dimension. It is reasonable to expect that this zero padding ends up as zero padding in the resulting `hidden_states`

, even without checking all remaining steps in the attention code.

### Solution 1: increase `num_heads`

The analysis above points to one possible solution: increase `num_key_value_heads`

and `num_attention_heads`

in `model.config`

so that keys, values and queries are simply expanded with zeros as we intended in the first place. Repeating the model creation, but with two additional lines:

torch.set_default_device('cuda:0')
**config.num_key_value_heads = 40**
**config.num_attention_heads = 40**
biggermodel = tp(config).to(dtype=torch.bfloat16)
for fr,to in zip(model.parameters(),biggermodel.parameters()):
to.data[:] = 0
idx = [slice(n) for n in fr.data.shape]
to.data[idx] = fr.data[:]
biggermodel(inputs['input_ids'], labels=labels).loss.item()
0.96875

That’s a 25% hit on the loss; a lot better than random initialization but not yet close to the original score. It is close enough to assume all numbers are in the right place, so let’s see what more needs to be done.

#### Scaling `LayerNorms`

We did not account for the change in variance calculation as described above. Let’s try to fix it; in every `LayerNorm`

the variance value goes down as the `hidden_size`

goes up; the `hidden_state`

is multiplied by `torch.rsqrt(variance)`

and this should be compensated in weight:

for name,p in biggermodel.named_parameters():
if 'norm.weight' in name:
p.data[:] *= math.sqrt(2048/2560)
biggermodel(inputs['input_ids'], labels=labels).loss.item()
0.77734375

That’s a 0.5% higher loss than the original model. The exact increase of loss seems to depend on model type, geometry, and sample content, but in general you can expect a loss increase in the order of 1%. In the model code no other factors depend on `hidden_size`

or `num_key_value_heads`

, so this is as good as it gets if we don’t touch the model code.

### Solution 2: increase `head_dim`

We can also increase `head_dim`

. At bit more work, but it can be done. Picking up the small numbers example where we left off, and searching for the way we would have wanted it to look like, that is, expanding a 4D matrix with zeros and shaping it back:

torch.set_default_device('cpu')
head_dim_grown = hidden_size_grown//num_heads
value_states_4d = value_states.reshape(batch_size,seq_len,num_heads,head_dim)
shp = batch_size,seq_len,num_heads,head_dim_grown-head_dim
z = torch.zeros(shp,dtype=dt)
padded = torch.cat((value_states_4d,z),dim=3)
padded
tensor([[[[1, 2, 0],
[3, 4, 0],
[5, 6, 0],
[7, 8, 0]]]])
padded.reshape(batch_size,seq_len,hidden_size_grown)
tensor([[[1, 2, 0, 3, 4, 0, 5, 6, 0, 7, 8, 0]]])

It is immediately clear that zeros are *inserted* in `value_states`

in its lower dimensional form, when `head_dim`

is increased, and no trivial padding of zeros in `v_proj`

would do that. Expanding our toy example upward towards some hypothetical `hidden_states`

and `v_proj`

shows the way to go. Let’s first get from some `hidden_states`

to the point we ended up before, with a naively padded `v_proj`

and `v_bias`

:

hidden_states = torch.arange(1,hidden_size+1)
v_proj = torch.eye(hidden_size,dtype=dt)
v_bias = torch.ones(hidden_size,dtype=dt)*10
value_states = torch.matmul(hidden_states,v_proj.T) + v_bias
value_states
tensor([11, 12, 13, 14, 15, 16, 17, 18])
hidden_states_grown = torch.zeros(hidden_size_grown,dtype=dt)
hidden_states_grown[:hidden_size] = hidden_states
v_proj_grown = torch.zeros(hidden_size_grown,hidden_size_grown,dtype=dt)
v_proj_grown[:hidden_size,:hidden_size] = v_proj
v_bias_grown = torch.zeros(hidden_size_grown,dtype=dt)
v_bias_grown[:hidden_size] = v_bias
value_states_grown = torch.matmul(hidden_states_grown,v_proj_grown) + v_bias_grown
value_states_grown
tensor([11, 12, 13, 14, 15, 16, 17, 18, 0, 0, 0, 0])
value_states_grown.view(batch_size,seq_len,num_heads,head_dim_grown)
tensor([[[[11, 12, 13],
[14, 15, 16],
[17, 18, 0],
[ 0, 0, 0]]]])

The problem clearly lies in how `v_proj_grown`

is derived from `v_proj`

, and using the same reshape-and-extend logic as before, but now applied to `v_proj`

, we can fix it:

v_proj_3d = v_proj.reshape(num_heads,head_dim,hidden_size)
v_proj_3d
tensor([[[1, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0]],
[[0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0]],
[[0, 0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0]],
[[0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 1]]])
expand in the head_dim dimension
shp = num_heads,head_dim_grown-head_dim,hidden_size
z = torch.zeros(shp,dtype=dt)
v_proj_3d_padded = torch.cat((v_proj_3d,z),dim=1)
expand in the hidden_size dimension
shp = num_heads,head_dim_grown,hidden_size_grown-hidden_size
z = torch.zeros(shp,dtype=dt)
v_proj_3d_padded = torch.cat((v_proj_3d_padded,z),dim=2)
v_proj_3d_padded
tensor([[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]])

And now to see if it indeed produces the desired outcome:

v_proj_grown = v_proj_3d_padded.reshape(hidden_size_grown,hidden_size_grown)
v_proj_grown
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
value_states_grown = torch.matmul(hidden_states_grown,v_proj_grown.T) + v_bias_grown
value_states_grown
tensor([11, 12, 10, 13, 14, 10, 15, 16, 0, 7, 8, 0])
value_states_grown.view(batch_size,seq_len,num_heads,head_dim_grown)
tensor([[[[11, 12, 10],
[13, 14, 10],
[15, 16, 0],
[ 7, 8, 0]]]])

Almost there: because the reshaping of the projection matrices in 3D introduces new rows in the 2D form, equivalent zeros need to be inserted in the bias vectors:

v_bias_grown = torch.zeros(hidden_size_grown,dtype=dt)
v = v_bias_grown.view(num_heads,head_dim_grown)
v[:,:head_dim] = v_bias.view(num_heads,head_dim)
v_bias_grown
tensor([10, 10, 0, 10, 10, 0, 10, 10, 0, 10, 10, 0])
value_states_grown = torch.matmul(hidden_states_grown,v_proj_grown.T) + v_bias_grown
value_states_grown
tensor([11, 12, 0, 13, 14, 0, 15, 16, 0, 17, 18, 0])
value_states_grown.view(batch_size,seq_len,num_heads,head_dim_grown)
tensor([[[[11, 12, 0],
[13, 14, 0],
[15, 16, 0],
[17, 18, 0]]]])

So far, so good! We worked our way up from `value_states`

to see what needed to be done to `v_proj`

(and, by extension, to `q_proj`

and `k_proj`

), but still need to check what happens on the way down to `attn_output`

.

#### Intermezzo: a minor simplification

Extending the 2D/3D matrices as done above, is a bit verbose and clutters the code. We will use this function in the code that follows:

def grow_proj(p,n_heads,hs_grown):
hs = p.shape[0]
hd = hs//n_heads
hd_grown = hs_grown//n_heads
if len(p.shape) == 2:
# weight
ret = torch.zeros(hs_grown,hs_grown,dtype=p.dtype)
v = ret.view(n_heads,hd_grown,hs_grown)
v[:,:hd,:hs] = p.view(n_heads,hd,hs)
if len(p.shape) == 1:
# bias
ret = torch.zeros(hs_grown,dtype=p.dtype)
v = ret.view(n_heads,hd_grown)
v[:,:hd] = p.view(n_heads,hd)
return ret

This code exploits the fact that reshaping and creating a view amount to the same folding of matrix values.

#### Working our way down to `attn_output`

As queries, keys and values are all derived in the same way from `hidden_states`

, and have their padded zeros on the same locations, in our toy model we can simply use `value_states_grown`

for each of the state matrices and see where the zeros end up.

value_states_grown_4d = value_states_grown.view(batch_size,seq_len,num_heads,head_dim_grown).transpose(1,2)
key_states_grown_4d = value_states_grown_4d
query_states_grown_4d = value_states_grown_4d
attn_weights = torch.matmul(query_states_grown_4d, key_states_grown_4d.transpose(2, 3))
attn_output_4d = torch.matmul(attn_weights, value_states_grown_4d)
attn_output_4d
tensor([[[[ 2915, 3180, 0]],
[[ 4745, 5110, 0]],
[[ 7215, 7696, 0]],
[[10421, 11034, 0]]]])
attn_output = attn_output_4d.reshape(batch_size,seq_len,hidden_size_grown)
attn_output
tensor([[[ 2915, 3180, 0, 4745, 5110, 0, 7215, 7696, 0, 10421, 11034, 0]]])

This `attn_output`

matrix is then projected using `o_proj`

, for which we will use an identity matrix like we did for `v_proj`

. Let’s first see what happens if we naively grow `o_proj`

to `hidden_size_grown`

, padding with zeros:

o_proj = torch.eye(hidden_size,dtype=dt)
o_proj_grown = torch.zeros(hidden_size_grown,hidden_size_grown,dtype=dt)
o_proj_grown[:hidden_size,:hidden_size] = o_proj
torch.matmul(attn_output,o_proj_grown.T)
tensor([[[ 2915, 3180, 0, 4745, 5110, 0, 7215, 7696, 0, 0, 0, 0]]])

This clearly cuts off some values (note that `10421`

and `11034`

are gone), and leaves zeros in, where it should ideally re-arrange the matrix so that the padding is at the end – the end goal being that the attention block outputs the same `hidden_states`

as before, but padded. We need to do the inverse of what we did going from `v_proj`

to `v_proj_grown`

to reverse the insertion of zeros. The attention code doesn’t make the higher dimensional nature of `o_proj`

explicit, but as it is the inverse operation of the other projection matrices, we can make an educated guess and follow the same steps as before for `v_proj`

. We also have to realise that `o_proj`

needs to be transposed to get the inserted zeros on the right dimension, with the net result that, in addition to the modifications derived above, only two transpositions (`.T`

) need to be done:

o_proj_grown = grow_proj(o_proj.T,num_heads,hidden_size_grown).T
o_proj_grown
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

And now for the final test:

torch.matmul(attn_output,o_proj_grown.T)
tensor([[[ 2915, 3180, 4745, 5110, 7215, 7696, 10421, 11034, 0, 0, 0, 0]]])

Which is what we were looking for. The attention block now takes `hidden_states`

padded with zeros, and returns `hidden_states`

padded with zeros. The above of course only shows it works for the special mix of identity and all-one matrices; the same steps can be repeated with matrices filled with random numbers and the conclusions will hold.

The toy example now works out, but can we use the above to fix the grown model?

## Increasing `head_dim`

in the model

Here we create the bigger model, and iterate over the parameters to copy them, possibly altering them in the process:

torch.set_default_device('cuda:0')
config = copy.deepcopy(model.config)
config.hidden_size = 2560
config.partial_rotary_factor *= 64/80
config._attn_implementation = 'flash_attention_2'
tp = type(model)
biggermodel = tp(config).to(dtype=torch.bfloat16)
import re
match_qkv = re.compile('.*[qkv]_proj')
match_o = re.compile('.*o_proj')
for name,p in model.named_parameters():
p_big = biggermodel.get_parameter(name)
if match_qkv.match(name):
p_big.data[:] = grow_proj(p.data,config.num_key_value_heads,config.hidden_size)
elif match_o.match(name):
p_big.data[:] = grow_proj(p.data.T,config.num_key_value_heads,config.hidden_size).T
else:
p_big.data[:] = 0
idx = [slice(n) for n in p.data.shape]
p_big.data[idx] = p.data
biggermodel(inputs['input_ids'], labels=labels).loss.item()
0.91015625

This is only 18% worse than the original.

Note that we reduced the `partial_rotary_factor`

, so that the number of `head_dim`

rows to which rotary embedding is applied remains equal:

`self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)`

After deriving `key_states`

and `value_states`

, these states are cut in two parts, and on one part (of size `self.rotary_ndims`

) rotary position embedding is applied, to finally concatenate the parts back together.

We skipped a few additional important details. One of them is the scaling of `LayerNorms`

described before:

for name,p in biggermodel.named_parameters():
if 'norm.weight' in name:
p.data[:] *= math.sqrt(2048/2560)
biggermodel(inputs['input_ids'], labels=labels).loss.item()
0.8125

This is 5% worse than the original. Another factor is the scaling by `math.sqrt(head_dim)`

which we need to account for, which we do by increasing either `k_proj`

or `q_proj`

, as the scaling is done on the matrix product of `key_states`

and `query_states`

.

for name,p in biggermodel.named_parameters():
if 'q_proj' in name:
p.data[:] *= math.sqrt(80/64)
biggermodel(inputs['input_ids'], labels=labels).loss.item()
0.77734375

This value is 0.5% above the original loss value, in line with the result achieved through increasing the number of heads.

### Solution 3: hybrid approaches

Of course it is also be possible to combine both solutions, and increase both `num_heads`

and `head_dim`

, e.g. 36 heads with a head dimension of 71, for a `hidden_size`

of 2556. The creation of such a model is left as an exercise to the reader – the end result is, again, 0.5% loss increase.

## Sanity check: is it all about numerical precision?

It may be tempting to think that higher precision would allow models to grow and keep the same loss. This is tested by increasing precision and repeating the process, and it turns out numerical precision does not fully explain the additional loss. Repeating the above steps with `float32`

shows that the loss is still 0.4% higher for the bigger model.

The reason for this is that StableLm uses `nn.LayerNorm`

for normalization, which, in addition to the `LayerNorm.forward()`

quoted above from Llama, subtracts the mean of the hidden states, before normalizing the magnitude using the variance:

`x`

denotes `hidden_states`

, `E[x]`

is equivalent to `x.sum()/hidden_size`

, and `Var[x]`

to `(x-E[x]).pow(2).sum()/hidden_size`

. Increasing `hidden_size`

by padding `hidden_states`

with zeros proportionally decreases `E[x]`

. There is no way to compensate the change in `x-E[x]`

with changes to parameters alone, if `E[x]`

is non-zero: as the hidden size increases, the mean value `E[x]`

decreases, and `Var[E]`

changes in unpredictable ways. (In Llama models, `-E[x]`

is omitted and the increase in `hidden_size`

can be compensated by appropriately scaling the weights `γ`

).

Here is an attempt to fix this issue in StableLm, by implementing a local clone of `nn.LayerNorm`

that keeps calculations equal to the model before it was grown, by hardcoding the divisor and only applying subtraction on the original, non-zero part of `hidden_states`

:

```
class StableLmLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
def forward(self,hidden_states):
hs_orig = 2048
mask = torch.zeros_like(hidden_states)
mask[:,:,:hs_orig] = 1
hidden_states = hidden_states - mask * hidden_states.sum(dim=-1,keepdim=True)/hs_orig
variance = hidden_states.pow(2).sum(dim=-1,keepdim=True)/hs_orig
hidden_states *= torch.rsqrt(variance+self.eps)
hidden_states *= self.weight
hidden_states += self.bias
return hidden_states
```

Using this version of `LayerNorm`

instead of `nn.LayerNorm`

indeed allows us to grow the model without any loss increase. The model will of course now rely on this code and will not be compatible with the regular transformers code base.

## Conclusion

It is possible to grow a model in the hidden size dimension, and doing so leads to a model performing slightly worse than the original model, with more capacity to train. Further research is needed to determine the optimum between the number of heads and the head dimension, the product of which is the hidden size.

It is clear that numerical error induced by growing is a significant source of additional loss; but even with increased precision, implementation details can cause additional loss as well. These details can be patched, leading to (incompatible) models that can be grown without loss. These findings hint to areas of interest, both regarding efficient training and model optimization in general.

## 0 Comments