The original Transformer can only attend to a fixed and limited segment of the input to compute its attention. The major drawback of this architecture is that no information can flow across separated segments which prevents the Transformer to model long-term dependencies. Transformer-XL is an enhancement to the vanilla Transformer which enables the latter to store the most recent hidden states in a fixed-length memory in order to use them for training and inference when computing the attention. Moreover, to make sure that positional information remains coherent across segments, Transformer-XL relies on relative positional encoding to encode the relative position of each element with respect to the expanded segment with its corresponding memory segment. Taking into consideration the memory and relative positional encoding, the new formula to compute the attention of Transformer becomes:

Readers not familiar with Transformer-XL can refer to the official paper for a more detailed explanation. In this post we are going to compare the speed of two similar implementations of the attention-XL: the first one is implemented using einsum (Einstein summation) while the second one is implemented with standard matrix multiplications.

## Benchmark results

This benchmark aims to test the speed of two similar implementations of the attention of Transformer-XL on both *cpu* and *cuda* devices. The first version makes use of the standard matrix multiplication to implement the attention (`torch.matmul`

), while the second version relies on the einsum method (`torch.einsum`

). This speed benchmark might be interesting because most common implementations of attention-XL are based on one of these 2 methods.

**Using device cpu**

- 10000 matmul attention took 23.157329082489014
- 10000 einsum attention took 56.454954624176025

**Using device cuda**

- 10000 matmul attention took 9.774099826812744
- 10000 einsum attention took 11.209516763687134

**Specifications**:

- Cuda: 9.0
- GPU: GeForce GTX 1080TI
- OS: CentOS Linux 7
- CPU: Intel(R) Xeon(R) CPU E5-2697A v4 @ 2.60GHz
- Torch version: 1.3.1+cuda90_cudnn7.6.3_lms

## Einsum

Einstein notation (or summation) is an elegant way to express complex operations on tensors, using a domain-specific language. Supported operations include dot products, outer products, transposes, and matrix multiplications.

**Example 1**: we want to calculate the dot product of two vectors *a,b *∈ R^{I}

In the Einstein notation, the summation can be dropped and the meaning of the notation is that the two vectors *a* and b are multiplied and summed along their only axis *i*. When an axis is not present in the output (*i* in the example), it means that the result is summed over that axis. Now let’s have a look at another example.

**Example 2**: we want to multiply two matrices A ∈ R^{I×K} and B ∈ R^{K×J} followed by calculating the sum of each column (axis *i*) resulting in a vector *c* ∈ R^{J}. Using Einstein summation notation, we can write this as:

All elements *c _{i}* in c are calculated by multiplying values in the column vectors A

_{i}and row vectors B

_{j}and summing them up. The notation will sum over

*i*and

*k,*since they both don’t compare in the output axes

*,*and obtain the vector c of size

*j*. Finally, let’s have a look at a little more complex example.

**Example 3**: we have a 3-dimensional tensor T ∈ R^{N×T×K} and we want to project vectors in the 3rd dimension *k* of size K to Q using a projection matrix of size W ∈ R^{K×Q}. In this way, we should obtain a tensor C ∈ R^{N×T×Q} .

T is multiplied by W summing over *k*, thus T x W = C has shape *ntq*. If we wanted to transpose the result along the first and third dimensions, we could simply define C_{qtn }instead of Cnqt.

## Practical implementations

Einsum is implemented in numpy via `np.einsum`

, in PyTorch via `torch.einsum`

, and in Tensorflow via `tf.einsum`

, and they all share the same signature.

For instance, the notation in the second example c_{j}=∑_{i}∑_{k}A_{ik}B_{kj} can be written as the equation string “ik,kj->j”. The naming of the indices (i, j, k) is arbitrary but it needs to be used consistently, that is, for operations that will be performed on an axis of equal dimension on multiple tensors, we must use the same symbol. In the example, since we want the axes k to be summed due to matrix multiplication, we have to use the same symbol to index both.

In the notation, we only specify the dimensions of the input matrices and of the output. The function automatically recognizes it has to sum over axes not mentioned in the output. Taking PyTorch as a reference, the following would be coded as:

`# A is a tensor of shape i x k`

# B is a tensor of shape k x j`C = torch.einsum("ik,kj->jj", A, B, C)`

# C is a tensor of shape j

Now, let’s have a look at how we can implement a part of the attention-XL using einsum. The tensor `key`

has shape (`full_seq x bs x num_head x dim_head`

) and `query`

has shape (`cur_seq x bs x num_head x dim_head`

).

`# operation: (query + u) * key^T`

# for simplicity rename: (cur_seq x bs x num_head x dim_head) -> (ibhd)`# for simplicity rename: (full_seq x bs x num_head x dim_head) -> (jbhd)`

`query = query + u`

`content_attn = torch.einsum(`

` "ibhd,jbhd->ijbh", `

`query, `

`key`

`),`

`) # final shape: cur_seq x full_seq x bs x head_num`

Let’s compare it with the version using `matmul`

:

`# operation: (query + u) * key^T`

`query = (query + u).permute(1, 2, 0, 3)`

`key = key.permute(1, 2, 3, 0)`

`content_attn = torch.matmul(query, key) # final shape: bs x head_num x cur_seq x full_seq`

**Other common operations**:

- Transpose: torch.einsum(‘ij->ji’, a) # final shape j x i
- Sum over all elements: torch.einsum(‘ij->’, a) # final shape 1
- Sum over axis: torch.einsum(‘ij->j’, a) # final shape j
- Matrix multiplication: torch.einsum(‘ik,kj->ij’, a, b) # final shape ij
- Outer product: torch.einsum(‘i,j->ij’, a, b) # final shape ij
- Element-wise product: torch.einsum(‘ij,ij->ij’, a, b) # final shape ij

## Common operations benchmark

**Benchmark 1**: operations are executed on the same elements `iterations`

times:

`for i in range(iterations):`

` op(a[0])`

Results **cpu**:

`100000 (2 dim) matmul took 0.8134260177612305`

`100000 (2 dim) einsum matmul took 5.2123119831085205`

`100000 (4 dim) matmul took 8.770070314407349`

`100000 (4 dim) einsum matmul took 11.546452045440674`

`100000 (2 dim) permute took 0.22981929779052734`

`100000 (2 dim) einsum permute took 1.0236704349517822`

`100000 (4 dim) permute took 0.24297189712524414`

`100000 (4 dim) einsum permute took 1.0838022232055664`

`100000 (4 dim) matmul and permute took 9.7267587184906`

`100000 (4 dim) einsum matmul and permute took 22.426202058792114`

Results **cuda**:

`100000 (2 dim) matmul took 1.27480149269104`

`100000 (2 dim) einsum matmul took 6.106035947799683`

`100000 (4 dim) matmul took 4.648792266845703`

`100000 (4 dim) einsum matmul took 6.548394441604614`

`100000 (2 dim) permute took 0.2555727958679199`

`100000 (2 dim) einsum permute took 1.1705482006072998`

`100000 (4 dim) permute took 0.2695293426513672`

`100000 (4 dim) einsum permute took 1.219839334487915`

`100000 (4 dim) matmul and permute took 4.64383864402771`

`100000 (4 dim) einsum matmul and permute took 10.697574853897095`

**Benchmark 2**: operations are executed on different elements `iterations`

times, (due to memory constraints `iterations`

=10000):

`for i in range(iterations):`

` op(a[i])`

Results **cpu**:

`10000 (2 dim) matmul took 0.14002013206481934`

`10000 (2 dim) einsum matmul took 0.6052663326263428`

`10000 (4 dim) matmul took 28.781249284744263`

**10000 (4 dim) einsum matmul took 1.8268659114837646**

`10000 (2 dim) permute took 0.04583311080932617`

`10000 (2 dim) einsum permute took 0.1613006591796875`

`10000 (4 dim) permute took 0.045726776123046875`

`10000 (4 dim) einsum permute took 0.16038727760314941`

`10000 (4 dim) matmul and permute took 1.3352277278900146`

`10000 (4 dim) einsum matmul and permute took 3.2331697940826416`

Results **cuda**:

`10000 (2 dim) matmul took 0.6923580169677734`

`10000 (2 dim) einsum matmul took 1.2274961471557617`

`10000 (4 dim) matmul took 19.534417152404785`

`10000 (4 dim) einsum matmul took 9.732558488845825`

`10000 (2 dim) permute took 0.24881267547607422`

`10000 (2 dim) einsum permute took 0.4596419334411621`

`10000 (4 dim) permute took 4.46686577796936`

`10000 (4 dim) einsum permute took 4.810505390167236`

`10000 (4 dim) matmul and permute took 21.444167137145996`

**10000 (4 dim) einsum matmul and permute took 20.451982736587524**