Fine-tuning de modèles multimodaux avec des données textuelles uniquement

0 Shares

Contexte et objectifs

Dans le domaine de l’intelligence artificielle, un modèle multimodal est capable de traiter plusieurs typologies de données (par exemple du texte, des images, ou encore de l’audio). Aujourd’hui, de nombreuses applications requièrent la compréhension et le traitement de divers médias. Cependant, dans notre cas d’usage spécifique, nous devons d’abord nous concentrer sur l’extraction automatique et la mise en forme de contenus exclusivement textuels. À terme, nous prévoyons d’intégrer également le traitement d’images associées à ces contenus.

Notre objectif est donc de ne ré-entraîner que la partie langage du modèle multimodal à l’aide de données exclusivement textuelles, afin de préserver les capacités déjà acquises en matière d’analyse d’images, tout en améliorant la précision et les performances pour les tâches d’extraction et de traitement du texte. Cette approche nous permettra de tirer pleinement parti des compétences existantes de grands modèles, tout en spécialisant sa composante textuelle selon nos besoins.

Défis de recherche

Dans la pratique, la principale difficulté rencontrée provient du fait que la plupart des tutoriels et exemples disponibles se concentrent exclusivement sur le fine-tuning de modèles multimodaux avec des données d’entraînement combinant texte et images. Il existe très peu de ressources décrivant l’adaptation d’un modèle multimodal exclusivement à partir de données textuelles. Pour relever ce défi, nous avons entrepris des travaux de recherches et mené des expériences approfondies sur les trois grands modèles multimodaux suivants :

  1. Llava 1.6 Vicuna 13B [6]
  2. Pixtral 12B [8]
  3. Llama-3.2-11B-Vision [5]

Lors de ces travaux, nous avons utilisé différentes plates-formes de fine-tuning (telles que Unsloth [4] ou PEFT [3] + bitsandbytes [7]) afin d’identifier les meilleures approches.

Stratégies de fine-tuning et environnement d’expérimentation

Dans cette section, nous présentons les différentes approches utilisées pour le fine-tuning de modèles multimodaux sur des données exclusivement textuelles. Chaque modèle a ses propres particularités en termes d’architecture et de compatibilité, ce qui nécessite des stratégies d’adaptation spécifiques.

  • Llama-3.2-11B-Vision
    • Modèle de base (Checkpoint) : unsloth/Llama-3.2-11B-Vision-Instruct
    • Méthode d’entraînement : utilisation du framework Unsloth
  • Pixtral-12B
    • Modèle de base (Checkpoint) : unsloth/Pixtral-12B-2409
    • Méthode d’entraînement : utilisation du framework Unsloth
  • Llava-1.6-Vicuna-13B
    • Modèle de base (Checkpoint) : llava-hf/llava-v1.6-vicuna-13b-hf
    • Méthode d’entraînement : utilisation de PEFT et de bitsandbytes (bnb)

Dans cet article, nous présenterons et détaillerons la procédure de fine-tuning de Llava-1.6-Vicuna-13B à l’aide de PEFT et bitsandbytes.

Fine-tuning de Llava-1.6-Vicuna-13B avec PEFT et bitsandbytes : focus sur les couches linéaires

1. Préparation de l’environnement et chargement du modèle

Avant de commencer le fine-tuning, nous devons installer les dépendances nécessaires pour la phase d’entraînement. Cela inclut TRL, PEFT, Accelerate et BitsAndBytes, qui permettent l’optimisation de la gestion de la mémoire et l’utilisation de techniques d’entraînement efficaces :

! pip install trl peft accelerate
! pip install -U bitsandbytes

Une fois l’environnement prêt, nous pouvons charger le modèle Llava-1.6-Vicuna-13B, accompagné de son processor :

from transformers import AutoProcessor, AutoModelForImageTextToText
import torch
from transformers import BitsAndBytesConfig

MODEL_ID = "llava-hf/llava-v1.6-vicuna-13b-hf"

# Chargement du processor
processor = AutoProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.padding_side = "right"

# Configuration pour l'utilisation de QLoRA (quantization 4-bit)
USE_QLORA = True
if USE_QLORA:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16
    )
    model = AutoModelForImageTextToText.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float16,
        quantization_config=bnb_config,
    )

Nous utilisons QLoRA (4-bit quantization) [1][2] afin d’accélérer le fine-tuning et d’optimiser l’utilisation des ressources serveurs, permettant d’entraîner efficacement le modèle sur du matériel à mémoire limitée.

Le jeu de données mlabonne/FineTome-100k est également chargé et les conversations sont transformées pour préparer l’entraînement.

2. Test d’inférence en mode texte

Avant d’entamer le fine-tuning, nous testons l’inférence du modèle en mode purement textuel (sans images).

# Préparation d'un prompt de conversation
instruction = dataset[0]["conversations"][0]['value']
messages = [{"role": "user", "content": [{"type": "text", "text": instruction}]}]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)

# Inférence
inputs = processor(
    images=None,
    text=input_text,
    add_special_tokens=False,
    return_tensors="pt",
).to("cuda")
generate_ids = model.generate(**inputs, max_new_tokens=2000)
print(processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))

3. Analyse du modèle et extraction des couches linéaires du module de langage

Pour n’entraîner que la partie textuelle, nous recherchons dans le modèle les modules Linear4bit qui appartiennent au sous-module language_model (à l’exclusion de lm_head) :

import bitsandbytes as bnb

cls = bnb.nn.Linear4bit
lora_module_names = set()
for name, module in model.named_modules():
    if 'language_model' in name and isinstance(module, cls) and 'lm_head' not in name:
        lora_module_names.add(name)
print(lora_module_names)

Ensuite, nous configurons LoRA via PEFT pour cibler uniquement ces couches :

from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

lora_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=lora_module_names,  # Ciblage exclusif des couches linéaires du module de langage
    init_lora_weights="gaussian",
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

4. Prétraitement des données et création du jeu de données

Nous convertissons les conversations en texte brut et utilisons un collateur de données personnalisé (TextDataCollator) pour créer un jeu de données pré-traité.

import torch

class TextDataCollator:
    def __init__(self, model, tokenizer, max_length=2048):
        """Initialisation du collateur de données."""
        self.model = model
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, examples):
        """Prépare un batch de données pour l'entraînement."""
        input_ids = [ex["input_ids"][:self.max_length] for ex in examples]
        attention_mask = [ex["attention_mask"][:self.max_length] for ex in examples]

        # Padding des séquences
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [torch.tensor(ids) for ids in input_ids],
            batch_first=True,
            padding_value=0
        )

        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [torch.tensor(mask) for mask in attention_mask],
            batch_first=True,
            padding_value=0
        )

        labels = input_ids.clone()  # Les labels sont identiques aux input_ids

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

5. Fine-tuning du modèle

Une fois le jeu de données prêt, nous lançons l’e fine-tuning avec SFTTrainer de TRL.

from trl import SFTTrainer, SFTConfig

trainer = SFTTrainer(
    model=model,
    tokenizer=processor.tokenizer,
    data_collator=TextDataCollator(model, processor),
    train_dataset=preprocessed_dataset,  # Jeu de données prétraité
    args=SFTConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        max_steps=30,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        optim="adamw_8bit",
        output_dir="outputs",
    ),
)
trainer.train()

6. Test d’inférence post-affinage

Après la phase d’entraînement, nous réalisons un nouveau test d’inférence pour vérifier l’impact du fine-tuning sur la performance du modèle en mode texte :

instruction = "Explain what boolean operators are, ..."
messages = [{"role": "user", "content": [{"type": "text", "text": instruction}]}]
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)

inputs = processor(
    images=None,
    text=input_text,
    add_special_tokens=False,
    return_tensors="pt",
).to("cuda")
generate_ids = model.generate(**inputs, max_new_tokens=2000)
print(processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))

Conclusion

Cet article propose une méthode complète permettant de fine-tunné un modèle multimodal uniquement avec des données textuelles, tout en conservant sa capacité de traitement d’images. Nous avons principalement présenté l’entraînement de Llava-1.6-Vicuna-13B à l’aide de PEFT et bitsandbytes, mais avons également entraîné Llama-3.2-11B-Vision et Pixtral-12B avec Unsloth, afin d’assurer une adaptation efficace aux données textuelles dans différents frameworks.

Nous avons validé nos expériences sur des ressources de petites tailles telles que des GPU T4. Pour des tests ou démonstrations à petite échelle, un T4 est généralement suffisant. Cependant, pour un entraînement complet et de grande envergure, la puissance et la mémoire d’un T4 peuvent être insuffisantes.

Disponibilité et licence

L’ensemble des programmes mentionnés dans cet article est disponible en open source sous la licence MIT sur le GitHub de LeviatanAI.

Références

[1] T. Dettmers, A. Pagnoni, A. Holtzman, and L. Zettlemoyer, “QLoRA: Efficient finetuning of quantized llms,” arXiv.org. Accessed: Feb. 06, 2025. [Online]. Available: https://arxiv.org/abs/2305.14314

[2] E. J. Hu et al., “LoRA: Low-Rank adaptation of large language models,” arXiv.org. Accessed: Feb. 06, 2025. [Online]. Available: https://arxiv.org/abs/2106.09685

[3] huggingface, “GitHub – Huggingface/peft: 🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning.,” GitHub. Accessed: Feb. 06, 2025. [Online]. Available: https://github.com/huggingface/peft

[4] unslothai, “GitHub – Unslothai/unsloth: Finetune Llama 3.3, DeepSeek-R1, Mistral, Phi-4 & Gemma 2 LLMs 2-5x faster with 70% less memory,” GitHub. Accessed: Feb. 06, 2025. [Online]. Available: https://github.com/unslothai/unsloth

[5] “Llama,” Meta Llama. Accessed: Feb. 06, 2025. [Online]. Available: https://www.llama.com/

[6] H. Liu, “LLaVA-NeXT: Improved reasoning, OCR, and world knowledge,” LLaVA. Accessed: Feb. 06, 2025. [Online]. Available: https://llava-vl.github.io/blog/2024-01-30-llava-next/

[7] bitsandbytes-foundation, “GitHub – Bitsandbytes-foundation/bitsandbytes: Accessible large language models via k-bit quantization for PyTorch.,” GitHub. Accessed: Feb. 06, 2025. [Online]. Available: https://github.com/bitsandbytes-foundation/bitsandbytes

[8] “mistralai/Pixtral-12B-2409 · Hugging Face.” Accessed: Feb. 06, 2025. [Online]. Available: https://huggingface.co/mistralai/Pixtral-12B-2409

[9] “🤗 Transformers.” Accessed: Feb. 06, 2025. [Online]. Available: https://huggingface.co/docs/transformers/index