A Permutation-Invariant, Multi-Instance ML System for Ancient DNA Genotype Calling

Project Guide & Proposal

September 25, 2025

Executive Summary

Ancient DNA (aDNA) presents unique statistical and engineering challenges: ultra-short fragments, characteristic chemical damage (e.g., C→T changes near read ends), low and uneven coverage, and modern contamination. Classical tools such as ATLAS solve these with explicit (often hand-tuned) generative models and likelihood-based inference. While robust, these methods can be conservative at very low depth and may not optimally integrate the spectrum of read-level cues (base quality, mapping quality, end-position, fragment length, strand, library prep).

This proposal reframes genotype calling for aDNA as a per-site, permutation-invariant, multi-instance, multi-output multiclass classification problem with calibrated probabilities. Each site (a reference coordinate) is a "bag" of reads; the model ingests read-level features and outputs a probability distribution over genotypes (e.g., 0/0, 0/1, 1/1) or over alleles (A/C/G/T for pseudo-haploid), plus optional auxiliary estimates (damage rates, contamination fraction, site usability).

The design emphasizes set-based inductive bias (order of reads must not matter), calibration (probabilities reflect reality at ultra-low coverage), interoperability (VCF/BCF outputs), and reproducibility (containerized pipelines, deterministic configs).

Motivation & Background

Why Ancient DNA is Hard

🧬
Short Fragments

DNA breaks into 30-70 bp pieces; mapping and pileup context are limited.

āš ļø
Chemical Damage

Deamination produces predictable substitution patterns near fragment ends (e.g., C→T at read 5' ends).

šŸ“Š
Low Coverage

Many sites have 0-1 reads; variance dominates; naive majority votes fail.

šŸ”¬
Contamination

Modern human DNA can mix with the ancient signal.

šŸ“
Reference Bias

Aligners and downstream callers may over-favor the reference allele in ambiguous evidence.

What Current Tools Do

Damage-aware tools (e.g., ATLAS, ANGSD) estimate damage and contamination parameters and incorporate them into likelihoods for genotype inference. This proposal builds a learned inference module that:

  1. Ingests richer read-level features
  2. Learns non-linear interactions
  3. Produces calibrated genotype probabilities and auxiliary estimates
  4. Remains compatible with established downstream pipelines (VCF/BCF, pseudo-haploid VCF, QC JSON/TSV)

Why Not Sequence-to-Sequence?

Inputs are sets of reads (variable number, unordered) aligned to a fixed coordinate (a site). Outputs are per-site probability tables on a predefined grid (SNP panel or known variants), often with sparse evidence. Seq2seq assumes contiguous, length-coupled sequences; it is a poor inductive bias here. Set-based or local grid (pileup) models are more appropriate.

Problem Formulation

Definitions

Reference build
A known genome (e.g., GRCh37/38)
Site
Tuple (chrom, pos, ref[, alt]) on the reference; typically biallelic SNPs
Callable site
Site with sufficient usable evidence (e.g., ≄1 read for pseudo-haploid; BQ/MAPQ thresholds; outside blacklists)
Read-level features
Observed base, base quality, mapping quality, position from read ends, strand, fragment length, duplicate flag, library/UDG flags
Pileup tensor
Fixed-shape representation of the set of reads at a site (e.g., top-K reads Ɨ feature channels Ɨ local context)

Tasks

3-class classification {0/0, 0/1, 1/1} with probabilities (genotype likelihoods)

4-class classification {A, C, G, T} (optionally no-call) at a site

Continuous regression for damage rates and contamination fraction; binary classification for site usability

Objectives

Data Strategy

Site Sets

  1. Curated SNP panels (e.g., ~1.24M sites): best for low coverage; high mappability; mostly biallelic
  2. Known variant catalogs (e.g., biallelic SNPs from public datasets): broader WGS coverage; apply strict mappability filters
  3. Region-based sites (all bases in a BED): exhaustive but imbalanced; optional later milestone

Real vs. Simulated Data

Real aDNA

Invaluable for final validation (ground truth limited; use consensus across tools/labs)

Simulated Data

Modern truth sets provide high-coverage genotypes to downsample & corrupt to mimic aDNA. Simulation from modern truth genomes should generate reads with:

  • Realistic length distributions
  • End-biased deamination
  • Sequencing noise
  • Mapping ambiguity
  • Contamination mixtures
  • Varied library/UDG flags

Labels & Splits

Diploid task uses true genotypes at each site; pseudo-haploid task derives single-allele targets from diploid truth. Use sample-level splits; avoid leakage across replicates/libraries; stratify by coverage/damage regimes.

Preprocessing & Callability

Filters
  • MAPQ ≄ 30, BQ ≄ 30
  • Exclude soft-clips and adjacent indel contexts
  • Remove blacklisted/multimapping regions
  • Optional CpG masking

If no read passes, write a no-call and log the reason.

Model Architecture

High-Level Pipeline

Inputs
Overlapping Reads at Site
(unordered set)
→
Featurizer
Pileup tensor /
Read tokens
→
Permutation-Invariant Model
CNN on pileup or
Transformer
→
Outputs
Genotype Probs
Pseudo-haploid
VCF/BCF & QC

Representations

Pileup "Image" (CNN)

Tensor shape (R_max, W, C) with top-K reads, window width W (e.g., ±16 bp), and channels C including:

Base Features
  • Base one-hot encoding
  • Base quality (BQ)
  • Mapping quality (MAPQ)
Position Features
  • Read-end distance
  • Strand information
  • Fragment-length bin
Metadata Features
  • Duplicate flag
  • Mappability score
  • CpG flag
  • Ref/alt markers

Transformer over Reads (Set Model)

Each read is a token with features; add a [SITE] token to pool. Positional encodings reflect distance from read ends (damage signal), not genomic order. Attention summarizes variable numbers of reads; outputs site-level logits and auxiliary heads.

Hybrid Architecture

Transformer over read tokens + shallow CNN over a small reference context window; fuse via cross-attention or concatenation.

Heads & Outputs

3
Diploid Classes
(0/0, 0/1, 1/1)
4
Pseudo-haploid
(A, C, G, T)
2
Regression Heads
(Damage, Contamination)
1
Binary Classification
(Usability)

Calibration Methods

Regularization & Stability

Training Plan

Losses

Primary

Cross-entropy/NLL on genotype (diploid) or allele (pseudo-haploid)

Auxiliary

MSE/Huber for damage/contamination; BCE for usability

Composite

Weighted sum of all objectives

Curriculum & Schedules

Start with moderate-coverage simulated data (~1Ɨ), then mix lower coverage (0.05Ɨ-0.5Ɨ).

Hyperparameters (Starting Points)

  • 4-6 layers, 8 heads
  • d_model: 256-512
  • FFN: 1024-2048
  • Dropout: 0.1
  • 3-6 conv blocks (3Ɨ3)
  • Channels: 64→256
  • Activation: GELU
  • Normalization: LayerNorm
  • R_max: 32-64
  • W: ±16-32 bp
  • Optional 4/8-bit quantization for BQ/MAPQ channels

Infrastructure

Evaluation & Benchmarks

Core Metrics

Classification
Accuracy, F1, Non-ref sensitivity
Calibration
Brier, ECE/ACE, Reliability
Bias
REF/ALT balance, Allele distribution
Auxiliary
RMSE/MAE for estimates

Stratified Analyses

Coverage Bins

0-1Ɨ
1-2Ɨ
2-5Ɨ
>5Ɨ

Other Stratifications

Downstream Validations

Ablations

Software & Pipeline Design

Components

šŸ”§
Featurizer

Inputs: BAM/CRAM + reference + site list
Outputs: per-site tensors (NPZ/Parquet) & metadata
Applies filters and masks

šŸŽÆ
Trainer

Loads tensors, trains model, logs metrics, checkpoints
Supports DDP and mixed precision

šŸ“Š
Calibrator

Fits temperature/Dirichlet/isotonic
Exports calibration parameters

🧬
Caller (Inference CLI)

Inputs: BAM/CRAM + reference + site list + model + calibration
Outputs: VCF/BCF with GT/PL/GL/DP/AD or pseudo-haploid VCF, QC JSON/TSV

šŸ“ˆ
QC & Reports

Plots: fragment length, damage profiles, calibration curves
Tables: coverage, callability, error by bin

šŸ“¦
Packaging & Distribution

Docker images, Conda package, PyPI wheel (optional)
WDL/Nextflow/Cromwell wrappers

Configuration & Reproducibility

Example Code

# 1) Load site list sites = load_sites("sites.vcf_or_bed") # [(chrom, pos, ref, alt)] # 2) Optional: filter to good regions sites = [s for s in sites if is_mappable(s) and not in_blacklist(s)] # 3) For each sample, decide callability & build inputs for s in sites: reads = fetch_reads(bam, s.chrom, s.pos, window=33) bases = filter_bases(reads, MAPQ>=30, BQ>=30, snp_only=True) if len(bases) == 0: write_no_call(s) continue X = build_pileup_tensor(bases, R_max=48, W=33, channels=spec) logits = model(X) # or model(read_tokens) probs = calibrate(logits) # temperature scaling / isotonic write_vcf_record(s, probs, depths=bases)

Milestones & Timeline

M0: Project Setup

Repo, containers, featurizer skeleton, unit tests

2 weeks
M1: Simulation v1

aDNA simulator, modern truth ingestion

2 weeks
M2: v0 Model

Pseudo-haploid CNN, featurization, training loop

3 weeks
M3: Eval v0

Calibration, coverage-bin metrics, baselines

2 weeks
M4: v1 Model

Transformer over reads, damage head, REF/ALT aug

3 weeks
M5: Diploid GLs

3-class head, VCF writer, temp scaling

3 weeks
M6: Real Data

Small real aDNA eval, downstream checks

2 weeks
M7: Hardening

CLI/docs/workflows/CI, reproducibility audits

2 weeks
M8: Report

Final benchmarks, ablations, user guide

1-2 weeks
Total Duration

20-22 weeks

Risks & Mitigations

Risk Mitigation Strategies
Simulation-Reality Gap
  • Diverse parameter sweeps
  • Mixtures of simulations
  • Real holdouts
Reference Bias
  • REF/ALT flipping augmentation
  • Allele-balance features
  • Bias-aware metrics
Calibration Drift
  • Coverage-conditioned calibration
  • Per-regime temperature scaling
Data Governance
  • Use public/approved datasets
  • Strict access controls
  • Anonymization
  • IRB/ethics compliance
Compute/Costs
  • Focus on panels for v1
  • Cache tensors
  • Mixed precision
  • Efficient batching

Ethical, Legal & Social Considerations

Human Subjects & Ancestry

Privacy & Data Use

Transparency

Integration & Outputs

Outputs & Formats

VCF/BCF (diploid)

GT, PL/GL, DP, AD, GQ

Pseudo-haploid VCF

Single allele calls for low coverage

QC Reports

JSON/TSV format

Plots

PDF/PNG visualizations

BED masks

Filtered regions

Logs/provenance

Configs, git/container hashes

Interoperability

Population Genetics Tools

Imputation

Supply GL/PL to off-the-shelf imputation tools

Resources & Budget (Indicative)

Personnel

Compute

Software

Deliverables

  1. ML caller (binary + library) with CLI:
    • Inputs: BAM/CRAM, reference, site list, config
    • Outputs: VCF/BCF, QC, plots, logs
  2. Training toolkit:
    • Simulator & downsampler
    • Tensorizer
    • Training scripts
    • Calibration tool
  3. Pipelines & configs:
    • Nextflow/WDL
    • Docker/Conda
    • Example YAMLs
  4. Documentation:
    • User & developer guides
    • API docs
    • Tutorial notebooks
  5. Benchmark report:
    • Coverage/damage-bin metrics
    • Ablations
    • Calibration diagnostics
    • Downstream analyses

Glossary (Beginner-Friendly)

Site
Specific location in the reference genome (chromosome + position)
Read
Short DNA fragment sequenced from the sample
Pileup
Stack of reads overlapping a site
MAPQ
Mapping quality (confidence a read is placed correctly)
BQ
Base quality (confidence in the called letter within a read)
Pseudo-haploid
One allele per site; used when coverage is too low for diploid genotypes
Calibration
Making predicted probabilities match observed frequencies
Biallelic SNP
Variant with only one alternate allele besides the reference
UDG
Uracil-DNA glycosylase treatment; reduces certain damage patterns

Appendix

Minimal Pileup Tensor Specification

Inputs per site:

  • Top-R_max reads by quality (e.g., 48)
  • Reference window width W (e.g., 33 bp = ±16)
  • Channels C:
    • Base one-hot encoding
    • Bucketized BQ/MAPQ
    • Read-end distance
    • Strand
    • Fragment-length bin
    • Duplicate flag
    • Mappability
    • CpG flag
    • Ref/alt markers

Shape: (R_max, W, C)

Transformer variant: uses one token per read plus a [SITE] token

Example VCF Lines (Annotated)

Diploid (biallelic) VCF example (simplified):

#CHROM  POS     ID      REF     ALT     QUAL    FILTER  INFO    FORMAT  SAMPLE1
chr1    10583   .       G       A       .       PASS    DP=3    GT:PL:DP:AD     0/1:10,0,12:3:2,1

Pseudo-haploid VCF example:

#CHROM  POS     ID      REF     ALT     QUAL    FILTER  INFO    FORMAT  SAMPLE1
chr1    10583   .       G       A       .       PASS    DP=1    GT:GP   1:0.18,0.12,0.70,0.00

Pseudocode: Featurize → Predict → Write

# 1) Load site list
sites = load_sites("sites.vcf_or_bed")  # [(chrom, pos, ref, alt)]

# 2) Optional: filter to good regions
sites = [s for s in sites if is_mappable(s) and not in_blacklist(s)]

# 3) For each sample, decide callability & build inputs
for s in sites:
    reads = fetch_reads(bam, s.chrom, s.pos, window=33)
    bases = filter_bases(reads, MAPQ>=30, BQ>=30, snp_only=True)
    
    if len(bases) == 0:
        write_no_call(s)
        continue
    
    X = build_pileup_tensor(bases, R_max=48, W=33, channels=spec)
    logits = model(X)  # or model(read_tokens)
    probs = calibrate(logits)  # temperature scaling / isotonic
    write_vcf_record(s, probs, depths=bases)

Extended Implementation Details

Featurization Pipeline (Detailed)

class PileupFeaturizer:
    def __init__(self, config):
        self.max_reads = config.max_reads      # e.g., 48
        self.window_size = config.window       # e.g., 33
        self.min_mapq = config.min_mapq        # e.g., 30
        self.min_bq = config.min_bq            # e.g., 30
    
    def featurize_site(self, bam, chrom, pos, ref, alt=None):
        """Extract features for a single genomic site."""
        # Fetch reads overlapping the site
        reads = []
        for pileupcolumn in bam.pileup(chrom, pos-1, pos):
            if pileupcolumn.pos == pos-1:
                for pileupread in pileupcolumn.pileups:
                    if not pileupread.is_del and not pileupread.is_refskip:
                        reads.append(self._extract_read_features(pileupread))
        
        # Sort by quality and truncate
        reads = sorted(reads, key=lambda x: x['qual'], reverse=True)[:self.max_reads]
        
        # Build tensor
        return self._build_tensor(reads, ref, alt)
    
    def _extract_read_features(self, pileupread):
        """Extract features from a single read."""
        alignment = pileupread.alignment
        return {
            'base': alignment.query_sequence[pileupread.query_position],
            'qual': alignment.query_qualities[pileupread.query_position],
            'mapq': alignment.mapping_quality,
            'pos_from_5p': pileupread.query_position,
            'pos_from_3p': len(alignment.query_sequence) - pileupread.query_position - 1,
            'strand': alignment.is_reverse,
            'frag_len': abs(alignment.template_length),
            'is_duplicate': alignment.is_duplicate,
            'library': alignment.get_tag('RG') if alignment.has_tag('RG') else 'unknown'
        }
    
    def _build_tensor(self, reads, ref, alt):
        """Convert read features to tensor format."""
        # Initialize tensor with zeros
        tensor = np.zeros((self.max_reads, self.window_size, self.n_channels))
        
        for i, read in enumerate(reads):
            # One-hot encode base
            base_idx = {'A': 0, 'C': 1, 'G': 2, 'T': 3}.get(read['base'], 4)
            tensor[i, self.window_size//2, base_idx] = 1
            
            # Add quality scores
            tensor[i, self.window_size//2, 5] = read['qual'] / 40.0
            tensor[i, self.window_size//2, 6] = read['mapq'] / 60.0
            
            # Add position features
            tensor[i, self.window_size//2, 7] = read['pos_from_5p'] / 100.0
            tensor[i, self.window_size//2, 8] = read['pos_from_3p'] / 100.0
            
            # Add binary flags
            tensor[i, self.window_size//2, 9] = float(read['strand'])
            tensor[i, self.window_size//2, 10] = float(read['is_duplicate'])
        
        return tensor

Model Architecture Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerGenotypeCaller(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.d_model
        self.n_heads = config.n_heads
        self.n_layers = config.n_layers
        
        # Input embedding
        self.read_embedding = nn.Linear(config.n_features, self.d_model)
        self.site_token = nn.Parameter(torch.randn(1, 1, self.d_model))
        
        # Positional encoding for damage signal
        self.pos_encoding = DamagePositionalEncoding(self.d_model)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.d_model,
            nhead=self.n_heads,
            dim_feedforward=config.d_ff,
            dropout=config.dropout
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, self.n_layers)
        
        # Output heads
        self.genotype_head = nn.Linear(self.d_model, 3)        # 0/0, 0/1, 1/1
        self.pseudo_haploid_head = nn.Linear(self.d_model, 4)  # A, C, G, T
        self.damage_head = nn.Linear(self.d_model, 2)          # C->T, G->A rates
        self.contamination_head = nn.Linear(self.d_model, 1)   # contamination %
        
        # Calibration temperature
        self.temperature = nn.Parameter(torch.ones(1))
    
    def forward(self, read_features, read_mask=None):
        batch_size = read_features.size(0)
        
        # Embed reads
        read_embeds = self.read_embedding(read_features)
        
        # Add site token
        site_token = self.site_token.expand(batch_size, 1, -1)
        x = torch.cat([site_token, read_embeds], dim=1)
        
        # Add positional encoding
        x = self.pos_encoding(x, read_features)
        
        # Transform
        x = self.transformer(x, src_key_padding_mask=read_mask)
        
        # Extract site representation (first token)
        site_repr = x[:, 0, :]
        
        # Compute outputs
        genotype_logits = self.genotype_head(site_repr) / self.temperature
        pseudo_haploid_logits = self.pseudo_haploid_head(site_repr)
        damage_estimates = torch.sigmoid(self.damage_head(site_repr))
        contamination = torch.sigmoid(self.contamination_head(site_repr))
        
        return {
            'genotype': F.softmax(genotype_logits, dim=-1),
            'pseudo_haploid': F.softmax(pseudo_haploid_logits, dim=-1),
            'damage': damage_estimates,
            'contamination': contamination
        }

class DamagePositionalEncoding(nn.Module):
    """Encodes position from read ends (damage signal)."""
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
    
    def forward(self, x, read_features):
        # Extract distance from 5' and 3' ends
        dist_5p = read_features[..., 7:8]  # Assuming this is the position
        dist_3p = read_features[..., 8:9]
        
        # Create sinusoidal encoding based on end distances
        pe = torch.zeros_like(x)
        position = torch.cat([dist_5p, dist_3p], dim=-1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2) * 
                           -(math.log(10000.0) / self.d_model))
        pe[..., 0::2] = torch.sin(position[..., 0:1] * div_term)
        pe[..., 1::2] = torch.cos(position[..., 1:2] * div_term)
        
        return x + pe

Training Loop

def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0
    
    for batch in tqdm(dataloader):
        # Move to device
        features = batch['features'].to(device)
        genotypes = batch['genotype'].to(device)
        coverage = batch['coverage'].to(device)
        
        # Forward pass
        outputs = model(features)
        
        # Compute losses
        genotype_loss = criterion['genotype'](outputs['genotype'], genotypes)
        
        # Coverage-aware weighting
        weights = compute_coverage_weights(coverage)
        weighted_loss = (genotype_loss * weights).mean()
        
        # Add auxiliary losses
        if 'damage' in batch:
            damage_loss = F.mse_loss(outputs['damage'], batch['damage'].to(device))
            weighted_loss += 0.1 * damage_loss
        
        # Backward pass
        optimizer.zero_grad()
        weighted_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Track metrics
        total_loss += weighted_loss.item()
        predictions = outputs['genotype'].argmax(dim=-1)
        correct_predictions += (predictions == genotypes).sum().item()
        total_samples += genotypes.size(0)
    
    return {
        'loss': total_loss / len(dataloader),
        'accuracy': correct_predictions / total_samples
    }

def compute_coverage_weights(coverage):
    """Weight samples by inverse coverage to balance training."""
    # Higher weight for low-coverage samples
    weights = 1.0 / (1.0 + torch.log1p(coverage))
    return weights / weights.mean()

Calibration Module

class TemperatureScaling(nn.Module):
    """Temperature scaling for calibration."""
    def __init__(self, n_bins=10):
        super().__init__()
        self.temperature = nn.Parameter(torch.ones(1))
        self.n_bins = n_bins
    
    def fit(self, logits, labels, coverage=None):
        """Fit temperature on validation set."""
        self.temperature.requires_grad = True
        optimizer = torch.optim.LBFGS([self.temperature], lr=0.01, max_iter=50)
        
        def eval_loss():
            optimizer.zero_grad()
            scaled_logits = logits / self.temperature
            loss = F.cross_entropy(scaled_logits, labels)
            loss.backward()
            return loss
        
        optimizer.step(eval_loss)
        self.temperature.requires_grad = False
    
    def forward(self, logits):
        return logits / self.temperature
    
    def compute_ece(self, probs, labels):
        """Expected Calibration Error."""
        confidences, predictions = probs.max(dim=1)
        accuracies = predictions.eq(labels)
        
        ece = 0
        for bin_idx in range(self.n_bins):
            bin_lower = bin_idx / self.n_bins
            bin_upper = (bin_idx + 1) / self.n_bins
            in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
            prop_in_bin = in_bin.float().mean()
            
            if prop_in_bin > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
        
        return ece

VCF Writer

import pysam
from datetime import datetime

class VCFWriter:
    def __init__(self, output_path, reference, samples, mode='diploid'):
        self.output_path = output_path
        self.reference = reference
        self.samples = samples
        self.mode = mode
        
        # Create VCF header
        self.header = pysam.VariantHeader()
        
        # Add metadata
        self.header.add_meta('fileformat', 'VCFv4.3')
        self.header.add_meta('fileDate', datetime.now().strftime('%Y%m%d'))
        self.header.add_meta('source', 'aDNA_ML_Genotyper_v1.0')
        self.header.add_meta('reference', reference)
        
        # Add INFO fields
        self.header.add_meta('INFO', items=[
            ('ID', 'DP'),
            ('Number', 1),
            ('Type', 'Integer'),
            ('Description', 'Total depth')
        ])
        
        # Add FORMAT fields
        if self.mode == 'diploid':
            self.header.add_meta('FORMAT', items=[
                ('ID', 'GT'),
                ('Number', 1),
                ('Type', 'String'),
                ('Description', 'Genotype')
            ])
            self.header.add_meta('FORMAT', items=[
                ('ID', 'PL'),
                ('Number', 'G'),
                ('Type', 'Integer'),
                ('Description', 'Phred-scaled genotype likelihoods')
            ])
        else:  # pseudo-haploid
            self.header.add_meta('FORMAT', items=[
                ('ID', 'GT'),
                ('Number', 1),
                ('Type', 'String'),
                ('Description', 'Pseudo-haploid genotype')
            ])
            self.header.add_meta('FORMAT', items=[
                ('ID', 'GP'),
                ('Number', 4),
                ('Type', 'Float'),
                ('Description', 'Genotype probabilities for A,C,G,T')
            ])
        
        # Add samples
        for sample in samples:
            self.header.add_sample(sample)
        
        # Add contigs
        for contig in self.reference.references:
            self.header.add_meta('contig', items=[
                ('ID', contig),
                ('length', self.reference.get_reference_length(contig))
            ])
        
        # Open VCF for writing
        self.vcf = pysam.VariantFile(self.output_path, 'w', header=self.header)
    
    def write_variant(self, chrom, pos, ref, alt, sample_data):
        """Write a single variant to VCF."""
        record = self.vcf.new_record()
        record.chrom = chrom
        record.pos = pos
        record.ref = ref
        record.alts = (alt,) if alt else None
        
        # Add INFO fields
        total_depth = sum(data.get('DP', 0) for data in sample_data.values())
        record.info['DP'] = total_depth
        
        # Add sample data
        for sample, data in sample_data.items():
            if self.mode == 'diploid':
                record.samples[sample]['GT'] = data['GT']
                record.samples[sample]['PL'] = data['PL']
            else:
                record.samples[sample]['GT'] = data['GT']
                record.samples[sample]['GP'] = data['GP']
        
        self.vcf.write(record)
    
    def close(self):
        self.vcf.close()

Main Inference Pipeline

def run_inference(args):
    """Main inference pipeline."""
    # Load configuration
    config = load_config(args.config)
    
    # Initialize model
    model = TransformerGenotypeCaller(config.model)
    model.load_state_dict(torch.load(args.model_checkpoint))
    model.eval()
    model = model.to(args.device)
    
    # Load calibration
    calibrator = TemperatureScaling()
    calibrator.load_state_dict(torch.load(args.calibration_checkpoint))
    
    # Initialize featurizer
    featurizer = PileupFeaturizer(config.featurization)
    
    # Open BAM and reference
    bam = pysam.AlignmentFile(args.bam, 'rb')
    reference = pysam.FastaFile(args.reference)
    
    # Load sites
    sites = load_sites(args.sites)
    
    # Initialize VCF writer
    vcf_writer = VCFWriter(
        args.output, reference, [args.sample_name], mode=args.mode
    )
    
    # Process sites
    for site in tqdm(sites, desc="Processing sites"):
        # Featurize
        features = featurizer.featurize_site(
            bam, site.chrom, site.pos, site.ref, site.alt
        )
        
        # Skip if no reads
        if features is None or features.sum() == 0:
            continue
        
        # Run model
        with torch.no_grad():
            features_tensor = torch.tensor(features).unsqueeze(0).to(args.device)
            outputs = model(features_tensor)
        
        # Apply calibration
        if args.mode == 'diploid':
            genotype_probs = calibrator(outputs['genotype'])
            genotype_probs = F.softmax(genotype_probs, dim=-1)
            
            # Convert to Phred-scaled likelihoods
            pl = probs_to_phred(genotype_probs[0])
            gt = call_genotype(genotype_probs[0])
            
            sample_data = {
                args.sample_name: {
                    'GT': gt,
                    'PL': pl,
                    'DP': len(features)
                }
            }
        else:  # pseudo-haploid
            allele_probs = outputs['pseudo_haploid'][0]
            called_allele = ['A', 'C', 'G', 'T'][allele_probs.argmax()]
            
            sample_data = {
                args.sample_name: {
                    'GT': called_allele,
                    'GP': allele_probs.tolist(),
                    'DP': len(features)
                }
            }
        
        # Write to VCF
        vcf_writer.write_variant(
            site.chrom, site.pos, site.ref, site.alt, sample_data
        )
    
    # Clean up
    vcf_writer.close()
    bam.close()
    reference.close()
    
    print(f"Genotype calling complete. Output written to {args.output}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='aDNA ML Genotype Caller')
    parser.add_argument('--bam', required=True, help='Input BAM file')
    parser.add_argument('--reference', required=True, help='Reference FASTA')
    parser.add_argument('--sites', required=True, help='Sites to genotype')
    parser.add_argument('--model-checkpoint', required=True)
    parser.add_argument('--calibration-checkpoint', required=True)
    parser.add_argument('--config', required=True)
    parser.add_argument('--output', required=True)
    parser.add_argument('--sample-name', default='SAMPLE1')
    parser.add_argument('--mode', choices=['diploid', 'pseudo-haploid'], 
                       default='diploid')
    parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu')
    
    args = parser.parse_args()
    run_inference(args)

Conclusion

This proposal transforms aDNA genotype calling into a modern ML classification problem that respects the set-structured nature of evidence, the chemistry-driven artifacts (damage, contamination), and downstream needs. By focusing on calibrated, per-site probabilities and clean I/O (VCF/BCF, QC), the system can serve as a drop-in alternative or companion to classical callers, with tangible benefits at the coverage regimes where aDNA is most challenging.