Building Embedding Models for Large-Scale Real-World Applications
Transcript
Sahil Dua: Let's start with a simple scenario of show me cute dogs. You go on any search engine, and you write, show me cute dogs. It's very likely that you will get a very nice photo like this. What happens under the hood? How is the search engine able to take that simple query, look for images in the billions, trillions of images that are available online? How is it able to find this one or similar photos from all that? Usually, there is an embedding model that is doing this work behind the hood. Today, we'll dig deep into what is an embedding model, what does it do, how does it work, where is it used. Some practical tips on how do we put these models in production, what are the challenges that we face at large scale, and we'll also look at how we can mitigate those issues and use these models reliably.
I'm co-leading the team at Google that's building the Gemini embedding models, as well as the infrastructure. Recently, I had the pleasure to work on the Gemini Embedding paper. I'm really proud of this team, because together, we have built the best embedding model that's available on all the known benchmarks. Before Google, I was working at booking.com. I was building machine learning infrastructure. This actually was the topic of my talk at QCon 2018. Besides that, I wrote a book called, "The Kubernetes Workshop."
Outline
Let's look at the topics. Let's look at what we are going to cover today. We'll start with embedding models. What are they? What is their importance? What are the use cases? We'll look at the architecture. How are these models formed? How are they able to generate these embeddings? Next, we'll look at the training techniques. How are these models trained? Then we'll see, once you have trained these larger size models, how are you going to distill them into smaller models that can actually be used in production? Next, we will see how we can evaluate these models. It might be non-trivial. Then we will look at, once you have these models, once you are happy with the quality, how do you put these models in production? How do you make sure that they are running reliably without any issues? Then in the end, we'll summarize with the key takeaways that you can take home or to the office and start working on your applications immediately.
Embedding Models and Their Applications
Let's start with embedding models. What are they, and where are they used? Embedding model is basically a model that takes any kind of input. It could be a string input. It could be an image. It will generate a digital fingerprint of that. That's what we call a vector or an embedding. It's a list of numbers that uniquely represent the meaning of a given input. For example, we have, show me cute dogs. It will have an embedding.
Similarly, any other input, like an actual picture of dogs, will also have an embedding. The key idea for embedding models is that the embeddings of similar inputs are going to be closer to each other in the embedding space. Usually, we use cosine similarity to find the similarity or closeness between any two given vectors or given embeddings. Now, on the other side, it will also make sure that embeddings of different inputs are going to be far apart from each other. For example, if you have a query called, show me cute cats, but there is an image of cute dogs, it will generate its embeddings and will make sure that these embeddings are going to be far apart from each other using the same cosine similarity or similar similarity measure.
Let's look at some of the common applications. The most fundamental application is retrieving best matching documents, passages, images, or videos, whatever the use case is. Just like we saw in the example in the beginning, you would write, show me cute dogs. It's able to look through billions of web pages or images and find just the right one that matches your given query. Embedding models are usually the ones that are doing this retrieval task. Most of the search engines, it doesn't need to be like a full-scale search engine. It could be something like searching on Facebook, for example. All of these are powered by embedding models, which are able to sift through a huge amount of data and find just the right information for your query.
The second use case that's very common is generating personalized recommendations. We are able to capture the user preferences in these embedding models and generate the outputs that are very specific to what the user wants. For example, if you have a shopping website and a user buys an iPhone, now if the user comes, the user is more likely to buy an accessory which is related to the iPhone. Using these embedding models, we are able to capture the past behavior, the history, and predict the right products that are relevant. Similarly, for example, Snapchat. I recently read the blog that they are using exactly these embedding models to power the search for what stories to show. For a given user, what is the most relevant story that we should show?
The next and one of the most popular use cases these days is RAG applications. RAG stands for Retrieval-Augmented Generation. As the name suggests, we are augmenting the generation of the large language models using retrieval. What happens in a RAG application is that you use a large language model to generate the responses. Before you do so, you find the most relevant pieces of information that are useful for the model to give the output. These retrieved passages or documents are now added to the model's context. This guides the whole generation process so that the model is able to generate more factually correct and accurate results, and it helps reduce the hallucinations. Last but not the least, this is a use case that's more behind the scenes for training the large language models.
Usually, we have huge amounts of data that's used to train. What we can do is we can generate embeddings for all of the data points and find the near duplicates based on their similarities using cosine similarity. Then we are able to remove the redundant data. This helps to improve the quality as well as the efficiency of the large language model training. These are the common use cases. Not an exhaustive list, but some of the main use cases that I've seen in recent times.
Architecture of Embedding Models
Now that we know what embedding models are at a high level, and we know their importance, their applications, let's look at the architecture. This is the architecture of an embedding model. Let's look at each of the components one by one. We don't need to look at everything together. The first component is a tokenizer. It takes an input, string, breaks it down into multiple parts, and each of these parts are called tokens. Then it replaces these tokens with its corresponding token IDs. The input is a string, and the output is a list of token IDs. Next, we have an embedding projection. Now that we have input tokens, we have broken the string into multiple tokens, we are going to replace these tokens with its corresponding embedding or a vector. Because the model doesn't know what these token IDs are supposed to mean, it only knows what they represent. We have this embedding projection. It's a huge vocabulary table. Whatever number of tokens you have, you will have a corresponding representation for that, and you will replace it so that the output of the embedding projection is going to be a list of token embeddings.
The next component is actually the heart of most of the models that we're using these days, the transformer. What transformer does is it takes these token-level embeddings, which have no context of what's around those tokens, it will output a very enriched representation of token embeddings. What this does is, for each token, it will look at the surrounding tokens, and add that information and enrich the embeddings so that now the output is token-level activations, you can also consider those to be embeddings. At this point, it contains the context of the whole sequence, not just the one token.
Next, we have a pooler. Pooler's job is very simple, take these token-level embeddings and generate a single embedding. There are a lot of different techniques that we can use. For example, mean pooling, where we take the average of all of these token-level embeddings and generate a single average embedding. This is the most commonly used method. There are a few other methods. For example, we can take only the first token embedding and remove all the other tokens and consider that to be the representation of the entire sequence.
Similarly, we can also do it so that we take the last token to be the representation of the entire sequence. Most commonly, we just use mean pooling, because it allows us to take the information from all of the tokens and combine into a single embedding. Now, we already have gone from input, string, to the embedding. There is another optional component, which is called output projection layer. This is a linear layer that takes the pooled embedding and generates another embedding, which is of a different size. A lot of times, you want your embedding model to generate the embedding of a very fixed dimension. You can control that dimension using this component. There is one more technique where, if you don't want to fix the output embedding, you can co-train multiple embedding sizes.
For example, in this case, what we do is we take a d dimension, whatever that number is, we can co-train the smaller embeddings along with that. Like d by 2, d by 4, up to d by 16. This allows us to co-train these multiple embeddings so that, at the production time, we can decide which embedding size to use. Research shows that these smaller embeddings can be almost as good as the larger size embeddings. That's like we're getting smaller embeddings almost for free and with high quality.
Now, putting it all together, what was the input? A string. In the end, we get an embedding. The same logic applies to other modalities. For example, if you have an image, instead of going through the text tokenizer, in this case, what we'll do is we'll have a vision encoder. A vision encoder is just a special type of model that will take an image, break it apart into multiple patches. You can think of patches as the tokens. For text, we are breaking it down into tokens. For images, we break it down into patches. Again, it will replace those patches with its corresponding vector. The same stuff will happen that will be passed to the transformer, pooler, and the projection layer. The same thing happens with the video. Most commonly, the video is represented as a list of frames, list of images. The same thing will happen here that it will replace each of the images with its corresponding patch embeddings and then create a final embedding that captures all of the information that's in a single video.
Now, we're going to simplify a bit. We're not going to look at each of these components separately. For the rest of the slides, we're going to look at this whole box as an embedding model, embedding model that takes an input and generates a final embedding. Usually, we have two sides of inputs. One is a query, and the other one is a document. What we do is we create two embeddings, one for the query and one for the document. We want to make sure that if the query and documents are similar, their embeddings should be closer to each other in the vector space.
Training Techniques
Now we know what the embedding models are created out of, let's look at how we can train them. The most common technique that we use is called contrastive learning. As I said earlier, we want to make sure that any two inputs that are similar, their embeddings are closer. Any two inputs that are not similar, their embeddings are far apart. This is what usually training data looks like. We have pairs of query and documents, where each example has a given query and its corresponding relevant good document. What we do is we want to make sure that for any given query and document pair, for example, the query1 and document1, we want to make sure that the embeddings of these are closer. The similarity score is higher. To challenge the model more, we also want to consider all the other documents and treat those as negatives. In short, we take the query1. We want to make sure that similarity with the document1 is high.
At the same time, we want to make sure that its similarity with all the other documents in the batch is minimized. This is captured very well by this loss called in-batch cross entropy loss. It's very simple. This is a simplified representation where we want to maximize the similarities between positives and minimize the similarities between these in-batch negatives, because these are the negatives that we just take from within the batch. There is another addition to that. We can challenge the model more by adding a hard negative for each example. The way it works is, let's say your query is, find me best Italian restaurants in London. If we have a restaurant, which is an Asian restaurant in London, that's an easy thing for the model to know that, ok, this is not an Italian restaurant, so this is not a good match.
To make things challenging, we can add a hard negative, which is going to be semantically similar, so maybe Italian restaurant in New York. This will teach the model to know that, ok, being Italian restaurant is not enough, we want to pay attention to the location as well. This is basically adding some hard negatives for each example, and we just modify the loss to maximize the positive similarity, minimize the in-batch negatives, and also minimize similarity between the hard negatives.
This is the training technique. How do we actually prepare the data? Let's say we have a bunch of text data, how do we prepare that data to train these embedding models? There are two techniques. One is supervised learning, and the other one is unsupervised learning. In supervised learning, what we do is we use a next sentence prediction. Let's say we take this text from the Wikipedia. I just search for London. These are the first two lines on the Wikipedia. We will split those into two separate sentences, and we will say that the left input, the query, is going to be the first sentence, and the document that we need to match is going to be the next sentence. This means that you can take any text corpus that you have and convert that into this next sentence prediction task to train your embedding models.
The other method is unsupervised learning. In this case, what we use is called span corruption. What we'll do is we'll take the same sentence, corrupt some span of that sentence. For example, in this case, we are going to mask out London is the capital, and only keep, of both England and the United Kingdom. On the other side, we are also going to take the same sentence, but now we're going to corrupt a different span. We're going to feed that as a positive example for the model to know that even though these spans are corrupted and masked, it still needs to predict embedding so that both of these are closer to each other in the embedding space. Similarly, the second sentence can be done the same way.
Let's look at how we can convert these large language models that you see everywhere these days into an embedding model. The first stage is going to be how do we prepare the data. We covered two techniques, supervised and unsupervised. Optionally, you can also add hard negatives. The second is, how do we choose the architecture? I will cover this more in detail in the later slides. What's more important is how we choose the size, as well as the output embedding dimension. The next, this is very important, because here, what we are doing is we are taking the large language model that's good at generating text, we are converting that into an embedding model. We'll load the model weights into the embedding model and change that attention to be bidirectional so that it can look at the whole sequence as an input. The next stage is training. We usually have two stage training. The first stage training is called pre-training. The goal of this stage is to take the large language model and convert that into embedding model. What happens here is that we train it on a lot of data, which is usually noisy, usually slightly low quality. The main goal is that instead of generating text token, now it knows that I have to be trained to generate embeddings. Then the next stage is usually fine-tuning, where we take a very specific data for what task we have. For example, let's say you have a task of RAG application. What you would do is you would take some given input, and for the documents, you will have the best matching document that needs to be retrieved so that the model is able to generate truthful results.
Distilling Large Models for Production
Next, let's look at how we can distill these large models into smaller size models for production. What is distillation? Distillation is basically a process of training a large size model and then distilling that into smaller size. We have a large model. Then we are going to train the smaller model using this large size model. There are three techniques that we use for distillation. The first one is scoring distillation. In scoring distillation, we are going to use Teacher model's Similarity Scores to train the student model. This is what it looks like. We have a query and document. We will generate the embeddings using the teacher model. We'll compute the similarity score that the teacher model predicts, and then we will pass the same input through the student model that's being trained. We will make sure that whatever similarity score it creates, that is closer to the similarity score that the teacher model generated. We usually use some loss which can compare these two scores. For example, mean squared error loss, which can compare these and teach the model to predict similar scores.
The second approach is embedding distillation. Instead of using only the final score, we will use the embeddings. For example, we have a teacher model and a student model. We'll put an input through the teacher model. It will generate some embedding. We'll do the same thing with the student model. It will generate another embedding. We are going to teach the model that student model's embedding should be very close to the teacher model embedding. We can combine both of these things together. We can use scoring plus embedding distillation. In this case, it's going to combine both the powers of the the model, that it will take the actual embedding that the model generates plus the final similarity score, and use both of them to train the student model. This is what it looks like. This is what we saw for the scoring distillation, where we take scores between the query and document for teacher as well as student, and we try to match that. On top of this, we will add another component, which will take how similar are the query embedding from teacher and the query embedding from the student. Similarly, we'll add another component, which will compare the document embedding from the teacher and document embedding from the student. This is called embedding distillation loss