Transformers are a breakthrough in AI, especially in natural language processing (NLP). Renowned for their performance and scalability, they are vital in applications like language translation and conversational AI. This article explores their structure, comparisons with other neural networks, and their pros and cons.
Table of contents
- What is a transformer model?
- Transformers vs. CNNs and RNNs
- How transformer models work
- Examples of transformer models
- Advantages of transformers
- Disadvantages of transformers
What is a transformer model?
A transformer is a type of deep learning model that is widely used in NLP. Due to its task performance and scalability, it is the core of models like the GPT series (made by OpenAI), Claude (made by Anthropic), and Gemini (made by Google) and is extensively used throughout the industry.
Deep learning models consist of three main components: model architecture, training data, and training methods. Within this framework, a transformer represents one kind of model architecture. It defines the structure of the neural networks and their interactions. The key innovation that sets transformers apart from other machine learning (ML) models is the use of “attention.”
Attention is a mechanism in transformers that enables them to process inputs efficiently and maintain information over long sequences (e.g., an entire essay).
Here’s an example to illustrate. “The cat sat on the bank by the river. It then moved to the branch of the nearby tree.” You can recognize that “bank” here is not the bank at which you deposit money. You’d probably use the context clue of “river” to figure that out. Attention works similarly; it uses the other words to define what each word means. What does “it” refer to in the example? The model would look at the words “moved” and “tree” as clues to realize the answer is “cat.”
The important unanswered question is how the model knows which words to look at. We’ll get to that a bit later. But now that we’ve defined the transformer model, let’s explain further why it’s used so heavily.
Transformers vs. CNNs and RNNs
Recurrent neural networks (RNNs) and convolutional neural networks (CNNs) are two other common deep learning models. While RNNs and CNNs have their benefits, transformers are more widely used because they handle long inputs much better.
Transformers vs. RNNs
RNNs are sequential models. An apt analogy is a human reading a book. As they read, word by word, their memory and understanding of the book evolve. For astute readers, they might even predict what will happen next based on what came before. An RNN functions in the same way. It reads word by word, updates its memory (called a hidden state), and can then make a prediction (e.g., the next word in the sentence or the sentiment of some text). The downside is that the hidden state can’t hold very much information. If you fed a whole book into an RNN, it would not remember many details about the intro chapters because there’s only so much space in its hidden state. Later chapters, by virtue of being added into the hidden state more recently, get precedence.
Transformers don’t suffer the same memory problem. They compare every word with every other word in the input (as part of the attention mechanism) so they don’t need to use a hidden state or “remember” what happened earlier. Using the same book analogy, a transformer is like a human reading the next word in a book and then looking at every prior word in the book to understand the new word properly. If the first sentence of a book contained the phrase “He was born in France,” and the last sentence of a book contained the phrase “his native language,” the transformer would be able to deduce his native language is French. An RNN may not be able to do that, since the hidden state is not guaranteed to keep that information. Additionally, an RNN needs to read each word one at a time and then update its hidden state. A transformer can apply its attention in parallel.
Transformers vs. CNNs
CNNs use the surrounding context of each item in a sequence to assign meaning. For a word on a page, CNNs would look at the words immediately surrounding it to figure out the meaning of the word. It would not be able to connect the last and first page of a book. CNNs are predominantly used with images because pixels often relate to their neighbors much more than words do. That said, CNNs can be used for NLP as well.
Transformers differ from CNNs in that they look at more than just the immediate neighbors of an item. They use an attention mechanism to compare each word with every other word in the input, providing a broader and more comprehensive understanding of the context.
How do transformer models work?
Transformers have layers of attention blocks, feedforward neural networks (FNNs), and embeddings. The model takes in a text-based input and returns output text. To do this, it follows these steps:
1 Tokenization: Turns the text into tokens (similar to breaking down a sentence into individual words).
2 Embedding: Converts the tokens into vectors, incorporating positional embeddings so the model understands the token’s location in the input.
3 Attention mechanism: Processes the tokens using self-attention (for input tokens) or cross-attention (between input tokens and generated tokens). This mechanism allows the model to weigh the importance of different tokens when generating output.
4 FNNs: Passes the result through an FNN, which allows the model to capture complex patterns by introducing nonlinearity.
5 Repetition: Steps 3–4 are repeated multiple times through several layers to refine the output.
6 Output distribution: Produces a probability distribution over all possible tokens.
7 Token selection: Chooses the token with the highest probability.
This process makes up one forward pass through the transformer model. The model does this repeatedly until it has completed its output text. Within each pass, the embedding process can be performed in parallel, as can the attention mechanism and the feedforward stage. Essentially, the transformer doesn’t need to do each token one at a time. It can run attention across all tokens at the same time.
We can now turn to the question from earlier: How does the model know which tokens to attend to? The answer is simply by looking at lots of training data. At first, the model will attend to the wrong tokens and so will generate the wrong outputs. Using the correct output that comes with the training data, the attention mechanism can be modified to output the correct answer next time. Over billions (and even trillions) of examples, the attention mechanism can pick the proper tokens almost all the time.
Examples of transformer models
Transformers are everywhere. Although first designed for translation, transformers have scaled well into almost all language, vision, and even audio tasks.
Large language models
The transformer architecture powers almost all large language models (LLMs): GPT, Claude, Gemini, Llama, and many smaller open-source models. LLMs can handle various text (and, increasingly, image and audio) tasks, such as question-answering, classification, and free-form generation.
This is achieved by training the transformer model on billions of text examples (usually scraped from the internet). Then, companies fine-tune the model on classification examples to teach the model how to perform classification correctly. In short, the model learns a broad knowledge base and is then “taught” skills via fine-tuning.
Vision transformers
Vision transformers are standard transformers adapted to work on images. The main difference is that the tokenization process has to work with images instead of text. Once the input is turned into tokens, the normal transformer computation occurs, and finally, the output tokens are used to classify the image (e.g., an image of a cat). Vision transformers are often merged with text LLMs to form multimodal LLMs. These multimodal models can take in an image and reason over it, such as accepting a user interface sketch and getting back the code needed to create it.
CNNs are also popular for image tasks, but transformers allow the model to use all the pixels in the image instead of just nearby pixels. As an example, if an image contained a stop sign on the far left side and a car on the far right side, the model could determine that the car needs to stop. A CNN may not be able to connect those two data points because they are far from each other in the image.
Audio transformers
Audio transformers, like vision transformers, are standard transformers with a unique tokenization scheme tailored for audio data. These models can process both text and raw audio as input, outputting either text or audio. An example of this is Whisper, a speech-to-text model that converts raw audio into a transcript. It accomplishes this by segmenting the audio into chunks, transforming these chunks into spectrograms, and encoding the spectrograms into embeddings. These embeddings are then processed by the transformer, which generates the final transcript tokens.
Beyond speech-to-text applications, audio transformers have various other use cases, including music generation, automatic captioning, and voice conversion. Additionally, companies are integrating audio transformers with LLMs to enable voice-based interactions, allowing users to ask questions and receive responses through voice commands.
Advantages of transformer models
Transformers have become ubiquitous in the field of machine learning due to their scalability and exceptional performance across a wide array of tasks. Their success is attributed to several key factors:
Long context
The attention mechanism can compare all tokens in the input sequence with each other. So, information throughout the entire input will be remembered and used to generate the output. In contrast, RNNs forget older information, and CNNs can only use information that is close to each token. This is why you can upload hundreds of pages to an LLM chatbot, ask it a question about any of the pages, and get an accurate response. The lack of long context in RNNs and CNNs is the biggest reason why transformers beat them in tasks.
Parallelizability
The attention mechanism in transformers can be executed in parallel across all tokens in the input sequence. This contrasts with RNNs, which process tokens sequentially. As a result, transformers can be trained and deployed more quickly, providing faster responses to users. This parallel processing capability significantly enhances the efficiency of transformers compared to RNNs.
Scalability
Researchers have continually upped the size of transformers and the amount of data used to train them. They have not yet seen a limit to how much transformers can learn. The larger the transformer model, the more complex and nuanced is the text it can understand and generate (GPT-3 has 175 billion parameters while GPT-4 has more than 1 trillion). Remarkably, scaling up transformer models, such as creating a 10-billion-parameter model compared to a 1-billion-parameter model, does not require significantly more time. This scalability makes transformers powerful tools for various advanced applications.
Disadvantages of transformer models
The downside of transformer models is that they require a lot of computational resources. The attention mechanism is quadratic: every token in the input is compared to every other token. Two tokens would have 4 comparisons, three tokens would have 9, four tokens would have 16, and so on—essentially, the computational cost is the square of the token count. This quadratic cost has a few implications:
Specialized hardware
LLMs can’t easily be run on an average computer. Due to their size, they often require dozens of gigabytes of RAM to load the model parameters. Also, traditional CPUs are not optimized for parallel computation; a GPU is required instead. An LLM running on a CPU could take minutes to generate a single token. Unfortunately, GPUs are not exactly the cheapest or most accessible hardware.
Limited input length
Transformers have a limited amount of text they can process (known as their context length). GPT-3 originally could only process 2,048 tokens. Advancements in attention implementations have yielded models with context lengths of up to 1 million tokens. Even so, substantial research is needed to find each extra token of context length. In contrast, RNNs do not have a maximum context length. Their accuracy greatly drops as the input goes up, but you could feed a 2-million-token-long input into one right now.
Energy cost
The data centers powering transformer computation require energy to run them and water to cool them. By one estimate, GPT-3 required 1,300 megawatt-hours of electricity to train: the equivalent of powering 130 homes in the US for a whole year. As models get bigger, the amount of energy needed increases. By 2027, the AI industry may require as much electricity every year as the Netherlands. Significant efforts are being made to reduce the energy transformers need, but this problem has not yet been solved.