Encoder¶
The encoder line classifier for fast, discriminative tool output extraction.
Model¶
SqueezEncoderConfig(base_model_name='jhu-clsp/mmBERT-base', encoder_config=None, vocab_size=None, num_labels=2, classifier_dropout=0.1, max_length=8192, **kwargs)
¶
Bases: PretrainedConfig
Configuration for SqueezEncoderForLineClassification.
SqueezEncoderForLineClassification(config)
¶
Bases: PreTrainedModel
Token-level binary classifier on top of an encoder (mmBERT/ModernBERT).
Designed to classify each token in the answer portion as relevant (1) or irrelevant (0). At inference time, per-line scores are computed by max-pooling token scores between [LINE_SEP] markers.
The encoder is created from config (no pretrained weights) in init.
Use from_encoder_pretrained to initialise with pretrained encoder
weights for training, or from_pretrained to load a saved checkpoint.
from_encoder_pretrained(config)
classmethod
¶
Create a model and load pretrained encoder weights (for training).
This is the entry point for new training runs. For loading a
previously-saved checkpoint, use from_pretrained instead.
extract(task, tool_output, tokenizer, threshold=0.5, window_overlap=2)
¶
High-level inference: return list of relevant line strings.
Handles sliding-window chunking for long tool outputs that exceed the model's max sequence length.
Dataset¶
LineClassificationDataset(data_path, tokenizer, max_length=8192, max_negative_ratio=None, seed=42)
¶
Bases: Dataset
PyTorch dataset for encoder-based line classification.
Each sample is tokenized into: [CLS] task [SEP] line_0 [LINE_SEP] line_1 [LINE_SEP] ... line_n [SEP]
With labels: -100 for CLS, task tokens, SEP, LINE_SEP 0 or 1 for each line token
Samples whose tool output exceeds max_length are split into
overlapping windows so every line is supervised at least once.
Evaluation¶
evaluate_encoder(model_path, eval_file, max_samples=None, threshold=0.5)
¶
Evaluate the encoder model on an eval set.
Args: model_path: Path to trained encoder model eval_file: Path to encoder-format JSONL max_samples: Maximum samples to evaluate threshold: Relevance score threshold
Returns: Dict with aggregate metrics (same format as generative evaluate.py)