A hierarchical JEPA framework supporting masked-patch SSL, end-to-end SIGReg-only training (LeWM-style), and action-conditioned world modeling
H-JEPA is a hierarchical Joint-Embedding Predictive Architecture framework with three opt-in training regimes that share the same encoder and SIGReg primitives:
- Masked-patch SSL (default) — extends Meta's I-JEPA with a Feature Pyramid Network that fuses representations across multiple hierarchy levels for both fine-grained and coarse semantic features.
- End-to-end SIGReg-only training — drops the EMA target encoder and relies on LeJEPA / LeWorldModel-style SIGReg regularization to prevent collapse, halving trainable parameters and reducing tunable loss hyperparameters from six to one.
- Action-conditioned world modeling — adds an AdaLN-Zero predictor and CEM latent-space planner so the same encoder backbone can be trained on action-annotated trajectories for control.
All three regimes are selectable via config — defaults preserve the original H-JEPA behavior. The architecture uses Rotary Position Embeddings (RoPE) for spatial awareness, multi-crop augmentation for scale invariance, and supports CUDA, Apple Silicon (MPS), and CPU backends.
The implementation is validated end-to-end with 1400+ tests. No pretrained weights are published yet — this is a research codebase for training from scratch.
- Multi-scale hierarchy — learn representations at multiple levels with configurable depth
- Feature Pyramid Network — fuse features across hierarchy levels
- Rotary Position Embeddings — 2D spatial encoding without learned position parameters
- Flash Attention — fused attention kernels on CUDA
- SIGReg loss — sketched isotropic Gaussian regularization with two Epps–Pulley variants (reference points and characteristic-function quadrature)
- End-to-end mode — opt-out of the EMA target encoder per LeWorldModel
- BatchNorm projection heads — preserves the variance signal SIGReg shapes
- AdaLN-Zero action predictor — zero-initialized action conditioning for stable world-model training
- CEM latent-space planner — MPC + Cross-Entropy Method for downstream control
- Multi-crop augmentation — multiple views at different scales per image
- VICReg + prediction loss — combined objective for the classic SSL regime
- Apple Silicon support — runs on MPS with automatic fallbacks
- Mixed precision & gradient checkpointing — efficient training on limited hardware
| Config | EMA target encoder | Loss | Use case |
|---|---|---|---|
configs/default.yaml |
Yes | VICReg + prediction (current default) | Hierarchical masked-patch SSL |
configs/lewm_tier1.yaml |
Yes | SIGReg + prediction | SSL with vectorized SIGReg, char-function test, BatchNorm projector |
configs/lewm_tier2.yaml |
No | SIGReg-only, no detach | SSL without EMA — LeWM-style end-to-end training of the hierarchical stack |
configs/lewm_world_model.yaml |
No | SIGReg + next-latent prediction | Action-conditioned world model with AdaLN-Zero predictor + CEM planner |
- Python 3.11 or higher
- PyTorch 2.0+
- CUDA 11.7+ (optional, for GPU) or Apple Silicon Mac
# Clone the repository
git clone https://github.com/jonwiggins/H-JEPA.git
cd H-JEPA
# Create virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install package
pip install -e .python -c "import torch; print('PyTorch:', torch.__version__); print('Device:', 'CUDA' if torch.cuda.is_available() else 'MPS' if torch.backends.mps.is_available() else 'CPU')"Note (MPS): Flash Attention and mixed precision (AMP) are unavailable on Apple Silicon — the code falls back to standard attention and full precision automatically. SVD operations also fall back to CPU due to PyTorch MPS limitations.
# Train on CIFAR-10 (auto-downloads)
python scripts/train.py --config configs/default.yaml
# Train on ImageNet-100
python scripts/train.py --config configs/imagenet100.yaml
# Apple Silicon optimized
python scripts/train.py --config configs/mps_optimized.yaml
# Debug/test configuration (minimal)
python scripts/train.py --config configs/debug_minimal.yaml# Linear probe evaluation
python scripts/eval_linear_probe.py --checkpoint path/to/checkpoint.pth
# k-NN evaluation
python scripts/eval_knn.py --checkpoint path/to/checkpoint.pthSee docs/TRAINING.md for the full training guide and docs/EVALUATION.md for evaluation details.
H-JEPA/
├── src/
│ ├── models/ # Model architectures (encoder, predictor, H-JEPA)
│ ├── losses/ # Loss functions (VICReg, SigReg, combined)
│ ├── masks/ # Masking strategies
│ ├── data/ # Datasets and transforms
│ ├── trainers/ # Training loops
│ ├── evaluation/ # Evaluation protocols
│ ├── visualization/ # Attention and feature visualization
│ ├── serving/ # Model serving utilities
│ ├── inference/ # Inference pipelines
│ └── utils/ # Utilities (logging, checkpointing)
├── configs/ # YAML configuration files
├── scripts/ # Training and evaluation scripts
├── tests/ # Unit tests
└── docs/ # Documentation
Training is configured via YAML files in configs/. Key parameters:
model:
encoder_type: "vit_base_patch16_224"
num_hierarchies: 3 # Number of hierarchy levels
use_fpn: true # Feature Pyramid Network
use_rope: true # Rotary Position Embeddings
training:
epochs: 100
batch_size: 256
learning_rate: 1.5e-4
use_amp: true # Mixed precision (CUDA only)
loss:
type: "combined" # or "vicreg", "sigreg", "mse"
hierarchy_weights: [1.0, 0.7, 0.5]The test suite contains 1400+ tests across 30+ test modules, using pytest.
# Quick subset (skip slow tests)
pytest tests/ -m "not slow" -v
# Full suite
pytest tests/ -v
# With coverage report
pytest tests/ --cov=src --cov-report=term-missingTests cover all core modules — models, losses, masks, data, trainers, evaluation, and utilities — with mocked hardware backends so the full suite runs on CPU, CUDA, or MPS.
CI pipeline (GitHub Actions, every push/PR):
- black — code formatting
- ruff — linting
- mypy — static type checking
- pytest — full test suite with coverage
Pre-commit hooks (install with pre-commit install): black, ruff, mypy.
See docs/testing.md for the full testing guide.
| Component | Details |
|---|---|
| Encoder | ViT-Tiny (5.5M params context + 5.5M target EMA) |
| Predictor | 4-layer transformer (2.8M params) |
| Total | ~13.8M parameters (8.3M trainable) |
| Hierarchies | 3 levels with FPN fusion (128-ch) |
| Embed dim | 192 |
See docs/ARCHITECTURE.md for a detailed architecture description.
# Build Docker image
docker build -t hjepa .
# Run training in container
docker run --gpus all -v $(pwd)/data:/app/data hjepa python scripts/train.py
# Run with Docker Compose
docker-compose upContributions are welcome — see CONTRIBUTING.md for guidelines.
@software{hjepa2025,
title={H-JEPA: Hierarchical Joint-Embedding Predictive Architecture},
author={Wiggins, Jon and Contributors},
year={2025},
url={https://github.com/jonwiggins/H-JEPA}
}MIT License — see LICENSE for details.