Tensor Product Attention: Curiosities abound

A recent paper Tensor Product Attention Is All You Need1 grabbed my attention. Over the last year, I have been exploring and investigating ways to reinterpret attention mechanism, mainly for my own edification. What correlations do a transformer really capture? And unsurprisingly, I have been looking at using intuition from the physics of correlated systems.

Firstly, attention mechanism is often written in a mathematically confusing and redundant way in the machine learning literature. The notation is often obfuscated by implementation quirks of matrix multiplications on GPUs. So let’s set up the notation, and simplify.

In the notes below, I will ignore position encoding. RoPE or learnable additive position encodings do not change the foundational mathematical intuitions I am trying to convey here — it is a distraction.

I use ℓ\ell for layer index and hh for head index.

The key quantity is the residual stream, XℓX^\ell. This matrix is getting transformed by attention and MLP blocks. The embedding dimension dmodeld_\textrm{model} is the size of the vector space in which tokens are being embedded.

We need a few other matrices to really explain what’s going on.

Note that in ML/ AI papers the Query, Value and Key matrices are always written separately, but in essence, we are low-rank decomposing (as product of rectangular matrices) two matrices, 𝐖QKℓ,h,𝐖OVℓ,h\mathbf{W}_{QK}^{\ell,h} \, \, , \mathbf{W}_{OV}^{\ell,h}. This will be clear when we write attention is terms of these matrices — 

Attnℓ(𝐱i)=∑h=1H∑j=1n[softmaxj(𝐱i⊤𝐖QKℓ,h𝐱jdhead)]𝐖OVℓ,h𝐱j\begin{aligned} \text{Attn}^{\ell}(\mathbf{x}_i) = \sum_{h=1}^{H} \sum_{j=1}^{n} \left[ \text{softmax}_j \left( \frac{\mathbf{x}_i^\top \mathbf{W}_{\text{QK}}^{\ell,h} \mathbf{x}_{j}}{\sqrt{d_{\text{head}}}} \right) \right] \mathbf{W}_{\text{OV}}^{\ell,h} \mathbf{x}_j \end{aligned}

The attention operator Attnℓ\textrm{Attn}^\ell at layer ℓ\ell is a sum over individual attention heads, hh, with HH total heads. Note, here I choose to call the operator the net function that returns a vector of same size as 𝐱i\mathbf{x}_i — one can choose to add this back to the residual XℓX^\ell. Some architectures do so, others send it through the MLP operator. There are a lot of different transformer architectures out there in the various LLMs, and for the purpose of this discussion, it’s unimportant. Moreover, the papers have a bewildering range of definitions of what part of is called attention, which is why I bored you with setting up notation. You are welcome.

Note that the number of heads and head dimensions are chosen such that we always have dmodel×dmodeld_{\text{model}} \times d_{\text{model}} matrices in the above expression.

The only correlation between tokens explored in an transformer is pairwise. The MLP operator acts on the per-token embedding 𝐱i\mathbf{x}_i and do not mix 𝐱i\mathbf{x}_i and 𝐱j\mathbf{x}_j. In the Attention operator softmaxj\textrm{softmax}_j term is a normalized weight — and every other token embedding 𝐱j\mathbf{x}_j in the context window is getting summed over by this weight multiple by a linear transformation matrix. It is really quite simple.

Well, one may wonder — why only pairwise correlations? And, why only the above functional form for pairwise correlations?

A digression — for physicists like me, any time we see pairwise correlations, we think about Potts model, a generalization of the Ising Model which is perhaps better known. In the q-state Potts model the “spins” are unit vectors that point in q symmetric directions of a hypertetrahedron in q-1 dimensions, see here2. In the classical Potts model these vectors interact only if their “spins” (state) are the same.

Can we draw an analogy with Potts Model? Yes, of course! Well, a paper3 already did a version of it—with a Potts Model where the interactions are not restricted to same “spins” but mix “spins”. It’s an enticing direction to study the dynamics of transformers using such mappings.

OK, end of digression.

The Memory Bottleneck in Modern Transformers

Large language models face a critical scalability challenge: the Key-Value (KV) cache. During autoregressive generation, standard Multi-Head Attention (MHA) stores keys and values for all previously generated tokens, consuming memory that grows linearly with sequence length:

MemoryMHAâˆŧn×H×dhead\text{Memory}_{\text{MHA}} \sim n \times H \times d_\text{head}

See table to to recall notation. For a model with H=32H = 32 and dhead=128d_\text{head} = 128 processing a n=105n = 10^5 token context, this amounts to over 800MB just for the KV cache of a single layer!

The fundamental question is whether we must store the full H×dheadH \times d_\text{head}representation for each token, or whether a more compact factorized representation can capture the essential structure with minimal information loss.

Tensor Decompositions: A Primer

Before diving into Tensor Product Attention (TPA), we need to understand the landscape of tensor decomposition methods. A tensor is simply a multi-dimensional array—scalars are 0-order tensors, vectors are 1st-order, matrices are 2nd-order, and so on.

CP Decomposition (CANDECOMP/PARAFAC)

The most common Tensor Decomposition is probably the CP decomposition.

Definition (CP Decomposition): A third-order tensor đ’ŗâˆˆâ„I×J×K\mathcal{X} \in \mathbb{R}^{I \times J \times K} has a rank-RR CP decomposition if it can be written as:

đ’ŗ=∑r=1R𝐚r∘𝐛r∘𝐜r \mathcal{X} = \sum_{r=1}^{R} \mathbf{a}_r \circ \mathbf{b}_r \circ \mathbf{c}_r where 𝐚r∈ℝI\mathbf{a}_r \in \mathbb{R}^I, 𝐛r∈ℝJ\mathbf{b}_r \in \mathbb{R}^J, 𝐜r∈ℝK\mathbf{c}_r \in \mathbb{R}^K and ∘\circ denotes the outer product.

Element wise, Equivalently, for indices i,j,ki,j,k :

đ’ŗijk=∑r=1Rairbjrckr\mathcal{X}_{ijk} = \sum_{r=1}^{R} a_{ir} b_{jr} c_{kr}

The CP decomposition represents a tensor as a sum of rank-1 tensors (outer products of vectors). This is the natural generalization of matrix SVD to higher orders, though unlike SVD, computing the optimal CP decomposition is NP-hard. Yeah, sucks, right?

Tucker Decomposition

Another popular tensor decomposition method is the Tucker Decomposition.

Definition (Tucker Decomposition): A Tucker decomposition factorizes a tensor into a core tensor đ’ĸ∈ℝR1×R2×R3\mathcal{G} \in \mathbb{R}^{R_1 \times R_2 \times R_3} and factor matrices along each mode: đ’ŗ=đ’ĸ×1đ€Ã—2đÃ—3𝐂 \mathcal{X} = \mathcal{G} \times_1 \mathbf{A} \times_2 \mathbf{B} \times_3 \mathbf{C} where 𝐀∈ℝI×R1\mathbf{A} \in \mathbb{R}^{I \times R_1}, 𝐁∈ℝJ×R2\mathbf{B} \in \mathbb{R}^{J \times R_2} , 𝐂∈ℝK×R3\mathbf{C} \in \mathbb{R}^{K \times R_3} and ×n\times_ndenotes the mode-nn product.

More directly, the decomposition is — 

đ’ŗpqr=∑iR1∑jR2∑kR3đ’ĸijk𝐀pi𝐁qj𝐂rk\mathcal{X}_{p q r} = \sum_{i}^{R_1} \sum_{j}^{R_2} \sum_{k}^{R_3}\mathcal{G}_{i j k}\, \mathbf{A}_{pi} \,\mathbf{B}_{qj} \mathbf{C}_{rk}

The Tucker decomposition generalizes CP by allowing a dense core tensor. Note that the the sizes R1,R2,R3R_1, R_2, R_3 is obviously within the sizes I,J,KI, J, K of the tensor dimensions— a common choice is R1=R2=R3=min(I,J,K)R_1 = R_2 = R_3 = \text{min} ( I, J, K) . When tensor đ’ĸ\mathcal{G} is super-diagonal (non-zero only when all indices are equal), Tucker reduces to CP.

Tensor Train Decomposition

The tensor decomposition most familiar to physicists is probably the tensor train decomposition.

Definition (Tensor Train): A tensor train (TT) or Matrix Product State (MPS) represents a dd-dimensional tensor as a product of matrices —

đ’ŗi1,i2,â€Ļ,id=𝐆i1[1]𝐆i2[2]â‹¯đ†id[d]\mathcal{X}_{i_1, i_2, \ldots, i_d} = \mathbf{G}^{[1]}_{i_1} \mathbf{G}^{[2]}_{i_2} \cdots \mathbf{G}^{[d]}_{i_d}

where 𝐆ik[k]∈ℝrk−1×rk\mathbf{G}^{[k]}_{i_k} \in \mathbb{R}^{r_{k-1} \times r_k} with r0=rd=1r_0 = r_d = 1. The parameters {r1,â€Ļ,rk,â€Ļ,rd−1}\{r_1, \ldots, r_k, \ldots, r_{d-1}\}are called bond dimensions or TT-ranks.

This is the same structure used to represent quantum many-body states in physics.

Tensor Product Attention: The Core Claim

Now we arrive at the key contribution of the TPA paper. Instead of storing full query, key, and value matrices, TPA represents them using contextual low-rank factorizations.

Standard Multi-head Attention

For token ii with embedding 𝐱i\mathbf{x}_i, layer ℓ\ell and head h∈{1,â€Ļ,H}h \in \{ 1, \dots, H \}

đĒiℓ,h=𝐖Qℓ,h𝐱i∈ℝdhead𝐤iℓ,h=𝐖Kℓ,h𝐱i∈ℝdheadđ¯iℓ,h=𝐖Vℓ,h𝐱t∈ℝdhead\begin{align} \mathbf{q}_i^{\ell,h} = \mathbf{W}_Q^{\ell,h} \mathbf{x}_i \in \mathbb{R}^{d_{\text{head}}} \\ \mathbf{k}_i^{\ell,h} = \mathbf{W}_K^{\ell,h} \mathbf{x}_i \in \mathbb{R}^{d_{\text{head}}} \\ \mathbf{v}_i^{\ell,h} = \mathbf{W}_V^{\ell,h} \mathbf{x}_t \in \mathbb{R}^{d_{\text{head}}} \end{align}

We can stack all the heads into matrices, note that now the matrices are not just weights, but weights multiplied by embeddings—

𝐐i=[đĒi1đĒi2⋮đĒiH]∈ℝH×dhead\begin{equation} \mathbf{Q}_i = \begin{bmatrix} \mathbf{q}_i^1 \\ \mathbf{q}_i^2 \\ \vdots \\ \mathbf{q}_i^H \end{bmatrix} \in \mathbb{R}^{H \times d_{\text{head}}} \end{equation}

TPA

TPA factorizes the stacked query/key/value matrices as rank-RR sums of outer products.

𝐐i=1RQ∑r=1RQ𝐚rQ(𝐱i)⊗𝐛rQ(𝐱i)∈ℝH×dhead\begin{equation} \mathbf{Q}_i = \frac{1}{R_Q} \sum_{r=1}^{R_Q} \mathbf{a}^Q_r(\mathbf{x}_i) \otimes \mathbf{b}^Q_r(\mathbf{x}_i) \in \mathbb{R}^{H \times d_{\text{head}}} \end{equation}

Note that the dimensions work out, for clarity — 

𝐱i∈ℝdmodel(input)𝐖ra,Q𝐱i=𝐚rQ(𝐱t)∈ℝH(head factor)𝐖rb,Q𝐱i=𝐛rQ(𝐱t)∈ℝdhead(feature factor)𝐚rQ⊗𝐛rQ=ℝH×dhead(outer product)1RQ∑r=1RQ𝐚rQ⊗𝐛rQ=𝐐i∈ℝH×dhead✓\begin{align} \mathbf{x}_i \in \mathbb{R}^{d_{\text{model}}} \quad \text{(input)} \\ \mathbf{W}^{a,Q}_r \mathbf{x}_i = \mathbf{a}^Q_r(\mathbf{x}_t) \in \mathbb{R}^{H} \quad \text{(head factor)} \\ \mathbf{W}^{b,Q}_r \mathbf{x}_i = \mathbf{b}^Q_r(\mathbf{x}_t) \in \mathbb{R}^{d_{\text{head}}} \quad \text{(feature factor)} \\ \mathbf{a}^Q_r \otimes \mathbf{b}^Q_r = \mathbb{R}^{H \times d_{\text{head}}} \quad \text{(outer product)} \\ \frac{1}{R_Q}\sum_{r=1}^{R_Q} \mathbf{a}^Q_r \otimes \mathbf{b}^Q_r \, = \mathbf{Q}_i \in \mathbb{R}^{H \times d_{\text{head}}} \quad \checkmark \end{align}

So for standard MHA, each head independently projects the input—

đĒih=𝐖Qh𝐱i\begin{equation} \mathbf{q}_i^h = \mathbf{W}_Q^h \mathbf{x}_i \end{equation}

whereas for TPA, all heads share RQR_Q feature vectors, weighted differently per head,

đĒih=1RQ∑r=1RQ[𝐚rQ(𝐱i)]h⏟head-specific weight⋅𝐛rQ(𝐱i)⏟shared feature vector\begin{equation} \mathbf{q}_i^h = \frac{1}{R_Q} \sum_{r=1}^{R_Q} \underbrace{[\mathbf{a}^Q_r(\mathbf{x}_i)]_h}_{\text{head-specific weight}} \cdot \underbrace{\mathbf{b}^Q_r(\mathbf{x}_i)}_{\text{shared feature vector}} \end{equation}

The Key Idea: Instead of H independent dheadd_\text{head} -dimensional vectors (one per head), TPA uses— 

  • RQR_Q shared feature vectors 𝐛rQ∈ℝdhead\mathbf{b}^Q_r \in \mathbb{R}^{d_{\text{head}}}
  • RQR_Q weight vectors 𝐚rQ∈ℝH\mathbf{a}^Q_r \in \mathbb{R}^H— one scalar per head, determining how much each head uses each feature

where RQâ‰ĒHR_Q \ll H, therefore leading to parameter efficiency. Obviously, we have similar things going on for 𝐊i\mathbf{K}_i and 𝐕i\mathbf{V}_i.

Parameter counts

For MHA, we total number of parameters for queries only (similar for Keys and Values) are H×dhead×dmodel=dmodel2H \times d_\text{head} \times d_\text{model} = d^2_\text{model}

For TPA we have— 

  • Head factors: RQR_Q matrices of size H×dmodelH \times d_\text{model}
  • Feature factors: RQR_Q matrices of size dhead×dmodeld_\text{head} \times d_\text{model}
  • Total parameters— RQ(H+dhead)dmodelR_Q (H + d_\text{head} ) d_\text{model}

Example with typical paper values: H=32H=32, dhead=128d_{\text{head}}=128, dmodel=4096d_{\text{model}}=4096, RQ=6\boxed{R_Q=6}:

  • MHA: 32×128×4096=16,777,21632 \times 128 \times 4096 = 16{,}777{,}216 parameters
  • TPA: 6×4096×(32+128)=3,932,1606 \times 4096 \times (32 + 128) = 3{,}932{,}160 parameters
  • TPA uses ~23% of MHA’s parameters

Note: Unlike LoRA which factorizes weights, TPA factorizes activations. This means the factorization is contextual—it depends on the input token 𝐱i\mathbf{x}_i. It’s a very interesting idea in how to capture input-dependent structure while maintaining compression!

Memory Reduction

The major advantage claimed by the paper is the memory saving in KV cache. My interest in this paper is beyond this, to study other forms of attention, but it’s useful to note the memory arguments.

From standard MHA we have— 

  • Store 𝐊i∈ℝH×dhead\mathbf{K}_i \in \mathbb{R}^{H \times d_{\text{head}}} and 𝐕i∈ℝH×dhead\mathbf{V}_i \in \mathbb{R}^{H \times d_{\text{head}}}
  • Total: 2×H×dhead=2dmodel2 \times H \times d_{\text{head}} = 2d_{\text{model}}

TPA stores only the factors— 

  • Store {𝐚rK(𝐱i)}r=1RK\{\mathbf{a}^K_r(\mathbf{x}_i)\}_{r=1}^{R_K} and {𝐛rK(𝐱i)}r=1RK\{\mathbf{b}^K_r(\mathbf{x}_i)\}_{r=1}^{R_K}for keys
  • Store {𝐚rV(𝐱i)}r=1RV\{\mathbf{a}^V_r(\mathbf{x}_i)\}_{r=1}^{R_V} and {𝐛rV(𝐱i)}r=1RV\{\mathbf{b}^V_r(\mathbf{x}_i)\}_{r=1}^{R_V}for values
  • Total: (RK+RV)(H+dhead)(R_K + R_V)(H + d_{\text{head}})

The compression ratio is

΁=(RK+RV)(H+dhead)2Hdhead\rho = \frac{(R_K + R_V)(H + d_{\text{head}})}{2H \, d_{\text{head}}}

Concrete example: H=32,dhead=128,RK=RV=1H = 32, d_{\text{head}} = 128, R_K = R_V = 1:

  • TPA cache =2×(32+128)=320= 2 \times (32 + 128) = 320 values per token
  • MHA cache =2×32×128=8192= 2 \times 32 \times 128 = 8192 values per token

so TPA leads to 96%96 \% memory reduction! For context window of 100,000 tokens, MHA needs 1.6 GB of memory wheres TPA needs 64 MB of memory! (both per layer)

Connection to MPS

Another way to look at TPA is recasting it as a MPS. Per head, instead of the term 𝐱i𝐖QKℓ,h𝐱j\mathbf{x}_{i}\mathbf{W}_{\text{QK}}^{\ell,h} \mathbf{x}_{j} in MHA, for TPA we have

(đĒih)⊤⋅𝐤jh=(1RQ∑r=1RQ[𝐚rQ]h⋅𝐛rQ)⊤⋅(1RK∑s=1RK[𝐚sK]h⋅𝐛sK)=1RQRK∑r=1RQ∑s=1RK([𝐚rQ]h⋅𝐛rQ)⊤⋅([𝐚sK]h⋅𝐛sK)=∑r=1RQ∑s=1RK[𝐚rQ(𝐱i)]h⋅[𝐚sK(𝐱j)]h⏟head-space mixing⋅(𝐛rQ(𝐱i))⊤⋅𝐛sK(𝐱j)⏟feature-space contraction\begin{align} (\mathbf{q}_i^h)^\top \cdot \mathbf{k}_j^h = \left(\frac{1}{R_Q} \sum_{r=1}^{R_Q} [\mathbf{a}^Q_r]_h \cdot \mathbf{b}^Q_r\right)^\top \cdot \left(\frac{1}{R_K} \sum_{s=1}^{R_K} [\mathbf{a}^K_s]_h \cdot \mathbf{b}^K_s\right) \\ = \frac{1}{R_Q R_K} \sum_{r=1}^{R_Q} \sum_{s=1}^{R_K} ([\mathbf{a}^Q_r]_h \cdot \mathbf{b}^Q_r)^\top \cdot ([\mathbf{a}^K_s]_h \cdot \mathbf{b}^K_s) \\ =\sum_{r=1}^{R_Q} \sum_{s=1}^{R_K} \underbrace{[\mathbf{a}^Q_r(\mathbf{x}_i)]_h \cdot [\mathbf{a}^K_s(\mathbf{x}_j)]_h}_{\text{head-space mixing}} \cdot \underbrace{(\mathbf{b}^Q_r(\mathbf{x}_i))^\top \cdot \mathbf{b}^K_s(\mathbf{x}_j)}_{\text{feature-space contraction}} \end{align}

We now we are getting somewhere, right? That’s a very different take on the attention matrix capturing token-token correlations!

  • Rank indices (r,s)(r,s) play the role of bond indices in MPS
  • ∑r=1RQ∑s=1RK\sum_{r=1}^{R_Q} \sum_{s=1}^{R_K}is the bond cotraction
  • Low ranks RQ,RKR_Q, R_K is equivalent to low bond dimension and increased efficiency and high bond dimension leads to more expressiveness

Copy Tensor

We can look at the above expression in terms of copy tensors in Tensor Networks. A copy tensor4 allows for reusing information. For a vector 𝐚∈ℝd\mathbf{a} \in \mathbb{R}^d, the copy operation is represented by a diagonal tensor, 𝒞ij=δij\mathcal{C}_{ij} = \delta_{ij} , the Kronecker delta. In other words, a copy tensor allows a single input to be reused in multiple tensor contractions.

Note what’s happening in TPA! The same input vector 𝐱i\mathbf{x}_i is used 2RQ2 R_Q times for Query, and so on for Key and Value — 

𝐱i→𝐖1a,Q𝐚1Q(𝐱i)∈ℝH𝐱i→𝐖1b,Q𝐛1Q(𝐱i)∈ℝdhead⋮𝐱i→𝐖RQa,Q𝐚RQQ(𝐱i)∈ℝH𝐱i→𝐖RQb,Q𝐛RQQ(𝐱i)∈ℝdhead\begin{align} \mathbf{x}_i \xrightarrow{\mathbf{W}^{a,Q}_1} \mathbf{a}^Q_1(\mathbf{x}_i) \in \mathbb{R}^H \\ \mathbf{x}_i \xrightarrow{\mathbf{W}^{b,Q}_1} \mathbf{b}^Q_1(\mathbf{x}_i) \in \mathbb{R}^{d_{\text{head}}} \\ \vdots \\ \mathbf{x}_i \xrightarrow{\mathbf{W}^{a,Q}_{R_Q}} \mathbf{a}^Q_{R_Q}(\mathbf{x}_i) \in \mathbb{R}^H \\ \mathbf{x}_i \xrightarrow{\mathbf{W}^{b,Q}_{R_Q}} \mathbf{b}^Q_{R_Q}(\mathbf{x}_i) \in \mathbb{R}^{d_{\text{head}}} \end{align}

Instead of computing H independent projections (standard MHA), TPA computes 2RQ2 R_Q projections and cleverly recombines them. When RQâ‰ĒHR_Q \ll H, this architecture is much more efficient while maintaining expressiveness of a Tensor Network (outer product).

Few other things…

  • The paper shows that TPA is compatible with RoPE embedding. RoPE only acts on the 𝐛\mathbf{b} vectors. The keys are pre-rotated and stored, so no rotation is needed during decoding. Only the current query needs to be rotated. Neat!
  • Remarkably, standard attention mechanisms are non-contextual variants of TPA! They show that both GQA (Grouped Query Attention) and MQA (Multi-Query Attention) are simply poor man’s version of TPA with 𝐚\mathbf{a} being independent of 𝐱i\mathbf{x}_i !

I loved the paper. The key lessons:

  1. Structure matters: Exploiting low-rank structure in attention patterns enables massive compression
  2. Contextual factorization: Factorizing activations (not weights) is a very interesting concept
  3. Model performance and memory needs: As with several other work recently, the belief that larger context window either means larger models, or we need to compromise on expressivity of the correlations captured in attention, may be incorrect

As we push toward longer contexts and larger models, principled compression techniques like TPA is a fruitful area of research. The tensor network perspective suggests we’ve only begun to explore the space of possible architectures!

References

  1. Zhang, Yifan, et al. “Tensor product attention is all you need.” arXiv preprint arXiv:2501.06425 (2025). â†Šī¸Ž
  2. Wu, Fa-Yueh. “The Potts Model.” Reviews of modern physics 54.1 (1982): 235. â†Šī¸Ž
  3. Rende, Riccardo, et al. “Mapping of attention mechanisms to a generalized Potts Model.” Physical Review Research 6.2 (2024): 023057. â†Šī¸Ž
  4. Glasser, Ivan, Nicola Pancotti, and J. Ignacio Cirac. “From probabilistic graphical models to generalized tensor networks for supervised learning.” IEEE Access 8 (2020): 68169-68182. â†Šī¸Ž

Comments

Leave a comment