Mosaic: Sharding Attention Across GPUs When Your Sequence Doesn't Fit

Authors
  • avatar
    Name
    Nino
    Occupation
    Senior Tech Editor

The evolution of Large Language Models (LLMs) is increasingly defined by context window size. However, as developers at n1n.ai often observe, the hardware bottleneck remains the primary obstacle to true long-context reasoning. You have likely heard that transformers suffer from a 'quadratic attention bottleneck.' While this sounds theoretical, it manifests as a hard physical limit when your sequence length exceeds a few thousand tokens. In this guide, we explore Mosaic, a lightweight library designed to implement sharding attention across multiple GPUs, allowing for sequences as large as 150,000 tokens to be processed efficiently.

The Quadratic Reality: Why 84GB Isn't Enough

To understand why we need sharding attention, we must look at the math. The standard attention mechanism computes:

Attention(Q, K, V) = softmax(QKᵀ / √d) × V

The bottleneck is the QKᵀ operation, which produces a matrix of shape (sequence_length × sequence_length). For a sequence of 150,000 tokens, the calculation is as follows:

  • Matrix Elements: 150,000 * 150,000 = 22.5 Billion elements.
  • Memory (FP32): 22.5B * 4 bytes = 90 GB.
  • Memory (BF16/FP16): 22.5B * 2 bytes = 45 GB.

Keep in mind, this is for a single attention head in a single layer. A modern A100 GPU has 80GB of total VRAM. When you factor in model weights, optimizer states, and activations for other layers, a 150k sequence is physically impossible to fit on one card. Even FlashAttention, which reduces memory to O(n) by tiling, still requires the entire sequence to reside on a single device's memory pool during the computation of that tile.

The Limitations of Existing Solutions

Before Mosaic, two primary methods dominated the landscape:

  1. FlashAttention: Excellent for local GPU optimization but lacks native multi-node distribution.
  2. Ring Attention: Distributed via the ring-flash-attn library, it shards the sequence across GPUs in a 1D ring. While effective for linear sequences, it struggles with complex multi-axis attention patterns found in tabular transformers or multi-modal models.

At n1n.ai, we found that developers working on tabular data (like nanoTabPFN) needed to attend over both features (small axis) and rows (massive axis). Standard libraries didn't allow for this hybrid approach without significant custom engineering. This is where Mosaic provides a critical bridge by automating sharding attention across specific dimensions.

How Mosaic Works: Multi-Axis Coordination

Mosaic acts as a thin coordination layer. It doesn't replace FlashAttention; it orchestrates it. It routes different attention axes to the most appropriate backend based on their size and the available hardware topology.

import mosaic

# A small axis (e.g., features) runs locally on one GPU
feature_attn = mosaic.MultiAxisAttention(
    embed_dim=96, num_heads=4,
    attention_axis=2,    # features dimension
    backend="local"      # No cross-GPU communication needed
)

# A large axis (e.g., 150k rows) is sharded across the cluster
row_attn = mosaic.MultiAxisAttention(
    embed_dim=96, num_heads=4,
    attention_axis=1,    # rows dimension
    backend="ring"       # Distributed sharding attention
)

By decoupling the attention logic from the tensor reshaping, Mosaic handles the permutation of the attention axis, projection of QKV, and the eventual restoration of the tensor shape automatically.

Deep Dive: The Ring Sharding Algorithm

The core of sharding attention in Mosaic is the Ring Attention mechanism. Imagine you have 4 GPUs and a 150k sequence split into 4 chunks of 37.5k tokens each.

  1. Initial State: Each GPU holds its local Query (Q), Key (K), and Value (V) chunks.
  2. Local Computation: Each GPU computes attention for its local Q and local K/V.
  3. The Ring Pass: GPU 0 sends its K/V to GPU 1, while receiving K/V from GPU 3.
  4. Iterative Update: In each step, the GPU computes attention between its fixed Q and the rotating K/V chunks. It accumulates the partial softmax sums and max values (using the Online Softmax algorithm).
  5. Normalization: After a full rotation, each GPU has the complete attention output for its specific chunk of the sequence.

This reduces the memory requirement per GPU to O(n²/p), where p is the number of GPUs. On an 8-GPU node, that 84GB requirement drops to roughly 10.5GB, making the 150k sequence easily manageable.

Scaling to the Extreme: Mesh2D and Composed Attention

For sequences exceeding 1 million tokens, even Ring Attention hits limits because the communication overhead of the ring becomes a bottleneck. Mosaic introduces Mesh2D Sharding, which shards both the Query and Key dimensions in a grid.

Sequence LengthGPUsBackendMemory Complexity
< 10k1Local (Flash)O(n²)
10k–100k2–8RingO(n²/p)
100k–1M8–64Mesh2DO(n²/p²)
> 1M64+ComposedO(n²/(p²·h))

Mosaic also features Topology Awareness. In a real-world cluster, GPUs within a node communicate via NVLink (fast), while GPUs across nodes use InfiniBand (slower). Mosaic's ComposedAttention allows you to run Ring Attention within a node and Head Parallelism across nodes to minimize the impact of slower interconnects.

Implementation Pro-Tips for n1n.ai Users

When implementing sharding attention for your production models, keep these three optimizations in mind:

  1. Use Pre-allocated Collectives: Avoid using torch.cat inside your forward pass. Mosaic uses pre-sized buffers for all_gather operations to prevent memory fragmentation.
  2. Fused Kernels: Always ensure flash-attn is installed. Mosaic is designed to dispatch to F.scaled_dot_product_attention whenever possible to leverage hardware-level fusions.
  3. Axis Permutation: If your data is not in (batch, seq, feature) format, use Mosaic's internal _permute_to_seq utility. It uses view() instead of reshape() to maintain memory contiguity and avoid expensive copies.

Setting Up Your Environment

To begin using Mosaic for sharding attention, install the core library and its high-performance dependencies:

pip install git+https://github.com/stprnvsh/mosaic.git
pip install flash-attn ring-flash-attn

For multi-node training, initialize the process group and let Mosaic handle the device context:

import mosaic
import torch.distributed as dist

dist.init_process_group("nccl")
# Initialize Mosaic with the sequence parallel size
ctx = mosaic.init(sp_size=dist.get_world_size())

# Your model now automatically handles sharding
model = MyModel().to(ctx.device)

Conclusion

As LLMs move toward "infinite" context, the ability to shard the attention mechanism is no longer optional—it is a requirement. Mosaic provides a clean, 800-line Python abstraction that solves the multi-axis problem while maintaining the performance of FlashAttention. Whether you are training tabular transformers or massive document readers, sharding attention is the key to unlocking the next generation of AI capabilities.

At n1n.ai, we provide the infrastructure and API access needed to run these high-demand workloads at scale. By leveraging libraries like Mosaic and the high-speed access provided by n1n.ai, developers can focus on model architecture rather than memory management.

Get a free API key at n1n.ai.