跳转到主要内容

Documentation Index

Fetch the complete documentation index at: https://docs.monostate.ai/llms.txt

Use this file to discover all available pages before exploring further.

Python API

AITraining 提供了一个 Python API,用于以编程方式访问所有训练功能。

安装

pip install aitraining torch

快速开始

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

# Configure training
params = LLMTrainingParams(
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",
    trainer="sft",
    epochs=3,
    batch_size=4,
    lr=2e-5,
    peft=True,
    lora_r=16,
)

# Start training
project = AutoTrainProject(params=params, backend="local", process=True)
job_id = project.create()
print(f"Training started: {job_id}")

API 结构

训练参数

每种任务类型都有自己的参数类:
任务参数类
LLM 训练LLMTrainingParams
文本分类TextClassificationParams
图像分类ImageClassificationParams
令牌分类TokenClassificationParams
Seq2SeqSeq2SeqParams
表格数据TabularParams
目标检测ObjectDetectionParams
VLMVLMTrainingParams

项目执行

from autotrain.project import AutoTrainProject

# Create project
project = AutoTrainProject(
    params=params,
    backend="local",  # or "spaces"
    process=True      # Start immediately
)

# Run training
job_id = project.create()

示例:完整训练脚本

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

def train_model():
    # Configure parameters
    params = LLMTrainingParams(
        # Model
        model="meta-llama/Llama-3.2-1B",
        project_name="llama-sft",

        # Data
        data_path="./conversations.jsonl",
        train_split="train",
        text_column="text",
        block_size=2048,

        # Training
        trainer="sft",
        epochs=3,
        batch_size=2,
        gradient_accumulation=4,
        lr=2e-5,
        mixed_precision="bf16",

        # LoRA
        peft=True,
        lora_r=16,
        lora_alpha=32,
        lora_dropout=0.05,

        # Logging
        log="wandb",
        logging_steps=10,
    )

    # Start training
    project = AutoTrainProject(
        params=params,
        backend="local",
        process=True
    )

    return project.create()

if __name__ == "__main__":
    job_id = train_model()
    print(f"Training complete: {job_id}")

核心模块

模块描述
autotrain.project项目执行
autotrain.trainers.clm.paramsLLM 参数
autotrain.trainers.text_classification.params文本分类
autotrain.dataset数据集处理
autotrain.generation推理工具

下一步

LLM Endpoints

LLM 训练 API

Python SDK

完整 SDK 参考