3-class classification {0/0, 0/1, 1/1} with probabilities (genotype likelihoods)
Project Guide & Proposal
September 25, 2025
Ancient DNA (aDNA) presents unique statistical and engineering challenges: ultra-short fragments, characteristic chemical damage (e.g., CāT changes near read ends), low and uneven coverage, and modern contamination. Classical tools such as ATLAS solve these with explicit (often hand-tuned) generative models and likelihood-based inference. While robust, these methods can be conservative at very low depth and may not optimally integrate the spectrum of read-level cues (base quality, mapping quality, end-position, fragment length, strand, library prep).
This proposal reframes genotype calling for aDNA as a per-site, permutation-invariant, multi-instance, multi-output multiclass classification problem with calibrated probabilities. Each site (a reference coordinate) is a "bag" of reads; the model ingests read-level features and outputs a probability distribution over genotypes (e.g., 0/0, 0/1, 1/1) or over alleles (A/C/G/T for pseudo-haploid), plus optional auxiliary estimates (damage rates, contamination fraction, site usability).
The design emphasizes set-based inductive bias (order of reads must not matter), calibration (probabilities reflect reality at ultra-low coverage), interoperability (VCF/BCF outputs), and reproducibility (containerized pipelines, deterministic configs).
DNA breaks into 30-70 bp pieces; mapping and pileup context are limited.
Deamination produces predictable substitution patterns near fragment ends (e.g., CāT at read 5' ends).
Many sites have 0-1 reads; variance dominates; naive majority votes fail.
Modern human DNA can mix with the ancient signal.
Aligners and downstream callers may over-favor the reference allele in ambiguous evidence.
Damage-aware tools (e.g., ATLAS, ANGSD) estimate damage and contamination parameters and incorporate them into likelihoods for genotype inference. This proposal builds a learned inference module that:
Inputs are sets of reads (variable number, unordered) aligned to a fixed coordinate (a site). Outputs are per-site probability tables on a predefined grid (SNP panel or known variants), often with sparse evidence. Seq2seq assumes contiguous, length-coupled sequences; it is a poor inductive bias here. Set-based or local grid (pileup) models are more appropriate.
3-class classification {0/0, 0/1, 1/1} with probabilities (genotype likelihoods)
4-class classification {A, C, G, T} (optionally no-call) at a site
Continuous regression for damage rates and contamination fraction; binary classification for site usability
Invaluable for final validation (ground truth limited; use consensus across tools/labs)
Modern truth sets provide high-coverage genotypes to downsample & corrupt to mimic aDNA. Simulation from modern truth genomes should generate reads with:
Diploid task uses true genotypes at each site; pseudo-haploid task derives single-allele targets from diploid truth. Use sample-level splits; avoid leakage across replicates/libraries; stratify by coverage/damage regimes.
If no read passes, write a no-call and log the reason.
Tensor shape (R_max, W, C) with top-K reads, window width W (e.g., ±16 bp), and channels C including:
Each read is a token with features; add a [SITE] token to pool. Positional encodings reflect distance from read ends (damage signal), not genomic order. Attention summarizes variable numbers of reads; outputs site-level logits and auxiliary heads.
Transformer over read tokens + shallow CNN over a small reference context window; fuse via cross-attention or concatenation.
Cross-entropy/NLL on genotype (diploid) or allele (pseudo-haploid)
MSE/Huber for damage/contamination; BCE for usability
Weighted sum of all objectives
Start with moderate-coverage simulated data (~1Ć), then mix lower coverage (0.05Ć-0.5Ć).
Inputs: BAM/CRAM + reference + site
list
Outputs: per-site tensors
(NPZ/Parquet) & metadata
Applies filters and masks
Loads tensors, trains model, logs metrics,
checkpoints
Supports DDP and mixed precision
Fits temperature/Dirichlet/isotonic
Exports calibration parameters
Inputs: BAM/CRAM + reference + site
list + model + calibration
Outputs: VCF/BCF with
GT/PL/GL/DP/AD or pseudo-haploid VCF, QC JSON/TSV
Plots: fragment length, damage
profiles, calibration curves
Tables: coverage, callability,
error by bin
Docker images, Conda package, PyPI wheel
(optional)
WDL/Nextflow/Cromwell wrappers
Repo, containers, featurizer skeleton, unit tests
aDNA simulator, modern truth ingestion
Pseudo-haploid CNN, featurization, training loop
Calibration, coverage-bin metrics, baselines
Transformer over reads, damage head, REF/ALT aug
3-class head, VCF writer, temp scaling
Small real aDNA eval, downstream checks
CLI/docs/workflows/CI, reproducibility audits
Final benchmarks, ablations, user guide
20-22 weeks
| Risk | Mitigation Strategies |
|---|---|
| Simulation-Reality Gap |
|
| Reference Bias |
|
| Calibration Drift |
|
| Data Governance |
|
| Compute/Costs |
|
GT, PL/GL, DP, AD, GQ
Single allele calls for low coverage
JSON/TSV format
PDF/PNG visualizations
Filtered regions
Configs, git/container hashes
Supply GL/PL to off-the-shelf imputation tools
Inputs per site:
Shape: (R_max, W, C)
Transformer variant: uses one token per read plus a [SITE] token
#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT SAMPLE1
chr1 10583 . G A . PASS DP=3 GT:PL:DP:AD 0/1:10,0,12:3:2,1
#CHROM POS ID REF ALT QUAL FILTER INFO FORMAT SAMPLE1
chr1 10583 . G A . PASS DP=1 GT:GP 1:0.18,0.12,0.70,0.00
# 1) Load site list
sites = load_sites("sites.vcf_or_bed") # [(chrom, pos, ref, alt)]
# 2) Optional: filter to good regions
sites = [s for s in sites if is_mappable(s) and not in_blacklist(s)]
# 3) For each sample, decide callability & build inputs
for s in sites:
reads = fetch_reads(bam, s.chrom, s.pos, window=33)
bases = filter_bases(reads, MAPQ>=30, BQ>=30, snp_only=True)
if len(bases) == 0:
write_no_call(s)
continue
X = build_pileup_tensor(bases, R_max=48, W=33, channels=spec)
logits = model(X) # or model(read_tokens)
probs = calibrate(logits) # temperature scaling / isotonic
write_vcf_record(s, probs, depths=bases)
class PileupFeaturizer:
def __init__(self, config):
self.max_reads = config.max_reads # e.g., 48
self.window_size = config.window # e.g., 33
self.min_mapq = config.min_mapq # e.g., 30
self.min_bq = config.min_bq # e.g., 30
def featurize_site(self, bam, chrom, pos, ref, alt=None):
"""Extract features for a single genomic site."""
# Fetch reads overlapping the site
reads = []
for pileupcolumn in bam.pileup(chrom, pos-1, pos):
if pileupcolumn.pos == pos-1:
for pileupread in pileupcolumn.pileups:
if not pileupread.is_del and not pileupread.is_refskip:
reads.append(self._extract_read_features(pileupread))
# Sort by quality and truncate
reads = sorted(reads, key=lambda x: x['qual'], reverse=True)[:self.max_reads]
# Build tensor
return self._build_tensor(reads, ref, alt)
def _extract_read_features(self, pileupread):
"""Extract features from a single read."""
alignment = pileupread.alignment
return {
'base': alignment.query_sequence[pileupread.query_position],
'qual': alignment.query_qualities[pileupread.query_position],
'mapq': alignment.mapping_quality,
'pos_from_5p': pileupread.query_position,
'pos_from_3p': len(alignment.query_sequence) - pileupread.query_position - 1,
'strand': alignment.is_reverse,
'frag_len': abs(alignment.template_length),
'is_duplicate': alignment.is_duplicate,
'library': alignment.get_tag('RG') if alignment.has_tag('RG') else 'unknown'
}
def _build_tensor(self, reads, ref, alt):
"""Convert read features to tensor format."""
# Initialize tensor with zeros
tensor = np.zeros((self.max_reads, self.window_size, self.n_channels))
for i, read in enumerate(reads):
# One-hot encode base
base_idx = {'A': 0, 'C': 1, 'G': 2, 'T': 3}.get(read['base'], 4)
tensor[i, self.window_size//2, base_idx] = 1
# Add quality scores
tensor[i, self.window_size//2, 5] = read['qual'] / 40.0
tensor[i, self.window_size//2, 6] = read['mapq'] / 60.0
# Add position features
tensor[i, self.window_size//2, 7] = read['pos_from_5p'] / 100.0
tensor[i, self.window_size//2, 8] = read['pos_from_3p'] / 100.0
# Add binary flags
tensor[i, self.window_size//2, 9] = float(read['strand'])
tensor[i, self.window_size//2, 10] = float(read['is_duplicate'])
return tensor
import torch
import torch.nn as nn
import torch.nn.functional as F
class TransformerGenotypeCaller(nn.Module):
def __init__(self, config):
super().__init__()
self.d_model = config.d_model
self.n_heads = config.n_heads
self.n_layers = config.n_layers
# Input embedding
self.read_embedding = nn.Linear(config.n_features, self.d_model)
self.site_token = nn.Parameter(torch.randn(1, 1, self.d_model))
# Positional encoding for damage signal
self.pos_encoding = DamagePositionalEncoding(self.d_model)
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.d_model,
nhead=self.n_heads,
dim_feedforward=config.d_ff,
dropout=config.dropout
)
self.transformer = nn.TransformerEncoder(encoder_layer, self.n_layers)
# Output heads
self.genotype_head = nn.Linear(self.d_model, 3) # 0/0, 0/1, 1/1
self.pseudo_haploid_head = nn.Linear(self.d_model, 4) # A, C, G, T
self.damage_head = nn.Linear(self.d_model, 2) # C->T, G->A rates
self.contamination_head = nn.Linear(self.d_model, 1) # contamination %
# Calibration temperature
self.temperature = nn.Parameter(torch.ones(1))
def forward(self, read_features, read_mask=None):
batch_size = read_features.size(0)
# Embed reads
read_embeds = self.read_embedding(read_features)
# Add site token
site_token = self.site_token.expand(batch_size, 1, -1)
x = torch.cat([site_token, read_embeds], dim=1)
# Add positional encoding
x = self.pos_encoding(x, read_features)
# Transform
x = self.transformer(x, src_key_padding_mask=read_mask)
# Extract site representation (first token)
site_repr = x[:, 0, :]
# Compute outputs
genotype_logits = self.genotype_head(site_repr) / self.temperature
pseudo_haploid_logits = self.pseudo_haploid_head(site_repr)
damage_estimates = torch.sigmoid(self.damage_head(site_repr))
contamination = torch.sigmoid(self.contamination_head(site_repr))
return {
'genotype': F.softmax(genotype_logits, dim=-1),
'pseudo_haploid': F.softmax(pseudo_haploid_logits, dim=-1),
'damage': damage_estimates,
'contamination': contamination
}
class DamagePositionalEncoding(nn.Module):
"""Encodes position from read ends (damage signal)."""
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
def forward(self, x, read_features):
# Extract distance from 5' and 3' ends
dist_5p = read_features[..., 7:8] # Assuming this is the position
dist_3p = read_features[..., 8:9]
# Create sinusoidal encoding based on end distances
pe = torch.zeros_like(x)
position = torch.cat([dist_5p, dist_3p], dim=-1)
div_term = torch.exp(torch.arange(0, self.d_model, 2) *
-(math.log(10000.0) / self.d_model))
pe[..., 0::2] = torch.sin(position[..., 0:1] * div_term)
pe[..., 1::2] = torch.cos(position[..., 1:2] * div_term)
return x + pe
def train_epoch(model, dataloader, optimizer, criterion, device):
model.train()
total_loss = 0
correct_predictions = 0
total_samples = 0
for batch in tqdm(dataloader):
# Move to device
features = batch['features'].to(device)
genotypes = batch['genotype'].to(device)
coverage = batch['coverage'].to(device)
# Forward pass
outputs = model(features)
# Compute losses
genotype_loss = criterion['genotype'](outputs['genotype'], genotypes)
# Coverage-aware weighting
weights = compute_coverage_weights(coverage)
weighted_loss = (genotype_loss * weights).mean()
# Add auxiliary losses
if 'damage' in batch:
damage_loss = F.mse_loss(outputs['damage'], batch['damage'].to(device))
weighted_loss += 0.1 * damage_loss
# Backward pass
optimizer.zero_grad()
weighted_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# Track metrics
total_loss += weighted_loss.item()
predictions = outputs['genotype'].argmax(dim=-1)
correct_predictions += (predictions == genotypes).sum().item()
total_samples += genotypes.size(0)
return {
'loss': total_loss / len(dataloader),
'accuracy': correct_predictions / total_samples
}
def compute_coverage_weights(coverage):
"""Weight samples by inverse coverage to balance training."""
# Higher weight for low-coverage samples
weights = 1.0 / (1.0 + torch.log1p(coverage))
return weights / weights.mean()
class TemperatureScaling(nn.Module):
"""Temperature scaling for calibration."""
def __init__(self, n_bins=10):
super().__init__()
self.temperature = nn.Parameter(torch.ones(1))
self.n_bins = n_bins
def fit(self, logits, labels, coverage=None):
"""Fit temperature on validation set."""
self.temperature.requires_grad = True
optimizer = torch.optim.LBFGS([self.temperature], lr=0.01, max_iter=50)
def eval_loss():
optimizer.zero_grad()
scaled_logits = logits / self.temperature
loss = F.cross_entropy(scaled_logits, labels)
loss.backward()
return loss
optimizer.step(eval_loss)
self.temperature.requires_grad = False
def forward(self, logits):
return logits / self.temperature
def compute_ece(self, probs, labels):
"""Expected Calibration Error."""
confidences, predictions = probs.max(dim=1)
accuracies = predictions.eq(labels)
ece = 0
for bin_idx in range(self.n_bins):
bin_lower = bin_idx / self.n_bins
bin_upper = (bin_idx + 1) / self.n_bins
in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
prop_in_bin = in_bin.float().mean()
if prop_in_bin > 0:
accuracy_in_bin = accuracies[in_bin].float().mean()
avg_confidence_in_bin = confidences[in_bin].mean()
ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
return ece
import pysam
from datetime import datetime
class VCFWriter:
def __init__(self, output_path, reference, samples, mode='diploid'):
self.output_path = output_path
self.reference = reference
self.samples = samples
self.mode = mode
# Create VCF header
self.header = pysam.VariantHeader()
# Add metadata
self.header.add_meta('fileformat', 'VCFv4.3')
self.header.add_meta('fileDate', datetime.now().strftime('%Y%m%d'))
self.header.add_meta('source', 'aDNA_ML_Genotyper_v1.0')
self.header.add_meta('reference', reference)
# Add INFO fields
self.header.add_meta('INFO', items=[
('ID', 'DP'),
('Number', 1),
('Type', 'Integer'),
('Description', 'Total depth')
])
# Add FORMAT fields
if self.mode == 'diploid':
self.header.add_meta('FORMAT', items=[
('ID', 'GT'),
('Number', 1),
('Type', 'String'),
('Description', 'Genotype')
])
self.header.add_meta('FORMAT', items=[
('ID', 'PL'),
('Number', 'G'),
('Type', 'Integer'),
('Description', 'Phred-scaled genotype likelihoods')
])
else: # pseudo-haploid
self.header.add_meta('FORMAT', items=[
('ID', 'GT'),
('Number', 1),
('Type', 'String'),
('Description', 'Pseudo-haploid genotype')
])
self.header.add_meta('FORMAT', items=[
('ID', 'GP'),
('Number', 4),
('Type', 'Float'),
('Description', 'Genotype probabilities for A,C,G,T')
])
# Add samples
for sample in samples:
self.header.add_sample(sample)
# Add contigs
for contig in self.reference.references:
self.header.add_meta('contig', items=[
('ID', contig),
('length', self.reference.get_reference_length(contig))
])
# Open VCF for writing
self.vcf = pysam.VariantFile(self.output_path, 'w', header=self.header)
def write_variant(self, chrom, pos, ref, alt, sample_data):
"""Write a single variant to VCF."""
record = self.vcf.new_record()
record.chrom = chrom
record.pos = pos
record.ref = ref
record.alts = (alt,) if alt else None
# Add INFO fields
total_depth = sum(data.get('DP', 0) for data in sample_data.values())
record.info['DP'] = total_depth
# Add sample data
for sample, data in sample_data.items():
if self.mode == 'diploid':
record.samples[sample]['GT'] = data['GT']
record.samples[sample]['PL'] = data['PL']
else:
record.samples[sample]['GT'] = data['GT']
record.samples[sample]['GP'] = data['GP']
self.vcf.write(record)
def close(self):
self.vcf.close()
def run_inference(args):
"""Main inference pipeline."""
# Load configuration
config = load_config(args.config)
# Initialize model
model = TransformerGenotypeCaller(config.model)
model.load_state_dict(torch.load(args.model_checkpoint))
model.eval()
model = model.to(args.device)
# Load calibration
calibrator = TemperatureScaling()
calibrator.load_state_dict(torch.load(args.calibration_checkpoint))
# Initialize featurizer
featurizer = PileupFeaturizer(config.featurization)
# Open BAM and reference
bam = pysam.AlignmentFile(args.bam, 'rb')
reference = pysam.FastaFile(args.reference)
# Load sites
sites = load_sites(args.sites)
# Initialize VCF writer
vcf_writer = VCFWriter(
args.output, reference, [args.sample_name], mode=args.mode
)
# Process sites
for site in tqdm(sites, desc="Processing sites"):
# Featurize
features = featurizer.featurize_site(
bam, site.chrom, site.pos, site.ref, site.alt
)
# Skip if no reads
if features is None or features.sum() == 0:
continue
# Run model
with torch.no_grad():
features_tensor = torch.tensor(features).unsqueeze(0).to(args.device)
outputs = model(features_tensor)
# Apply calibration
if args.mode == 'diploid':
genotype_probs = calibrator(outputs['genotype'])
genotype_probs = F.softmax(genotype_probs, dim=-1)
# Convert to Phred-scaled likelihoods
pl = probs_to_phred(genotype_probs[0])
gt = call_genotype(genotype_probs[0])
sample_data = {
args.sample_name: {
'GT': gt,
'PL': pl,
'DP': len(features)
}
}
else: # pseudo-haploid
allele_probs = outputs['pseudo_haploid'][0]
called_allele = ['A', 'C', 'G', 'T'][allele_probs.argmax()]
sample_data = {
args.sample_name: {
'GT': called_allele,
'GP': allele_probs.tolist(),
'DP': len(features)
}
}
# Write to VCF
vcf_writer.write_variant(
site.chrom, site.pos, site.ref, site.alt, sample_data
)
# Clean up
vcf_writer.close()
bam.close()
reference.close()
print(f"Genotype calling complete. Output written to {args.output}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='aDNA ML Genotype Caller')
parser.add_argument('--bam', required=True, help='Input BAM file')
parser.add_argument('--reference', required=True, help='Reference FASTA')
parser.add_argument('--sites', required=True, help='Sites to genotype')
parser.add_argument('--model-checkpoint', required=True)
parser.add_argument('--calibration-checkpoint', required=True)
parser.add_argument('--config', required=True)
parser.add_argument('--output', required=True)
parser.add_argument('--sample-name', default='SAMPLE1')
parser.add_argument('--mode', choices=['diploid', 'pseudo-haploid'],
default='diploid')
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_args()
run_inference(args)
This proposal transforms aDNA genotype calling into a modern ML classification problem that respects the set-structured nature of evidence, the chemistry-driven artifacts (damage, contamination), and downstream needs. By focusing on calibrated, per-site probabilities and clean I/O (VCF/BCF, QC), the system can serve as a drop-in alternative or companion to classical callers, with tangible benefits at the coverage regimes where aDNA is most challenging.