Mosaic: Sharding Attention Across GPUs When Your Sequence Doesn't Fit
- Authors

- Name
- Nino
- Occupation
- Senior Tech Editor
The evolution of Large Language Models (LLMs) is increasingly defined by context window size. However, as developers at n1n.ai often observe, the hardware bottleneck remains the primary obstacle to true long-context reasoning. You have likely heard that transformers suffer from a 'quadratic attention bottleneck.' While this sounds theoretical, it manifests as a hard physical limit when your sequence length exceeds a few thousand tokens. In this guide, we explore Mosaic, a lightweight library designed to implement sharding attention across multiple GPUs, allowing for sequences as large as 150,000 tokens to be processed efficiently.
The Quadratic Reality: Why 84GB Isn't Enough
To understand why we need sharding attention, we must look at the math. The standard attention mechanism computes:
Attention(Q, K, V) = softmax(QKᵀ / √d) × V
The bottleneck is the QKᵀ operation, which produces a matrix of shape (sequence_length × sequence_length). For a sequence of 150,000 tokens, the calculation is as follows:
- Matrix Elements: 150,000 * 150,000 = 22.5 Billion elements.
- Memory (FP32): 22.5B * 4 bytes = 90 GB.
- Memory (BF16/FP16): 22.5B * 2 bytes = 45 GB.
Keep in mind, this is for a single attention head in a single layer. A modern A100 GPU has 80GB of total VRAM. When you factor in model weights, optimizer states, and activations for other layers, a 150k sequence is physically impossible to fit on one card. Even FlashAttention, which reduces memory to O(n) by tiling, still requires the entire sequence to reside on a single device's memory pool during the computation of that tile.
The Limitations of Existing Solutions
Before Mosaic, two primary methods dominated the landscape:
- FlashAttention: Excellent for local GPU optimization but lacks native multi-node distribution.
- Ring Attention: Distributed via the
ring-flash-attnlibrary, it shards the sequence across GPUs in a 1D ring. While effective for linear sequences, it struggles with complex multi-axis attention patterns found in tabular transformers or multi-modal models.
At n1n.ai, we found that developers working on tabular data (like nanoTabPFN) needed to attend over both features (small axis) and rows (massive axis). Standard libraries didn't allow for this hybrid approach without significant custom engineering. This is where Mosaic provides a critical bridge by automating sharding attention across specific dimensions.
How Mosaic Works: Multi-Axis Coordination
Mosaic acts as a thin coordination layer. It doesn't replace FlashAttention; it orchestrates it. It routes different attention axes to the most appropriate backend based on their size and the available hardware topology.
import mosaic
# A small axis (e.g., features) runs locally on one GPU
feature_attn = mosaic.MultiAxisAttention(
embed_dim=96, num_heads=4,
attention_axis=2, # features dimension
backend="local" # No cross-GPU communication needed
)
# A large axis (e.g., 150k rows) is sharded across the cluster
row_attn = mosaic.MultiAxisAttention(
embed_dim=96, num_heads=4,
attention_axis=1, # rows dimension
backend="ring" # Distributed sharding attention
)
By decoupling the attention logic from the tensor reshaping, Mosaic handles the permutation of the attention axis, projection of QKV, and the eventual restoration of the tensor shape automatically.
Deep Dive: The Ring Sharding Algorithm
The core of sharding attention in Mosaic is the Ring Attention mechanism. Imagine you have 4 GPUs and a 150k sequence split into 4 chunks of 37.5k tokens each.
- Initial State: Each GPU holds its local Query (Q), Key (K), and Value (V) chunks.
- Local Computation: Each GPU computes attention for its local Q and local K/V.
- The Ring Pass: GPU 0 sends its K/V to GPU 1, while receiving K/V from GPU 3.
- Iterative Update: In each step, the GPU computes attention between its fixed Q and the rotating K/V chunks. It accumulates the partial softmax sums and max values (using the Online Softmax algorithm).
- Normalization: After a full rotation, each GPU has the complete attention output for its specific chunk of the sequence.
This reduces the memory requirement per GPU to O(n²/p), where p is the number of GPUs. On an 8-GPU node, that 84GB requirement drops to roughly 10.5GB, making the 150k sequence easily manageable.
Scaling to the Extreme: Mesh2D and Composed Attention
For sequences exceeding 1 million tokens, even Ring Attention hits limits because the communication overhead of the ring becomes a bottleneck. Mosaic introduces Mesh2D Sharding, which shards both the Query and Key dimensions in a grid.
| Sequence Length | GPUs | Backend | Memory Complexity |
|---|---|---|---|
| < 10k | 1 | Local (Flash) | O(n²) |
| 10k–100k | 2–8 | Ring | O(n²/p) |
| 100k–1M | 8–64 | Mesh2D | O(n²/p²) |
| > 1M | 64+ | Composed | O(n²/(p²·h)) |
Mosaic also features Topology Awareness. In a real-world cluster, GPUs within a node communicate via NVLink (fast), while GPUs across nodes use InfiniBand (slower). Mosaic's ComposedAttention allows you to run Ring Attention within a node and Head Parallelism across nodes to minimize the impact of slower interconnects.
Implementation Pro-Tips for n1n.ai Users
When implementing sharding attention for your production models, keep these three optimizations in mind:
- Use Pre-allocated Collectives: Avoid using
torch.catinside your forward pass. Mosaic uses pre-sized buffers forall_gatheroperations to prevent memory fragmentation. - Fused Kernels: Always ensure
flash-attnis installed. Mosaic is designed to dispatch toF.scaled_dot_product_attentionwhenever possible to leverage hardware-level fusions. - Axis Permutation: If your data is not in
(batch, seq, feature)format, use Mosaic's internal_permute_to_sequtility. It usesview()instead ofreshape()to maintain memory contiguity and avoid expensive copies.
Setting Up Your Environment
To begin using Mosaic for sharding attention, install the core library and its high-performance dependencies:
pip install git+https://github.com/stprnvsh/mosaic.git
pip install flash-attn ring-flash-attn
For multi-node training, initialize the process group and let Mosaic handle the device context:
import mosaic
import torch.distributed as dist
dist.init_process_group("nccl")
# Initialize Mosaic with the sequence parallel size
ctx = mosaic.init(sp_size=dist.get_world_size())
# Your model now automatically handles sharding
model = MyModel().to(ctx.device)
Conclusion
As LLMs move toward "infinite" context, the ability to shard the attention mechanism is no longer optional—it is a requirement. Mosaic provides a clean, 800-line Python abstraction that solves the multi-axis problem while maintaining the performance of FlashAttention. Whether you are training tabular transformers or massive document readers, sharding attention is the key to unlocking the next generation of AI capabilities.
At n1n.ai, we provide the infrastructure and API access needed to run these high-demand workloads at scale. By leveraging libraries like Mosaic and the high-speed access provided by n1n.ai, developers can focus on model architecture rather than memory management.
Get a free API key at n1n.ai.