The original Transformer can only attend to a fixed and limited segment of the input to compute its attention. The major drawback of this architecture is that no information can flow across separated segments which prevents the Transformer to model long-term dependencies. Transformer-XL is an enhancement to the vanilla Transformer which enables the latter to store the most recent hidden states in a fixed-length memory in order to use them for training and inference when computing the attention. Moreover, to make sure that positional information remains coherent across segments, Transformer-XL relies on relative positional encoding to encode the relative position of each element with respect to the expanded segment with its corresponding memory segment. Taking into consideration the memory and relative positional encoding, the new formula to compute the attention of Transformer becomes:
Readers not familiar with Transformer-XL can refer to the official paper for a more detailed explanation. In this post we are going to compare the speed of two similar implementations of the attention-XL: the first one is implemented using einsum (Einstein summation) while the second one is implemented with standard matrix multiplications.
Benchmark results
This benchmark aims to test the speed of two similar implementations of the attention of Transformer-XL on both cpu and cuda devices. The first version makes use of the standard matrix multiplication to implement the attention (torch.matmul
), while the second version relies on the einsum method (torch.einsum
). This speed benchmark might be interesting because most common implementations of attention-XL are based on one of these 2 methods.
Using device cpu
- 10000 matmul attention took 23.157329082489014
- 10000 einsum attention took 56.454954624176025
Using device cuda
- 10000 matmul attention took 9.774099826812744
- 10000 einsum attention took 11.209516763687134
Specifications:
- Cuda: 9.0
- GPU: GeForce GTX 1080TI
- OS: CentOS Linux 7
- CPU: Intel(R) Xeon(R) CPU E5-2697A v4 @ 2.60GHz
- Torch version: 1.3.1+cuda90_cudnn7.6.3_lms
Einsum
Einstein notation (or summation) is an elegant way to express complex operations on tensors, using a domain-specific language. Supported operations include dot products, outer products, transposes, and matrix multiplications.
Example 1: we want to calculate the dot product of two vectors a,b ∈ RI

In the Einstein notation, the summation can be dropped and the meaning of the notation is that the two vectors a and b are multiplied and summed along their only axis i. When an axis is not present in the output (i in the example), it means that the result is summed over that axis. Now let’s have a look at another example.
Example 2: we want to multiply two matrices A ∈ RI×K and B ∈ RK×J followed by calculating the sum of each column (axis i) resulting in a vector c ∈ RJ. Using Einstein summation notation, we can write this as:

All elements ci in c are calculated by multiplying values in the column vectors Ai and row vectors Bj and summing them up. The notation will sum over i and k, since they both don’t compare in the output axes, and obtain the vector c of size j. Finally, let’s have a look at a little more complex example.
Example 3: we have a 3-dimensional tensor T ∈ RN×T×K and we want to project vectors in the 3rd dimension k of size K to Q using a projection matrix of size W ∈ RK×Q. In this way, we should obtain a tensor C ∈ RN×T×Q .

T is multiplied by W summing over k, thus T x W = C has shape ntq. If we wanted to transpose the result along the first and third dimensions, we could simply define Cqtn instead of Cnqt.
Practical implementations
Einsum is implemented in numpy via np.einsum
, in PyTorch via torch.einsum
, and in Tensorflow via tf.einsum
, and they all share the same signature.
For instance, the notation in the second example cj=∑i∑kAikBkj can be written as the equation string “ik,kj->j”. The naming of the indices (i, j, k) is arbitrary but it needs to be used consistently, that is, for operations that will be performed on an axis of equal dimension on multiple tensors, we must use the same symbol. In the example, since we want the axes k to be summed due to matrix multiplication, we have to use the same symbol to index both.
In the notation, we only specify the dimensions of the input matrices and of the output. The function automatically recognizes it has to sum over axes not mentioned in the output. Taking PyTorch as a reference, the following would be coded as:
# A is a tensor of shape i x k
# B is a tensor of shape k x jC = torch.einsum("ik,kj->jj", A, B, C)
# C is a tensor of shape j
Now, let’s have a look at how we can implement a part of the attention-XL using einsum. The tensor key
has shape (full_seq x bs x num_head x dim_head
) and query
has shape (cur_seq x bs x num_head x dim_head
).
# operation: (query + u) * key^T
# for simplicity rename: (cur_seq x bs x num_head x dim_head) -> (ibhd)# for simplicity rename: (full_seq x bs x num_head x dim_head) -> (jbhd)
query = query + u
content_attn = torch.einsum(
"ibhd,jbhd->ijbh",
query,
key
),
) # final shape: cur_seq x full_seq x bs x head_num
Let’s compare it with the version using matmul
:
# operation: (query + u) * key^T
query = (query + u).permute(1, 2, 0, 3)
key = key.permute(1, 2, 3, 0)
content_attn = torch.matmul(query, key) # final shape: bs x head_num x cur_seq x full_seq
Other common operations:
- Transpose: torch.einsum(‘ij->ji’, a) # final shape j x i
- Sum over all elements: torch.einsum(‘ij->’, a) # final shape 1
- Sum over axis: torch.einsum(‘ij->j’, a) # final shape j
- Matrix multiplication: torch.einsum(‘ik,kj->ij’, a, b) # final shape ij
- Outer product: torch.einsum(‘i,j->ij’, a, b) # final shape ij
- Element-wise product: torch.einsum(‘ij,ij->ij’, a, b) # final shape ij
Common operations benchmark
Benchmark 1: operations are executed on the same elements iterations
times:
for i in range(iterations):
op(a[0])
Results cpu:
100000 (2 dim) matmul took 0.8134260177612305
100000 (2 dim) einsum matmul took 5.2123119831085205
100000 (4 dim) matmul took 8.770070314407349
100000 (4 dim) einsum matmul took 11.546452045440674
100000 (2 dim) permute took 0.22981929779052734
100000 (2 dim) einsum permute took 1.0236704349517822
100000 (4 dim) permute took 0.24297189712524414
100000 (4 dim) einsum permute took 1.0838022232055664
100000 (4 dim) matmul and permute took 9.7267587184906
100000 (4 dim) einsum matmul and permute took 22.426202058792114
Results cuda:
100000 (2 dim) matmul took 1.27480149269104
100000 (2 dim) einsum matmul took 6.106035947799683
100000 (4 dim) matmul took 4.648792266845703
100000 (4 dim) einsum matmul took 6.548394441604614
100000 (2 dim) permute took 0.2555727958679199
100000 (2 dim) einsum permute took 1.1705482006072998
100000 (4 dim) permute took 0.2695293426513672
100000 (4 dim) einsum permute took 1.219839334487915
100000 (4 dim) matmul and permute took 4.64383864402771
100000 (4 dim) einsum matmul and permute took 10.697574853897095
Benchmark 2: operations are executed on different elements iterations
times, (due to memory constraints iterations
=10000):
for i in range(iterations):
op(a[i])
Results cpu:
10000 (2 dim) matmul took 0.14002013206481934
10000 (2 dim) einsum matmul took 0.6052663326263428
10000 (4 dim) matmul took 28.781249284744263
10000 (4 dim) einsum matmul took 1.8268659114837646
10000 (2 dim) permute took 0.04583311080932617
10000 (2 dim) einsum permute took 0.1613006591796875
10000 (4 dim) permute took 0.045726776123046875
10000 (4 dim) einsum permute took 0.16038727760314941
10000 (4 dim) matmul and permute took 1.3352277278900146
10000 (4 dim) einsum matmul and permute took 3.2331697940826416
Results cuda:
10000 (2 dim) matmul took 0.6923580169677734
10000 (2 dim) einsum matmul took 1.2274961471557617
10000 (4 dim) matmul took 19.534417152404785
10000 (4 dim) einsum matmul took 9.732558488845825
10000 (2 dim) permute took 0.24881267547607422
10000 (2 dim) einsum permute took 0.4596419334411621
10000 (4 dim) permute took 4.46686577796936
10000 (4 dim) einsum permute took 4.810505390167236
10000 (4 dim) matmul and permute took 21.444167137145996
10000 (4 dim) einsum matmul and permute took 20.451982736587524