Estimate per-GPU memory consumption during LLM training with 5D parallelism (TP, PP, DP, CP, EP). Supports MoE, distributed optimizer, and NCCL buffer estimation. Fetch model config from HuggingFace.

Memory Breakdown (per GPU)
ComponentMemoryFormula
Parameters
Gradients
Optimizer States
Activations
NCCL Buffers (est.)
Total
Model Parameters
GPU Fit Check
Roofline Analysis (Training, per layer)

Arithmetic Intensity (AI) = FLOPs / Bytes. If AI < ops:byte ratio of the GPU, the operation is Memory Bound; otherwise Compute Bound.

OperationFLOPsBytesAI (F/B)Bound
Note
This estimator targets Transformer-based LLMs (decoder-only, with SwiGLU FFN and RMSNorm). It does not support SSM-based models (Mamba, RWKV, etc.), Gated Linear Networks, or Diffusion models. These are theoretical estimates. Actual memory consumption during training includes additional overheads not modeled here:
  • NCCL communication buffers: Each parallelism dimension creates its own communicator. Total NCCL overhead can reach 1–20 GB depending on the number of communicators and NCCL_BUFFSIZE. This estimator includes a rough estimate (~0.5 GB per communicator).
  • CUDA context: ~300–500 MB per GPU for the CUDA runtime and driver.
  • Memory fragmentation: PyTorch's caching allocator may hold more memory than actually used.
  • Temporary buffers: All-gather buffers during distributed optimizer gather, pipeline send/recv buffers, etc.
  • MoE routing imbalance: This estimator assumes uniform token distribution across experts. In practice, routing can be imbalanced—tokens may concentrate on specific experts, causing higher activation memory on those Expert Parallel ranks and potentially triggering OOM even when the average estimate fits in memory. Load balancing loss mitigates but does not eliminate this risk.
Citation
@misc{fujii2024acceleratinglargelanguagemodel, title={Accelerating Large Language Model Training with 4D Parallelism and Memory Consumption Estimator}, author={Kazuki Fujii and Kohei Watanabe and Rio Yokota}, year={2024}, eprint={2411.06465}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2411.06465}, }