Multi-Query Attention and Memory-Efficient Decoding for LLMs
- Authors

- Name
- Nino
- Occupation
- Senior Tech Editor
In our previous exploration of LLM caching, we established how KV caching transforms the autoregressive decoding process by eliminating redundant attention computations. By storing the Keys (K) and Values (V) from previous tokens, Transformers reduce per-token computation from quadratic to linear relative to the sequence length. However, as we scale these models to handle longer contexts and higher throughput, KV caching introduces a significant new bottleneck: memory capacity and bandwidth.
As models scale, the memory required for the KV cache often exceeds the memory required for the model weights themselves, particularly when dealing with long contexts. This post examines Multi-Query Attention (MQA)—an architectural modification that directly attacks this memory bottleneck by changing how attention heads share representations. For developers using high-performance APIs like n1n.ai, understanding these underlying architectures is crucial for optimizing prompt engineering and context management.
The Memory Wall: Why MHA Struggles
In standard Multi-Head Attention (MHA), each attention head has its own independent set of query, key, and value projections. While this provides the model with high expressiveness, it creates a linear scaling problem for memory. For a model with transformer layers, attention heads, a sequence length of , and a head dimension of , the KV cache memory scales as:
O(L * H * T * d_h)
KV caching removes redundant computation, but it does nothing to reduce the memory growth relative to the number of heads. For modern LLMs with 32–128 heads and context windows reaching 128k tokens or more, the KV cache memory and bandwidth quickly become the primary limiting factors in inference throughput. This is why platforms like n1n.ai prioritize models that utilize memory-efficient attention mechanisms to ensure stable, high-speed responses.
Defining Multi-Query Attention (MQA)
Multi-Query Attention (MQA) answers the scaling problem by imposing a deliberate constraint: all attention heads share a single set of keys and values, while maintaining independent queries.
Formally, let be the input. In MQA, the projections are defined as:
Each head computes attention as:
Attention_i = softmax((Q_i * K^T) / sqrt(d_h)) * V
In this setup, the keys and values are shared across all heads. It is important to note that is not equal to ; they remain distinct projections, but they are no longer duplicated for every head. This collapses the KV cache size by a factor of .
The Quantitative Impact: A 64x Reduction
To visualize the impact, let's look at a hypothetical model with the following parameters:
- Layers (): 80
- Attention heads (): 64
- Head dimension (): 128
- Context length (): 2048
- Precision: FP16 (2 bytes per element)
| Attention Type | KV Cache Formula | KV Cache per Sequence |
|---|---|---|
| Multi-Head Attention (MHA) | bytes | ~1.2 GB |
| Multi-Query Attention (MQA) | bytes | ~19 MB |
| Reduction | — | ~64x smaller |
This reduction is transformative for serving infrastructure. By reducing the memory footprint, we can fit significantly more concurrent requests (batch size) into the same GPU memory, which is a key reason why n1n.ai can offer such competitive throughput for LLM API access.
The Trade-off: Representational Collapse
A common misconception is that attention diversity comes solely from queries. In reality, MHA allows each head to learn a distinct attention subspace. Different heads can specialize in syntax, long-range dependencies, or positional biases by having different similarity metrics (via ) and different retrieval semantics (via ).
MQA forces all heads to score relevance in the same key space and retrieve from the same value manifold. This reduces the model's "point-of-view capacity." It limits the model's ability to represent multiple incompatible interpretations of the same sequence simultaneously.
Why MQA Still Works
Despite the theoretical loss in expressiveness, MQA has been successfully adopted in models like PaLM and Falcon. This is due to several factors:
- Head Redundancy: In large MHA models, many heads often learn highly correlated patterns.
- Depth Compensation: The model's depth (many layers) and the width of the Feed-Forward Networks (FFN) can absorb some of the representational burden lost in the attention layers.
- Training Adaptation: When a model is trained from scratch with MQA, it learns to optimize the shared KV space effectively.
Implementation Guide: MQA in PyTorch
If you are building custom inference kernels, implementing MQA requires a change in how you handle tensor shapes. Below is a simplified conceptual implementation:
import torch
import torch.nn as nn
import math
class MultiQueryAttention(nn.Module):
def __init__(self, d_model, n_heads, d_head):
super().__init__()
self.n_heads = n_heads
self.d_head = d_head
# Independent queries for each head
self.w_q = nn.Linear(d_model, n_heads * d_head)
# Single shared Key and Value projection
self.w_k = nn.Linear(d_model, d_head)
self.w_v = nn.Linear(d_model, d_head)
self.w_o = nn.Linear(n_heads * d_head, d_model)
def forward(self, x, mask=None):
batch, seq_len, _ = x.shape
# Queries: (batch, seq_len, n_heads, d_head)
q = self.w_q(x).view(batch, seq_len, self.n_heads, self.d_head).transpose(1, 2)
# Keys/Values: (batch, 1, seq_len, d_head)
k = self.w_k(x).view(batch, seq_len, 1, self.d_head).transpose(1, 2)
v = self.w_v(x).view(batch, seq_len, 1, self.d_head).transpose(1, 2)
# Scaled Dot-Product Attention
# k is broadcasted across the n_heads dimension of q
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = torch.softmax(scores, dim=-1)
# v is broadcasted across the n_heads dimension
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
return self.w_o(out)
Conclusion
MQA is not a "free" optimization; it is a deliberate architectural trade-off that favors inference scalability over maximal per-layer expressiveness. By collapsing the KV cache, it allows for massive context windows and higher throughput, making it a cornerstone of modern LLM design.
For developers looking to leverage these optimizations without managing the underlying infrastructure, n1n.ai provides a unified API to access the world's most efficient models.
Get a free API key at n1n.ai.