Multi-Query Attention and Memory-Efficient Decoding for LLMs

Authors
  • avatar
    Name
    Nino
    Occupation
    Senior Tech Editor

In our previous exploration of LLM caching, we established how KV caching transforms the autoregressive decoding process by eliminating redundant attention computations. By storing the Keys (K) and Values (V) from previous tokens, Transformers reduce per-token computation from quadratic to linear relative to the sequence length. However, as we scale these models to handle longer contexts and higher throughput, KV caching introduces a significant new bottleneck: memory capacity and bandwidth.

As models scale, the memory required for the KV cache often exceeds the memory required for the model weights themselves, particularly when dealing with long contexts. This post examines Multi-Query Attention (MQA)—an architectural modification that directly attacks this memory bottleneck by changing how attention heads share representations. For developers using high-performance APIs like n1n.ai, understanding these underlying architectures is crucial for optimizing prompt engineering and context management.

The Memory Wall: Why MHA Struggles

In standard Multi-Head Attention (MHA), each attention head has its own independent set of query, key, and value projections. While this provides the model with high expressiveness, it creates a linear scaling problem for memory. For a model with LL transformer layers, HH attention heads, a sequence length of TT, and a head dimension of dhd_h, the KV cache memory scales as:

O(L * H * T * d_h)

KV caching removes redundant computation, but it does nothing to reduce the memory growth relative to the number of heads. For modern LLMs with 32–128 heads and context windows reaching 128k tokens or more, the KV cache memory and bandwidth quickly become the primary limiting factors in inference throughput. This is why platforms like n1n.ai prioritize models that utilize memory-efficient attention mechanisms to ensure stable, high-speed responses.

Defining Multi-Query Attention (MQA)

Multi-Query Attention (MQA) answers the scaling problem by imposing a deliberate constraint: all attention heads share a single set of keys and values, while maintaining independent queries.

Formally, let XX be the input. In MQA, the projections are defined as:

  • Qi=XW{Qi}Q_i = X W_\{Q_i\}
  • K=XWKK = X W_K
  • V=XWVV = X W_V

Each head ii computes attention as:

Attention_i = softmax((Q_i * K^T) / sqrt(d_h)) * V

In this setup, the keys and values are shared across all HH heads. It is important to note that WKW_K is not equal to WVW_V; they remain distinct projections, but they are no longer duplicated for every head. This collapses the KV cache size by a factor of HH.

The Quantitative Impact: A 64x Reduction

To visualize the impact, let's look at a hypothetical model with the following parameters:

  • Layers (LL): 80
  • Attention heads (HH): 64
  • Head dimension (dhd_h): 128
  • Context length (TT): 2048
  • Precision: FP16 (2 bytes per element)
Attention TypeKV Cache FormulaKV Cache per Sequence
Multi-Head Attention (MHA)2LHTdh22 * L * H * T * d_h * 2 bytes~1.2 GB
Multi-Query Attention (MQA)2L1Tdh22 * L * 1 * T * d_h * 2 bytes~19 MB
Reduction~64x smaller

This reduction is transformative for serving infrastructure. By reducing the memory footprint, we can fit significantly more concurrent requests (batch size) into the same GPU memory, which is a key reason why n1n.ai can offer such competitive throughput for LLM API access.

The Trade-off: Representational Collapse

A common misconception is that attention diversity comes solely from queries. In reality, MHA allows each head to learn a distinct attention subspace. Different heads can specialize in syntax, long-range dependencies, or positional biases by having different similarity metrics (via KiK_i) and different retrieval semantics (via ViV_i).

MQA forces all heads to score relevance in the same key space and retrieve from the same value manifold. This reduces the model's "point-of-view capacity." It limits the model's ability to represent multiple incompatible interpretations of the same sequence simultaneously.

Why MQA Still Works

Despite the theoretical loss in expressiveness, MQA has been successfully adopted in models like PaLM and Falcon. This is due to several factors:

  1. Head Redundancy: In large MHA models, many heads often learn highly correlated patterns.
  2. Depth Compensation: The model's depth (many layers) and the width of the Feed-Forward Networks (FFN) can absorb some of the representational burden lost in the attention layers.
  3. Training Adaptation: When a model is trained from scratch with MQA, it learns to optimize the shared KV space effectively.

Implementation Guide: MQA in PyTorch

If you are building custom inference kernels, implementing MQA requires a change in how you handle tensor shapes. Below is a simplified conceptual implementation:

import torch
import torch.nn as nn
import math

class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads, d_head):
        super().__init__()
        self.n_heads = n_heads
        self.d_head = d_head

        # Independent queries for each head
        self.w_q = nn.Linear(d_model, n_heads * d_head)
        # Single shared Key and Value projection
        self.w_k = nn.Linear(d_model, d_head)
        self.w_v = nn.Linear(d_model, d_head)
        self.w_o = nn.Linear(n_heads * d_head, d_model)

    def forward(self, x, mask=None):
        batch, seq_len, _ = x.shape

        # Queries: (batch, seq_len, n_heads, d_head)
        q = self.w_q(x).view(batch, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        # Keys/Values: (batch, 1, seq_len, d_head)
        k = self.w_k(x).view(batch, seq_len, 1, self.d_head).transpose(1, 2)
        v = self.w_v(x).view(batch, seq_len, 1, self.d_head).transpose(1, 2)

        # Scaled Dot-Product Attention
        # k is broadcasted across the n_heads dimension of q
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = torch.softmax(scores, dim=-1)
        # v is broadcasted across the n_heads dimension
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).contiguous().view(batch, seq_len, -1)
        return self.w_o(out)

Conclusion

MQA is not a "free" optimization; it is a deliberate architectural trade-off that favors inference scalability over maximal per-layer expressiveness. By collapsing the KV cache, it allows for massive context windows and higher throughput, making it a cornerstone of modern LLM design.

For developers looking to leverage these optimizations without managing the underlying infrastructure, n1n.ai provides a unified API to access the world's most efficient models.

Get a free API key at n1n.ai.