← Back to main page

Mini-Fold — Protein Structure Diffusion Model

Full architecture history, training tables, loss functions, and analysis for the protein backbone diffusion model.

v13 Architecture

Protein Backbone Diffusion Model v13

Training
Goal: Learn beta-sheet topology by enabling long-range pair information propagation. v12 pair stack is blind to sequence separations >32 residues, which is exactly where beta-sheet hydrogen bonds live. v13 fixes this with Evoformer-style triangle updates and log-scaled relative position encoding.
Architecture Enhanced pair stack (8 blocks, triangle multiplicative updates) + IPA denoiser (8 layers, 8 heads, 8 query points, 512-dim) + frozen ContactClassifier encoder — ~35–38M params
Training Single GPU, batch_size=4, grad_accum=4 (eff=16), LR=2e-5, pair_stack 3x LR, CFG p_uncond=0.1, SC=0.25
Initialization v12b EMA weights (b30 best) via strict=False; denoiser/aux/rg: exact match; pair_stack: trains from scratch (new architecture)
Dataset CATH 4.2: 18,024 train / 608 val proteins, max 125 residues

Key Architectural Changes

Deeper Pair Stack
4 → 8 EnhancedPairBlock blocks. Each block applies: (1) triangle multiplicative update (outgoing), (2) triangle multiplicative update (incoming), (3) row + column axial attention + FFN. The triangle updates implement transitivity: "if i contacts k and j contacts k, then i and j are structurally related." This is exactly how beta-sheet topology is encoded — two strands share contacts via loop residues. Hidden dim: tri_mul_dim=64, d_pair=128.
Log-Scaled RPE
Replaces linear-clipped max_rel_pos=32 with 128 log-spaced bins covering 0–512 residues. The old RPE was zero for |i-j| > 32, making long-range sheet contacts invisible. New encoding: fine linear bins for positions 0–8 (helices: i+3, i+4), log-spaced bins for positions 8–512 (sheets: i+20..i+100+), sign-aware (separate embeddings for upstream/downstream). Additionally, sinusoidal continuous RPE features (32-dim sin/cos encoding projected to d_pair) provide smooth interpolation between bins.
Classifier-Free Guidance
10% conditioning dropout during training (p_uncond=0.1). When triggered, residue tokens are replaced with MASK tokens (preserving CLS/EOS/PAD structure) so the model learns unconditional generation. At sampling time, enables CFG: eps = eps_uncond + w * (eps_cond - eps_uncond). Critical fix: using PAD as the null token caused mask = ids.ne(PAD) to return all-False, creating degenerate all-masked pair representations that produced NaN. Using MASK instead preserves valid attention masks.
LR Schedule
Three-phase schedule motivated by v12b finding that breakthroughs happen in narrow LR windows: (1) Linear warmup: 3 epochs (0.01x → 1x LR), (2) Constant LR: 20 epochs at peak LR (the key change — give the model more time at productive learning rates), (3) Slow cosine decay: remaining ~37 epochs down to eta_min=1e-6. Per-module LR groups: pair_stack at 3x base LR (6e-5), aux_pair at 2x.
Temperature/Eta Sampling
DDIM with configurable eta parameter to fix mode collapse observed in pure DDIM (eta=0) sampling. Adding stochasticity (eta > 0) diversifies generated structures while maintaining quality.
Parameters
~35–38M total (up from 31.1M in v12b). Extra ~5–8M from deeper pair stack + triangle multiplicative updates. Denoiser, aux_pair_stack, rg_pred, dist_head: exact match with v12b. Pair stack: entirely new architecture, trains from scratch.
The NaN Debugging Story

Three sources of numerical instability discovered and fixed during v12/v13 development:

  • fp16 epsilon underflow in Gram-Schmidt: Frame initialization computes cross products and normalizations on noised coordinates. Under fp16 autocast, near-zero vectors produced NaN during normalization. Fixed with explicit fp32 casting and epsilon guards.
  • IPA point attention overflow: The 3D point attention distances can become very large for distant residues, causing exp() overflow in softmax. Triangle multiplicative updates now force fp32 for the einsum accumulation (L terms under fp16 can overflow to inf, and inf * sigmoid(0) = NaN). Output clamped to [-1e4, 1e4] before norm/projection.
  • Unsafe torch.cdist backward: torch.cdist backward produces NaN when distance is exactly zero (self-distances and coincident masked-out coords at the origin). Replaced with manual computation: (diff.pow(2).sum(-1) + 1e-10).sqrt().

v13 Training Progress

Epoch Val Total FAPE Frame Rot Dist MSE Bond Aux Dist Chirality Angle Rg DDIM TM Status
1 2.571 1.657 0.846 0.353 0.009 0.323 0.347 0.130 0.036 0.114 NEW BEST
2 2.494 1.610 0.792 0.359 0.007 0.324 0.336 0.115 0.036 NEW BEST
3 2.533 1.611 0.829 0.374 0.008 0.324 0.336 0.120 0.036 pat 1
4 2.373 1.546 0.771 0.312 0.008 0.323 0.319 0.113 0.036 NEW BEST
5 2.467 1.595 0.809 0.338 0.006 0.324 0.332 0.113 0.036 pat 1
6 2.447 1.588 0.780 0.343 0.006 0.324 0.319 0.108 0.035 pat 2
7 2.539 1.622 0.834 0.365 0.007 0.350 0.334 0.116 0.035 pat 3
8 2.550 1.651 0.841 0.346 0.005 0.350 0.347 0.112 0.035 pat 4
9 2.568 1.623 0.826 0.391 0.007 0.353 0.344 0.117 0.035 pat 5
10 2.501 1.588 0.810 0.377 0.005 0.360 0.341 0.107 0.035 pat 6
11 2.504 1.611 0.815 0.355 0.005 0.384 0.333 0.106 0.035 pat 7
12 2.431 1.555 0.783 0.353 0.006 0.391 0.329 0.107 0.035 pat 8
13 2.494 1.591 0.811 0.363 0.005 0.400 0.342 0.110 0.035 pat 9
14 2.534 1.622 0.810 0.375 0.006 0.362 0.336 0.108 0.035 pat 10
15 2.390 1.556 0.770 0.318 0.006 0.396 0.354 0.110 0.035 0.105 pat 11

Comparison with Previous Architectures

E4 breakthrough: v13 now surpasses v12b’s best frame_rot (b16: 0.780) after only 4 epochs — frame_rot 0.771, FAPE 1.546, val_total 2.373. The dist_mse of 0.312 matches v12b b30’s all-time best. For context, v12b required 16 epochs with per-module gradient clipping to reach frame_rot 0.780; v13 surpassed that milestone 4× faster. The deeper pair stack with triangle multiplicative updates and log-scaled RPE provides much stronger structural signal, propagating long-range contacts that v12b’s 4-block stack (max_rel=32) could not represent. DDIM sampling runs every 5 epochs, so TM-score will first be reported at E5.

v13 Loss Curves

v12/v12b/v13 loss curves

Loss curves from the plotting script. v13 overlay will appear as training epochs complete.

v13 Pseudo-Ramachandran (E5)

v13 pseudo-Ramachandran plots at E5

Interpretation

Ground truth (left, 102K residues) shows the expected pseudo-Ramachandran landscape: a dominant α-helix cluster at ~(50°, 50°) and a β-sheet ridge at ~(−120°, 120°), with the characteristic L-shaped density that reflects the strong preference for these secondary structure elements in globular proteins.

DDIM samples (center, 3.2K residues, 100 steps, η=0.3, guidance=1.5) at E5 show early structural signal: density is beginning to concentrate in the α-helix and β-sheet regions rather than being uniformly scattered. However, the distribution is still diffuse with substantial density in physically disfavored regions, indicating the model has not yet fully learned the local geometry constraints that produce sharp Ramachandran basins. This is expected at epoch 5 — v12b required ~15 epochs before DDIM samples showed tight clustering.

x0-prediction (right, single-step denoising from t=100) shows comparable quality to DDIM, confirming the denoiser has learned meaningful structure. The similar density between DDIM and x0-pred suggests the iterative refinement in DDIM is not yet adding much beyond single-step prediction — a gap that should widen as training continues and the model learns to leverage the multi-step denoising process.

v14 — Full Backbone Diffusion Model

Full Backbone Diffusion Model v14

Initializing

Why Full Backbone?

  • Cα-only is lossy: v13 predicts a single point per residue — backbone geometry (N-Cα-C-O) must be reconstructed post-hoc using tools like pNeRF or idealization heuristics.
  • Post-hoc reconstruction introduces errors: Peptide bond planarity, bond angles, and omega dihedrals are approximated, not learned. Small errors compound over the chain, especially for longer proteins.
  • Learnable peptide geometry: Full backbone prediction lets the model learn actual peptide plane geometry, enforcing physically valid backbone structures during training rather than hoping reconstruction recovers them.
  • More natural frame representation: Frames are built from actual N-Cα-C planes instead of Gram–Schmidt on Cα triplets. This is more numerically stable, especially at high noise levels where consecutive Cα atoms can become nearly collinear.
  • Direct angle supervision: Enables direct φ/ψ angle supervision and omega planarity loss — the model is penalized for non-planar peptide bonds, not just Cα position errors.

Architecture Overview

Model Size
~35–40M params. Same IPA denoiser as v13 (8 layers, 8 heads, 8 query points, 512-dim), with a new backbone reconstruction head that maps per-residue single representations to full backbone atom positions.
Key Change
Model predicts per-residue rigid frames (rotation + translation) → places N, Cα, C, O using ideal geometry + learned offsets. The frame defines the local coordinate system; backbone atoms are positioned relative to it using idealized bond lengths as initialization.
Ideal Bond Lengths
N–Cα ≈ 1.458Å, Cα–C ≈ 1.523Å, C–N(next) ≈ 1.329Å, C=O ≈ 1.231Å. These serve as initialization for the backbone head; the model learns residual offsets from these idealized values.
Noise Strategy
Noise is added to Cα positions (frame translations) — backbone atoms (N, C, O) are derived from the denoised frames, not independently noised. This ensures backbone geometry stays consistent throughout the diffusion process.
Initialization
Initialized from v13/v12b weights via strict=False. All existing modules (denoiser, pair stack, aux heads) load exact weights; the backbone reconstruction head trains from scratch.

New Loss Functions

Backbone FAPE
Frame-aligned point error computed over all 4 backbone atoms (N, Cα, C, O), not just Cα. Measures how well the predicted backbone superimposes onto the ground truth in each residue’s local frame.
Backbone Bond Loss
MSE on predicted N–Cα, Cα–C, C–N(next), and C=O bond lengths vs ideal values. Ensures the model produces chemically valid bond distances.
Backbone Angle Loss
MSE on predicted N–Cα–C, Cα–C–N(next), and C–N–Cα bond angles vs ideal values. Penalizes deviations from expected tetrahedral/planar geometry.
Omega Dihedral Loss
Peptide bond planarity — cosine loss enforcing ω ≈ 180°. The peptide bond is planar due to partial double-bond character; this loss directly penalizes non-planar (cis-like) peptide bonds.

v14 Training Progress

Epoch Val Total FAPE Frame Rot Dist MSE BB FAPE BB Bond Omega DDIM TM Status
1 2.728 1.550 0.769 0.347 1.542 0.004 0.076 0.103 NEW BEST
2 3.049 1.559 0.787 0.327 1.550 0.004 0.047 pat 1 (bb_ramp 0.4)

v14 Pseudo-Ramachandran

v14 Restarted (Mar 11, 2026)

The previous v14 run (E2–E7) was stopped after an audit revealed three bugs causing diverging total loss despite improving backbone sub-losses:

  1. Ideal offsets 10× too large: BackboneAtomHead registered ideal N/C/O offsets in Ångstroms (~1.5Å) but the model operates in normalized space (÷10). Atoms were placed ~15Å away instead of ~1.5Å, producing huge FAPE losses.
  2. Backbone head LR group misrouted: hasattr(model, "backbone_head") was always False (head lives at model.denoiser.backbone_head), so it got 1× LR instead of the intended 5×. The new head trained too slowly to compensate.
  3. bb_angle loss invisible: Missing from loss_keys, so it was never logged or tracked — contributing to total loss but invisible in CSV/logs.

All three bugs are now fixed. New Ramachandran plots will be generated after the restarted training produces enough epochs.

Relationship to v13

v14 runs in parallel with v13, not as a replacement. v13 explores pair stack improvements for long-range contacts and beta-sheet topology (deeper triangle updates, log-scaled RPE). v14 explores a richer output representation — predicting full backbone atom positions instead of Cα-only. These are orthogonal improvements: v13 improves how the model reasons about structure (pair representation), while v14 improves what the model outputs (backbone fidelity). Insights from both lines will be combined in a future version.

Side Chain Prediction (Future)

Side Chain Reconstruction

Planned
Goal: Extend v14’s full backbone prediction to include Cβ atoms and eventually full side chain reconstruction, enabling the model to produce all-atom protein structures directly from sequence.

Planned Approach

  • Chi angle prediction: Predict side chain dihedral angles (χ1–χ4) conditioned on backbone frames, residue type, and pair representation.
  • Hierarchical placement: Backbone atoms are placed first (from v14 frames), then side chain atoms are built outward using predicted chi angles and ideal bond geometry.
  • Rotamer-aware loss: Side chain loss will account for rotamer distributions — penalizing chi angles that fall outside known rotamer basins for each residue type.
  • Clash penalty: Steric clash loss between predicted side chain atoms to enforce physically valid packing.

Prerequisites

This module depends on v14’s full backbone prediction being stable and well-converged first. Accurate backbone frames are essential — side chain atoms are placed relative to the backbone frame, so errors in backbone geometry propagate directly to side chain positions. Work will begin once v14 demonstrates consistent bond/angle geometry and competitive FAPE scores.

v12b Training Archive Archived — superseded by v13

Protein Backbone Diffusion Model v12b

Complete
Goal: Generate realistic protein backbone structures (Cα coordinates) conditioned on sequence, using contact-aware embeddings from Stage 1. v12 scales IPA capacity 4.1x over v11b (31.1M params) to resolve the gradient competition collapse that limited v11b — same validated loss design, wider single representation, deeper frame update MLP.
Architecture IPA denoiser (8 layers, 8 heads, 8 query points) + independent aux pair stack (64-dim) + frozen ContactClassifier encoder — 31.1M params (29.9M trainable, 1.2M frozen); d_ipa_hidden=512, d_ipa_ffn=1024, 2-layer FrameUpdate MLP
Training Single RTX 2080 Ti (11.5 GB), batch_size=4, grad_accum=4 (eff=16), LR=2e-5, gradient checkpointing, T=1000, DDIM-50 eval, SC=0.25, w_frame_rot=0.5
Dataset CATH 4.2: 18,024 train / 608 val proteins, max 125 residues

Why IPA? The v10 Ceiling

v10 used an 8-layer EGNN denoiser that learned pairwise distance statistics well (dist_mse 60% below random) but could not learn protein topology. FAPE stayed at its random baseline (~1.31) across 21 epochs and TM-score peaked at 0.131 (random ~0.10). EGNN has no concept of local reference frames — it reasons about distances, not backbone geometry. IPA solves this by maintaining and refining per-residue rigid-body frames (rotation + translation) through 3D point attention in local coordinate systems.

v11b Key Changes (from v11)

  • Frame rotation loss (w=0.5) — direct angular distance: 1 - cos(angle) between learned R and Gram-Schmidt ground truth. Gives rotation quaternions direct gradient for the first time.
  • FAPE with learned framesfape_loss_with_frames uses R_pred from IPA layers instead of rebuilding from coordinates. The R → FAPE gradient path is now intact.
  • LR halved — 2e-5 (from 5e-5) to prevent overshooting now that rotations receive gradient
  • Self-conditioning reduced — SC probability 0.25 (from 0.5) to let the model learn from scratch more
  • Initialized from v11 E10 best — EMA weights only, fresh optimizer + scheduler

v12b Training Fixes (over v12)

  • Per-module gradient clipping: denoiser max_norm=1.0, pair_stack=0.5, aux_pair=0.5
  • Per-module LR groups: pair_stack gets 3x base LR (6e-5) to compensate gradient attenuation
  • Pair stack tripwire: if grad norm < 0.01 for 50 consecutive steps, inject gradient noise (scale=0.01)
  • Hard halt: if grad norm < 0.001 for 500 total steps, stop training
  • Atomic checkpoint saves: write to tmp file then rename to survive SLURM preemption
  • NaN loss batches: skipped entirely instead of poisoning gradient accumulator

v12b Loss Weights

Loss v11 v11b/v12b Rationale
FAPE (w/ learned R)1.01.0v11b uses learned R_pred instead of Gram-Schmidt rebuilt frames
Frame Rot0.5NEW: direct angular loss on learned R vs ground truth. The v11b fix.
Bond3.03.0
Clash0.10.1
Aux dist0.030.03
Dist MSE1.01.0
Chirality0.10.1
Angle0.50.5
Rg0.50.5

v12/v12b Loss Curves

V12 diffusion training loss curves

Pseudo-Ramachandran Analysis

Pseudo-Ramachandran plots

Pseudo-dihedrals computed from consecutive Cα positions. Ground truth (left) shows clear α-helix (~50°,50°) and β-sheet (~−120°,120°) clusters.

v12/v12b Full Training Table

Epoch Val Total Val FAPE Val Frame Rot Val Dist Val Bond DDIM TM DDIM RMSD Status
14.9282.0741.1990.8570.1350.10115.28Å NEW BEST
23.5382.0021.1380.5810.0280.13213.73Å NEW BEST
33.0411.8591.0220.4490.0160.11815.33Å NEW BEST
42.8561.7900.9420.3990.0130.10316.56Å NEW BEST
52.8521.7630.9290.4280.0130.10116.88Å NEW BEST
62.7711.7320.9100.4030.0100.10216.99Å NEW BEST
72.7071.7490.8810.3430.0100.09917.26Å NEW BEST
82.7191.7080.8840.3990.0090.09917.23Å pat 1
92.7201.7280.8760.3860.0080.10117.04Å pat 2
102.5811.6510.8330.3520.0080.10217.09Å NEW BEST
112.6851.6840.8640.3910.0090.09817.41Å pat 1
122.8311.8040.9050.3910.0090.08718.46Å pat 2
132.8221.8140.9130.3700.0090.07819.54Å pat 3
— v12b rollback to E10 EMA — per-module grad clipping, pair_stack 3x LR, tripwire
b12.5411.6270.8180.3660.0080.10216.98Å NEW BEST
b22.5771.6620.8310.3580.0080.10516.89Å pat 1
b32.5551.6520.8370.3390.0090.10317.09Å pat 2
b42.6481.6800.8700.3860.0090.10716.88Å pat 3
b52.5651.6440.8350.3560.0080.10417.02Å pat 4
b62.5401.6380.8270.3430.0070.10816.80Å pat 5
b72.6411.6680.8520.3960.0080.11316.47Å pat 6
b82.6071.6510.8510.3790.0080.11316.39Å pat 7
b92.5641.6320.8290.3740.0070.11216.40Å pat 8
b102.5211.6300.8130.3410.0060.11316.42Å NEW BEST
b112.5531.6330.8290.3620.0060.11016.52Å pat 1
b122.5581.6280.8350.3700.0070.11116.56Å pat 2
b132.5341.6350.8260.3430.0070.11016.64Å pat 3
b142.5001.6060.8140.3490.0060.10717.13Å NEW BEST
b152.4801.5860.8070.3910.0060.11216.62Å NEW BEST
b162.4691.5640.7800.3470.0060.11216.62Å NEW BEST
b172.5061.6080.8150.3550.0060.11016.88Å pat 1
b182.4971.6260.8090.3350.0050.10816.83Å pat 2
b192.5171.6070.8200.3650.0060.11016.92Å pat 3
b202.4461.5780.8020.3340.0050.10317.53Å NEW BEST
b212.4821.5890.7970.3620.0060.10417.54Å pat 1
b222.4881.6070.8010.3470.0050.10717.42Å pat 2
b232.6091.6410.8420.4080.0060.10617.61Å pat 3
b242.4501.5890.7990.3290.0060.10217.80Å pat 4
b252.4561.6040.8070.3200.0050.10118.02Å pat 5
b262.4601.5970.7960.3350.0050.10118.02Å pat 6
b272.5481.6350.8390.3640.0050.10517.69Å pat 7
b282.5011.6260.8130.3420.0050.10717.53Å pat 8
b292.4431.5780.7910.3400.0050.11017.02Å NEW BEST
b302.4131.5840.7900.3120.0050.10717.38Å NEW BEST
b312.4621.5870.8080.3420.0050.10817.43Å pat 1
Gradient cosine similarity diagnostic (E10 vs E12 vs E13)

Per-loss gradient directions on shared parameters (denoiser + pair_stack). Negative cosine = direct competition. Positive = aligned.

Loss Pair E10 (best) E12 E13 Verdict
FAPE vs frame_rot +0.59 +0.57 +0.50 Aligned
dist_mse vs FAPE +0.16 +0.14 +0.41 Near-orthogonal
dist_mse vs frame_rot +0.24 +0.00 -0.02 Near-orthogonal
FAPE vs bond_geom -0.04 -0.23 +0.02 Near-orthogonal

Key finding: No gradient competition between any loss pair. FAPE and frame_rot are strongly aligned (+0.5 to +0.6). The E11-E13 regression was caused entirely by gradient starvation of the pair_stack module — not conflicting loss objectives.

v11 post-mortem (different failure mode)
Hypothesis Result
Frame confidence starving updates NO — 57.7% conf > 0.5
Gradients dead at frame_update NO — highest grad norm (4.017)
FAPE gradient reaches frame_update YES — 13.03 (largest)
Learned R used in output NO — R discarded

Root cause (v11): x0_pred = t_vec discarded learned R. v11b fix: frame_rotation_loss + fape_loss_with_frames using learned R.

Loss Reference & Targets

Metric w Type v11b E1 v11b E5 v12 E1 v12 E10 Target / Interpretation
FAPE1.0L11.9341.7562.0741.651 Primary metric. Uses learned R_pred frames. v10 ceiling=1.31, untrained >2.0. <1.0 = correct folds.
Frame Rot0.51−cosθ1.0670.8881.1990.833 Angular error of learned R vs Gram-Schmidt truth. 1.0 = ~90° (random), 0 = perfect. Target: <0.5 by E15.
Dist MSE1.0MSE0.4090.3840.8570.352 Pairwise Cα distance error. <0.1 = sub-Å accuracy. Plateauing ~0.35–0.40.
Bond3.0*MSE0.0150.0100.1350.008 Cα–Cα bond error. *annealed 1→3. Solved.
Rg0.5MSE0.0380.0381.8050.038 Radius of gyration error. Converged by E3.
TM-scoreDDIM0.0940.0910.1010.102 50-step DDIM. Target: >0.15 by E10, >0.30 by E20. >0.17 = recognizable folds.
RMSDDDIM17.22Å17.62Å15.28Å17.09Å <10Å = partial fold. <5Å = high quality.
Last updated: 2026-03-10
v11b Historical Results (14 epochs, gradient competition collapse at E8 — superseded by v12)

v11b validated the IPA + frame rotation loss design. Best epoch E8: FAPE 1.655, frame_rot 0.830, TM 0.093. After E8, gradient competition between dist_mse and frame_rot through shared 128-dim single representation caused collapse — TM dropped to 0.061, frame_rot reverted to 0.926. Stopped at E14.

EpochVal TotalVal FAPEFrame RotVal DistVal BondDDIM TMStatus
13.0571.9341.0670.4090.0150.094BEST
52.7361.7560.8880.3840.0100.091BEST
82.5801.6550.8300.3510.0090.093BEST (peak)
112.7971.8130.8800.3740.0080.079pat 3
142.9731.8980.9260.4230.0110.061pat 6 (stopped)
v11 Historical Results (15 epochs, no frame rotation supervision — superseded)

v11 used the same IPA architecture but had a critical bug: x0_pred = t_vec discarded learned rotation matrices R. FAPE peaked at 1.818 (E10) then degraded to 2.052 (E15). DDIM TM-score collapsed from 0.095 to 0.046.

V11 diffusion training loss curves (historical)
EpochVal TotalVal FAPEVal DistVal BondDDIM TMStatus
14.3381.9501.1630.1710.099BEST
32.7411.9880.4790.0220.128BEST
62.5581.8650.4480.0190.095BEST
102.4311.8180.3820.0150.087BEST (peak)
132.5391.9190.3800.0160.053pat 3
152.7262.0520.4110.0190.046pat 5 (stopped)
v10 Historical Results (21 epochs, EGNN — superseded)

v10 used an 8-layer EGNN denoiser (14.6M params). After 21 epochs: dist_mse 60% below random (0.218), bond essentially solved (0.006), but FAPE stuck at random (~1.31) and TM-score peaked at 0.131. Best val structural = 0.805 (E16). DDIM best: TM=0.131, RMSD=14.53Å.

V10 diffusion training loss curves (historical)

Contact Classifier (Stage 1 — Multi-task Encoder)

Contact Classifier (Stage 1 — Multi-task Encoder)

Complete
Goal: Train a transformer encoder from scratch on CATH 4.2 dataset (18k proteins) to jointly predict CATH class/architecture labels AND inter-residue contact maps from sequence alone. The learned embeddings encode spatial proximity information needed for Stage 2.
Architecture ContactClassifier — 1.2M params, dim=128, 2 transformer towers, d_pair=64, 2 contact prediction blocks with outer product mean
Training Single GPU (1080 Ti), batch_size=24, lr=2e-4 with warmup cosine schedule, patience=15 early stopping
Resilience Per-epoch checkpoints with auto-resume, CSV loss logging, self-resubmitting watchdog system on SLURM

Training Progress (all 25 epochs — early stopped)

Epoch Val Total Loss Train Class Acc Val Class Acc Train Arch Acc Contact Recall (Val) Contact BCE (Val) LR Status
1 4.846 47.5% 41.8% 16.1% 69.6% 0.759 3.33e-05 NEW BEST
2 4.626 58.2% 54.9% 25.4% 72.3% 0.725 6.67e-05 NEW BEST
3 4.454 65.9% 53.8% 33.8% 73.9% 0.702 1.00e-04 NEW BEST
4 4.356 70.0% 58.4% 39.2% 73.8% 0.693 1.33e-04 NEW BEST
5 4.354 72.9% 60.7% 42.4% 73.7% 0.691 1.67e-04 NEW BEST
6 4.220 74.4% 65.1% 43.7% 75.5% 0.668 2.00e-04 NEW BEST
7 4.321 75.9% 64.8% 44.5% 75.9% 0.686 2.00e-04 pat 1
8 3.994 77.3% 65.5% 46.4% 77.4% 0.660 1.99e-04 NEW BEST
9 3.998 77.2% 66.0% 47.0% 78.0% 0.665 1.98e-04 pat 1
10 3.988 78.1% 66.3% 47.8% 77.3% 0.659 1.97e-04 BEST (final)
11 4.199 78.9% 66.1% 48.5% 78.0% 0.655 1.96e-04 pat 1
12 4.073 79.4% 68.9% 49.1% 78.7% 0.652 1.94e-04 pat 2
13 4.128 79.5% 66.8% 49.9% 77.4% 0.652 1.92e-04 pat 3
14 4.014 80.1% 67.8% 50.1% 77.2% 0.653 1.89e-04 pat 4
15 4.083 80.4% 65.6% 50.7% 78.1% 0.652 1.87e-04 pat 5
16 4.010 81.3% 68.8% 51.5% 77.1% 0.666 1.84e-04 pat 6
17 4.139 80.9% 66.3% 52.1% 77.8% 0.659 1.80e-04 pat 7
18 4.062 81.8% 68.1% 52.8% 77.5% 0.659 1.77e-04 pat 8
19 4.134 81.7% 66.0% 53.4% 77.8% 0.654 1.73e-04 pat 9
20 4.208 82.2% 69.2% 53.8% 77.5% 0.660 1.69e-04 pat 10
21 4.112 82.4% 68.9% 54.5% 77.5% 0.653 1.64e-04 pat 11
22 4.138 83.3% 68.6% 55.1% 78.0% 0.650 1.60e-04 pat 12
23 4.208 82.8% 67.6% 55.2% 77.3% 0.649 1.55e-04 pat 13
24 4.169 83.1% 68.9% 55.6% 77.0% 0.654 1.50e-04 pat 14
25 4.358 83.8% 65.1% 56.4% 78.0% 0.652 1.45e-04 EARLY STOP

Final Results

  • Best val loss: 3.988 at epoch 10 (best weights saved)
  • Val class accuracy: 66.3% (4-way CATH class), Val architecture accuracy: 31.8% (38+ architectures)
  • Contact recall: 77.3%, Contact BCE: 0.659 — model successfully learned spatial proximity from sequence
  • Train class accuracy: 78.1%, Train arch accuracy: 47.8%
  • Early stopped at epoch 25 (patience 15) — val loss plateaued after epoch 10
  • 1.2M params, trained from scratch on CATH 4.2 (18k proteins)
  • Training survived 7 SLURM job allocations with checkpoint resume

Training Curves

Contact classifier training curves

Contact Map Predictions (3 test proteins)

Ground truth vs predicted contact maps

Each row shows a held-out test protein from a different CATH structural class. The left column is the ground truth contact map (binary: two Cα atoms < 8Å apart), and the right column is the model’s predicted probability of contact from sequence alone. Metrics (precision P, recall R, and Top-L long-range accuracy) are annotated on each prediction panel.

  • 1bf0.A (L=60, Few Secondary Structure): A small protein with sparse, irregular contacts. The model captures the overall topology despite limited structural regularity.
  • 3ggm.A (L=81, Mainly Beta): Beta-sheet proteins produce characteristic off-diagonal block patterns from strand–strand hydrogen bonding. The model recovers these long-range parallel and anti-parallel strand pairings well.
  • 1f9x.A (L=120, Mainly Alpha): Alpha-helical proteins show strong banded diagonal patterns from helix-internal i→i+4 contacts. The model reproduces both the local helical periodicity and inter-helix contacts at larger separations.

These results demonstrate that a 1.2M-parameter transformer encoder trained from scratch on CATH 4.2 (~18k proteins) can learn meaningful spatial proximity signals across all major fold classes — without any pretrained language model or evolutionary information.

Learned Embedding Space (PCA & UMAP)

PCA and UMAP of learned protein embeddings colored by CATH class and architecture

Attention-pooled protein embeddings (128-dim) from the encoder’s val+test set, projected via PCA and UMAP. The encoder learns to separate CATH classes without explicit contrastive loss — mainly-alpha and mainly-beta proteins form distinct clusters, while alpha-beta proteins span the intermediate region. UMAP reveals finer sub-structure at the architecture level, with several CATH architectures forming tight, well-separated clusters (e.g., 3.40 Rossmann fold, 1.10 orthogonal bundle). This confirms the multi-task training objective (classification + contact prediction) produces structurally meaningful representations suitable for conditioning the downstream diffusion model.

Model Evolution Timeline

v8
Foundation: EGNN Denoiser + Contact Conditioning
First working diffusion model. 8-layer EGNN denoiser with contact-conditioned pair stack and outer product mean. Established the core pipeline: frozen ContactClassifier encoder → pair representation → coordinate denoiser. Learned basic protein compactness (dist_mse well below random) but struggled with topology.
Key learning: Contact conditioning works. Pair stack provides meaningful structural signal.
v9
Loss Function Engineering
Introduced chirality loss (signed volumes), bond angle loss, and steric clash penalty. Discovered clash loss was catastrophically miscalibrated in Rg-normalized coordinates (3.0A threshold mapped to 30A in real space, penalizing everything). Disabled clash, stabilized training.
Key learning: Loss calibration in normalized coordinate space is critical. Removed clash, added bond annealing.
v10
Scaling Up EGNN + Hitting the Ceiling
14.6M params, 8-layer EGNN. Fixed-scale coordinates (divide by 10A instead of per-protein Rg). Self-conditioning. After 21 epochs: dist_mse 60% below random, bonds solved, but FAPE stuck at random baseline (~1.31) and TM-score peaked at 0.131. Diagnosed the root cause: EGNN has no concept of local reference frames, so it cannot optimize frame-aligned metrics.
Key learning: Distance-based denoisers hit a hard ceiling. Local reference frames are required for topology.
v11
IPA Architecture (Bug: Rotations Discarded)
Replaced EGNN with Invariant Point Attention (AlphaFold2-style). 6 IPA layers, 4 heads, 4 query points. FAPE improved to 1.818 (first time below random). But discovered a critical bug: x0_pred = t_vec discarded learned rotation matrices R. FAPE rebuilt frames via Gram-Schmidt, so rotations got no direct gradient. After E10, FAPE degraded to 2.052, TM collapsed to 0.046.
Key learning: Learned rotations must participate in the loss. Gram-Schmidt rebuilding is not enough.
v11b
Frame Rotation Loss (The v11b Fix)
Added frame_rotation_loss: direct angular distance (1 - cos theta) between learned R and Gram-Schmidt ground truth. FAPE now uses learned R_pred instead of rebuilt frames. LR halved to 2e-5. Best at E8: FAPE 1.655, frame_rot 0.830. Then gradient competition between dist_mse and frame_rot through the shared 128-dim single representation caused collapse — TM dropped to 0.061 by E14.
Key learning: Frame rotation loss works, but 128-dim bottleneck causes gradient competition. Need more capacity.
v12
4.1x Capacity Scaling
31.1M params. d_ipa_hidden: 256→512, 8 heads, 8 query points, 2-layer FrameUpdate MLP. Resolved the gradient competition by giving the model enough capacity to represent both distance and frame information without interference. Best E10: val_total 2.581, FAPE 1.651, frame_rot 0.833. Then E11-E13 regressed due to gradient starvation.
Key learning: Capacity scaling works. But global grad clipping lets the denoiser (28.5M params) consume 99%+ of gradient budget, starving the pair stack.
v12b
Per-Module Gradient Clipping + Breakthrough
Rolled back to v12 E10 EMA weights. Applied per-module gradient clipping (denoiser=1.0, pair_stack=0.5), pair_stack 3x LR multiplier, gradient tripwire system, atomic checkpoint saves. 31 epochs trained with 6 NEW BESTs (b14, b15, b16, b20, b29, b30). Confirmed the "consolidate → breakthrough" pattern: long patience plateaus followed by sudden improvements.
Best result: b30 — val_total 2.413, FAPE 1.584, frame_rot 0.790, dist_mse 0.312. All-time record.
v13
Deeper Pair Stack + Triangle Updates for Beta-Sheet Learning
The highest-impact architectural change. v12 learns helices but not beta sheets because its pair stack (4 blocks, max_rel_pos=32) cannot propagate information between residues separated by >32 positions — exactly where sheet contacts live. v13 doubles the pair stack to 8 blocks with Evoformer-style triangle multiplicative updates, replaces linear-clipped RPE with 128 log-scaled bins covering 0-512 residues, and adds classifier-free guidance training (10% conditioning dropout). Initialized from v12b EMA weights.
Status: Training in progress (E12, pat 8/15). E4 best — val_total 2.373, FAPE 1.546, frame_rot 0.771. Surpasses v12b b16 frame_rot (0.780) after only 4 epochs. LR constant at 2e-5; decay starts at E20.

Architecture & Loss Function Details

Diffusion v13 — Triangle Pair Stack + CFG (Current)

Why v13: The Beta-Sheet Problem

v12b achieved record structural metrics (FAPE 1.584, frame_rot 0.790) but visual inspection of generated structures revealed a systematic failure: the model learns alpha-helices well but cannot form beta-sheets. The root cause is the pair stack’s limited receptive field — with max_rel_pos=32, residue pairs separated by more than 32 positions in sequence have zero relative position signal. Beta-sheet hydrogen bonds typically connect residues 20–100+ positions apart, making them invisible to v12b’s pair representation.

v13 Overview (~35–38M params)

v13 makes three high-impact changes to solve the beta-sheet problem while preserving v12b’s proven IPA denoiser. The denoiser, aux pair stack, Rg predictor, and distance head are initialized from v12b b30 EMA weights (exact match via strict=False). The pair stack is entirely new architecture and trains from scratch.

1. Deeper Pair Stack with Triangle Multiplicative Updates (8 blocks)

The pair stack doubles from 4 to 8 EnhancedPairBlock blocks. Each block now applies Evoformer-style triangle multiplicative updates (outgoing + incoming) before the existing row/column axial attention + FFN. The triangle updates implement transitivity: “if residue i contacts residue k, and residue j contacts residue k, then i and j are structurally related.” This is exactly how beta-sheet topology is encoded — two strands share contacts through loop residues.

Implementation: each triangle update projects the pair representation to gate/value tensors (tri_mul_dim=64), computes einsum('bikd,bjkd->bijd') (outgoing) or 'bkid,bkjd->bijd' (incoming), then projects back to d_pair=128. The einsum is forced to fp32 to prevent fp16 overflow from the L-dimensional accumulation. Output is clamped to [-1e4, 1e4] and returned as a residual delta to avoid catastrophic cancellation.

2. Log-Scaled Relative Position Encoding (128 bins, max_sep=512)

Replaces the linear-clipped RPE (max_rel_pos=32, 65 bins) with 128 log-spaced bins covering separations from 0 to 512 residues. The encoding is sign-aware (separate embeddings for upstream and downstream). Bin spacing:

  • Bins 0–8: Linear spacing (1-residue resolution) for helix contacts (i+3, i+4)
  • Bins 8–128: Log-spaced for long-range sheet contacts (i+20 to i+512)

Additionally, 32-dimensional sinusoidal continuous RPE features (sin/cos encoding projected to d_pair) provide smooth interpolation between discrete bins.

3. Classifier-Free Guidance Training

10% of training batches (p_uncond=0.1) replace residue tokens with MASK tokens (token ID 1), preserving CLS/EOS/PAD structure so attention masks remain valid. This trains the model for both conditional and unconditional generation. At sampling time, CFG enables guided generation: ε = εuncond + w · (εcond − εuncond).

Critical bug found and fixed: Initially used PAD (token 0) as the null token. Since the attention mask is computed as ids.ne(PAD), this produced an all-False mask, creating degenerate pair representations that caused NaN. Using MASK (token 1) instead preserves valid masks.

v13 LR Schedule: Warmup → Constant → Cosine Decay

Three-phase schedule motivated by v12b’s finding that breakthroughs happen in narrow LR windows:

  1. Warmup (3 epochs): Linear 0.01x → 1x base LR
  2. Constant (20 epochs): Hold at peak LR=2e-5 (pair_stack at 3x = 6e-5)
  3. Cosine decay (~37 epochs): Slow decay to eta_min=1e-6

Per-module gradient clipping carried from v12b: denoiser max_norm=1.0, pair_stack=0.5.

v13 Numerical Stability Fixes

v13 development uncovered three classes of fp16 instability present since v12 (see NaN Debugging Story in the Structure Folding tab). All fixes are applied in both v12 and v13 codebases:

  • fp16 epsilon underflow: clamp(min=1e-6) underflows to 0 in fp16 (min positive ~6e-5). Fixed by forcing fp32 in Gram-Schmidt, slerp, IPA point attention, and all loss functions.
  • Triangle einsum overflow: L-dimensional accumulation exceeds fp16 max (65504). Forced fp32 with torch.amp.autocast("cuda", enabled=False).
  • Unsafe torch.cdist backward: Produces NaN when distance is exactly 0. Replaced with manual (diff.pow(2).sum(-1) + 1e-10).sqrt().

Shared Components (v11–v13)

These components were introduced in earlier versions and carried forward into v13. Understanding them is essential context for the v13 design decisions above.

IPA Denoiser (8 layers in v12/v13, 6 in v11)

The core structure module, adapted from AlphaFold2. Each IPA block performs three operations:

  1. Invariant Point Attention — multi-head attention on the single representation, augmented with pair bias and 3D point attention. Each head generates query/key/value points in \(\mathbb{R}^3\) transformed into each residue’s local frame. Attention weights depend on geometric distances between learned points — invariant to global rotation/translation. Why it matters: EGNN (v8–v10) had no concept of local frames and could not optimize frame-aligned metrics like FAPE. IPA was the architectural pivot that broke through the topology ceiling.
  2. Transition MLP — 2-layer feedforward on the single representation.
  3. Frame update — predicts a small quaternion + translation update per residue, composed in the local frame (right-multiplication for SE(3) equivariance). Initialized near-zero so frames are approximately preserved in early training.

SNR-Gated Frame Initialization

Per-residue rigid frames are built from noised C\(\alpha\) coordinates via Gram-Schmidt orthogonalization on consecutive backbone triplets. At high noise (SNR < 1.0, roughly t > 700), the noised coordinates are near-isotropic and Gram-Schmidt is numerically unstable. The frame confidence mechanism smoothly blends toward identity frames:

$$\text{conf}(t) = \text{clamp}\!\left(\frac{\text{SNR}(t) - 0.2}{1.0 - 0.2},\, 0,\, 1\right), \qquad R_{\text{init}} = \text{slerp}(I,\, R_{\text{GS}},\, \text{conf})$$

Why it matters: Without this, the first IPA layer receives garbage frames at high noise levels, producing cascading errors through all 8 layers. This was a major source of training instability in early v11.

Frame-Aware Self-Conditioning (SC=0.25 in v13)

25% of training steps (50% in v11): run a no-grad forward pass to get x\(_{0}^{\text{prev}}\), build clean frames from it (treated as t=0), and use those as the initial frames for the second pass. At high noise where x\(_t\) frames are identity, self-conditioning provides the model’s best guess at clean local geometry — the IPA layers refine good frames instead of building them from scratch.

Why reduced to 0.25 in v13: Higher SC probability means fewer “cold start” training steps. The model needs enough cold-start experience to generalize at inference when no previous prediction exists.

Fixed-Scale Coordinates (introduced v11)

All coordinates divided by a fixed constant (10\(\text{\AA}\)) instead of per-protein R\(_g\). Why: R\(_g\) normalization made the noise schedule protein-size-dependent — a protein with R\(_g\)=5\(\text{\AA}\) had coordinate values ~1.0 while R\(_g\)=25\(\text{\AA}\) gave ~0.2–0.5, meaning the same noise level destroyed more signal for larger proteins. This silently capped TM-scores and looked like a “plateau” rather than a systematic bias. All successful protein diffusion models (FrameDiff, RFDiffusion, Genie) use fixed-scale coordinates.


Loss Functions

The total loss is a weighted combination of eight components. All losses are shared across v11–v13; only the weights and frame handling differ between versions.

$$\mathcal{L}_{\text{total}} = w_{\text{fape}} \cdot \mathcal{L}_{\text{fape}} + w_{\text{rot}} \cdot \mathcal{L}_{\text{rot}} + w_{\text{dist}} \cdot \mathcal{L}_{\text{dist}} + \beta(e) \cdot \mathcal{L}_{\text{bond}} + w_{\chi} \cdot \mathcal{L}_{\chi} + w_{\theta} \cdot \mathcal{L}_{\theta} + w_{\text{rg}} \cdot \mathcal{L}_{\text{rg}} + w_{\text{aux}} \cdot \mathcal{L}_{\text{aux}}$$

1. FAPE (Frame-Aligned Point Error)  \(w = 1.0\)

The primary metric. Measures local structural consistency by computing point error in each residue’s local coordinate frame:

$$\mathcal{L}_{\text{fape}} = \frac{1}{N_f \cdot L} \sum_{f=1}^{N_f} \sum_{j=1}^{L} \min\!\Big( \| R_f^{\top}(\hat{x}_j - o_f) - R_f^{*\top}(x_j - o_f^*) \|,\; d_{\text{clamp}} \Big)$$

Why it’s the hardest loss: Requires global structural correctness, not just local geometry. Random baseline ~1.31; drops below 1.0 only when the model learns correct fold topology. v11b’s key fix: use learned R\(_{\text{pred}}\) from IPA layers instead of Gram-Schmidt rebuilt frames, giving rotations direct gradient for the first time.

2. Frame Rotation Loss  \(w = 0.5\) (introduced v11b)

Direct angular distance between learned rotation matrices and Gram-Schmidt ground truth: \(\mathcal{L}_{\text{rot}} = 1 - \cos\theta\), where \(\theta\) is the rotation angle between R\(_{\text{pred}}\) and R\(_{\text{GT}}\). Random baseline ~1.0 (~90\(^\circ\)), target <0.5.

Why it was added: v11 had a critical bug where x0_pred = t_vec discarded learned R. FAPE rebuilt frames via Gram-Schmidt, so rotations got no direct gradient. Adding this loss was the single change that enabled topology learning.

3. Distance MSE  \(w = 1.0\)

MSE on all pairwise C\(\alpha\) distances. Random baseline ~0.54. The easiest structural loss — even v8’s EGNN could reduce this 60% below random. But distance alone cannot encode topology (many different folds have similar distance distributions).

$$\mathcal{L}_{\text{dist}} = \frac{1}{|\mathcal{M}|} \sum_{(i,j) \in \mathcal{M}} \left( \| \hat{x}_i - \hat{x}_j \| - \| x_i - x_j \| \right)^2$$

4. Bond Geometry  \(\beta(e) = \min(3.0,\; 1.0 + 2.0 \cdot \min(e/10, 1))\)

MSE on consecutive C\(\alpha\)–C\(\alpha\) distances vs ideal 3.8\(\text{\AA}\). Weight is annealed from 1.0 to 3.0 over 10 epochs. Random baseline ~0.17; below 0.02 = bonds within 0.1\(\text{\AA}\) of ideal.

Why annealed: Starting high prevents the model from learning global structure (it just chains beads at 3.8\(\text{\AA}\) apart). Starting low lets the model explore, then gradually enforces physical backbone geometry.

5. Chirality  \(w = 0.1\)

MSE on normalized signed volumes (scalar triple products) of C\(\alpha\) quartets. Ensures correct backbone handedness — natural proteins are L-amino acids with consistent chirality. Without this loss the model can generate mirror-image structures that score well on all other metrics.

6. Bond Angle  \(w = 0.5\)

MSE on cosines of C\(\alpha\)–C\(\alpha\)–C\(\alpha\) bond angles. Ideal angle ~120\(^\circ\) (\(\cos\theta \approx -0.5\)). Working in cosine space avoids discontinuities at 0\(^\circ\)/360\(^\circ\). Random baseline ~0.70; below 0.1 = correct backbone geometry.

7. Radius of Gyration  \(w = 0.5\)

MSE on log-transformed R\(_g\) predictions. A separate MLP predicts absolute R\(_g\) from sequence embeddings to recover real-space coordinates at inference. Converges below 0.05 by E2 and stays solved.

8. Auxiliary Distance CE  \(w = 0.03\)

Ordinal regression on binned pairwise distances (32 bins, 2–40\(\text{\AA}\)) from an independent lightweight pair stack (64-dim, ~500K params). The pair representation is detached — aux gradients train only the distance head, not the main pair stack.

Lesson learned: In v10 the distance head read from the main pair stack through detach(), causing feature drift divergence (aux_dist_ce: 3.95→38.8 over 3 epochs). v11+ uses a completely independent pair stack to avoid this.

DDIM Evaluation Metrics

Every 5 epochs (v13) or every epoch (v12b), we generate structures via 50-step DDIM sampling using EMA weights and evaluate against ground truth:

  • TM-score: Global fold similarity [0, 1]. >0.17 = same fold, >0.5 = same topology. Random ~0.10.
  • RMSD: Average atomic displacement after superposition. Random ~15–16\(\text{\AA}\). <5\(\text{\AA}\) = high quality.
  • GDT: Fraction of residues within 1–8\(\text{\AA}\) of truth. Random ~3–4%.

Previous Architecture Details

v11/v12 — IPA-Based Frame Denoising (superseded by v13)

Why the pivot from v10: EGNN has no concept of local reference frames — it passes messages based on pairwise distances and updates coordinates through distance-weighted vectors. FAPE measures frame-aligned point error, which EGNN has no inductive bias to optimize. v11 replaced EGNN with IPA (AlphaFold2-style), explicitly maintaining per-residue rigid-body frames (R ∈ SO(3) + t ∈ ℝ³).

v11 → v11b: v11 had a critical bug: x0_pred = t_vec discarded learned R. v11b added frame_rotation_loss and used learned R in FAPE. Best E8: FAPE 1.655, frame_rot 0.830. Collapsed at E14 due to gradient competition in the shared 128-dim single representation.

v12: 4.1x capacity scaling (8.4M → 31.1M params). d_ipa_hidden: 256→512, 8 heads, 8 query points, 2-layer FrameUpdate MLP. Resolved gradient competition. Best E10: val_total 2.581. Then regressed E11–E13 due to gradient starvation of the pair stack.

v12b: Per-module gradient clipping (denoiser=1.0, pair_stack=0.5), pair_stack 3x LR. 31 epochs, 6 NEW BESTs. All-time record b30: val_total 2.413, FAPE 1.584, frame_rot 0.790.

Loss v10 v11 v11b/v12b Rationale
FAPE0.31.01.0IPA can optimize frame consistency
Frame Rot0.5v11b fix: direct angular loss on learned R
Bond5.03.03.0Gentler anneal avoids tug-of-war
Clash0.00.10.1Fixed-scale coords make threshold meaningful
Dist MSE1.01.01.0

Training: LR 5e-5 (v11) → 2e-5 (v11b+), cosine decay, 1000 diffusion timesteps, DDIM-50 eval.

v10 — EGNN Denoiser + Full Loss Function Derivations (superseded)

8-layer SE(3)-equivariant graph neural network (EGNN), 14.6M params. Operated in R\(_g\)-normalized coordinate space. After 21 epochs: dist_mse 60% below random, bonds solved, but FAPE stuck at random baseline (~1.31) and TM-score peaked at 0.131. The model learned “proteins are compact blobs of the right size” but could not learn topology.

Loss Function Derivations

Distance MSE

$$\mathcal{L}_{\text{dist}} = \frac{1}{|\mathcal{M}|} \sum_{(i,j) \in \mathcal{M}} \left( \| \hat{x}_i^{(0)} - \hat{x}_j^{(0)} \| - \| x_i^{(0)} - x_j^{(0)} \| \right)^2$$

Bond Geometry (annealed \(\beta(e) = \min(5.0, 1.0 + 4.0 \cdot \min(e/15, 1))\))

$$\mathcal{L}_{\text{bond}} = \frac{1}{L-1} \sum_{i=1}^{L-1} \left( \| \hat{x}_i^{(0)} - \hat{x}_{i+1}^{(0)} \| - \frac{3.8}{R_g} \right)^2$$

FAPE

$$\mathcal{L}_{\text{fape}} = \frac{1}{N_f \cdot L} \sum_{f=1}^{N_f} \sum_{j=1}^{L} \min\!\Big( \| R_f^{\top}(\hat{x}_j - o_f) - R_f^{*\top}(x_j - o_f^*) \|,\; d_{\text{clamp}} \Big)$$

Chirality (signed volumes)

$$\mathcal{L}_{\chi} = \frac{1}{L-3} \sum_{i=1}^{L-3} \left( \frac{\mathbf{v}_1 \cdot (\mathbf{v}_2 \times \mathbf{v}_3)}{\|\mathbf{v}_1\| \|\mathbf{v}_2\| \|\mathbf{v}_3\|} \bigg|_{\hat{x}} - \frac{\mathbf{v}_1 \cdot (\mathbf{v}_2 \times \mathbf{v}_3)}{\|\mathbf{v}_1\| \|\mathbf{v}_2\| \|\mathbf{v}_3\|} \bigg|_{x} \right)^2$$

Bond Angle

$$\mathcal{L}_{\theta} = \frac{1}{L-2} \sum_{i=1}^{L-2} \left( \cos\hat{\theta}_i - \cos\theta_i \right)^2$$

Radius of Gyration

$$\mathcal{L}_{\text{rg}} = \left( \log \hat{R}_g - \log R_g \right)^2$$

Auxiliary Distance CE (disabled in v10, w=0.03 in v11+)

$$\mathcal{L}_{\text{aux}} = -\frac{1}{|\mathcal{M}'|} \sum_{(i,j) \in \mathcal{M}'} \log p_{ij}\big[\text{bin}(d_{ij})\big]$$

96 bins in v10 (2–40\(\text{\AA}\), \(\Delta\)=0.396\(\text{\AA}\)/bin). Replaced with 32-bin ordinal regression in v11+ after discovering detach-induced feature drift divergence.

Clash Loss (disabled, w=0)

$$\mathcal{L}_{\text{clash}} = \frac{1}{|\mathcal{N}|} \sum_{(i,j) \in \mathcal{N}} \left[ (1 + 2 c_{ij}) \cdot \text{ReLU}(3.0 - d_{ij}) \right]^2$$

Why disabled: 3.0\(\text{\AA}\) threshold in R\(_g\)-normalized coords maps to ~30\(\text{\AA}\) in real space, penalizing nearly all non-bonded pairs. Dominated ~47% of total loss in v9, drowning structural learning signal. Structural losses handle steric quality implicitly.

v10 Training Configuration
OptimizerAdamW (\(\beta_1=0.9, \beta_2=0.999\))
Peak LR\(10^{-4}\) with CosineAnnealingWarmRestarts (\(T_0=15\))
Batch size8 (grad accum=2, effective=16)
Mixed precisionAMP with GradScaler
EMADecay=0.999
Self-conditioning50% probability
HardwareSingle NVIDIA A40 (48 GB)