Training¶
Trainer¶
Trainer(model, tokenizer, train_loader, test_loader, epochs=6, learning_rate=1e-05, save_path='best_model', device=None)
¶
Token classification trainer with epoch-based training and validation.
Trains a model using AdamW, evaluates on a test set after each epoch, and saves the best checkpoint based on hallucinated-class F1.
Initialize the trainer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The model to train |
required |
tokenizer
|
PreTrainedTokenizer
|
Tokenizer for the model |
required |
train_loader
|
DataLoader
|
DataLoader for training data |
required |
test_loader
|
DataLoader
|
DataLoader for test data |
required |
epochs
|
int
|
Number of training epochs |
6
|
learning_rate
|
float
|
Learning rate for optimization |
1e-05
|
save_path
|
str
|
Path to save the best model |
'best_model'
|
device
|
device | None
|
Device to train on (defaults to cuda if available) |
None
|
train()
¶
Train the model.
Returns: Best F1 score achieved during training
Evaluator¶
evaluator
¶
evaluate_model(model, dataloader, device, verbose=True)
¶
Evaluate a model for hallucination detection.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The model to evaluate. |
required |
dataloader
|
DataLoader
|
The data loader to use for evaluation. |
required |
device
|
device
|
The device to use for evaluation. |
required |
verbose
|
bool
|
If True, print the evaluation metrics. |
True
|
Returns:
| Type | Description |
|---|---|
dict[str, dict[str, float]]
|
A dictionary containing the evaluation metrics. { "supported": {"precision": float, "recall": float, "f1": float}, "hallucinated": {"precision": float, "recall": float, "f1": float} } |
print_metrics(metrics)
¶
Print evaluation metrics in a readable format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics
|
dict[str, dict[str, float]]
|
A dictionary containing the evaluation metrics. |
required |
Returns:
| Type | Description |
|---|---|
None
|
None |
evaluate_model_example_level(model, dataloader, device, verbose=True)
¶
Evaluate a model for hallucination detection at the example level.
For each example, if any token is marked as hallucinated (label=1), then the whole example is considered hallucinated. Otherwise, it is supported.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The model to evaluate. |
required |
dataloader
|
DataLoader
|
DataLoader providing the evaluation batches. |
required |
device
|
device
|
Device on which to perform evaluation. |
required |
verbose
|
bool
|
If True, prints a detailed classification report. |
True
|
Returns:
| Type | Description |
|---|---|
dict[str, dict[str, float]]
|
A dict containing example-level metrics: { "supported": {"precision": float, "recall": float, "f1": float}, "hallucinated": {"precision": float, "recall": float, "f1": float} } |
create_sample_llm(sample, labels)
¶
Creates a sample where the annotations / labels are based on the LLM responses.
evaluate_detector_char_level(detector, samples)
¶
Evaluate the HallucinationDetector at the character level.
This function assumes that each sample is a dictionary containing: - "prompt": the prompt text. - "answer": the answer text. - "gold_spans": a list of dictionaries where each dictionary has "start" and "end" keys indicating the character indices of the gold (human-labeled) span.
It uses the detector (xwhich should have been initialized with the appropriate model) to obtain predicted spans, compares those spans with the gold spans, and computes global precision, recall, and F1 based on character overlap.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
detector
|
HallucinationDetector
|
The detector to evaluate. |
required |
samples
|
list[HallucinationSample]
|
A list of samples to evaluate. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
A dictionary with global metrics: {"char_precision": ..., "char_recall": ..., "char_f1": ...} |
evaluate_detector_example_level_batch(detector, samples, batch_size=10, verbose=True)
¶
Evaluate the HallucinationDetector at the example level.
This function assumes that each sample is a dictionary containing: - "prompt": the prompt text. - "answer": the answer text. - "gold_spans": a list of dictionaries where each dictionary has "start" and "end" keys indicating the character indices of the gold (human-labeled) span.
evaluate_detector_example_level(detector, samples, verbose=True)
¶
Evaluate the HallucinationDetector at the example level.
This function assumes that each sample is a dictionary containing: - "prompt": the prompt text. - "answer": the answer text. - "gold_spans": a list of dictionaries where each dictionary has "start" and "end" keys indicating the character indices of the gold (human-labeled) span.
It uses the detector (which should have been initialized with the appropriate model) or gets samples from the baseline file if it exists to obtain predicted spans, if any span is predicted the example is marked as hallucinated (label = 1) then the whole example is considered hallucinated. Otherwise, it is supported.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
detector
|
HallucinationDetector
|
The detector to evaluate. |
required |
samples
|
list[HallucinationSample]
|
A list of samples to evaluate containing the ground truth labels. |
required |
|
samples_llm
|
A list of samples containing LLM generated labels, is used if baseline file exists. |
required |
baseline_file_exists
|
Gives information if baseline file exists or should be created. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, dict[str, float]]
|
A dict containing example-level metrics: { "supported": {"precision": float, "recall": float, "f1": float}, "hallucinated": {"precision": float, "recall": float, "f1": float} } |