Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

I find all the writeups of self-attention very confusing. They seem to inevitably explain small bits in enormous detail and then leave all kinds of giant areas unexplained, usually are replete with jargon that they don't bother to clarify, etc. I feel like most of these are written by people who are confused themselves, trying to clarify the parts they don't understand as if this is what everybody else must be confused about to, so they inevitably gloss over everything they already know.

Fundamentally, I have never understood why a conventional neural network cannot learn self-attention. If layers are fully connected then they definitely have the ability to create a weighting of the inputs accommodating learning of relative spatial features in the data. In fact that's almost a definition of what a neural network is. If relative positional learning is important, that could be added the same way it is in Transformers without the explicit self-attention layer. So what is self-attention really doing beyond this? Why do we need a Query-Key-Value construct here? I am sure I am missing something very basic and fundamental here.



If by "conventional neural network" you mean a stack off fully connected layers, then yes, in theory one of those could learn a similar mechanism because of the universal approximation theorem. However, training one might be intractable.

It's good to ignore self-attention for a moment and take a look at a convolutional network (a CNN). Why is a CNN more effective than just stacks of fully connected layers? Well, instead of just throwing data at the network and telling it to figure out what to do with it, we've built in some prior knowledge into the network. We tell it "you know, a cup is going to be a cup even if it is 10 pixels up or 10 pixels down; even if it is in the upper right of the image or the lower left." We also tell it, "you know, the pixels near a given pixel are going to be pretty correlated with that pixel, much more so than pixels far away." Convolutions help us express that kind of knowledge in the form of a neural network.

Self-attention plays a similar role. We are imbuing our network with an architecture that is aware of the data it is about to receive. We tell it "hey, elements in this sequence have a relation with one another, and that relative location might be important for the answer". Similar to convolutions, we also tell it that the location of various tokens in a sequence is going to vary: there shouldn't be much difference between "Today the dog went to the park" and "The dog went to the park today." Like convolutions, self-attention builds in certain assumptions we have about the data we are going to train the network on.

So yes, you are right that fully-connected layers can emulate similar behavior, but training them to do that isn't easy. With self-attention, we've started with more prior knowledge about the problem at hand, so it is easier to solve.


Great answer. Imbuing a deep learning model with well thought out inductive biases is one of the strongest ways of guiding your model to interpret the data the way you want it to. Otherwise it’s kind of shooting in the dark and hoping to get lucky.

I can’t stand it when people lazily personify ML models, but it’s akin to giving someone with no experience some wood and then pointing to a shed and saying “make one of those from this”. Instead you’d expect them to be much more successful if you also give them a saw, a drill, some screws etc.


Good explanation. Which is why the success of transformers, LLMs etc. is still not the final word in Rich Sutton's "The Bitter Lesson" -- no learning method is free of inductive biases.


Inductive biases can work even if they're wrong, because they allow for simple and quick action, simpler reasoning. They don't need to be correct for that to pay off, they just need a positive expected value.


Are there actually theorems from which you take those explanations or are these just plausibly sounding hypothetical explanations?


You can verify the reduction of the problem space. Think of it this way, if data has some property, for example, it's mirrored on an axis. If X is a data point, then so is -X.

Well, a model that is aware of this symmetry only has half as much data to look at and one less thing to learn.

But that's only half of it. Truth is, assuming symmetries works pretty well even if the assumption is wrong. Why? Generalization. A model with less data will generalize more (better is perhaps debatable, but it will definitely generalize more)

This is the basic idea behind "geometric deep learning". There's loads of papers, but here's a presentation.

https://www.youtube.com/watch?v=w6Pw4MOzMuo


Great answer!


The attention mechanism started as a simple trick to not use recurrent neural networks.

Read the intro in the original paper "Attention is all you need" (https://arxiv.org/abs/1706.03762)

This video explains the drawbacks to RNNs and how transformers solve that: https://youtu.be/S27pHKBEp30?t=394

Andrej Karpathy explains attention here: https://youtu.be/kCc8FmEb1nY?t=3719

He explains how attention is seen as a communication network: https://youtu.be/kCc8FmEb1nY?t=4298


> Read the intro in the original paper "Attention is all you need"

I wouldn't call this the original "attention" paper. Definitely not the first paper to use the phrase. If you want clear proof of this, let's read the paper

> Attention mechanisms have become an integral part of compelling sequence modeling and transduction models in various tasks, allowing modeling of dependencies without regard to their distance in the input or output sequences.

I do think a lot of people's lack of understanding of attention is because they are so focused on DP(S)A that they miss a lot of the broader picture. And math. Not enough people dig into the math.


I think your misunderstanding is that fully connected layers can operate in the same way that attention does — they can’t. A fully connected layer operates on one dimension at a time. Typically language models have two (plus a batch dimension). One dimension is your token/word dimension. The second is the hidden dimension.

The hidden dimension is constructed inside the model to create space where it can embed each token into a vector and then enrich that space with contextual information derived from the sequence. In order for that to occur, the model must have a means of transferring information from along the token dimension.

One way to accomplish this is to use a 2d convolution; however, the scope of a convolution is limited to the size of its kernel. A fully connected layer is the same as a 2d convolution with a kernel size of 1. So you can see that no information from neighboring tokens can be applied to the hidden space.

The standard self-attention equation has a global scope from the full matrix multiplication of the input tensor with its transpose. Each element of the resulting matrix demonstrates some interaction with every other token in the sequence. Next a softmax operation is applied, which acts as a gating or relevance function. Finally, this is multiplied back to the original input to build that information into the hidden dimension for each token.

There have been attempts to do similar operations using fully connected layers. Look at the architecture of SGUs (spatial getting units). In some applications, they have good performance, but because fully connected layers operate on each dimension independently and serially, they are not equivalent to attention.

Last, my best recommendation for anybody trying to understand attention is to stop reading articles and instead spend your time looking at the math. It’s usually much less confusing than any of the dozens of explanations floating around the web, including the one I just gave. The math is not too complicated, especially once you know the reasons for why we need to use it.


There's a bit more to it, but you can partly view (self-)attention as a trick that makes optimization easier, that is, it improves gradient flow similar to skip connections to make learning easier. That was more obvious to see when we used to use attention with RNNs, where attention can be viewed as being equivalent to "dynamic skip connections"

While fully connected layers can in theory learn anything, that's a very hard optimization problem in terms of gradient flow. Attention adds inductive biases (prior/domain knowledge) about what you want the network to learn, which makes the optimization of that specific aspect easier for the optimization algorithm.

In general, you can view almost anything in ML/DL as improving either optimization or generalization, and while it's a spectrum, attention falls more into the optimization category.


> Fundamentally, I have never understood why a conventional neural network cannot learn self-attention.

Here's how I think about it: Yes, you can learn everything with an MLP, since universal approximation and so on, but it is not efficient. And (modern) ML is all about scale.

Here's one way to see the difference: Self-attention takes as input N channels of dimension D. It maps to N keys and Queries with a DxD matrix, time = O(ND^2). Next it computes all pair-wise Key-Query inner products: time = O(N^2D). Finally softmax takes time = O(N^2) and computing "probability times values" takes O(DN^2).

All in all self-attention takes O(DN^2+ND^2) time to map ND values to ND values. How long would an MLP take to do the same thing? O((ND)^2).

So, in a typical case of D~1000 and N~1000 we save a factor 1000. For the price of 1 MLP layer, you could afford 1000 self-attention layers. It's a pretty major difference.


I think there's some room for the second factor too: generalization. If you have 2 models for the same phenomenon, and they're equally good, the model with less parameters is the better one. Not because it calculates faster, but because it will actually be more accurate.


A conventional neural network, i.e. one using a stack of dense layers, can't unroll across a sequence in the way the transformer does. So while it could compute the relative importance and interaction of the features it sees it wouldn't be able to compute that across arbitrary length sequences without a mechanism for the sequence elements to interact, which is what self attention provides.


Practical attention implementations don't work over arbitrary length sequences. The universal approximation theorem holds IMO. Information will mix as you go through fully connected MLP layers. Attention is apparently a prior structure that is needed to really reduce training costs.


> I feel like most of these are written by people who are confused themselves

I teach the course at my uni and I'm highly confident this is true, even in the research community. Part of this is that people are hyper concentrated on dot-product attention (<softmax(<q,k>),v>) (DPA). There is a lot more forms of attention than this. It does help to go back to early attention mechanisms like those discussed by Bahdanau (Bengio's student) and Graves (Deep Mind). When you look at these you'll find a clearer definition: a learned weighing function (Bahdanau specifies as a probability), conditioned on some input applied to a learned embedding conditioned on another input. If the two inputs are the same then it's self attention, otherwise cross. You'll see some people refer to the learned weighting as a score (not to be confused with Fisher Score -- gradient of likelihood -- used for diffusion training). Understanding this you'll see that the definition is broader that DPA but also what makes attention powerful. But lots of people don't catch this because they don't have the history of RNNs. DPA has become the de facto choice because the dot product between the two embeddings creates a more powerful score function without introducing lots of parameters (there are people exploring more complex structures).

> I have never understood why a conventional neural network cannot learn self-attention.

I'll even ask another question, why can't densely connected networks (linear) learn self-attention? Both convolutions and dense layers are universal approximators, right? But we often see convolutions as the preferred choice over dense layers (note: they are equivalent when 1D, kernel size is 1, no padding, stride of 1). Well the power is how information is encoded and connected. CNNs in essence capture a form of positional encoding as they have a structured order. Transformers need a bit more help (note that there's also relative positional bias (they also require augmentations)) and that information helps them create very powerful graphs. But one big advantage to DPA is the multiple heads, which can take different "views." Importantly, transformers scale extremely well. Essentially attention creates a more complex connection between the information. That can make them harder to train, but also allow for more efficient encoding of information.

I could talk a lot more but I'll stop here (though a very incomplete explanation). There are two resources that I really like to hand to my students:

Lilian Weng's blog (everything she does is great) has a good coverage of many different attention mechanisms and discusses the RNN history https://lilianweng.github.io/posts/2018-06-24-attention/

This medium blog. It isn't as in depth but it is good and you end up with a model that students can actually usefully train. It'll also hopefully answer some of your questions https://medium.com/pytorch/training-compact-transformers-fro...

Bonus: Softmax tempering can help you understand why we scale the score https://aclanthology.org/2021.mtsummit-research.10/




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: