Pular para o conteúdo principal

Documentation Index

Fetch the complete documentation index at: https://docs.monostate.ai/llms.txt

Use this file to discover all available pages before exploring further.

Referência do SDK Python

Guia abrangente da API Python do AITraining.

Instalação

pip install aitraining torch torchvision torchaudio
Nome do Pacote vs Import: Instale com pip install aitraining, mas importe com from autotrain import ...

Treinamento LLM

LLMTrainingParams

from autotrain.trainers.clm.params import LLMTrainingParams

params = LLMTrainingParams(
    # Required
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    # Training method
    trainer="sft",  # sft, dpo, orpo, ppo, reward, distillation

    # Training settings
    epochs=3,
    batch_size=2,
    lr=2e-5,
    gradient_accumulation=4,
    mixed_precision="bf16",

    # LoRA
    peft=True,
    lora_r=16,
    lora_alpha=32,
    lora_dropout=0.05,

    # Data processing
    text_column="text",
    block_size=2048,
    add_eos_token=True,
    save_processed_data="auto",  # auto, local, hub, both, none

    # Logging
    log="wandb",
    logging_steps=-1,  # Default: -1 (auto)

    # Hyperparameter sweep (optional)
    # use_sweep=True,
    # sweep_backend="optuna",
    # sweep_n_trials=20,
    # sweep_params='{"lr": {"type": "loguniform", "low": 1e-5, "high": 1e-3}}',
)

Parâmetros Principais

ParâmetroTipoDescrição
modelstrNome ou caminho do modelo
data_pathstrCaminho para os dados de treinamento
project_namestrDiretório de saída
trainerstrMétodo de treinamento
epochsintNúmero de épocas
batch_sizeintTamanho do lote
lrfloatTaxa de aprendizado
peftboolHabilitar LoRA
lora_rintRank do LoRA
lora_alphaintAlpha do LoRA
save_processed_datastrSalvar dados processados: auto, local, hub, both, none

Parâmetros de Sweep de Hiperparâmetros

ParâmetroTipoDescrição
use_sweepboolHabilitar sweep de hiperparâmetros
sweep_backendstrBackend: optuna, grid, random
sweep_n_trialsintNúmero de tentativas
sweep_metricstrMétrica a otimizar
sweep_directionstrminimize ou maximize
sweep_paramsstrEspaço de busca personalizado (string JSON)
wandb_sweepboolHabilitar dashboard nativo de sweeps W&B
wandb_sweep_projectstrProjeto W&B para sweep
wandb_sweep_entitystrEntidade W&B (equipe/usuário)
wandb_sweep_idstrID de sweep existente para continuar

Formato de sweep_params

Formatos de lista e dicionário são suportados:
import json

# Formato dict (recomendado)
sweep_params = json.dumps({
    "lr": {"type": "loguniform", "low": 1e-5, "high": 1e-3},
    "batch_size": {"type": "categorical", "values": [2, 4, 8]},
})

# Formato lista (abreviado para categóricos)
sweep_params = json.dumps({
    "batch_size": [2, 4, 8],
})
Tipos suportados: categorical, loguniform, uniform, int.

Classificação de Texto

TextClassificationParams

from autotrain.trainers.text_classification.params import TextClassificationParams

params = TextClassificationParams(
    model="bert-base-uncased",
    data_path="./reviews.csv",
    project_name="sentiment",
    text_column="text",
    target_column="label",
    epochs=5,
    batch_size=16,
    lr=2e-5,
)

Classificação de Imagem

ImageClassificationParams

from autotrain.trainers.image_classification.params import ImageClassificationParams

params = ImageClassificationParams(
    model="google/vit-base-patch16-224",
    data_path="./images/",
    project_name="classifier",
    image_column="image",
    target_column="label",
    epochs=10,
    batch_size=32,
)

Execução do Projeto

AutoTrainProject

from autotrain.project import AutoTrainProject

# Create and run project
project = AutoTrainProject(
    params=params,
    backend="local",  # "local" or "spaces"
    process=True      # Start training immediately
)

job_id = project.create()

Opções de Backend

BackendDescrição
localExecutar na máquina local
spaces-*Executar no Hugging Face Spaces (ex.: spaces-a10g-large, spaces-t4-medium)
ep-*Hugging Face Endpoints
ngc-*NVIDIA NGC
nvcf-*NVIDIA Cloud Functions

Inferência

Usando Completers

from autotrain.generation import CompletionConfig, create_completer

# Configure generation
config = CompletionConfig(
    max_new_tokens=256,
    temperature=0.7,
    top_p=0.95,
    top_k=50,
)

# Create completer (first param is "model", not "model_path")
completer = create_completer(
    model="./my-trained-model",
    completer_type="message",
    config=config
)

# Generate (returns MessageCompletionResult)
result = completer.chat("Hello, how are you?")
print(result.content)  # Access the text content

Usando Transformers Diretamente

from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model
model = AutoModelForCausalLM.from_pretrained("./my-model")
tokenizer = AutoTokenizer.from_pretrained("./my-model")

# Generate
inputs = tokenizer("Hello!", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))

Manipulação de Datasets

AutoTrainDataset

from autotrain.dataset import AutoTrainDataset

dataset = AutoTrainDataset(
    train_data=["train.csv"],
    task="text_classification",
    token="hf_...",
    project_name="my-project",
    username="my-username",
    column_mapping={
        "text": "review_text",
        "label": "sentiment"
    },
)

# Prepare dataset
data_path = dataset.prepare()

Arquivos de Configuração

Carregando de YAML

from autotrain.parser import AutoTrainConfigParser

# Parse config file
parser = AutoTrainConfigParser("config.yaml")

# Run training
parser.run()

Tratamento de Erros

from autotrain.project import AutoTrainProject
from autotrain.trainers.clm.params import LLMTrainingParams

try:
    params = LLMTrainingParams(
        model="google/gemma-3-270m",
        data_path="./data.jsonl",
        project_name="my-model",
    )

    project = AutoTrainProject(params=params, backend="local", process=True)
    job_id = project.create()

except ValueError as e:
    print(f"Configuration error: {e}")
except RuntimeError as e:
    print(f"Training error: {e}")

Exemplos Completos

Pipeline de Treinamento SFT

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

def train_sft():
    params = LLMTrainingParams(
        model="google/gemma-3-270m",
        data_path="./conversations.jsonl",
        project_name="gemma-sft",
        trainer="sft",
        epochs=3,
        batch_size=2,
        gradient_accumulation=8,
        lr=2e-5,
        peft=True,
        lora_r=16,
        lora_alpha=32,
        log="wandb",
    )

    project = AutoTrainProject(
        params=params,
        backend="local",
        process=True
    )

    return project.create()

if __name__ == "__main__":
    job_id = train_sft()
    print(f"Job ID: {job_id}")

Pipeline de Treinamento DPO

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

def train_dpo():
    params = LLMTrainingParams(
        model="meta-llama/Llama-3.2-1B",
        data_path="./preferences.jsonl",
        project_name="llama-dpo",
        trainer="dpo",
        dpo_beta=0.1,
        max_prompt_length=128,     # Default: 128
        max_completion_length=None,  # Default: None
        epochs=1,
        batch_size=2,
        lr=5e-6,
        peft=True,
        lora_r=16,
    )

    project = AutoTrainProject(
        params=params,
        backend="local",
        process=True
    )

    return project.create()

Próximos Passos

LLM Endpoints

API específica para LLM

CLI Reference

Comandos CLI