Text classification

Few-shot text classification

Overview

Few-shot text classification is a task of classifying a text into one of the pre-defined classes based on a few examples of each class. For example, given a few examples of the class positive, negative, and neutral, the model should be able to classify a new text into one of these classes.

The estimators provided by Scikit-LLM do not automatically select the subset of the training data, and instead use the entire training set to construct the examples. Therefore, if your training set is large, you might want to consider splitting it into training and validation sets, while keeping the training set small (we recommend not to exceed 10 examples per class). Additionally, it is advisable to permute the order of the samples in order to avoid the recency bias.

Example using GPT-4:

from skllm.models.gpt.classification.few_shot import (
FewShotGPTClassifier,
MultiLabelFewShotGPTClassifier,
)
from skllm.datasets import (
    get_classification_dataset,
    get_multilabel_classification_dataset,
)

# single label
X, y = get_classification_dataset()
clf = FewShotGPTClassifier(model="gpt-4o")
clf.fit(X,y)
labels = clf.predict(X)

# multi-label
X, y = get_multilabel_classification_dataset()
clf = MultiLabelFewShotGPTClassifier(max_labels=2, model="gpt-4o")
clf.fit(X,y)
labels = clf.predict(X)

API Reference

The following API reference only lists the parameters needed for the initialization of the estimator. The remaining methods follow the syntax of a scikit-learn classifier.

FewShotGPTClassifier

from skllm.models.gpt.classification.few_shot import FewShotGPTClassifier
ParameterTypeDescription
modelstrModel to use, by default "gpt-3.5-turbo".
default_labelstrDefault label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random".
prompt_templateOptional[str]Custom prompt template to use, by default None.
keyOptional[str]Estimator-specific API key; if None, retrieved from the global config, by default None.
orgOptional[str]Estimator-specific ORG key; if None, retrieved from the global config, by default None.

MultiLabelFewShotGPTClassifier

from skllm.models.gpt.classification.few_shot import MultiLabelFewShotGPTClassifier
ParameterTypeDescription
modelstrModel to use, by default "gpt-3.5-turbo".
default_labelstrDefault label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random".
max_labelsOptional[int]Maximum labels per sample, by default 5.
prompt_templateOptional[str]Custom prompt template to use, by default None.
keyOptional[str]Estimator-specific API key; if None, retrieved from the global config, by default None.
orgOptional[str]Estimator-specific ORG key; if None, retrieved from the global config, by default None.
Previous
Zero-shot text classification