# Tutorial 6 - Attention and Transformers
This tutorial consists of three parts.
1. In [Part 1](#part1) you will try and get a better understanding for how and why softmax attention works
2. In [Part 2](#part2) you will derive and investigate the positional encoding schemes used in transformer models when encoding sequences
3. In [Part 3](#part3) you will implement the attention and position encoding components of a State-of-the-Art Object-Centric Learning architecture
4. In [Part 4](#part4) you will derive a lower-complexity form of Softmax attention introduced in the recent Performer paper

**Notes**
* Only Part 1 (Multi-headed Attention) and Part 2 (Position Encodings) are directly related the course material (i.e. examinable). They are also very important in modern DL! - see the ["10 Novel Applications using Transformers"](https://paperswithcode.com/newsletter/3#10_novel_applications_using_transformers) section 
* A list of relevant and [further reading](#refs) is provided at the end of the document. In some cases this tutorial is based on these materials
* You can navigate this notebook easily through the _table of contents_ on the left sidebar
---

In [None]:
# Setup Colab - might take 30 sec
!pip install -q tensorflow-cpu tensorflow_datasets \
     jax git+https://github.com/deepmind/dm-haiku flax optax 

<a name="part1"></a>
# Part 1: Attention 
Here we will try and understand how attention came about in NLP by viewing Attention as soft-dictionary querying. If you get stuck see [Attention References 1](#refs) on which this question is based.

We'll only need numpy here (see [this cheatsheet](https://github.com/harrygcoppock/ImperialMScAIUtils/blob/main/cheatsheets/numpy-cheatsheet.pdf) for a refresher)

In [None]:
# Imports 
import numpy as np 
from flax.nn import dot_product_attention # For testing

## Implementing a Dictionary with matrices
We are all familiar with dictionaries which store `key: value` pairs, permitting the lookup of a `value` through a `query` which must match (exactly) its corresponding `key`. We will see that attention is essentially a form of soft dictionary lookup, where our queries are vectors, and our output will be a matrix containing _all_ the value vectors in our dictionary, weighted by the degree to which our query matched the corresponding key. 


We'll start by defining a matrix of values and corresponding keys:

In [None]:
values = np.array([[1,9],
                   [2,8],
                   [4,7],
                   [8,6]])

keys   = np.array([[1,0,0],
                   [0,1,0],
                   [1,0,0],
                   [0,0,1]])

In [None]:
# Now Let's perform a query using the dot product
query = np.array([0,1,0]) # This matches the second key
query_match = np.dot(query, keys.T)
query_match

We can see that when a query matches a key perfectly, the dot product of the query vector and keys matric returns a one-hot vector. 

**Question** Think about which operation you can perform using the `query_match` vector, to retrieve the corresponding value from the `values` matrix

In [None]:
matched_value = ... 
print("Well Done" if np.array_equal(matched_value, np.array([2,8])) else "Try again")

Think about the reason this worked so nicely. In particular:
1. Why did the `query_match` take the form of a one-hot vector, allowing us to extract a value perfectly? 
2. What happens when this isn't the case? (There are two separate behaviours worth investigating)

Once you have thought about these questions, we will try to address an issue with the more general case - we'll also call the `query_match` vector `scores` from now on, as it essentially gives us a smooth similarity score between queries and keys

**Your Answers**

In [None]:
# In the more general case our query or keys may not be normalized
query = np.array([0,3,0])
scores = np.dot(query, keys.T)
scores

You can see that if we extract one (or more generally, multiple) values which match this query, they will be multiplied by weights which do not sum to 1. This is not desirable, so let's fix it (I highly recommend using [einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html) - [see tutorial](https://rockt.github.io/2018/04/30/einsum), though it's not needed here):

In [None]:
# Modify scores so that it is always normalized 
scores = scores / np.sum(scores, axis= WHICH AXIS?) 
print("Well Done" if np.array_equal(scores, np.array([0,1,0,0])) else "Try again")

It would be better still if the query match_vector wasn't completely soft, but actually disproportinately weighted up those components of the vector which gave a significant match... I wonder which function we've met before that might do this?

In [None]:
# Which function might we use to get a normlized, positive and sharply directed vector?
query = np.array([200,3,-1])
scores = np.dot(query, keys.T) 
scores = ... 
scores

In [None]:
# Great! Now let's put this into a function which takes a batch of queries, keys and values
# and returns a batch of softly-matched values
def soft_lookup(queries, keys, values):

    return None

In [None]:
# We'll test this against a built-in soft attention implementation from flax
# This requires reshaping to specify single-headed attention
# Batch size 1, 3 key:value pairs, keys have 2 floats, values have 6 floats
queries = np.random.rand(1,4,1,2) # 4 queries
keys = np.random.rand(1,3,1,2)
values = np.random.rand(1,3,1,6)

ours = soft_lookup(queries[:,:,0,:]/np.sqrt(queries.shape[-1]), keys[:,:,0,:], values[:,:,0,:])
theirs = np.squeeze(dot_product_attention(queries, keys, values, deterministic=True), axis=-2)

print("Well Done" if np.isclose(ours, theirs, atol=1e-4).all() else "Try again")

Is you look carefully, you'll notice that we pass in the query divided by the dimension of the keys/queries. This is part of scaled dot-product attention,

$Attention (\mathbf{Q}, \mathbf{K}, \mathbf{V})=\operatorname{softmax}\left(\frac{\mathbf{Q K}^{\top}}{\sqrt{n}}\right) \mathbf{V},$

and is "motivated by the concern when the input is large, the softmax function may have an extremely small gradient" which would result in slow learning (from [Lillian Weng's blog post](https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html)).

## From the Attention Operation to Self-Attention Layers

<figure>
   <img src="https://github.com/afspies/icl_dl_tut8/blob/master/figs/felix_hill_selfattention.png?raw=true" alt="Self-attentin operating on word-embeddings" width="100%">
   <figcaption align="center">Figure 1 - Self-attention in Transformers. 
Taken from <a src="https://www.youtube.com/watch?v=8zAP2qWAsKg&feature=youtu.be&ab_channel=DeepMind">Felix Hill's NLP Lecture</a> - <a src="https://storage.googleapis.com/deepmind-media/UCLxDeepMind_2020/L7%20-%20UCLxDeepMind%20DL2020.pdf">slides</a>.
</figcaption>
</figure>



As you saw in lecture, we can use this soft-attention operation on top of learned embeddings to perform *fast* sequence processing (in comparison to RNNs). Let's quickly recap the key-advantages of self-attention over RNNs:

1. Applying self-attention over a sequence is faster than using an RNN on modern hardware, why?

2. Capturing relationships between elements of a sequence which are far apart is easier with self-attention than RNNs, why? 



**Your Answers**

The key idea of self-attention is to generate a query, key and value vector for every word in the _entire_ sequence by applying the same weights (often just a fully-connected 1 layer MLP, i.e. a Linear transformation). As such we are learning a projection from the word-embedding space to some sub-space which will hopefully capture meaningful aspects of words, and relationships between them.

It's clear then why this is called _self_-attention, as the output of the layer for every word will be a new embedding consisting of the score-weighted values of all embeddings in the previous layer. As such, every word in the sequence attends to every other word (and itself!), so the sequence is "self-attending".

## Multi-Head Self-Attention



If you try to intuitively grasp what might be occuring in these sub-spaces, it may seem strange that we assume one set of transformations will be rich enough to capture all kinds of interesting relationships between embeddings (especially as there is no guarantee that our embeddings are particularly semantically rich themselves). For example, if the word embedding of "beetle" and "bee" are very similar, then so too will be their queries and keys; this may be fine if we are handling generic data, such that the downstream parts of our model will handle these well as a result. However, if we are training out model on the entonomological literature, we might need to learn transformations that are very sensitive to the small differences in the "beetle" and "bee" embeddings, whilst still handling other words well - given that our transformations are just 1-layer MLPs though, this may not occur.

<figure>
   <img src="https://github.com/afspies/icl_dl_tut8/blob/master/figs/felix_hill_multihead.png?raw=true" alt="Multi-head Attention operating on word-embeddings" width="100%">
   <figcaption align="center">Figure 2 - Multi-Head Self-Attention in Transformers. 
Modified from <a src="https://www.youtube.com/watch?v=8zAP2qWAsKg&feature=youtu.be&ab_channel=DeepMind">Felix Hill's NLP Lecture</a> - <a src="https://storage.googleapis.com/deepmind-media/UCLxDeepMind_2020/L7%20-%20UCLxDeepMind%20DL2020.pdf">slides</a>.
</figcaption>
</figure>


As such, we use "multi-headed" attention, where we learn a set of linear transformation, each applied (still in parallel) across the sequence. Then our model can learn to map to distinct sub-spaces, which ideally end up capturing different aspects of the items in our sequence.



One nuance in transformers is that the linear transformations applied by each attention head are typically chosen such that the resulting concatenation of each head's outputs will equal the size of the input embedding (so $W_q$ is rectangular with shape `[dim_embedding,  dim_embedding/num_head]`).

Beyond a few tricks such as using layer normalization and skip connections, that's pretty much all there is to transformers! Now let's have a go at implementing multi-headed attention:

In [None]:
""" For this example we will suppose that we have:
    - A sequence of three words, each with an embedding dimension of 4 floats
    - Two attention heads, whose output shape will be chosen such that the 
      output of the layer will still be [3, 4] in shape
"""
import numpy as np
# First we'll create our artifical sequence
sequence = np.random.rand(3,4) # We'll ignore batching as it just adds a dim

# And some linear transformations (without bias for simplicity)
# For every head -> a W matrix for each of K, Q, V  - Overall matrix will be 4D
head_weights = np.random.rand(...) # You need to specify the shape! 

In [None]:
# Now let's carry out the multiheaded attention operation
kqv_matrix = ... 
print("Well Done" if kqv_matrix.shape == (2,3,3,2) else "Try again")

In [None]:
# We'll break up the keys, queries and values so we can apply our function
k, q, v = np.rollaxis(kqv_matrix, axis=SPECIFY) # Be careful with the axis 

In [None]:
# Now we can use the batched function from earlier, 
# using the batch dimension to treat the multiple heads
ours = soft_lookup(q/np.sqrt(q.shape[-1]), k, v)

# Now specify a function to concenate along the head dimension
concat = lambda x: x
ours = concat(ours)

print("Well Done" if ours.shape == (3,4) else "Try again")

In [None]:
# Again, we'll compare to flax's built-in dot_product_attention
from flax.nn import dot_product_attention
transform = lambda x: np.expand_dims(np.array(x).transpose(1,0,2), 0)
theirs = concat(transform(dot_product_attention(*map(transform, (q, k, v)),
                                       axis=1, deterministic=True)[0])[0])
print("Well Done" if np.isclose(ours, theirs, atol=1e-4).all() else "Try again")

Great Stuff! You'll be seeing a **fun** use-case for these ideas at the end of Part 2, but first, let's take a step back and talk about Position Encodings.

<a name="part2"></a>
# Part 2: Position Encodings

## Properties of Position Encodings
As you may have noticed in Part 1, none of the operations underlying the `soft-lookup` function were dependant on the ordering of the data. As transformers are typically used to process sequences (or more recently images), we need another way to ensure that information about position is being utilized by the network. 

This is where position encodings come in - by adding or concatenating information about the position to every element in the model's input sequence we allow the network to leverage positional information. However, we can't use just any position encoding if we want to give the network a reasonable chance of learning such relationships.

Ideally, we'd like the following to hold for our choice of positional encoding:

1. The encodings should be determinstic and unique for every word in the sentence
2. Distance (within sequence) based differences in value should not depend on sequence length


The above two properties allow the network to reason about _absolute_ positions and _relative_ position within sequences in a consistent fashion. Finally we also want the encodings to:

3. Be well-behaved for sequences of unseen length (e.g. defined over any domain, whilst maintaining above properties)

The discussion in the remainder of this part will follow Amirhossein Kazemnejad's [blog post](https://kazemnejad.com/blog/transformer_architecture_positional_encoding/), so if you get stuck you might look there.

## Transformers - Sinusoidal Position Encoding

In the original [Transformer paper](https://arxiv.org/abs/1706.03762) Vaswani et al. introduce a specific position encoding which uses the sine and cosine functions to (practically) satisfy all of the above properties (they also tried using a learnable position embeddings and found that these gave similar performance, but were less likely to generalize to sequences of unseen length). Their scheme provides a $d$-dimensional encoding, $p^{(t)}$ for every element $t$ in the input sequence, where the elements of this encoding are given by:
$$
p^{(t)}_{i\leq d} =\left\{\begin{array}{ll}
\sin \left(\omega_{k}t\right), & \text { if } i=2 k \\
\cos \left(\omega_{k} t\right), & \text { if } i=2 k+1
\end{array}\right.
$$
with 
$$
\omega_k = \frac{1}{10000^{2k/d}}=10^{-\frac{8k}{d}}
$$

Let's try and understand why and how this encoding works - we'll begin by seeing whether it satisfies the properties which we require from a good position encoding (except determinism which is trivially satisfied).



<h3> Uniqueness and Absolute Positions </h3>

Before considering relative distance, let's check that this encoding is unique and preserves some notion of absolute position in the sequence. We'll do this by implementing the function and taking a look at the results.

Below there are two functions `position_encoding` and `get_angles` which you should complete. Remember that the `position_encoding` produces a $d$ dimensional vector for each $d$ dimensional word embedding (token) in the input sequence, so that these can be added before feeding into the Transformer's attention layer.

In [None]:
import numpy as np 

tokens = 10 # i.e Sequence Length
token_dim = 64

def get_angles(pos, i, d):
    """
        in:
            pos - token positions : shape [L, 1]
            i   - pos encoding indices : shape [1, d]
            d   - pos encoding size 
        out:
            encoding_angles - angles in radians : shape [L, d]

    """
    # Returns a vector of angles given by omega (vector) * token position
    
    encoding_angles = None

    return encoding_angles

def position_encoding(position, d):
    """
        in: 
            position - token position (integer)
            d - token dimension size (d)
        out: 
            positions encodings - array shape [L, d]

    """

    angle_rads = get_angles()
    pos_enc = None
    
    return pos_enc

p = position_encoding(tokens, token_dim)
print("Well Done" if p.shape==(tokens,token_dim) else "Try again")

Now we'll visualize these encodings and try to reason about whether, and how, they satisfy the properties we want.

In [None]:
import matplotlib.pyplot as plt
tokens = 48
token_dim = 86
p = position_encoding(tokens, token_dim) 

f, ax = plt.subplots(2, 2, figsize=(12,10))

# Plot 1 - A few components
ax[0,0].plot(np.arange(token_dim), p[[4,5,tokens-2,tokens-1], :].T)
ax[0,0].set_title("Plot 1 - Some Encoding Dimensions")
ax[0,0].legend(["Seq Posn %d"%l for l in [4,5,tokens-2,tokens-1]])
ax[0,0].set_xlabel("Encoding Dimension")
ax[0,0].set_ylabel("Encoding Magnitude")

# Plot 2 - The position encoding matrix
pcm = ax[0,1].pcolormesh(p, cmap='coolwarm')
ax[0,1].set_title("Plot 2 - Encoding Matrix")
ax[0,1].set_xlabel('Encoding Dimensions')
ax[0,1].set_ylabel('Token Position')
f.colorbar(pcm, ax=ax[0,1])

# Plot 3 - Two components and their neighbors
colors = ['magenta', 'cyan']
style = [':', "-", "--"]
for i, enc_dim in enumerate([8,token_dim-24]):
    for j, nghbor in enumerate([-4,0,4]):
        ax[1,0].plot(np.arange(tokens), (p[:, enc_dim+nghbor]-p[:, enc_dim]),
                   label=f"Dim {enc_dim}+{nghbor}", c=colors[i], linestyle=style[j])
ax[1,0].set_title("Plot 3 - Neighboring Dimension Differences")
ax[1,0].set_xlabel("Token Position")
ax[1,0].set_ylabel(r"$Neighbors - Base$ (arbitrary units)")
ax[1,0].legend()
ax[1,0].set_yticks([])

# Plot 4 - Dot Product of encodings across time
pcm2 = ax[1,1].pcolormesh(p.dot(p.T), cmap='coolwarm')
ax[1,1].set_title("Plot 4 - Dot-product Distances Across Sequence")
ax[1,1].set_xlabel('Token Position')
ax[1,1].set_ylabel('Token Position')
f.colorbar(pcm2, ax=ax[1,1])

plt.subplots_adjust(hspace=0.35)
plt.show()

Now let's try and analyze these plots. Address the following questions:
1. What does Plot 1 tell use about the correspondence between the encoding dimensions and a) Phase Differences b) Magnitude Differences
2. What do Plots 1 and 2 tell us about the overall magnitude of the encodings. Is this desirable?
3. What do Plots 1 and 2 tell us about the relevance of the first vs. last encoding dimensions?
4. What do Plots 3 and 4 tell us about the way in which relative distances are encoded?

Based on your considerations of the above, do you think that this position encoding satisfied _all_ of the criteria we laid out earlier?

**Your Answer**


<h3>Relative Distances</h3>

In order to publish our findings, and to keep you on your toes, we will now introduce some extraneous mathematical rigour. 

To do so, we'll find a linear transformation matrix $M$ which shifts a given sine-cosine pair, and is not a function of the absolute position, $t$, in the sequence. I.e. find some

$$ M = \left[\begin{array}{ll}
u_{1} & v_{1} \\
u_{2} & v_{2}
\end{array}\right]
$$

such that 

$$
M \cdot\left[\begin{array}{l}
\sin \left(\omega_{k} t\right) \\
\cos \left(\omega_{k} t\right)
\end{array}\right]=\left[\begin{array}{l}
\sin \left(\omega_{k} (t+\phi)\right) \\
\cos \left(\omega_{k} (t+\phi)\right)
\end{array}\right]
$$

_Hint_: You'll want to make use of the [trigonometric addition formula.](https://mathworld.wolfram.com/TrigonometricAdditionFormulas.html)

**Your Answer**


Explain how the existence of this M suggests that the self-attention layer should be able to learn about relative positions between tokens from encoding values alone.

**Your Answer**

---

<a name="part3"></a>
# Part 3: Attention in Practice 
Now that we've covered the basics of attention and position encodings, we'll put these to use by completing the back-bone of a recent Object-Centric Learning architecture: Slot Attention Modules. 

You can feel free to skip this part (even though it's the most fun, in my unbiased opinion!), but you can also complete the questions with minimal understanding of how these modules work.

## Slot Attention Modules
Slot Attention modules were introduced by Locatello et al. in [this paper](https://arxiv.org/abs/2006.15055), and Thomas Kipf gave an excellent [ICML Talk](https://slideslive.com/38930703/attentive-grouping-and-gnns-for-objectcentric-learning?ref=speaker-22634-latest) for those of you who wish to understand these a bit better (minutes 6:30-15:00 in particular). We'll be using these here as they are easy to understand and State-of-the-Art for Object-Centric Learning (different from segmentation as we are primarily interested in learning useful latent representations), and because I've been playing around with them recently!


<table><tr>
<td>
  <p align="center" style="padding: 10px">
    <img alt="Forwarding" src="https://github.com/afspies/icl_dl_tut8/blob/master/figs/thomas_kipf_slotattention.png?raw=true" height="220">
    <br>
    <em style="color: grey">Figure 3.a - Slots compete over the feature map <br>via iterative Weighted Softmax Attention. <br>  The feature map is a stack of encoded-pixel vectors, <br> with a position encoding added.</em>
  </p>
</td>
<td>
  <p align="center">
    <img alt="Routing" src="https://github.com/afspies/icl_dl_tut8/blob/master/figs/thomas_kipf_slotatt_autoencoder.png?raw=true" height="220">
    <br>
    <em style="color: grey">3.b - An Auto-Encoding architecture is used in conjunction with the module. <br> This allows the end-to-end training of the architecture through an (MSE) <br> reconstruction loss. Each slot is decoded seperately into RBGA images,<br> and the Alpha masks are used to weight the (sum) combination of these .</em>
  </p>
</td>
</tr></table>

<!-- <div>
    <div style="float:left">
        <figure>
            <img src="https://github.com/afspies/icl_dl_tut8/blob/master/figs/thomas_kipf_slotattention.png?raw=true" alt="Image of word embeddings with attention on top" height="250">
            <figcaption align="left">asfasd</figcaption>
        </figure>
    </div>
    <div style="float:left">
        <figure>
            <img src="https://github.com/afspies/icl_dl_tut8/blob/master/figs/thomas_kipf_slotatt_autoencoder.png?raw=true" alt="Image of word embeddings with attention on top" height="250">
            <figcaption align="right">sda</figcaption>
        </figure>
    </div>
</div> -->





There are two key components of Slot Attention modules:
1. The architecture uses "slots" to encode objects. Each slot can represent any object as they are initialized using the same learned distribution (a side-effect of this is that there is no persistent allocation of a given object to a given slot - something addressed by [AlignNet](https://www.deepmind.com/research/publications/AlignNet-Unsupervised-Entity-Alignment). A slot is just a vector which is used to generate queries over the image's features
2. These slots compete (via attention) for different (ideally contiguous) regions of the feature-space, thus hopefully learning to each capture distinct objects


In the next two sections you'll implement your own position encoding, as well as the weighted soft-max attention (very similar to single-headed attention), and then using these the model will be trained for you.


## Creating your own Position Encoding

In order to easily allocate spatial information to the features (needed as we'll be using ordering-agnostic attention), and to allow us to visualize the attention scores,  the encoder is constructed such that the encoded features have the same width/heigth as the input images.

As such, you'll need to create a position encoding with dimensions [width, height, X] where X represents the encoding dimension you choose. This encoding is then passed through a simple MLP to increase the encoding dimension to match the feature space, so that they can be added (X->dim_features). A naïve choice of position encoding would be the (x,y) coordinates of every point in the image (ideally normalized), such that X=2. This is not far from what the authors use (see figure 4).
<table><tr>
<td>
  <p align="center" style="padding: 10px">
    <img alt="Visualization of the Position Encodings used the Slot Attention paper." src="https://github.com/afspies/icl_dl_tut8/blob/master/figs/slotattn_posenc.png?raw=true" height="320">
    <br>
    <em style="color: grey">Figure 4 - Visualization of the 4D position encoding used in the Slot Attention paper.</em>
  </p>
</td>
</tr></table>

Their encoding assigns 4 floats to every x,y point - although 2 dimensions would suffice (x,y) to uniquely label each point, their position encoding makes it easy to learn a good transformation from the position encoding to the encoder's feature space. For which two reasons do you think that might be?

**Your Answer**
1. 
2. 

Try and come up with a sensible position encoding and complete the function below. Keep in mind that at the very least:
* Every coordinate point's encoding should be unique
* The result will be added to the embedded feature map (after passing through an MLP), not concatenated, so consider the magnitude of your embeddings carefully


You shouldn't spend too long on this. In fact, a simple normalized x,y encoding will still work fairly well.

In [None]:
import numpy as np
def build_grid(resolution: tuple) -> np.ndarray:
    """
        Given the resolution of a 2D grid, this function returns a 3D numpy array
        consisting of position encodings for every x,y point
        in:
            resolution  (W, H)
        out:
            position encodings - shape [W, H, dim_encoding] 
    """
    x = None
    # If you can't think of anything sensible, no encoding (np.zeros)
    # will work somewhat, and normalized (x,y) is decent
    return x

In [None]:
# Let's check the output has the right shape
encoding = build_grid((10, 8))
if encoding.ndim == 3 and encoding.shape[:2] == (10,8):
    print("Well Done!")
else:
    print("Try again")

In [None]:
import matplotlib.pyplot as plt
# Visualize your position embedding
plt.imshow(encoding[:,:, CHOOSE_DIM or SUM]) 

Great! We're half way there

## Implementing Weighted Dot-Product Attention

Slightly deviating from the familiar "soft-lookup" style single-head attention from Part 1, Locatello et al. choose to combine values with a weighted _mean_, rather than a weighted _sum_ over attention scores, helping to stabilize the iterative attention mechanism.

In particular, the attention (elementwise) is the same as what we implemented in Part 1:

$$\operatorname{attn}_{i, j}=\frac{e^{M_{i, j}}}{\sum_{l} e^{M_{i, l}}} \quad \text{where} \quad M=\frac{k(inputs) \cdot q(slots)^{T} }{\sqrt{D_{slots}}} \in \mathbb{R}^{N \times K}.$$

But, rather than taking the dot product between the scores and values, we first scale the scores by the mean:
$$
\text { updates }=W^{T} \cdot v(\text { inputs }) \in \mathbb{R}^{K \times D} \quad \text { where } \quad W_{i, j}=\frac{\operatorname{attn}_{i, j}}{\sum_{l=1}^{N} \operatorname{attn}_{l, j}}.
$$

The two tricky parts here are choosing the axes over which to apply softmax and the sum  - look at the indices in the denominators carefully.

You'll now implement this in the function below. Please make sure that you assign the softmaxxed attention scores to the `attn` variable, so that the attention masks get visualized properly during training.

In [None]:
import jax.numpy as jnp 
from jax.nn import softmax # Recommended, as they compute softmax on x-x_max  =>
#                            more stable, but need to be careful with gradients

# Please use jnp instead of np, as this will use GPU when the function is executed
# All functions you need from np are in jnp too
def weighted_dotproduct_attention(keys, slot_queries, values, get_attn=False):
    """ Applies Weighted Dot-Product attention, with shapes

        in:
            slot_queries  [batch, num_slots, hidden_size=slot_size]
            keys          [batch, encoding_size (image WxH), hidden_size=slot_size]
            values        [batch, encoding_size (image WxH), hidden_size=slot_size]
            get_attn FLAG -> whether to return attn scores as well as slot updates
        out:
            updates         [batch, num_slots, slot_size]
    """
    attn_eps = 1E-8 # Add this before calculating weighted mean
    slot_size = 32  # Remember to scale the query in the softmax
    
    attn = None    # Used to visualize which features slots are attending to
    updates = None # Used to update slots via residual MLP
     
    return (updates, attn) if get_attn else updates

In [None]:
# We'll load some data to test whether your function is working properly
!wget -q -O /content/test_attention.npz https://github.com/afspies/icl_dl_tut8/blob/master/test_data/test_weighted_dotprod_attn.npz?raw=true
q, k, v, test_out = np.load('/content/test_attention.npz')['arr_0']

In [None]:
out = weighted_dotproduct_attention(k, q, v)
print("Well Done!" if np.isclose(out, test_out, atol=1e-6).all() else "Try again")

If the test is running without errors but failing (so your shapes match), double check that you are scaling the queries by the sqrt of the slot dimension size, and that you are applying softmax over the slot-dimension of the attention scores.


## Training Our Model

Great! Now we'll train the model and take a look at the output - you don't need to do anything here except wait and keep your fingers crossed :)

In [None]:
# Download Data and get useful code
!git clone -q https://github.com/afspies/icl_dl_tut8.git
!wget -q https://storage.googleapis.com/multi-object-datasets/tetrominoes/tetrominoes_train.tfrecords
%load_ext tensorboard

In [None]:
from icl_dl_tut8.src.slotattention_training import train_model, load_data

We'll be using the Tetominoes dataset from https://github.com/deepmind/multi_object_datasets to train and test our model. Here we'll load and visualize some data:

In [None]:
# Load data
ds = load_data("./tetrominoes_train.tfrecords", batch_size=64)

In [None]:
# Look at a sample of images
import matplotlib.pyplot as plt
image_batch = next(ds)
f, ax = plt.subplots(2,2,figsize=(6,6))
count = 0
for i in range(2):
    for j in range(2):
        ax[i,j].imshow(image_batch[count])
        ax[i,j].set_xticks([])
        ax[i,j].set_yticks([])
        count += 1
plt.tight_layout()

Now we'll train the model - this will take around 15 minutes, so you might want to get a cup of coffee and some popcorn: Watching your loss curves is the true DL Researcher experience. You should start to see sensible results after about 20k steps, and good looking ones around 40-50k. 

You'll want to click the settings button (cog) in tensorboard and enable auto-reload, so you don't have to keep refreshing. 

In [None]:
!rm -rf logs # Remove old logs
%tensorboard --logdir logs
train_model(ds, weighted_dotproduct_attention, build_grid)

Note that we chose to use four slots here as all images in the dataset contain exactly three tetrominoes. By looking at the attention masks throughout training, comment on:

1. What do the different slots attend to in the image? Are all of the slots needed here?
2. Given that the position encoding provides a fairly weak contribution to the feature embedding (it was simply up-transformed and added after all), can you imagine situations in which the competition between slots is not strong enough? (you may even have observed this)


**Your Answer**

---

<a name="part4"></a>
# Part 4: Making Transformers Faster
In lecture we've seen that the attention operation grows quadratically with the sequence length. Here we'll derive a faster version, and get you to work out it's complexity - the fast attention operation, FAVOR+, was introduced in the recent [Performer](https://arxiv.org/abs/2009.14794) paper from Google.

The run-time analysis is the most important question in this part, and you do not need to answer the more mathematical questions to be able to carry it out.

This question _very closely_ follows Teddy Koker's excellent [blog post](https://teddykoker.com/2020/11/performers/), so if you get stuck that may be helpful.

## The Kernel Trick and The Squared Exponential Kernel


Recall that we can use the kernel trick to solve problems of a certain form in a transformed space without needing to perform the computaionally expensive procedure of first mpapping every datapoint to the new space; instead, we can find a kernel expressed purely in terms of dotproducts over the data, which can be used to solve the transformed version of the problem.

For this to be useful here we'll first need to express the Softmax operation in terms of a kernel (this is easy as we're taking the Softmax of queries dotted with keys).

Starting with our dot-product softmax attention over queries, $Q\in\mathbb{R}^{N\times d}$, keys, $K\in\mathbb{R}^{M\times d}$, and values, $V\in\mathbb{R}^{M\times l}$:

$$\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V,$$
where we'll ignore the $d^{\frac{-1}{2}}$ from here on out, and assume that $d^{\frac{-1}{4}}$ was absorbed into both the $Q$ and $K$. Now note that the row-wise Softmax operation here corresponds to

$$\text{Softmax}(QK^T)_i = \frac{e^{Q_iK_i^T}}{\sum_i e^{Q_i K_i^T}} = \frac{A_i}{\sum_i A_i}, $$
where we've defined 

$$ A := \exp(QK^T) \in \mathbb{R}^{N\times M}.$$ 

In order to write the entire operation in terms of matrices we'll define the row-vector of ones $L_d = [1,...,1] \in\mathbb{R}^{1\times d}$, and then note that 

$$ \begin{align} 
    A L_d &=
        \begin{bmatrix}
           \sum_j A_{0j} \\
           \vdots \\
           \sum_j A_{Nj}
        \end{bmatrix}.
\end{align}
$$
Defining $D=\text{diag}(AL_d)$ we get 


$$ \begin{align}
D^{-1} &= \begin{bmatrix}
           \frac{1}{\sum_j A_{0j}} &  \dots & 0 \\
           \vdots & \ddots & \vdots \\
           0 &\dots & \frac{1}{\sum_j A_{Nj}} \\       
        \end{bmatrix}.
\end{align}      
$$
Using $D^{-1}$ we can then write the row-wise softmax as 

$$\text{Attention}(Q,K,V)=D^{-1}AV.$$

I would suggest covincing yourself that this holds using a generic $2\times2$ $K$ and $Q$. Armed with this reformulation, we'll try to find a way of approximating $A$ which is better than $\mathcal{O}(NMd)$ (where usually we have $N=M=\text{seq_length}$. To do this, we define the Softmax kernel between two vectors
$$K_{softmax}(x_i,x_j)=\exp(x_i^T x_j).$$

To get yourself warmed up, rewrite $A_{ij}$ in terms of this Softmax kernel, by considering the row-wise decomposition of $Q = [q_0,\dots, q_N]^T$ and $K = [k_0,\dots, k_M]^T$

**Your Answer**
<font color="blue">
$$A_{ij} = \dots K_{softmax}(...)^{...} \dots$$
</font>

We will see shortly that some clever people came up with a way of approximating the gaussian (or Radial Basis Function) kernel. To leverage this, we'll need to rewrite our softmax kernel in terms of guassian kernels:

$$K_{guass}(x_i, x_j) = \exp(-\gamma \|x_i-x_j\|^2),$$
where $\gamma$ is some constant whose value you will choose.

**Your Answer**
Express the softmax kernel in terms of exponential terms and the gaussian kernel.This will take a few steps of calculation

<font color="blue">
$$K_{softmax}(x_i,x_j) = \dots K_{gauss}(x_i, x_j)^{...} \dots $$
</font>

## Approximation via Sampling
In an [acclaimed 2007 paper](https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf),  Rahimi and Recht showed that it is possible to approximate a kernel corresponding to the inner product in some $d$ dimensional space with an approximate feature map to a lower dimensional space $D << d$:

$$K(x_i, x_j)=\phi(x_i)^T\phi(x_j) \approx z(x_i)^T z(x_j), $$

In particular, they showed that the guassian kernel can be approximated with the feature mapping

$$
z_{\omega}^{gauss}(x)=\left[\begin{array}{l}
\cos \left(\omega^{\top} x\right) \\
\sin \left(\omega^{\top} x\right)
\end{array}\right],
$$

where $\omega \sim \mathcal{N}_D (0, I)$. 

Using this result, and your expression for the softmax kernel in terms of the guassian kernel above, write out which form $z^{softmax}_\omega(x_i)$ will take for our approximate softmax kernel.

**Your Answer**
<font color="blue">
$$
z_{\omega}^{softmax}(x)=...
$$</font>

In the performer paper they show that they can do better than these _Random Fourier Features_ by:

1. Using Positive Random Features, which avoid negative values in the case where kernel outputs approach 0 (i.e. small, anti-parallel vectors) - this makes training more stable:$$
z_{\omega}^{\text {positive }}(x)=\exp \left(-\frac{\|x\|^{2}}{2}\right)\left[\exp \left(\omega^{\top} x\right)\right]
$$

2. Ensuring that random samples of $\omega$ are orthogonal in the $D$-dimensional space, by using the [Gram-Schmidt Orthonormalization procedure](https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process) (choose some vector, normalize it to get your first dimension, then choose another vector, subtract the components in the first dimension and normalize, and so on).

As the Positive Random Features outperform Random Fourier Features, and are easier to implement, we'll use those here (you don't need to worry about orthogonlizing the features, though this does improve the approximation considerably)

In [None]:
import numpy as np
# For this question we'll assume that Q and K have the same shape

def z_positive(x, omega):
    """
        Given a matrix X (your Q OR K) and sampled freq, returns
        positive random features of matching dimensions. 
        in:
            x -     shape [M, d]
            omega - shape [rand_dim, d]
        out:
            z -     shape [rand_dim, d]
    """
    z_pos = None

    return z_pos

This can be improved even further by forcing the features to be orthogonal


In [None]:
# Implement attention using this sampling procedure (NOT BATCHED)
def approx_attention(q, k, v, random_dim):
    M, d = q.shape
    rescale = d ** -0.25              # to normalize before multiplication
    
    # Generate Random Features - i.i.d. gaussian features
    omega = None

    q_prime = z_positive(q * rescale, omega) # apply feature map z to Q
    k_prime = z_positive(k * rescale, omega) # apply feature map z to K

    # Perform Attention Operation using these, as described at start of Part
    A = None
    D_inv = None
    attn_scores = None

    return attn_scores


We'll now use the code provided below to see how the error between our approximate attention varies from the determinstic one equivalent to that which you implemented in Part 1.

In [None]:
# Test this against your implementation from flax again and see how error
# varies with the number of Random Features

import matplotlib.pyplot as plt
from flax.nn import dot_product_attention

num_iterations = 25 # multiple runs of each to get some statistics

# Assume seq length 100, token size 32
q, k, v = list(np.random.rand(3,1,500,1,64))

# Again, we'll compare to flax's built-in dot_product_attention
theirs = np.squeeze(dot_product_attention(q, k, v, deterministic=True))

test_range = np.arange(5,150,5)
diffs_mean = np.zeros(len(test_range))
diffs_std = np.zeros(len(test_range))
for i, random_features in enumerate(test_range):
    diff = []
    for iteration in range(num_iterations):
        ours = approx_attention(*map(np.squeeze, (q, k, v)), random_features)
        diff.append(np.mean((ours - theirs) **2))

    diffs_mean[i] = np.mean(diff)
    diffs_std[i] = np.std(diff)
    
plt.plot(test_range, diffs_mean)
plt.fill_between(test_range, diffs_mean-diffs_std, diffs_mean+diffs_std, alpha=0.5)
plt.xlabel("Number of Random Features")
plt.ylabel("MSE")

You should observe a notable improvement when increasing the number of random features for the first 50 dimensions. Beyond that, we get significantly dimishing turns without orthonormalization. 

## Performer Complexity analysis
Given the form of the softmax attention with sampling derived above (FAVOR+), state the complexity of the operation.

Assume that the number of rows in $K$, $Q$ and $V$ all equal the sequence length $L$. Denote the dimension of the positive random features as $m$, and the token dimension as $d$.

Recall that vanilla Dot-Product softmax attention has a complexity of $$\mathcal{O}(L^2d)$$

**Your Answer**
<font color='blue'>
$$\mathcal{O}(...)$$
</font>

If you want to double check your complexity, you might try varying the input sequence length (For fixed token dimension, and sampling dimension) and seeing how the execution time varies as a function of $L$. Better still, compare this to your original implementation from Part 1. 

In [None]:
# Some pseudo-oey skeleton code 
import time

random_dim_size = 100
test_iterations = 100 # Multiple function calls to get decent statistics

attn_funcs = {'slow': ...,
              'fast': approx_attention}
timings = {'slow': [],
           'fast': []}

seq_lengths = arange(...)

for func_name, func in attn_funcs.items():

    for seq_length in seq_lengths:
        q, k, v = random... # initialise randomly, with correct sequence length
        inputs = (q,k,v) if func_name == 'slow' else (q,k,v,random_dim_size)

        start_time = time... # get current time
        
        for iteration in range(test_iterations): 
            _ = func(*inputs)
        
        time_taken = (time... - start_time) / test_iterations

        timings[func_name].append(time_taken)
        

In [None]:
import matplotlib.pyplot as plt

for func_name, timings in timings.items():
    plt.plot(seq_lengths, timings, label=func_name)
plt.legend()

 <a name="refs"></a>
# Further Reading


## Attention
1. Matt Kelcey's talk: [The Map Interpretation of Attention](https://www.youtube.com/watch?v=7wMQgveLiQ4&t=996s)
2. Jay Alammar's blog post: [Visualizing A Neural Machine Translation Model](https://jalammar.github.io/visualizing-neural-machine-translation-mechanics-of-seq2seq-models-with-attention/)
2. Lilian Weng's blog post: [Attention? Attention!](https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html)
3. Vincent Warmerdam's videos: 1 and 2 of the playlist, [Attention is all you Need](https://www.youtube.com/playlist?list=PLvaXXemMsV5CL2DNYTGtdmwyIr0PIz-S7)


## Slot-Attention 
* Paper by Locatello et al.: [Object-Centric Learning with Slot Attention](https://arxiv.org/abs/2006.15055)
* Thomas Kipf's ICML Talk: [Attentive Grouping and GNNs for Object-Centric Learning](https://slideslive.com/38930703/attentive-grouping-and-gnns-for-objectcentric-learning?ref=speaker-22634-latest ))
* Yannic Kilcher's Video on the paper: [Object-Centric Learning with Slot Attention](https://www.youtube.com/watch?v=DYBmD88vpiA)

## Transformers
1. The original paper by Vaswani et al. : [Attention is All you Need](https://arxiv.org/abs/1706.03762)
2. Jay Alammar's blog post: [The Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)
3. Alexander Rush's blog post: [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention)
4. Amirhossein Kazemnejad's blog post: [Transformer Architecture: The Positional Encoding](https://kazemnejad.com/blog/transformer_architecture_positional_encoding/)
5. Vincent Warmerdam's videos: 3 and 4 of the playlist, [Attention is all you Need](https://www.youtube.com/playlist?list=PLvaXXemMsV5CL2DNYTGtdmwyIr0PIz-S7)



## Performers
1. The Google Performer: [Paper](https://arxiv.org/abs/2009.14794) and [Blog Post](https://ai.googleblog.com/2020/10/rethinking-attention-with-performers.html)
2. Teddy Koker's blog post: [Performers: The Kernel Trick, Random Fourier Features, and Attention](https://teddykoker.com/2020/11/performers/)
3. Yannic Kilcher's Video on the paper: [Rethinking Attention with Performers (Paper Explained)](https://www.youtube.com/watch?v=xJrKIPwVwGM&feature=youtu.be)