v13b — CA-Only Diffusion Model

Click any block to expand. Every block has a diagram showing what it does.

Input Cα Coordinates
CA CA CA CA CA 3.8Å 3.8Å 3.8Å 3.8Å res 1 res 2 res 3 ... res L one CA per residue: (B, L, 3)
v13b uses only Cα atoms — one 3D coordinate per residue. The model learns to generate the simplified backbone trace, not individual N, C, O atoms. Input includes ids (amino acid tokens), true_coords (B, L, 3), and coord_mask (resolved residues).
1 Frozen Protein Encoder
amino acid sequence M E T H I K ... Contact Classifier 🔒 frozen single (B, L, 128) contacts (B, L, L)
Pre-trained encoder converts amino acid sequence into per-residue features (128-dim) and a contact probability map (which residues are physically close). All weights frozen — no gradients flow back.
2 Pair Stack — Triangle Updates ×8
Triangle Multiplicative Update How edge (i,j) learns about the structure i j k edge(i,k) edge(k,j) edge(i,j) gets updated The idea: • To understand how residue i   relates to residue j... • Look at ALL intermediate   residues k that connect them • Aggregate: edge(i,j) +=   Σk edge(i,k) × edge(k,j) • Repeated 8× to propagate   long-range structural info
The pair stack builds a pairwise relationship matrix (B, L, L, 128). Each entry describes how two residues relate structurally. The triangle update is inspired by AlphaFold2: if residue i is near k, and k is near j, then i is likely near j. Eight rounds of this propagate information across the whole chain. Also includes contact map conditioning (gated injection from encoder), OuterProductMean (d=32), and LogScaledRPE (128 bins, max_sep=512).
3 Diffusion — Add Noise, Learn to Denoise
clean structure x₀ + noise (ε) xₜ = √α̅ₜ⋅x₀ + √(1-α̅ₜ)⋅ε noisy structure xₜ model predicts denoise xₜ → x̂₀ predicted structure x̂₀
Training: Take a known protein structure, add random Gaussian noise at a random timestep t (out of 1000), then train the model to predict the original clean structure from the noisy version. Coordinates are normalized: x0 = (coords − centroid) / 10.0. Generation: Start from pure noise, iteratively denoise using DDIM (50 steps) to produce new protein structures.
4 Frame Initialization — CA Triplets
Gram-Schmidt on CA Triplets CAi-1 CAi origin CAi+1 x (CAi→CAi+1) z (plane normal) y (z × x) Gram-Schmidt Process 1. v1 = CAi+1 − CAi → normalize → x-axis 2. v2 = CAi-1 − CAi 3. z = v1 × v2 (cross product) 4. y = z × x (completes frame) ⚠ Near-collinear triplets at high    noise → SNR-gated confidence
Each residue gets a local coordinate frame (3 axes + origin at CAi) built from three consecutive Cα atoms via Gram-Schmidt orthogonalization. SNR-gated confidence (snr_high=1.0, snr_low=0.2) SLERPs toward identity at high noise to avoid unstable frames. Self-conditioning (50% of the time) feeds a previous no-grad prediction as extra input.
5 IPA Denoiser — 8 Layers
One IPA Layer (Invariant Point Attention) node features h (B, L, 512) Scalar Attention Standard Q, K, V 8 heads, 512-dim + pair bias on logits Point Attention query points 3D queries in local frames concat → 1216 → Linear → 512 FFN: 512 → 1024 → 512 Frame Update: MLP → ΔR, Δt → SE(3) compose ×8
The core denoiser. Invariant Point Attention (IPA) combines standard sequence attention with 3D geometric attention — queries and keys are actual 3D points positioned in each residue's local coordinate frame. This makes the attention SE(3)-equivariant: the output doesn't change if you rotate the whole structure. After each layer, the frame update refines each residue's rotation and translation. 8 layers of refinement progressively sharpen the structure prediction. Output: x0,pred (B, L, 3) + Rpred (B, L, 3, 3).
Auxiliary Heads
Aux A Distance Head
32-bin ordinal 2–40 Å range AuxPairStack (d=64) + 2× AxialBlock
Aux B Rg Predictor
Rg Radius of gyration Attn pooling + MLP softplus + 3.0 bias output: (B,) Å
6 Loss Functions — 9 Terms
FAPE
w = 1.0
Frame-aligned point error — how well do predicted points match true points in each residue's local frame? Clamped max 10.0
Frame Rotation
w = 0.5
Angular distance between predicted and true frames: 1 − cos(θ). Clamped max 2.0
d
Distance MSE
w = 1.0
MSE on all pairwise Cα distances. Clamped max 10.0
3.8Å
Bond Geometry
w = 3.0 (annealed 1→3)
Consecutive Cα distance vs 3.8Å target. Annealed E1→E10. Clamped max 10.0
L/D
Chirality
w = 0.1
Signed volume of Cα quartets — preserves correct handedness
θ
Angle
w = 0.5
MSE on Cα–Cα–Cα bond angles vs ground truth
overlap!
Clash
w = 0.1
ReLU²(3.8Å − d) for non-bonded pairs. Clamped max 25.0
Aux Distance
w = 0.03
Ordinal BCE on 32-bin distance predictions from aux head
Rg
Rg Loss
w = 0.5
MSE on log(radius of gyration) — ensures correct overall compactness
Tracked Metric (Early Stopping)
structural = total − aux_dist × 0.03
All losses except auxiliary distance prediction. Patience: 15 epochs.
Config Training Configuration
Optimizer
AdamW
Denoiser LR
2e-5
Pair Stack LR
6e-5 (3×)
Aux Head LR
4e-5
EMA Decay
0.999
CFG Uncond
10%
LR Schedule
  • Warmup: 3 epochs (linear ramp)
  • Constant: epochs 3–20
  • Cosine decay: epochs 20+
Regularization
  • Per-module gradient clipping
  • CFG training: 10% of batches mask tokens for unconditional generation
  • EMA model tracked for evaluation