大模型高效参数微调方法:LoRA(低秩适配)
近年来,大型语言模型(LLM)如GPT、LLaMA等在海量数据上预训练后,展现出强大的语言理解和生成能力。然而,要将这些通用模型应用到特定任务或领域时,往往需要进行微调。传统的全量参数微调虽然有效,但面临着巨大挑战:计算资源消耗巨大、存储需求过高、训练时间漫长、硬件门槛高。
LoRA(Low-Rank Adaptation,低秩适配)是一种参数高效微调技术,其核心思想是通过低秩矩阵分解来模拟模型参数的变化,而非直接微调原始参数,从而大幅降低训练资源需求的同时保持模型性能。
LoRA: Low-Rank Adaptation of Large Language Models一、数学基础原理

对于一个预训练权重矩阵 ,全参数微调需要学习更新矩阵 ,使得微调后的权重为:
低维固有维度假设:研究表明,尽管大型神经网络是过参数化的,拥有数十亿甚至万亿参数,但其在适应新任务时所需的实质性变化(参数更新矩阵 )其实存在于一个低维子空间中。也就是说,存在一个低维的参数更新,其效果可以与更新全部参数相媲美。
LoRA 正是基于该假设,通过引入两个低秩矩阵 和 ,用它们的乘积 来近似表达适应新任务所需的参数更新 。因此,只需要训练这两个小矩阵,而非庞大的原始模型参数 ,就能高效地让大模型适配下游任务。与传统的全参数微调相比,将可训练参数减少到原模型的 0.1% - 2.0%,使得在单张消费级 GPU 上微调大型模型成为可能。
其中,和 是可训练的低秩矩阵,秩 ,通常取值在 4~128 之间,是缩放因子,用于控制 LoRA 权重的幅度。
二、前向传播及初始化
前向传播过程中,保持预训练权重矩阵 完全冻结(不被更新),同时通过引入一个由低秩矩阵 和 构成的适配路径(即 ),来对模型的输出进行针对特定任务的精细化调整。在原始全连接层中,前向传播的计算公式为:
当引入 LoRA 适配器后,前向计算过程转变为:
为了确保训练过程的稳定性和收敛效率,LoRA 对引入的低秩矩阵采用了特定的初始化方案:
矩阵 采用从均值为 0 的高斯分布中随机采样的方式进行初始化。
矩阵 则初始化为全零矩阵。
该初始化策略可以保证在训练刚开始时,低秩适配项 的结果为零。这意味着模型的初始输出完全由预训练权重 决定,低秩适配器相当于处于”静默”状态。这样做的主要优势在于:
训练稳定性:避免了因随机初始化引入的噪声干扰预训练模型已经学到的强大表征能力,让微调从一个稳定、可靠的起点开始。
收敛高效性:确保了训练初期梯度的有效性,使模型能够平滑地从预训练知识过渡到任务特定知识,从而更有利于后续的高效收敛。
三、LoRA 代码实现
通过具体的代码实现可以帮助我们更好地深入理解 LoRA,以下是完整的 LoRA 层实现及其在 Transformer 模型中的应用:
1. 核心 LoRA 层实现
初始化使用 Kaiming(A) 与 Zero(B) 保证训练稳定且不改变初始输出,alpha/rank 控制适配强度,且可选 Dropout 做正则化防止过拟合。
class LoRALayer(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
rank: int = 8,
alpha: float = 16.0,
dropout: float = 0.0
):
super().__init__()
self.rank = rank
self.alpha = alpha
self.scaling = alpha / rank # 缩放系数控制适配强度
# A矩阵: Kaiming初始化保证梯度稳定传播
self.lora_A = nn.Parameter(torch.randn(rank, input_dim) / math.sqrt(rank))
# B矩阵: 零初始化确保训练初期ΔW=0,保持原始模型输出
self.lora_B = nn.Parameter(torch.zeros(output_dim, rank))
self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else None
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""前向传播计算 output = (alpha/rank) * B @ A @ x """
if self.dropout is not None:
x = self.dropout(x)
# 低秩适配路径计算
lora_output = F.linear(x, self.lora_B @ self.lora_A)
return self.scaling * lora_output2. 在 Transformer 模型中应用 LoRA
在 Transformer 注意力机制中的 、投影层应用 LoRA 通常效果最佳;秩 通常取 4~128,资源受限选小 (如 4/8/16),复杂任务选大 (如 64/128)。
在 Transformer 注意力机制中,优先对 (查询)和 (值)投影层应用 LoRA,是一种在参数效率和任务效果之间取得最佳平衡的经验性策略。矩阵决定了模型”关注哪里”,矩阵提供了”提取什么信息”,二者对任务适配的影响最为直接和显著;而 (键)矩阵更多与输入序列的固有属性相关,相对稳定,不微调 可以节省参数量并降低过拟合风险。
def apply_lora_to_transformer_model(
model: nn.Module,
target_modules: Optional[List[str]] = None,
rank: int = 8,
alpha: float = 16.0,
dropout: float = 0.0
) -> nn.Module:
"""将LoRA适配器应用到Transformer模型的指定模块
示例:
>>> model = TransformerModel(...)
>>> model = apply_lora_to_transformer_model(
... model,
... target_modules=["q_proj", "v_proj"], # 针对注意力机制
... rank=8,
... alpha=16
... )
"""
if target_modules is None:
# 默认针对Transformer的注意力关键投影层
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
lora_count = 0
for module_name, module in model.named_modules():
if not isinstance(module, nn.Linear):
continue
# 筛选目标模块:只对指定名称的线性层应用LoRA
if not any(target in module_name for target in target_modules):
continue
# 获取父模块并替换目标层
*parent_path, attr_name = module_name.split('.')
parent_module = model
for part in parent_path:
parent_module = getattr(parent_module, part)
# 创建LoRA增强的线性层
lora_linear = LinearWithLoRA(
linear_layer=module,
rank=rank,
alpha=alpha,
dropout=dropout
)
setattr(parent_module, attr_name, lora_linear)
lora_count += 1
print(f"成功将LoRA应用到{lora_count}个线性层")
return modelLoRA 权重提取/合并脚本
LoRA 权重提取(lora_extractor)
lora_extractor 能够通过计算微调后的模型与原始基础模型之间的参数差异,并利用奇异值分解(SVD)技术,将高维的权重更新矩阵近似分解为两个低秩矩阵的乘积,从而生成一个轻量级的 LoRA 适配器权重文件。
差分模式直接保存完整权重差 ,LoRA 模式通过低秩分解 压缩存储,其中 , ,秩 决定压缩比。
| 参数名 | 类型 | 默认值 | 可选值 | 说明 |
|---|---|---|---|---|
--source-model | str | 必填 | 文件/目录路径 | 源模型路径(基础模型) |
--source-type | str | safetensors | safetensors, pytorch | 源模型格式类型 |
--target-model | str | 必填 | 文件/目录路径 | 目标模型路径(微调后模型) |
--target-type | str | safetensors | safetensors, pytorch | 目标模型格式类型 |
--output | str | 必填 | 文件路径 | LoRA 输出路径 |
--output-format | str | safetensors | safetensors, pytorch | 输出 LoRA 模型格式 |
--rank | int | 32 | 正整数 | LoRA 秩值(rank) |
--output-dtype | str | bf16 | float32, fp32, float16, fp16, bfloat16, bf16 | 输出权重数据类型 |
--diff-only | bool | False | - | 仅保存直接差分,不进行 LoRA 分解 |
# 示例用法1: 标准 LoRA 提取
python lora_extractor.py \
--source-model "path/to/base_model.safetensors" \
--target-model "path/to/finetuned_model.safetensors" \
--output "path/to/output/lora.safetensors" \
--rank 32 \
--output-dtype bf16
# 示例用法2: 仅差分模式
python lora_extractor.py \
--source-model "path/to/base_model.safetensors" \
--target-model "path/to/finetuned_model.safetensors" \
--output "path/to/output/diff.safetensors" \
--diff-only \
--output-dtype fp16
# 示例用法3: PyTorch 格式转换
python lora_extractor.py \
--source-model "path/to/base_model.pth" \
--source-type pytorch \
--target-model "path/to/finetuned_model.pth" \
--target-type pytorch \
--output "path/to/output/lora.safetensors" \
--output-format safetensors#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
LoRA Extractor Script
Extract LoRA weights from the difference between two models
"""
import argparse
import os
from typing import Dict, Optional
import torch
from safetensors import safe_open
from safetensors import torch as st
from tqdm import tqdm
def _get_torch_dtype(dtype_str: str) -> torch.dtype:
"""
Convert string to torch data type
Args:
dtype_str: Data type string
Returns:
Torch data type
"""
dtype_mapping = {
"float32": torch.float32,
"fp32": torch.float32,
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
}
if dtype_str not in dtype_mapping:
raise ValueError(f"Unsupported data type: {dtype_str}")
return dtype_mapping[dtype_str]
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description="Extract LoRA weights from the difference between source and target models", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Source model parameters
parser.add_argument("--source-model", type=str, required=True, help="Path to source model")
parser.add_argument("--source-type", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="Source model format type")
# Target model parameters
parser.add_argument("--target-model", type=str, required=True, help="Path to target model (fine-tuned model)")
parser.add_argument("--target-type", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="Target model format type")
# Output parameters
parser.add_argument("--output", type=str, required=True, help="Path to output LoRA model")
parser.add_argument("--output-format", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="Output LoRA model format")
# LoRA related parameters
parser.add_argument("--rank", type=int, default=32, help="LoRA rank value")
parser.add_argument("--output-dtype", type=str, choices=["float32", "fp32", "float16", "fp16", "bfloat16", "bf16"], default="bf16", help="Output weight data type")
parser.add_argument("--diff-only", action="store_true", help="Save all weights as direct diff without LoRA decomposition")
return parser.parse_args()
def load_model_weights(model_path: str, model_type: str) -> Dict[str, torch.Tensor]:
"""
Load model weights (using fp32 precision)
Args:
model_path: Model file path or directory path
model_type: Model type ("safetensors" or "pytorch")
Returns:
Model weights dictionary (fp32 precision)
"""
print(f"Loading model: {model_path} (type: {model_type}, precision: fp32)")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model path does not exist: {model_path}")
weights = {}
if model_type == "safetensors":
if os.path.isdir(model_path):
# If it's a directory, load all .safetensors files in the directory
safetensors_files = []
for file in os.listdir(model_path):
if file.endswith(".safetensors"):
safetensors_files.append(os.path.join(model_path, file))
if not safetensors_files:
raise ValueError(f"No .safetensors files found in directory: {model_path}")
print(f"Found {len(safetensors_files)} safetensors files")
# Load all files and merge weights
for file_path in sorted(safetensors_files):
print(f" Loading file: {os.path.basename(file_path)}")
with safe_open(file_path, framework="pt", device="cpu") as f:
for key in f.keys():
if key in weights:
print(f"Warning: weight key '{key}' is duplicated in multiple files, will be overwritten")
weights[key] = f.get_tensor(key)
elif os.path.isfile(model_path):
# If it's a single file
if model_path.endswith(".safetensors"):
with safe_open(model_path, framework="pt", device="cpu") as f:
for key in f.keys():
weights[key] = f.get_tensor(key)
else:
raise ValueError(f"safetensors type file should end with .safetensors: {model_path}")
else:
raise ValueError(f"Invalid path type: {model_path}")
elif model_type == "pytorch":
# Load pytorch format (.pt, .pth)
if model_path.endswith((".pt", ".pth")):
checkpoint = torch.load(model_path, map_location="cpu")
# Handle possible nested structure
if isinstance(checkpoint, dict):
if "state_dict" in checkpoint:
weights = checkpoint["state_dict"]
elif "model" in checkpoint:
weights = checkpoint["model"]
else:
weights = checkpoint
else:
weights = checkpoint
else:
raise ValueError(f"pytorch type file should end with .pt or .pth: {model_path}")
else:
raise ValueError(f"Unsupported model type: {model_type}")
# Convert all floating point weights to fp32 to ensure computational precision
print("Converting weights to fp32 to ensure computational precision...")
converted_weights = {}
for key, tensor in weights.items():
# Only convert floating point tensors, keep integer tensors unchanged
if tensor.dtype.is_floating_point:
converted_weights[key] = tensor.to(torch.float32)
else:
converted_weights[key] = tensor
print(f"Successfully loaded model with {len(converted_weights)} weight tensors")
return converted_weights
def save_lora_weights(lora_weights: Dict[str, torch.Tensor], output_path: str, output_format: str, output_dtype: str = "bf16"):
"""
Save LoRA weights
Args:
lora_weights: LoRA weights dictionary
output_path: Output path
output_format: Output format
output_dtype: Output data type
"""
print(f"Saving LoRA weights to: {output_path} (format: {output_format}, data type: {output_dtype})")
# Ensure output directory exists
output_dir = os.path.dirname(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# Convert data type
target_dtype = _get_torch_dtype(output_dtype)
print(f"Converting LoRA weights to {output_dtype} type...")
converted_weights = {}
with tqdm(lora_weights.items(), desc="Converting data type", unit="weights") as pbar:
for key, tensor in pbar:
# Only convert floating point tensors, keep integer tensors unchanged
if tensor.dtype.is_floating_point:
converted_weights[key] = tensor.to(target_dtype).contiguous()
else:
converted_weights[key] = tensor.contiguous()
if output_format == "safetensors":
# Save as safetensors format
if not output_path.endswith(".safetensors"):
output_path += ".safetensors"
st.save_file(converted_weights, output_path)
elif output_format == "pytorch":
# Save as pytorch format
if not output_path.endswith((".pt", ".pth")):
output_path += ".pt"
torch.save(converted_weights, output_path)
else:
raise ValueError(f"Unsupported output format: {output_format}")
print(f"LoRA weights saved to: {output_path}")
def _compute_weight_diff(source_tensor: torch.Tensor, target_tensor: torch.Tensor, key: str) -> Optional[torch.Tensor]:
"""
Compute the difference between two weight tensors
Args:
source_tensor: Source weight tensor
target_tensor: Target weight tensor
key: Weight key name (for logging)
Returns:
Difference tensor, returns None if no change
"""
# Check if tensor shapes match
if source_tensor.shape != target_tensor.shape:
return None
# Check if tensor data types match
if source_tensor.dtype != target_tensor.dtype:
target_tensor = target_tensor.to(source_tensor.dtype)
# Compute difference
diff = target_tensor - source_tensor
# Check if there are actual changes
if torch.allclose(diff, torch.zeros_like(diff), atol=1e-8):
# No change
return None
return diff
def _decompose_to_lora(diff: torch.Tensor, key: str, rank: int) -> Dict[str, torch.Tensor]:
"""
Decompose weight difference into LoRA format
Args:
diff: Weight difference tensor
key: Original weight key name
rank: LoRA rank
Returns:
LoRA weights dictionary (containing lora_up and lora_down)
"""
# Ensure it's a 2D tensor
if len(diff.shape) != 2:
raise ValueError(f"LoRA decomposition only supports 2D weights, but got {len(diff.shape)}D tensor: {key}")
a, b = diff.shape
# Check if rank is reasonable
max_rank = min(a, b)
if rank > max_rank:
rank = max_rank
# Choose compute device (prefer GPU, fallback to CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
diff_device = diff.to(device)
# SVD decomposition
U, S, V = torch.linalg.svd(diff_device, full_matrices=False)
# Take the first rank components
U = U[:, :rank] # (a, rank)
S = S[:rank] # (rank,)
V = V[:rank, :] # (rank, b)
# Distribute square root of singular values to both matrices
S_sqrt = S.sqrt()
lora_up = U * S_sqrt.unsqueeze(0) # (a, rank) * (1, rank) = (a, rank)
lora_down = S_sqrt.unsqueeze(1) * V # (rank, 1) * (rank, b) = (rank, b)
# Move back to CPU and convert to original data type, ensure contiguous
lora_up = lora_up.cpu().to(diff.dtype).contiguous()
lora_down = lora_down.cpu().to(diff.dtype).contiguous()
# Generate LoRA weight key names
base_key = key.replace(".weight", "")
lora_up_key = "diffusion_model." + f"{base_key}.lora_up.weight"
lora_down_key = "diffusion_model." + f"{base_key}.lora_down.weight"
# Return the decomposed weights
lora_weights = {lora_up_key: lora_up, lora_down_key: lora_down}
return lora_weights
def extract_lora_from_diff(source_weights: Dict[str, torch.Tensor], target_weights: Dict[str, torch.Tensor], rank: int = 16, diff_only: bool = False) -> Dict[str, torch.Tensor]:
"""
Extract LoRA weights from model difference
Args:
source_weights: Source model weights
target_weights: Target model weights
rank: LoRA rank
diff_only: If True, save all weights as direct diff without LoRA decomposition
Returns:
LoRA weights dictionary
"""
print("Starting LoRA weight extraction...")
if diff_only:
print("Mode: Direct diff only (no LoRA decomposition)")
else:
print(f"Mode: Smart extraction - rank: {rank}")
print(f"Source model weight count: {len(source_weights)}")
print(f"Target model weight count: {len(target_weights)}")
lora_weights = {}
processed_count = 0
diff_count = 0
lora_count = 0
similar_count = 0
skipped_count = 0
fail_count = 0
# Find common keys between two models
common_keys = set(source_weights.keys()) & set(target_weights.keys())
source_only_keys = set(source_weights.keys()) - set(target_weights.keys())
target_only_keys = set(target_weights.keys()) - set(source_weights.keys())
if source_only_keys:
print(f"Warning: Source model exclusive weight keys ({len(source_only_keys)} keys): {list(source_only_keys)[:5]}...")
if target_only_keys:
print(f"Warning: Target model exclusive weight keys ({len(target_only_keys)} keys): {list(target_only_keys)[:5]}...")
print(f"Common weight keys count: {len(common_keys)}")
# Process common keys, extract LoRA weights
common_keys_sorted = sorted(common_keys)
pbar = tqdm(common_keys_sorted, desc="Extracting LoRA weights", unit="layer")
for key in pbar:
source_tensor = source_weights[key]
target_tensor = target_weights[key]
# Update progress bar description
short_key = key.split(".")[-2:] if "." in key else [key]
pbar.set_postfix_str(f"Processing: {'.'.join(short_key)}")
# Compute weight difference
diff = _compute_weight_diff(source_tensor, target_tensor, key)
if diff is None:
# No change or shape mismatch
if source_tensor.shape == target_tensor.shape:
similar_count += 1
else:
skipped_count += 1
continue
# Calculate parameter count
param_count = source_tensor.numel()
is_1d = len(source_tensor.shape) == 1
# Decide whether to save diff directly or perform LoRA decomposition
if diff_only or is_1d or param_count < 1000000:
# Save diff directly
lora_key = _generate_lora_diff_key(key)
if lora_key == "skip":
skipped_count += 1
continue
lora_weights[lora_key] = diff
diff_count += 1
else:
# Perform LoRA decomposition
if len(diff.shape) == 2 and key.endswith(".weight"):
try:
decomposed_weights = _decompose_to_lora(diff, key, rank)
lora_weights.update(decomposed_weights)
lora_count += 1
except Exception as e:
print(f"Error: {e}")
fail_count += 1
else:
print(f"Error: {key} is not a 2D weight tensor")
fail_count += 1
processed_count += 1
# Close progress bar
pbar.close()
print(f"\nExtraction statistics:")
print(f" Processed weights: {processed_count}")
print(f" Direct diff: {diff_count}")
print(f" LoRA decomposition: {lora_count}")
print(f" Skipped weights: {skipped_count}")
print(f" Similar weights: {similar_count}")
print(f" Failed weights: {fail_count}")
print(f" Total extracted LoRA weights: {len(lora_weights)}")
print("LoRA weight extraction completed")
return lora_weights
def _generate_lora_diff_key(original_key: str) -> str:
"""
Generate LoRA weight key based on original weight key
Args:
original_key: Original weight key name
Returns:
LoRA weight key name
"""
ret_key = "diffusion_model." + original_key
if original_key.endswith(".weight"):
return ret_key.replace(".weight", ".diff")
elif original_key.endswith(".bias"):
return ret_key.replace(".bias", ".diff_b")
elif original_key.endswith(".modulation"):
return ret_key.replace(".modulation", ".diff_m")
else:
# If no matching suffix, skip
return "skip"
if __name__ == "__main__":
args = parse_args()
print("=" * 50)
print("LoRA Extractor Started")
print("=" * 50)
print(f"Source model: {args.source_model} ({args.source_type})")
print(f"Target model: {args.target_model} ({args.target_type})")
print(f"Output path: {args.output} ({args.output_format})")
print(f"Output data type: {args.output_dtype}")
print(f"LoRA parameters: rank={args.rank}")
print(f"Diff only mode: {args.diff_only}")
print("=" * 50)
try:
# Load source and target models
source_weights = load_model_weights(args.source_model, args.source_type)
target_weights = load_model_weights(args.target_model, args.target_type)
# Extract LoRA weights
lora_weights = extract_lora_from_diff(source_weights, target_weights, rank=args.rank, diff_only=args.diff_only)
# Save LoRA weights
save_lora_weights(lora_weights, args.output, args.output_format, args.output_dtype)
print("=" * 50)
print("LoRA extraction completed!")
print("=" * 50)
except Exception as e:
print(f"Error: {e}")
raiseLoRA 权重合并(lora_merger)
lora_merger 能够将轻量级的 LoRA 适配器权重与原始基础模型进行融合,通过计算 (LoRA 模式)或 (差分模式),并将其叠加到基础模型权重 上,从而生成一个包含微调效果的完整模型文件。
| 参数名 | 类型 | 默认值 | 可选值 | 说明 |
|---|---|---|---|---|
--source-model | str | 必填 | 文件/目录路径 | 源模型路径(基础模型) |
--source-type | str | safetensors | safetensors, pytorch | 源模型格式类型 |
--lora-model | str | 必填 | 文件/目录路径 | LoRA 权重文件路径 |
--lora-type | str | safetensors | safetensors, pytorch | LoRA 权重格式类型 |
--output | str | 必填 | 文件路径 | 合并后模型输出路径 |
--output-format | str | safetensors | safetensors, pytorch | 输出模型格式 |
--alpha | float | 1.0 | 0.0 ~ 2.0 | LoRA 融合强度系数 |
--output-dtype | str | bf16 | float32, fp32, float16, fp16, bfloat16, bf16 | 输出权重数据类型 |
# 示例 1: 标准 LoRA 合并(完全融合)
python lora_merger.py \
--source-model "path/to/base_model.safetensors" \
--lora-model "path/to/lora_weights.safetensors" \
--output "path/to/output/merged_model.safetensors" \
--alpha 1.0 \
--output-dtype bf16
# 示例 2: 部分融合(50% 强度)
python lora_merger.py \
--source-model "path/to/base_model.safetensors" \
--lora-model "path/to/lora_weights.safetensors" \
--output "path/to/output/merged_model_half.safetensors" \
--alpha 0.5 \
--output-dtype bf16
# 示例 3: PyTorch 格式转换为 SafeTensors
python lora_merger.py \
--source-model "path/to/base_model.pth" \
--source-type pytorch \
--lora-model "path/to/lora_weights.pth" \
--lora-type pytorch \
--output "path/to/output/merged_model.safetensors" \
--output-format safetensors \
--alpha 1.0
# 示例 4: 差分模式合并
python lora_merger.py \
--source-model "path/to/base_model.safetensors" \
--lora-model "path/to/diff_weights.safetensors" \
--output "path/to/output/merged_model.safetensors" \
--alpha 1.0 \
--output-dtype fp32#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
LoRA Merger Script
Merge a source model with LoRA weights to create a new model
"""
import argparse
import os
from typing import Dict, Optional
import torch
from safetensors import safe_open
from safetensors import torch as st
from tqdm import tqdm
def _get_torch_dtype(dtype_str: str) -> torch.dtype:
"""
Convert string to torch data type
Args:
dtype_str: Data type string
Returns:
Torch data type
"""
dtype_mapping = {
"float32": torch.float32,
"fp32": torch.float32,
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
}
if dtype_str not in dtype_mapping:
raise ValueError(f"Unsupported data type: {dtype_str}")
return dtype_mapping[dtype_str]
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description="Merge a source model with LoRA weights to create a new model", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Source model parameters
parser.add_argument("--source-model", type=str, required=True, help="Path to source model")
parser.add_argument("--source-type", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="Source model format type")
# LoRA parameters
parser.add_argument("--lora-model", type=str, required=True, help="Path to LoRA weights")
parser.add_argument("--lora-type", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="LoRA weights format type")
# Output parameters
parser.add_argument("--output", type=str, required=True, help="Path to output merged model")
parser.add_argument("--output-format", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="Output model format")
# Merge parameters
parser.add_argument("--alpha", type=float, default=1.0, help="LoRA merge strength (alpha value)")
parser.add_argument("--output-dtype", type=str, choices=["float32", "fp32", "float16", "fp16", "bfloat16", "bf16"], default="bf16", help="Output weight data type")
return parser.parse_args()
def load_model_weights(model_path: str, model_type: str) -> Dict[str, torch.Tensor]:
"""
Load model weights (using fp32 precision)
Args:
model_path: Model file path or directory path
model_type: Model type ("safetensors" or "pytorch")
Returns:
Model weights dictionary (fp32 precision)
"""
print(f"Loading model: {model_path} (type: {model_type}, precision: fp32)")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model path does not exist: {model_path}")
weights = {}
if model_type == "safetensors":
if os.path.isdir(model_path):
# If it's a directory, load all .safetensors files in the directory
safetensors_files = []
for file in os.listdir(model_path):
if file.endswith(".safetensors"):
safetensors_files.append(os.path.join(model_path, file))
if not safetensors_files:
raise ValueError(f"No .safetensors files found in directory: {model_path}")
print(f"Found {len(safetensors_files)} safetensors files")
# Load all files and merge weights
for file_path in sorted(safetensors_files):
print(f" Loading file: {os.path.basename(file_path)}")
with safe_open(file_path, framework="pt", device="cpu") as f:
for key in f.keys():
if key in weights:
print(f"Warning: weight key '{key}' is duplicated in multiple files, will be overwritten")
weights[key] = f.get_tensor(key)
elif os.path.isfile(model_path):
# If it's a single file
if model_path.endswith(".safetensors"):
with safe_open(model_path, framework="pt", device="cpu") as f:
for key in f.keys():
weights[key] = f.get_tensor(key)
else:
raise ValueError(f"safetensors type file should end with .safetensors: {model_path}")
else:
raise ValueError(f"Invalid path type: {model_path}")
elif model_type == "pytorch":
# Load pytorch format (.pt, .pth)
if model_path.endswith((".pt", ".pth")):
checkpoint = torch.load(model_path, map_location="cpu")
# Handle possible nested structure
if isinstance(checkpoint, dict):
if "state_dict" in checkpoint:
weights = checkpoint["state_dict"]
elif "model" in checkpoint:
weights = checkpoint["model"]
else:
weights = checkpoint
else:
weights = checkpoint
else:
raise ValueError(f"pytorch type file should end with .pt or .pth: {model_path}")
else:
raise ValueError(f"Unsupported model type: {model_type}")
# Convert all floating point weights to fp32 to ensure computational precision
print("Converting weights to fp32 to ensure computational precision...")
converted_weights = {}
for key, tensor in weights.items():
# Only convert floating point tensors, keep integer tensors unchanged
if tensor.dtype.is_floating_point:
converted_weights[key] = tensor.to(torch.float32)
else:
converted_weights[key] = tensor
print(f"Successfully loaded model with {len(converted_weights)} weight tensors")
return converted_weights
def save_model_weights(model_weights: Dict[str, torch.Tensor], output_path: str, output_format: str, output_dtype: str = "bf16"):
"""
Save model weights
Args:
model_weights: Model weights dictionary
output_path: Output path
output_format: Output format
output_dtype: Output data type
"""
print(f"Saving merged model to: {output_path} (format: {output_format}, data type: {output_dtype})")
# Ensure output directory exists
output_dir = os.path.dirname(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# Convert data type
target_dtype = _get_torch_dtype(output_dtype)
print(f"Converting model weights to {output_dtype} type...")
converted_weights = {}
with tqdm(model_weights.items(), desc="Converting data type", unit="weights") as pbar:
for key, tensor in pbar:
# Only convert floating point tensors, keep integer tensors unchanged
if tensor.dtype.is_floating_point:
converted_weights[key] = tensor.to(target_dtype).contiguous()
else:
converted_weights[key] = tensor.contiguous()
if output_format == "safetensors":
# Save as safetensors format
if not output_path.endswith(".safetensors"):
output_path += ".safetensors"
st.save_file(converted_weights, output_path)
elif output_format == "pytorch":
# Save as pytorch format
if not output_path.endswith((".pt", ".pth")):
output_path += ".pt"
torch.save(converted_weights, output_path)
else:
raise ValueError(f"Unsupported output format: {output_format}")
print(f"Merged model saved to: {output_path}")
def merge_lora_weights(source_weights: Dict[str, torch.Tensor], lora_weights: Dict[str, torch.Tensor], alpha: float = 1.0) -> Dict[str, torch.Tensor]:
"""
Merge source model with LoRA weights
Args:
source_weights: Source model weights
lora_weights: LoRA weights
alpha: LoRA merge strength
Returns:
Merged model weights
"""
print("Starting LoRA merge...")
print(f"Merge parameters - alpha: {alpha}")
print(f"Source model weight count: {len(source_weights)}")
print(f"LoRA weight count: {len(lora_weights)}")
merged_weights = source_weights.copy()
processed_count = 0
lora_merged_count = 0
diff_merged_count = 0
skipped_source_count = 0
skipped_lora_count = 0
skipped_source_keys = []
skipped_lora_keys = []
# Group LoRA weights by base key
lora_pairs = {}
diff_weights = {}
for lora_key, lora_tensor in lora_weights.items():
if lora_key.endswith(".lora_up.weight"):
base_key = lora_key.replace(".lora_up.weight", "")
if base_key not in lora_pairs:
lora_pairs[base_key] = {}
lora_pairs[base_key]["up"] = lora_tensor
elif lora_key.endswith(".lora_down.weight"):
base_key = lora_key.replace(".lora_down.weight", "")
if base_key not in lora_pairs:
lora_pairs[base_key] = {}
lora_pairs[base_key]["down"] = lora_tensor
elif lora_key.endswith((".diff", ".diff_b", ".diff_m")):
diff_weights[lora_key] = lora_tensor
print(f"Found {len(lora_pairs)} LoRA pairs and {len(diff_weights)} diff weights")
# Process with progress bar
all_items = list(lora_pairs.items()) + list(diff_weights.items())
pbar = tqdm(all_items, desc="Merging LoRA weights", unit="weight")
for item in pbar:
if isinstance(item[1], dict): # LoRA pair
base_key, lora_pair = item
if "up" in lora_pair and "down" in lora_pair:
# Find corresponding source weight
source_key = _find_source_key(base_key, source_weights)
if source_key:
if source_weights[source_key].shape != (lora_pair["up"].shape[0], lora_pair["down"].shape[1]):
skipped_source_count += 1
skipped_source_keys.append(source_key)
continue
lora_up = lora_pair["up"]
lora_down = lora_pair["down"]
# Compute LoRA delta: alpha * (lora_up @ lora_down)
lora_delta = alpha * (lora_up @ lora_down)
# Apply to source weight
merged_weights[source_key] = source_weights[source_key] + lora_delta
lora_merged_count += 1
pbar.set_postfix_str(f"LoRA: {source_key.split('.')[-1]}")
else:
skipped_source_count += 1
skipped_source_keys.append(base_key)
else:
print(f"Warning: Incomplete LoRA pair for: {base_key}")
skipped_lora_count += 1
skipped_lora_keys.append(base_key)
else: # Diff weight
diff_key, diff_tensor = item
# Find corresponding source weight
source_key = _find_source_key_from_diff(diff_key, source_weights)
if source_key:
if source_weights[source_key].shape != diff_tensor.shape:
skipped_source_count += 1
skipped_source_keys.append(source_key)
continue
# Apply diff: source + alpha * diff
merged_weights[source_key] = source_weights[source_key] + alpha * diff_tensor
diff_merged_count += 1
pbar.set_postfix_str(f"Diff: {source_key.split('.')[-1]}")
else:
skipped_lora_count += 1
skipped_lora_keys.append(diff_key)
processed_count += 1
pbar.close()
print(f"\nMerge statistics:")
print(f" Processed weights: {processed_count}")
print(f" LoRA merged: {lora_merged_count}")
print(f" Diff merged: {diff_merged_count}")
print(f" Skipped source weights: {skipped_source_count}")
if skipped_source_count > 0:
print(f" Skipped source keys:")
for key in skipped_source_keys:
print(f" {key}")
print(f" Skipped LoRA weights: {skipped_lora_count}")
if skipped_lora_count > 0:
print(f" Skipped LoRA keys:")
for key in skipped_lora_keys:
print(f" {key}")
print(f" Total merged model weights: {len(merged_weights)}")
print("LoRA merge completed")
return merged_weights
def _find_source_key(lora_base_key: str, source_weights: Dict[str, torch.Tensor]) -> Optional[str]:
"""
Find corresponding source weight key for LoRA base key
Args:
lora_base_key: LoRA base key (e.g., "diffusion_model.input_blocks.0.0.weight")
source_weights: Source model weights
Returns:
Corresponding source key or None
"""
# Remove diffusion_model prefix if present
if lora_base_key.startswith("diffusion_model."):
source_key = lora_base_key[16:] + ".weight" # Remove "diffusion_model." and add ".weight"
else:
source_key = lora_base_key + ".weight"
if source_key in source_weights:
return source_key
# Try without adding .weight (in case it's already included)
if lora_base_key.startswith("diffusion_model."):
source_key_alt = lora_base_key[16:]
else:
source_key_alt = lora_base_key
if source_key_alt in source_weights:
return source_key_alt
return None
def _find_source_key_from_diff(diff_key: str, source_weights: Dict[str, torch.Tensor]) -> Optional[str]:
"""
Find corresponding source weight key for diff key
Args:
diff_key: Diff key (e.g., "diffusion_model.input_blocks.0.diff")
source_weights: Source model weights
Returns:
Corresponding source key or None
"""
# Remove diffusion_model prefix and diff suffix
if diff_key.startswith("diffusion_model."):
base_key = diff_key[16:] # Remove "diffusion_model."
else:
base_key = diff_key
# Remove diff suffixes
if base_key.endswith(".diff"):
source_key = base_key[:-5] + ".weight" # Remove ".diff" with ".weight"
elif base_key.endswith(".diff_b"):
source_key = base_key[:-7] + ".bias" # Replace ".diff_b" with ".bias"
elif base_key.endswith(".diff_m"):
source_key = base_key[:-7] + ".modulation" # Replace ".diff_m" with ".modulation"
else:
source_key = base_key
if source_key in source_weights:
return source_key
return None
if __name__ == "__main__":
args = parse_args()
print("=" * 50)
print("LoRA Merger Started")
print("=" * 50)
print(f"Source model: {args.source_model} ({args.source_type})")
print(f"LoRA weights: {args.lora_model} ({args.lora_type})")
print(f"Output path: {args.output} ({args.output_format})")
print(f"Output data type: {args.output_dtype}")
print(f"Merge parameters: alpha={args.alpha}")
print("=" * 50)
try:
# Load source model and LoRA weights
source_weights = load_model_weights(args.source_model, args.source_type)
lora_weights = load_model_weights(args.lora_model, args.lora_type)
# Merge LoRA weights with source model
merged_weights = merge_lora_weights(source_weights, lora_weights, alpha=args.alpha)
# Save merged model
save_model_weights(merged_weights, args.output, args.output_format, args.output_dtype)
print("=" * 50)
print("LoRA merge completed!")
print("=" * 50)
except Exception as e:
print(f"Error: {e}")
raise