Size Mismatch In Lora Fine Tuning

When you fine tuning a model with lora or qlora methods, the lora_r(lora rank) parameter is important, it will define the saved checkpoint size, for example, with lora_r = 8, the saved checkpoint shape size is [8, 4096], if the lora_r is different with the saved checkpoint size, you will get an size mismatch error as below:

RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM:
	size mismatch for base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight: copying a param with shape torch.Size([4096, 256]) from checkpoint, the shape in current model is torch.Size([4096, 32]).
	size mismatch for base_model.model.model.layers.0.self_attn.k_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for base_model.model.model.layers.0.self_attn.k_proj.lora_B.default.weight: copying a param with shape torch.Size([4096, 256]) from checkpoint, the shape in current model is torch.Size([4096, 32]).
	size mismatch for base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight: copying a param with shape torch.Size([4096, 256]) from checkpoint, the shape in current model is torch.Size([4096, 32]).
	size mismatch for base_model.model.model.layers.0.self_attn.o_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for base_model.model.model.layers.0.self_attn.o_proj.lora_B.default.weight: copying a param with shape torch.Size([4096, 256]) from checkpoint, the shape in current model is torch.Size([4096, 32]).
	size mismatch for base_model.model.model.layers.0.mlp.gate_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for base_model.model.model.layers.0.mlp.gate_proj.lora_B.default.weight: copying a param with shape torch.Size([11008, 256]) from checkpoint, the shape in current model is torch.Size([11008, 32]).
	size mismatch for base_model.model.model.layers.0.mlp.up_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for base_model.model.model.layers.0.mlp.up_proj.lora_B.default.weight: copying a param with shape torch.Size([11008, 256]) from checkpoint, the shape in current model is torch.Size([11008, 32]).
	size mismatch for base_model.model.model.layers.0.mlp.down_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 11008]) from checkpoint, the shape in current model is torch.Size([32, 11008]).
	size mismatch for base_model.model.model.layers.0.mlp.down_proj.lora_B.default.weight: copying a param with shape torch.Size([4096, 256]) from checkpoint, the shape in current model is torch.Size([4096, 32]).
	size mismatch for base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight: copying a param with shape torch.Size([4096, 256]) from checkpoint, the shape in current model is torch.Size([4096, 32]).
	size mismatch for base_model.model.model.layers.1.self_attn.k_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for base_model.model.model.layers.1.self_attn.k_proj.lora_B.default.weight: copying a param with shape torch.Size([4096, 256]) from checkpoint, the shape in current model is torch.Size([4096, 32]).
	size mismatch for base_model.model.model.layers.1.self_attn.v_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for base_model.model.model.layers.1.self_attn.v_proj.lora_B.default.weight: copying a param with shape torch.Size([4096, 256]) from checkpoint, the shape in current model is torch.Size([4096, 32]).
	size mismatch for base_model.model.model.layers.1.self_attn.o_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for base_model.model.model.layers.1.self_attn.o_proj.lora_B.default.weight: copying a param with shape torch.Size([4096, 256]) from checkpoint, the shape in current model is torch.Size([4096, 32]).
	size mismatch for base_model.model.model.layers.1.mlp.gate_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for base_model.model.model.layers.1.mlp.gate_proj.lora_B.default.weight: copying a param with shape torch.Size([11008, 256]) from checkpoint, the shape in current model is torch.Size([11008, 32]).
	size mismatch for base_model.model.model.layers.1.mlp.up_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for base_model.model.model.layers.1.mlp.up_proj.lora_B.default.weight: copying a param with shape torch.Size([11008, 256]) from checkpoint, the shape in current model is torch.Size([11008, 32]).
	size mismatch for base_model.model.model.layers.1.mlp.down_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 11008]) from checkpoint, the shape in current model is torch.Size([32, 11008]).
	size mismatch for base_model.model.model.layers.1.mlp.down_proj.lora_B.default.weight: copying a param with shape torch.Size([4096, 256]) from checkpoint, the shape in current model is torch.Size([4096, 32]).
	size mismatch for base_model.model.model.layers.2.self_attn.q_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for base_model.model.model.layers.2.self_attn.q_proj.lora_B.default.weight: copying a param with shape torch.Size([4096, 256]) from checkpoint, the shape in current model is torch.Size([4096, 32]).
	size mismatch for base_model.model.model.layers.2.self_attn.k_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for base_model.model.model.layers.2.self_attn.k_proj.lora_B.default.weight: copying a param with shape torch.Size([4096, 256]) from checkpoint, the shape in current model is torch.Size([4096, 32]).
	size mismatch for base_model.model.model.layers.2.self_attn.v_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for base_model.model.model.layers.2.self_attn.v_proj.lora_B.default.weight: copying a param with shape torch.Size([4096, 256]) from checkpoint, the shape in current model is torch.Size([4096, 32]).
	size mismatch for base_model.model.model.layers.2.self_attn.o_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for base_model.model.model.layers.2.self_attn.o_proj.lora_B.default.weight: copying a param with shape torch.Size([4096, 256]) from checkpoint, the shape in current model is torch.Size([4096, 32]).
	size mismatch for base_model.model.model.layers.2.mlp.gate_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
	size mismatch for base_model.model.model.layers.2.mlp.gate_proj.lora_B.default.weight: copying a param with shape torch.Size([11008, 256]) from checkpoint, the shape in current model is torch.Size([11008, 32]).
	size mismatch for base_model.model.model.layers.2.mlp.up_proj.lora_A.default.weight: copying a param with shape torch.Size([256, 4096]) from checkpoint, the shape in current model is torch.Size([32, 4096]).
..................

it mentions that there is a discrepancy in the shapes of certain tensors in your model. In this case, the shapes are described as follows:

  1. The checkpoint (the saved model parameters) contains a tensor with shape [256, 4096].
  2. The current model you are trying to load the checkpoint into has a tensor with shape [32, 4096].

This size mismatch makes it impossible to directly load the parameters from the checkpoint into the current model due to the differing shapes.

Possible Solutions

To resolve this error, you have a few options:

  1. Reconfigure Model Architecture: You can modify the architecture of your current model to match the shape of the tensor in the checkpoint. In this case, you would need to change the tensor with shape [32, 4096] to match [256, 4096]. However, this may not always be feasible, depending on the specifics of your model.
    In my case, change lora_r from 32 to 256 will solve the problem.
  1. Load Weights Selectively: If you know which layers in your current model correspond to the layers in the checkpoint with the mismatched tensor shape, you can selectively load the weights for those layers while ignoring the mismatched layers. This can be done using model state dictionary manipulation.Here’s an example of how to do this using PyTorch:
    state_dict = torch.load('your_checkpoint.pth') model_state = model.state_dict() # Filter out unnecessary keys state_dict = {k: v for k, v in state_dict.items() if k in model_state} # Update the current model state with the filtered checkpoint state model_state.update(state_dict) model.load_state_dict(model_state)
  2. Fine-tuning the Model: If the architecture changes are substantial, you may need to fine-tune the model based on your data and specific requirements. This involves training the model from a partially pre-trained state.
  3. Using a Compatible Pre-trained Model: Ensure that you are using a pre-trained model checkpoint that is compatible with your current model architecture. If the checkpoint and model architecture are supposed to match, make sure you are using the correct checkpoint file.

Leave a Reply

Your email address will not be published. Required fields are marked *