Skip to main content

Documentation Index

Fetch the complete documentation index at: https://docs.monostate.ai/llms.txt

Use this file to discover all available pages before exploring further.

Python SDK Reference

Comprehensive guide to the AITraining Python API.

Installation

pip install aitraining torch torchvision torchaudio
Package vs Import Name: Install with pip install aitraining, but import with from autotrain import ...

LLM Training

LLMTrainingParams

from autotrain.trainers.clm.params import LLMTrainingParams

params = LLMTrainingParams(
    # Required
    model="google/gemma-3-270m",
    data_path="./data.jsonl",
    project_name="my-model",

    # Training method
    trainer="sft",  # sft, dpo, orpo, ppo, reward, distillation

    # Training settings
    epochs=3,
    batch_size=2,
    lr=2e-5,
    gradient_accumulation=4,
    mixed_precision="bf16",

    # LoRA
    peft=True,
    lora_r=16,
    lora_alpha=32,
    lora_dropout=0.05,

    # Data processing
    text_column="text",
    block_size=2048,
    add_eos_token=True,
    save_processed_data="auto",  # auto, local, hub, both, none

    # Logging
    log="wandb",
    logging_steps=-1,  # Default: -1 (auto)

    # Hyperparameter sweep (optional)
    # use_sweep=True,
    # sweep_backend="optuna",
    # sweep_n_trials=20,
    # sweep_params='{"lr": {"type": "loguniform", "low": 1e-5, "high": 1e-3}}',
)

Key Parameters

ParameterTypeDescription
modelstrModel name or path
data_pathstrPath to training data
project_namestrOutput directory
trainerstrTraining method
epochsintNumber of epochs
batch_sizeintBatch size
lrfloatLearning rate
peftboolEnable LoRA
lora_rintLoRA rank
lora_alphaintLoRA alpha
save_processed_datastrSave processed data: auto, local, hub, both, none

Hyperparameter Sweep Parameters

ParameterTypeDescription
use_sweepboolEnable hyperparameter sweeping
sweep_backendstrBackend: optuna, grid, random
sweep_n_trialsintNumber of trials
sweep_metricstrMetric to optimize
sweep_directionstrminimize or maximize
sweep_paramsstrCustom search space (JSON string)
wandb_sweepboolEnable W&B native sweep dashboard
wandb_sweep_projectstrW&B project for sweep
wandb_sweep_entitystrW&B entity (team/username)
wandb_sweep_idstrExisting sweep ID to continue

sweep_params Format

Both list and dict formats are supported:
import json

# Dict format (recommended)
sweep_params = json.dumps({
    "lr": {"type": "loguniform", "low": 1e-5, "high": 1e-3},
    "batch_size": {"type": "categorical", "values": [2, 4, 8]},
})

# List format (categorical shorthand)
sweep_params = json.dumps({
    "batch_size": [2, 4, 8],
})
Supported types: categorical, loguniform, uniform, int.

Text Classification

TextClassificationParams

from autotrain.trainers.text_classification.params import TextClassificationParams

params = TextClassificationParams(
    model="bert-base-uncased",
    data_path="./reviews.csv",
    project_name="sentiment",
    text_column="text",
    target_column="label",
    epochs=5,
    batch_size=16,
    lr=2e-5,
)

Image Classification

ImageClassificationParams

from autotrain.trainers.image_classification.params import ImageClassificationParams

params = ImageClassificationParams(
    model="google/vit-base-patch16-224",
    data_path="./images/",
    project_name="classifier",
    image_column="image",
    target_column="label",
    epochs=10,
    batch_size=32,
)

Project Execution

AutoTrainProject

from autotrain.project import AutoTrainProject

# Create and run project
project = AutoTrainProject(
    params=params,
    backend="local",  # "local" or "spaces"
    process=True      # Start training immediately
)

job_id = project.create()

Backend Options

BackendDescription
localRun on local machine
spaces-*Run on Hugging Face Spaces (e.g., spaces-a10g-large, spaces-t4-medium)
ep-*Hugging Face Endpoints
ngc-*NVIDIA NGC
nvcf-*NVIDIA Cloud Functions

Inference

Using Completers

from autotrain.generation import CompletionConfig, create_completer

# Configure generation
config = CompletionConfig(
    max_new_tokens=256,
    temperature=0.7,
    top_p=0.95,
    top_k=50,
)

# Create completer (first param is "model", not "model_path")
completer = create_completer(
    model="./my-trained-model",
    completer_type="message",
    config=config
)

# Generate (returns MessageCompletionResult)
result = completer.chat("Hello, how are you?")
print(result.content)  # Access the text content

Using Transformers Directly

from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model
model = AutoModelForCausalLM.from_pretrained("./my-model")
tokenizer = AutoTokenizer.from_pretrained("./my-model")

# Generate
inputs = tokenizer("Hello!", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))

Dataset Handling

AutoTrainDataset

from autotrain.dataset import AutoTrainDataset

dataset = AutoTrainDataset(
    train_data=["train.csv"],
    task="text_classification",
    token="hf_...",
    project_name="my-project",
    username="my-username",
    column_mapping={
        "text": "review_text",
        "label": "sentiment"
    },
)

# Prepare dataset
data_path = dataset.prepare()

Configuration Files

Loading from YAML

from autotrain.parser import AutoTrainConfigParser

# Parse config file
parser = AutoTrainConfigParser("config.yaml")

# Run training
parser.run()

Error Handling

from autotrain.project import AutoTrainProject
from autotrain.trainers.clm.params import LLMTrainingParams

try:
    params = LLMTrainingParams(
        model="google/gemma-3-270m",
        data_path="./data.jsonl",
        project_name="my-model",
    )

    project = AutoTrainProject(params=params, backend="local", process=True)
    job_id = project.create()

except ValueError as e:
    print(f"Configuration error: {e}")
except RuntimeError as e:
    print(f"Training error: {e}")

Complete Examples

SFT Training Pipeline

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

def train_sft():
    params = LLMTrainingParams(
        model="google/gemma-3-270m",
        data_path="./conversations.jsonl",
        project_name="gemma-sft",
        trainer="sft",
        epochs=3,
        batch_size=2,
        gradient_accumulation=8,
        lr=2e-5,
        peft=True,
        lora_r=16,
        lora_alpha=32,
        log="wandb",
    )

    project = AutoTrainProject(
        params=params,
        backend="local",
        process=True
    )

    return project.create()

if __name__ == "__main__":
    job_id = train_sft()
    print(f"Job ID: {job_id}")

DPO Training Pipeline

from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.project import AutoTrainProject

def train_dpo():
    params = LLMTrainingParams(
        model="meta-llama/Llama-3.2-1B",
        data_path="./preferences.jsonl",
        project_name="llama-dpo",
        trainer="dpo",
        dpo_beta=0.1,
        max_prompt_length=128,     # Default: 128
        max_completion_length=None,  # Default: None
        epochs=1,
        batch_size=2,
        lr=5e-6,
        peft=True,
        lora_r=16,
    )

    project = AutoTrainProject(
        params=params,
        backend="local",
        process=True
    )

    return project.create()

Next Steps

LLM Endpoints

LLM-specific API

CLI Reference

CLI commands