主页
avatar

Kared

大模型高效参数微调方法:LoRA(低秩适配)

近年来,大型语言模型(LLM)如GPT、LLaMA等在海量数据上预训练后,展现出强大的语言理解和生成能力。然而,要将这些通用模型应用到特定任务或领域时,往往需要进行微调。传统的全量参数微调虽然有效,但面临着巨大挑战:计算资源消耗巨大、存储需求过高、训练时间漫长、硬件门槛高。

LoRA(Low-Rank Adaptation,低秩适配)是一种参数高效微调技术,其核心思想是通过低秩矩阵分解来模拟模型参数的变化,而非直接微调原始参数,从而大幅降低训练资源需求的同时保持模型性能。

LoRA: Low-Rank Adaptation of Large Language Models

一、数学基础原理

LoRA 原理示意图

对于一个预训练权重矩阵 W0Rd×k\mathbf{W}_0 \in \mathbb{R}^{d \times k},全参数微调需要学习更新矩阵 ΔWRd×k\Delta\mathbf{W} \in \mathbb{R}^{d \times k},使得微调后的权重为:

W=W0+ΔW\mathbf{W} = \mathbf{W}_0 + \Delta\mathbf{W}

低维固有维度假设:研究表明,尽管大型神经网络是过参数化的,拥有数十亿甚至万亿参数,但其在适应新任务时所需的实质性变化(参数更新矩阵 ΔW\Delta\mathbf{W})其实存在于一个低维子空间中。也就是说,存在一个低维的参数更新,其效果可以与更新全部参数相媲美。

LoRA 正是基于该假设,通过引入两个低秩矩阵 A\mathbf{A}B\mathbf{B},用它们的乘积 BA\mathbf{B}\mathbf{A}来近似表达适应新任务所需的参数更新 ΔW\Delta\mathbf{W}。因此,只需要训练这两个小矩阵,而非庞大的原始模型参数 W0\mathbf{W}_0,就能高效地让大模型适配下游任务。与传统的全参数微调相比,将可训练参数减少到原模型的 0.1% - 2.0%,使得在单张消费级 GPU 上微调大型模型成为可能。

W=W0+αBA\mathbf{W} = \mathbf{W}_0 + \alpha \mathbf{B}\mathbf{A}

其中,BRd×r\mathbf{B} \in \mathbb{R}^{d \times r}ARr×k\mathbf{A} \in \mathbb{R}^{r \times k}是可训练的低秩矩阵,秩 rmin(d,k)r \ll \min(d,k),通常取值在 4~128 之间,α\alpha是缩放因子,用于控制 LoRA 权重的幅度。

二、前向传播及初始化

前向传播过程中,保持预训练权重矩阵 W\mathbf{W}完全冻结(不被更新),同时通过引入一个由低秩矩阵 B\mathbf{B}A\mathbf{A}构成的适配路径(即 αBAx\alpha \mathbf{B}\mathbf{A}\mathbf{x}),来对模型的输出进行针对特定任务的精细化调整。在原始全连接层中,前向传播的计算公式为:

h=Wx+b\mathbf{h} = \mathbf{W}\mathbf{x} + \mathbf{b}

当引入 LoRA 适配器后,前向计算过程转变为:

h=Wx+αBAx+b\mathbf{h} = \mathbf{W}\mathbf{x} + \alpha \mathbf{B}\mathbf{A}\mathbf{x} + \mathbf{b}

为了确保训练过程的稳定性和收敛效率,LoRA 对引入的低秩矩阵采用了特定的初始化方案:

  • 矩阵 A\mathbf{A}采用从均值为 0 的高斯分布中随机采样的方式进行初始化。

  • 矩阵 B\mathbf{B}则初始化为全零矩阵。

该初始化策略可以保证在训练刚开始时,低秩适配项 BAx\mathbf{B}\mathbf{A}\mathbf{x}的结果为零。这意味着模型的初始输出完全由预训练权重 W\mathbf{W}决定,低秩适配器相当于处于”静默”状态。这样做的主要优势在于:

  1. 训练稳定性:避免了因随机初始化引入的噪声干扰预训练模型已经学到的强大表征能力,让微调从一个稳定、可靠的起点开始。

  2. 收敛高效性:确保了训练初期梯度的有效性,使模型能够平滑地从预训练知识过渡到任务特定知识,从而更有利于后续的高效收敛。

三、LoRA 代码实现

通过具体的代码实现可以帮助我们更好地深入理解 LoRA,以下是完整的 LoRA 层实现及其在 Transformer 模型中的应用:

1. 核心 LoRA 层实现

初始化使用 Kaiming(A)Zero(B) 保证训练稳定且不改变初始输出,alpha/rank 控制适配强度,且可选 Dropout 做正则化防止过拟合。

class LoRALayer(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        rank: int = 8,
        alpha: float = 16.0,
        dropout: float = 0.0
    ):
        super().__init__()
        
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank  # 缩放系数控制适配强度
        
        # A矩阵: Kaiming初始化保证梯度稳定传播
        self.lora_A = nn.Parameter(torch.randn(rank, input_dim) / math.sqrt(rank))
        # B矩阵: 零初始化确保训练初期ΔW=0,保持原始模型输出
        self.lora_B = nn.Parameter(torch.zeros(output_dim, rank))
        
        self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else None
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """前向传播计算 output = (alpha/rank) * B @ A @ x """
        if self.dropout is not None:
            x = self.dropout(x)
        
        # 低秩适配路径计算
        lora_output = F.linear(x, self.lora_B @ self.lora_A)
        return self.scaling * lora_output

2. 在 Transformer 模型中应用 LoRA

在 Transformer 注意力机制中的 Q\mathbf{Q}V\mathbf{V}投影层应用 LoRA 通常效果最佳;秩 rr通常取 4~128,资源受限选小 rr(如 4/8/16),复杂任务选大 rr(如 64/128)。

在 Transformer 注意力机制中,优先对 Q\mathbf{Q}(查询)和 V\mathbf{V}(值)投影层应用 LoRA,是一种在参数效率和任务效果之间取得最佳平衡的经验性策略。Q\mathbf{Q}矩阵决定了模型”关注哪里”,V\mathbf{V}矩阵提供了”提取什么信息”,二者对任务适配的影响最为直接和显著;而 K\mathbf{K}(键)矩阵更多与输入序列的固有属性相关,相对稳定,不微调 K\mathbf{K}可以节省参数量并降低过拟合风险。

def apply_lora_to_transformer_model(
    model: nn.Module,
    target_modules: Optional[List[str]] = None,
    rank: int = 8,
    alpha: float = 16.0,
    dropout: float = 0.0
) -> nn.Module:
    """将LoRA适配器应用到Transformer模型的指定模块
        
    示例:
        >>> model = TransformerModel(...)
        >>> model = apply_lora_to_transformer_model(
        ...     model,
        ...     target_modules=["q_proj", "v_proj"],  # 针对注意力机制
        ...     rank=8,
        ...     alpha=16
        ... )
    """
    if target_modules is None:
        # 默认针对Transformer的注意力关键投影层
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
    
    lora_count = 0
    
    for module_name, module in model.named_modules():
        if not isinstance(module, nn.Linear):
            continue
        
        # 筛选目标模块:只对指定名称的线性层应用LoRA
        if not any(target in module_name for target in target_modules):
            continue

        # 获取父模块并替换目标层
        *parent_path, attr_name = module_name.split('.')
        parent_module = model
        for part in parent_path:
            parent_module = getattr(parent_module, part)
        
        # 创建LoRA增强的线性层
        lora_linear = LinearWithLoRA(
            linear_layer=module,
            rank=rank,
            alpha=alpha,
            dropout=dropout
        )
        setattr(parent_module, attr_name, lora_linear)
        lora_count += 1
    
    print(f"成功将LoRA应用到{lora_count}个线性层")
    return model

LoRA 权重提取/合并脚本

LoRA 权重提取(lora_extractor)

lora_extractor 能够通过计算微调后的模型与原始基础模型之间的参数差异,并利用奇异值分解(SVD)技术,将高维的权重更新矩阵近似分解为两个低秩矩阵的乘积,从而生成一个轻量级的 LoRA 适配器权重文件。

差分模式直接保存完整权重差 ΔW=WtargetWsource\Delta\mathbf{W} = \mathbf{W}_{\text{target}} - \mathbf{W}_{\text{source}},LoRA 模式通过低秩分解 ΔWBA\Delta\mathbf{W} \approx \mathbf{B}\mathbf{A}压缩存储,其中 BRm×r\mathbf{B} \in \mathbb{R}^{m \times r}, ARr×n\mathbf{A} \in \mathbb{R}^{r \times n},秩 rmin(m,n)r \ll \min(m,n)决定压缩比。

参数名类型默认值可选值说明
--source-modelstr必填文件/目录路径源模型路径(基础模型)
--source-typestrsafetensorssafetensors, pytorch源模型格式类型
--target-modelstr必填文件/目录路径目标模型路径(微调后模型)
--target-typestrsafetensorssafetensors, pytorch目标模型格式类型
--outputstr必填文件路径LoRA 输出路径
--output-formatstrsafetensorssafetensors, pytorch输出 LoRA 模型格式
--rankint32正整数LoRA 秩值(rank)
--output-dtypestrbf16float32, fp32, float16, fp16, bfloat16, bf16输出权重数据类型
--diff-onlyboolFalse-仅保存直接差分,不进行 LoRA 分解
# 示例用法1: 标准 LoRA 提取
python lora_extractor.py \
  --source-model "path/to/base_model.safetensors" \
  --target-model "path/to/finetuned_model.safetensors" \
  --output "path/to/output/lora.safetensors" \
  --rank 32 \
  --output-dtype bf16

# 示例用法2: 仅差分模式
python lora_extractor.py \
  --source-model "path/to/base_model.safetensors" \
  --target-model "path/to/finetuned_model.safetensors" \
  --output "path/to/output/diff.safetensors" \
  --diff-only \
  --output-dtype fp16

# 示例用法3: PyTorch 格式转换
python lora_extractor.py \
  --source-model "path/to/base_model.pth" \
  --source-type pytorch \
  --target-model "path/to/finetuned_model.pth" \
  --target-type pytorch \
  --output "path/to/output/lora.safetensors" \
  --output-format safetensors
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
LoRA Extractor Script
Extract LoRA weights from the difference between two models
"""

import argparse
import os
from typing import Dict, Optional

import torch
from safetensors import safe_open
from safetensors import torch as st
from tqdm import tqdm


def _get_torch_dtype(dtype_str: str) -> torch.dtype:
    """
    Convert string to torch data type

    Args:
        dtype_str: Data type string

    Returns:
        Torch data type
    """
    dtype_mapping = {
        "float32": torch.float32,
        "fp32": torch.float32,
        "float16": torch.float16,
        "fp16": torch.float16,
        "bfloat16": torch.bfloat16,
        "bf16": torch.bfloat16,
    }

    if dtype_str not in dtype_mapping:
        raise ValueError(f"Unsupported data type: {dtype_str}")

    return dtype_mapping[dtype_str]


def parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="Extract LoRA weights from the difference between source and target models", formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Source model parameters
    parser.add_argument("--source-model", type=str, required=True, help="Path to source model")
    parser.add_argument("--source-type", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="Source model format type")

    # Target model parameters
    parser.add_argument("--target-model", type=str, required=True, help="Path to target model (fine-tuned model)")
    parser.add_argument("--target-type", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="Target model format type")

    # Output parameters
    parser.add_argument("--output", type=str, required=True, help="Path to output LoRA model")
    parser.add_argument("--output-format", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="Output LoRA model format")

    # LoRA related parameters
    parser.add_argument("--rank", type=int, default=32, help="LoRA rank value")

    parser.add_argument("--output-dtype", type=str, choices=["float32", "fp32", "float16", "fp16", "bfloat16", "bf16"], default="bf16", help="Output weight data type")
    parser.add_argument("--diff-only", action="store_true", help="Save all weights as direct diff without LoRA decomposition")

    return parser.parse_args()


def load_model_weights(model_path: str, model_type: str) -> Dict[str, torch.Tensor]:
    """
    Load model weights (using fp32 precision)

    Args:
        model_path: Model file path or directory path
        model_type: Model type ("safetensors" or "pytorch")

    Returns:
        Model weights dictionary (fp32 precision)
    """
    print(f"Loading model: {model_path} (type: {model_type}, precision: fp32)")

    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model path does not exist: {model_path}")

    weights = {}

    if model_type == "safetensors":
        if os.path.isdir(model_path):
            # If it's a directory, load all .safetensors files in the directory
            safetensors_files = []
            for file in os.listdir(model_path):
                if file.endswith(".safetensors"):
                    safetensors_files.append(os.path.join(model_path, file))

            if not safetensors_files:
                raise ValueError(f"No .safetensors files found in directory: {model_path}")

            print(f"Found {len(safetensors_files)} safetensors files")

            # Load all files and merge weights
            for file_path in sorted(safetensors_files):
                print(f"  Loading file: {os.path.basename(file_path)}")
                with safe_open(file_path, framework="pt", device="cpu") as f:
                    for key in f.keys():
                        if key in weights:
                            print(f"Warning: weight key '{key}' is duplicated in multiple files, will be overwritten")
                        weights[key] = f.get_tensor(key)

        elif os.path.isfile(model_path):
            # If it's a single file
            if model_path.endswith(".safetensors"):
                with safe_open(model_path, framework="pt", device="cpu") as f:
                    for key in f.keys():
                        weights[key] = f.get_tensor(key)
            else:
                raise ValueError(f"safetensors type file should end with .safetensors: {model_path}")
        else:
            raise ValueError(f"Invalid path type: {model_path}")

    elif model_type == "pytorch":
        # Load pytorch format (.pt, .pth)
        if model_path.endswith((".pt", ".pth")):
            checkpoint = torch.load(model_path, map_location="cpu")

            # Handle possible nested structure
            if isinstance(checkpoint, dict):
                if "state_dict" in checkpoint:
                    weights = checkpoint["state_dict"]
                elif "model" in checkpoint:
                    weights = checkpoint["model"]
                else:
                    weights = checkpoint
            else:
                weights = checkpoint
        else:
            raise ValueError(f"pytorch type file should end with .pt or .pth: {model_path}")
    else:
        raise ValueError(f"Unsupported model type: {model_type}")

    # Convert all floating point weights to fp32 to ensure computational precision
    print("Converting weights to fp32 to ensure computational precision...")

    converted_weights = {}
    for key, tensor in weights.items():
        # Only convert floating point tensors, keep integer tensors unchanged
        if tensor.dtype.is_floating_point:
            converted_weights[key] = tensor.to(torch.float32)
        else:
            converted_weights[key] = tensor

    print(f"Successfully loaded model with {len(converted_weights)} weight tensors")
    return converted_weights


def save_lora_weights(lora_weights: Dict[str, torch.Tensor], output_path: str, output_format: str, output_dtype: str = "bf16"):
    """
    Save LoRA weights

    Args:
        lora_weights: LoRA weights dictionary
        output_path: Output path
        output_format: Output format
        output_dtype: Output data type
    """
    print(f"Saving LoRA weights to: {output_path} (format: {output_format}, data type: {output_dtype})")

    # Ensure output directory exists
    output_dir = os.path.dirname(output_path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    # Convert data type
    target_dtype = _get_torch_dtype(output_dtype)
    print(f"Converting LoRA weights to {output_dtype} type...")

    converted_weights = {}
    with tqdm(lora_weights.items(), desc="Converting data type", unit="weights") as pbar:
        for key, tensor in pbar:
            # Only convert floating point tensors, keep integer tensors unchanged
            if tensor.dtype.is_floating_point:
                converted_weights[key] = tensor.to(target_dtype).contiguous()
            else:
                converted_weights[key] = tensor.contiguous()

    if output_format == "safetensors":
        # Save as safetensors format
        if not output_path.endswith(".safetensors"):
            output_path += ".safetensors"
        st.save_file(converted_weights, output_path)

    elif output_format == "pytorch":
        # Save as pytorch format
        if not output_path.endswith((".pt", ".pth")):
            output_path += ".pt"
        torch.save(converted_weights, output_path)
    else:
        raise ValueError(f"Unsupported output format: {output_format}")

    print(f"LoRA weights saved to: {output_path}")


def _compute_weight_diff(source_tensor: torch.Tensor, target_tensor: torch.Tensor, key: str) -> Optional[torch.Tensor]:
    """
    Compute the difference between two weight tensors

    Args:
        source_tensor: Source weight tensor
        target_tensor: Target weight tensor
        key: Weight key name (for logging)

    Returns:
        Difference tensor, returns None if no change
    """
    # Check if tensor shapes match
    if source_tensor.shape != target_tensor.shape:
        return None

    # Check if tensor data types match
    if source_tensor.dtype != target_tensor.dtype:
        target_tensor = target_tensor.to(source_tensor.dtype)

    # Compute difference
    diff = target_tensor - source_tensor

    # Check if there are actual changes
    if torch.allclose(diff, torch.zeros_like(diff), atol=1e-8):
        # No change
        return None

    return diff


def _decompose_to_lora(diff: torch.Tensor, key: str, rank: int) -> Dict[str, torch.Tensor]:
    """
    Decompose weight difference into LoRA format

    Args:
        diff: Weight difference tensor
        key: Original weight key name
        rank: LoRA rank

    Returns:
        LoRA weights dictionary (containing lora_up and lora_down)
    """
    # Ensure it's a 2D tensor
    if len(diff.shape) != 2:
        raise ValueError(f"LoRA decomposition only supports 2D weights, but got {len(diff.shape)}D tensor: {key}")

    a, b = diff.shape

    # Check if rank is reasonable
    max_rank = min(a, b)
    if rank > max_rank:
        rank = max_rank

    # Choose compute device (prefer GPU, fallback to CPU)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    diff_device = diff.to(device)

    # SVD decomposition
    U, S, V = torch.linalg.svd(diff_device, full_matrices=False)

    # Take the first rank components
    U = U[:, :rank]  # (a, rank)
    S = S[:rank]  # (rank,)
    V = V[:rank, :]  # (rank, b)

    # Distribute square root of singular values to both matrices
    S_sqrt = S.sqrt()
    lora_up = U * S_sqrt.unsqueeze(0)  # (a, rank) * (1, rank) = (a, rank)
    lora_down = S_sqrt.unsqueeze(1) * V  # (rank, 1) * (rank, b) = (rank, b)

    # Move back to CPU and convert to original data type, ensure contiguous
    lora_up = lora_up.cpu().to(diff.dtype).contiguous()
    lora_down = lora_down.cpu().to(diff.dtype).contiguous()

    # Generate LoRA weight key names
    base_key = key.replace(".weight", "")
    lora_up_key = "diffusion_model." + f"{base_key}.lora_up.weight"
    lora_down_key = "diffusion_model." + f"{base_key}.lora_down.weight"

    # Return the decomposed weights
    lora_weights = {lora_up_key: lora_up, lora_down_key: lora_down}

    return lora_weights


def extract_lora_from_diff(source_weights: Dict[str, torch.Tensor], target_weights: Dict[str, torch.Tensor], rank: int = 16, diff_only: bool = False) -> Dict[str, torch.Tensor]:
    """
    Extract LoRA weights from model difference

    Args:
        source_weights: Source model weights
        target_weights: Target model weights
        rank: LoRA rank
        diff_only: If True, save all weights as direct diff without LoRA decomposition

    Returns:
        LoRA weights dictionary
    """
    print("Starting LoRA weight extraction...")
    if diff_only:
        print("Mode: Direct diff only (no LoRA decomposition)")
    else:
        print(f"Mode: Smart extraction - rank: {rank}")
    print(f"Source model weight count: {len(source_weights)}")
    print(f"Target model weight count: {len(target_weights)}")

    lora_weights = {}
    processed_count = 0
    diff_count = 0
    lora_count = 0
    similar_count = 0
    skipped_count = 0
    fail_count = 0

    # Find common keys between two models
    common_keys = set(source_weights.keys()) & set(target_weights.keys())
    source_only_keys = set(source_weights.keys()) - set(target_weights.keys())
    target_only_keys = set(target_weights.keys()) - set(source_weights.keys())

    if source_only_keys:
        print(f"Warning: Source model exclusive weight keys ({len(source_only_keys)} keys): {list(source_only_keys)[:5]}...")
    if target_only_keys:
        print(f"Warning: Target model exclusive weight keys ({len(target_only_keys)} keys): {list(target_only_keys)[:5]}...")

    print(f"Common weight keys count: {len(common_keys)}")

    # Process common keys, extract LoRA weights
    common_keys_sorted = sorted(common_keys)
    pbar = tqdm(common_keys_sorted, desc="Extracting LoRA weights", unit="layer")

    for key in pbar:
        source_tensor = source_weights[key]
        target_tensor = target_weights[key]

        # Update progress bar description
        short_key = key.split(".")[-2:] if "." in key else [key]
        pbar.set_postfix_str(f"Processing: {'.'.join(short_key)}")

        # Compute weight difference
        diff = _compute_weight_diff(source_tensor, target_tensor, key)

        if diff is None:
            # No change or shape mismatch
            if source_tensor.shape == target_tensor.shape:
                similar_count += 1
            else:
                skipped_count += 1
            continue

        # Calculate parameter count
        param_count = source_tensor.numel()
        is_1d = len(source_tensor.shape) == 1

        # Decide whether to save diff directly or perform LoRA decomposition
        if diff_only or is_1d or param_count < 1000000:
            # Save diff directly
            lora_key = _generate_lora_diff_key(key)
            if lora_key == "skip":
                skipped_count += 1
                continue
            lora_weights[lora_key] = diff
            diff_count += 1

        else:
            # Perform LoRA decomposition
            if len(diff.shape) == 2 and key.endswith(".weight"):
                try:
                    decomposed_weights = _decompose_to_lora(diff, key, rank)
                    lora_weights.update(decomposed_weights)
                    lora_count += 1
                except Exception as e:
                    print(f"Error: {e}")
                    fail_count += 1

            else:
                print(f"Error: {key} is not a 2D weight tensor")
                fail_count += 1

        processed_count += 1

    # Close progress bar
    pbar.close()

    print(f"\nExtraction statistics:")
    print(f"  Processed weights: {processed_count}")
    print(f"  Direct diff: {diff_count}")
    print(f"  LoRA decomposition: {lora_count}")
    print(f"  Skipped weights: {skipped_count}")
    print(f"  Similar weights: {similar_count}")
    print(f"  Failed weights: {fail_count}")
    print(f"  Total extracted LoRA weights: {len(lora_weights)}")
    print("LoRA weight extraction completed")

    return lora_weights


def _generate_lora_diff_key(original_key: str) -> str:
    """
    Generate LoRA weight key based on original weight key

    Args:
        original_key: Original weight key name

    Returns:
        LoRA weight key name
    """
    ret_key = "diffusion_model." + original_key
    if original_key.endswith(".weight"):
        return ret_key.replace(".weight", ".diff")
    elif original_key.endswith(".bias"):
        return ret_key.replace(".bias", ".diff_b")
    elif original_key.endswith(".modulation"):
        return ret_key.replace(".modulation", ".diff_m")
    else:
        # If no matching suffix, skip
        return "skip"


if __name__ == "__main__":
    args = parse_args()

    print("=" * 50)
    print("LoRA Extractor Started")
    print("=" * 50)
    print(f"Source model: {args.source_model} ({args.source_type})")
    print(f"Target model: {args.target_model} ({args.target_type})")
    print(f"Output path: {args.output} ({args.output_format})")
    print(f"Output data type: {args.output_dtype}")
    print(f"LoRA parameters: rank={args.rank}")
    print(f"Diff only mode: {args.diff_only}")
    print("=" * 50)

    try:
        # Load source and target models
        source_weights = load_model_weights(args.source_model, args.source_type)
        target_weights = load_model_weights(args.target_model, args.target_type)

        # Extract LoRA weights
        lora_weights = extract_lora_from_diff(source_weights, target_weights, rank=args.rank, diff_only=args.diff_only)

        # Save LoRA weights
        save_lora_weights(lora_weights, args.output, args.output_format, args.output_dtype)

        print("=" * 50)
        print("LoRA extraction completed!")
        print("=" * 50)

    except Exception as e:
        print(f"Error: {e}")
        raise

LoRA 权重合并(lora_merger)

lora_merger 能够将轻量级的 LoRA 适配器权重与原始基础模型进行融合,通过计算 ΔW=α(B×A)\Delta\mathbf{W} = \alpha \cdot (\mathbf{B} \times \mathbf{A})(LoRA 模式)或 ΔW=αΔWdiff\Delta\mathbf{W} = \alpha \cdot \Delta\mathbf{W}_{\text{diff}}(差分模式),并将其叠加到基础模型权重 Wmerged=Wsource+ΔW\mathbf{W}_{\text{merged}} = \mathbf{W}_{\text{source}} + \Delta\mathbf{W}上,从而生成一个包含微调效果的完整模型文件。

参数名类型默认值可选值说明
--source-modelstr必填文件/目录路径源模型路径(基础模型)
--source-typestrsafetensorssafetensors, pytorch源模型格式类型
--lora-modelstr必填文件/目录路径LoRA 权重文件路径
--lora-typestrsafetensorssafetensors, pytorchLoRA 权重格式类型
--outputstr必填文件路径合并后模型输出路径
--output-formatstrsafetensorssafetensors, pytorch输出模型格式
--alphafloat1.00.0 ~ 2.0LoRA 融合强度系数
--output-dtypestrbf16float32, fp32, float16, fp16, bfloat16, bf16输出权重数据类型
# 示例 1: 标准 LoRA 合并(完全融合)
python lora_merger.py \
  --source-model "path/to/base_model.safetensors" \
  --lora-model "path/to/lora_weights.safetensors" \
  --output "path/to/output/merged_model.safetensors" \
  --alpha 1.0 \
  --output-dtype bf16

# 示例 2: 部分融合(50% 强度)
python lora_merger.py \
  --source-model "path/to/base_model.safetensors" \
  --lora-model "path/to/lora_weights.safetensors" \
  --output "path/to/output/merged_model_half.safetensors" \
  --alpha 0.5 \
  --output-dtype bf16

# 示例 3: PyTorch 格式转换为 SafeTensors
python lora_merger.py \
  --source-model "path/to/base_model.pth" \
  --source-type pytorch \
  --lora-model "path/to/lora_weights.pth" \
  --lora-type pytorch \
  --output "path/to/output/merged_model.safetensors" \
  --output-format safetensors \
  --alpha 1.0

# 示例 4: 差分模式合并
python lora_merger.py \
  --source-model "path/to/base_model.safetensors" \
  --lora-model "path/to/diff_weights.safetensors" \
  --output "path/to/output/merged_model.safetensors" \
  --alpha 1.0 \
  --output-dtype fp32
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
LoRA Merger Script
Merge a source model with LoRA weights to create a new model
"""

import argparse
import os
from typing import Dict, Optional

import torch
from safetensors import safe_open
from safetensors import torch as st
from tqdm import tqdm


def _get_torch_dtype(dtype_str: str) -> torch.dtype:
    """
    Convert string to torch data type

    Args:
        dtype_str: Data type string

    Returns:
        Torch data type
    """
    dtype_mapping = {
        "float32": torch.float32,
        "fp32": torch.float32,
        "float16": torch.float16,
        "fp16": torch.float16,
        "bfloat16": torch.bfloat16,
        "bf16": torch.bfloat16,
    }

    if dtype_str not in dtype_mapping:
        raise ValueError(f"Unsupported data type: {dtype_str}")

    return dtype_mapping[dtype_str]


def parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="Merge a source model with LoRA weights to create a new model", formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Source model parameters
    parser.add_argument("--source-model", type=str, required=True, help="Path to source model")
    parser.add_argument("--source-type", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="Source model format type")

    # LoRA parameters
    parser.add_argument("--lora-model", type=str, required=True, help="Path to LoRA weights")
    parser.add_argument("--lora-type", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="LoRA weights format type")

    # Output parameters
    parser.add_argument("--output", type=str, required=True, help="Path to output merged model")
    parser.add_argument("--output-format", type=str, choices=["safetensors", "pytorch"], default="safetensors", help="Output model format")

    # Merge parameters
    parser.add_argument("--alpha", type=float, default=1.0, help="LoRA merge strength (alpha value)")
    parser.add_argument("--output-dtype", type=str, choices=["float32", "fp32", "float16", "fp16", "bfloat16", "bf16"], default="bf16", help="Output weight data type")

    return parser.parse_args()


def load_model_weights(model_path: str, model_type: str) -> Dict[str, torch.Tensor]:
    """
    Load model weights (using fp32 precision)

    Args:
        model_path: Model file path or directory path
        model_type: Model type ("safetensors" or "pytorch")

    Returns:
        Model weights dictionary (fp32 precision)
    """
    print(f"Loading model: {model_path} (type: {model_type}, precision: fp32)")

    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model path does not exist: {model_path}")

    weights = {}

    if model_type == "safetensors":
        if os.path.isdir(model_path):
            # If it's a directory, load all .safetensors files in the directory
            safetensors_files = []
            for file in os.listdir(model_path):
                if file.endswith(".safetensors"):
                    safetensors_files.append(os.path.join(model_path, file))

            if not safetensors_files:
                raise ValueError(f"No .safetensors files found in directory: {model_path}")

            print(f"Found {len(safetensors_files)} safetensors files")

            # Load all files and merge weights
            for file_path in sorted(safetensors_files):
                print(f"  Loading file: {os.path.basename(file_path)}")
                with safe_open(file_path, framework="pt", device="cpu") as f:
                    for key in f.keys():
                        if key in weights:
                            print(f"Warning: weight key '{key}' is duplicated in multiple files, will be overwritten")
                        weights[key] = f.get_tensor(key)

        elif os.path.isfile(model_path):
            # If it's a single file
            if model_path.endswith(".safetensors"):
                with safe_open(model_path, framework="pt", device="cpu") as f:
                    for key in f.keys():
                        weights[key] = f.get_tensor(key)
            else:
                raise ValueError(f"safetensors type file should end with .safetensors: {model_path}")
        else:
            raise ValueError(f"Invalid path type: {model_path}")

    elif model_type == "pytorch":
        # Load pytorch format (.pt, .pth)
        if model_path.endswith((".pt", ".pth")):
            checkpoint = torch.load(model_path, map_location="cpu")

            # Handle possible nested structure
            if isinstance(checkpoint, dict):
                if "state_dict" in checkpoint:
                    weights = checkpoint["state_dict"]
                elif "model" in checkpoint:
                    weights = checkpoint["model"]
                else:
                    weights = checkpoint
            else:
                weights = checkpoint
        else:
            raise ValueError(f"pytorch type file should end with .pt or .pth: {model_path}")
    else:
        raise ValueError(f"Unsupported model type: {model_type}")

    # Convert all floating point weights to fp32 to ensure computational precision
    print("Converting weights to fp32 to ensure computational precision...")

    converted_weights = {}
    for key, tensor in weights.items():
        # Only convert floating point tensors, keep integer tensors unchanged
        if tensor.dtype.is_floating_point:
            converted_weights[key] = tensor.to(torch.float32)
        else:
            converted_weights[key] = tensor

    print(f"Successfully loaded model with {len(converted_weights)} weight tensors")
    return converted_weights


def save_model_weights(model_weights: Dict[str, torch.Tensor], output_path: str, output_format: str, output_dtype: str = "bf16"):
    """
    Save model weights

    Args:
        model_weights: Model weights dictionary
        output_path: Output path
        output_format: Output format
        output_dtype: Output data type
    """
    print(f"Saving merged model to: {output_path} (format: {output_format}, data type: {output_dtype})")

    # Ensure output directory exists
    output_dir = os.path.dirname(output_path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    # Convert data type
    target_dtype = _get_torch_dtype(output_dtype)
    print(f"Converting model weights to {output_dtype} type...")

    converted_weights = {}
    with tqdm(model_weights.items(), desc="Converting data type", unit="weights") as pbar:
        for key, tensor in pbar:
            # Only convert floating point tensors, keep integer tensors unchanged
            if tensor.dtype.is_floating_point:
                converted_weights[key] = tensor.to(target_dtype).contiguous()
            else:
                converted_weights[key] = tensor.contiguous()

    if output_format == "safetensors":
        # Save as safetensors format
        if not output_path.endswith(".safetensors"):
            output_path += ".safetensors"
        st.save_file(converted_weights, output_path)

    elif output_format == "pytorch":
        # Save as pytorch format
        if not output_path.endswith((".pt", ".pth")):
            output_path += ".pt"
        torch.save(converted_weights, output_path)
    else:
        raise ValueError(f"Unsupported output format: {output_format}")

    print(f"Merged model saved to: {output_path}")


def merge_lora_weights(source_weights: Dict[str, torch.Tensor], lora_weights: Dict[str, torch.Tensor], alpha: float = 1.0) -> Dict[str, torch.Tensor]:
    """
    Merge source model with LoRA weights

    Args:
        source_weights: Source model weights
        lora_weights: LoRA weights
        alpha: LoRA merge strength

    Returns:
        Merged model weights
    """
    print("Starting LoRA merge...")
    print(f"Merge parameters - alpha: {alpha}")
    print(f"Source model weight count: {len(source_weights)}")
    print(f"LoRA weight count: {len(lora_weights)}")

    merged_weights = source_weights.copy()
    processed_count = 0
    lora_merged_count = 0
    diff_merged_count = 0
    skipped_source_count = 0
    skipped_lora_count = 0
    skipped_source_keys = []
    skipped_lora_keys = []

    # Group LoRA weights by base key
    lora_pairs = {}
    diff_weights = {}

    for lora_key, lora_tensor in lora_weights.items():
        if lora_key.endswith(".lora_up.weight"):
            base_key = lora_key.replace(".lora_up.weight", "")
            if base_key not in lora_pairs:
                lora_pairs[base_key] = {}
            lora_pairs[base_key]["up"] = lora_tensor
        elif lora_key.endswith(".lora_down.weight"):
            base_key = lora_key.replace(".lora_down.weight", "")
            if base_key not in lora_pairs:
                lora_pairs[base_key] = {}
            lora_pairs[base_key]["down"] = lora_tensor
        elif lora_key.endswith((".diff", ".diff_b", ".diff_m")):
            diff_weights[lora_key] = lora_tensor

    print(f"Found {len(lora_pairs)} LoRA pairs and {len(diff_weights)} diff weights")

    # Process with progress bar
    all_items = list(lora_pairs.items()) + list(diff_weights.items())
    pbar = tqdm(all_items, desc="Merging LoRA weights", unit="weight")

    for item in pbar:
        if isinstance(item[1], dict):  # LoRA pair
            base_key, lora_pair = item
            if "up" in lora_pair and "down" in lora_pair:
                # Find corresponding source weight
                source_key = _find_source_key(base_key, source_weights)
                if source_key:
                    if source_weights[source_key].shape != (lora_pair["up"].shape[0], lora_pair["down"].shape[1]):
                        skipped_source_count += 1
                        skipped_source_keys.append(source_key)
                        continue
                    lora_up = lora_pair["up"]
                    lora_down = lora_pair["down"]

                    # Compute LoRA delta: alpha * (lora_up @ lora_down)
                    lora_delta = alpha * (lora_up @ lora_down)

                    # Apply to source weight
                    merged_weights[source_key] = source_weights[source_key] + lora_delta
                    lora_merged_count += 1
                    pbar.set_postfix_str(f"LoRA: {source_key.split('.')[-1]}")
                else:
                    skipped_source_count += 1
                    skipped_source_keys.append(base_key)
            else:
                print(f"Warning: Incomplete LoRA pair for: {base_key}")
                skipped_lora_count += 1
                skipped_lora_keys.append(base_key)
        else:  # Diff weight
            diff_key, diff_tensor = item
            # Find corresponding source weight
            source_key = _find_source_key_from_diff(diff_key, source_weights)
            if source_key:
                if source_weights[source_key].shape != diff_tensor.shape:
                    skipped_source_count += 1
                    skipped_source_keys.append(source_key)
                    continue
                # Apply diff: source + alpha * diff
                merged_weights[source_key] = source_weights[source_key] + alpha * diff_tensor
                diff_merged_count += 1
                pbar.set_postfix_str(f"Diff: {source_key.split('.')[-1]}")
            else:
                skipped_lora_count += 1
                skipped_lora_keys.append(diff_key)

        processed_count += 1

    pbar.close()

    print(f"\nMerge statistics:")
    print(f"  Processed weights: {processed_count}")
    print(f"  LoRA merged: {lora_merged_count}")
    print(f"  Diff merged: {diff_merged_count}")
    print(f"  Skipped source weights: {skipped_source_count}")
    if skipped_source_count > 0:
        print(f"  Skipped source keys:")
        for key in skipped_source_keys:
            print(f"    {key}")
    print(f"  Skipped LoRA weights: {skipped_lora_count}")
    if skipped_lora_count > 0:
        print(f"  Skipped LoRA keys:")
        for key in skipped_lora_keys:
            print(f"    {key}")
    print(f"  Total merged model weights: {len(merged_weights)}")
    print("LoRA merge completed")

    return merged_weights


def _find_source_key(lora_base_key: str, source_weights: Dict[str, torch.Tensor]) -> Optional[str]:
    """
    Find corresponding source weight key for LoRA base key

    Args:
        lora_base_key: LoRA base key (e.g., "diffusion_model.input_blocks.0.0.weight")
        source_weights: Source model weights

    Returns:
        Corresponding source key or None
    """
    # Remove diffusion_model prefix if present
    if lora_base_key.startswith("diffusion_model."):
        source_key = lora_base_key[16:] + ".weight"  # Remove "diffusion_model." and add ".weight"
    else:
        source_key = lora_base_key + ".weight"

    if source_key in source_weights:
        return source_key

    # Try without adding .weight (in case it's already included)
    if lora_base_key.startswith("diffusion_model."):
        source_key_alt = lora_base_key[16:]
    else:
        source_key_alt = lora_base_key

    if source_key_alt in source_weights:
        return source_key_alt

    return None


def _find_source_key_from_diff(diff_key: str, source_weights: Dict[str, torch.Tensor]) -> Optional[str]:
    """
    Find corresponding source weight key for diff key

    Args:
        diff_key: Diff key (e.g., "diffusion_model.input_blocks.0.diff")
        source_weights: Source model weights

    Returns:
        Corresponding source key or None
    """
    # Remove diffusion_model prefix and diff suffix
    if diff_key.startswith("diffusion_model."):
        base_key = diff_key[16:]  # Remove "diffusion_model."
    else:
        base_key = diff_key

    # Remove diff suffixes
    if base_key.endswith(".diff"):
        source_key = base_key[:-5] + ".weight"  # Remove ".diff" with ".weight"
    elif base_key.endswith(".diff_b"):
        source_key = base_key[:-7] + ".bias"  # Replace ".diff_b" with ".bias"
    elif base_key.endswith(".diff_m"):
        source_key = base_key[:-7] + ".modulation"  # Replace ".diff_m" with ".modulation"
    else:
        source_key = base_key

    if source_key in source_weights:
        return source_key

    return None


if __name__ == "__main__":
    args = parse_args()

    print("=" * 50)
    print("LoRA Merger Started")
    print("=" * 50)
    print(f"Source model: {args.source_model} ({args.source_type})")
    print(f"LoRA weights: {args.lora_model} ({args.lora_type})")
    print(f"Output path: {args.output} ({args.output_format})")
    print(f"Output data type: {args.output_dtype}")
    print(f"Merge parameters: alpha={args.alpha}")
    print("=" * 50)

    try:
        # Load source model and LoRA weights
        source_weights = load_model_weights(args.source_model, args.source_type)
        lora_weights = load_model_weights(args.lora_model, args.lora_type)

        # Merge LoRA weights with source model
        merged_weights = merge_lora_weights(source_weights, lora_weights, alpha=args.alpha)

        # Save merged model
        save_model_weights(merged_weights, args.output, args.output_format, args.output_dtype)

        print("=" * 50)
        print("LoRA merge completed!")
        print("=" * 50)

    except Exception as e:
        print(f"Error: {e}")
        raise
Deep Learning PEFT Model Fine-Tuning LoRA