Code Search with Vector Embeddings: A Transformer's Approach
What you will learn
- What is the main challenge addressed by the Python script discussed in the blog post?
- The main challenge addressed is transforming raw code snippets from a codebase into meaningful vector representations, known as embeddings, to search through the codebase using natural language queries.
- What are the required steps to set up the environment for running the Python script?
- The required steps include ensuring Python 3.7 or newer is installed, installing necessary libraries via pip, cloning the repository, and running the script.
- Which libraries are utilized in the Python script for searching through a codebase using natural language?
- The libraries utilized include `os`, `numpy` for basic file handling and numerical operations, `torch` and `torch.nn.functional` for tensor operations, and `transformers` from the Hugging Face library for pre-trained transformer models.
- How does the `find_k_nearest_neighbors` function determine the most relevant code snippets to a query?
- The `find_k_nearest_neighbors` function calculates the cosine similarity between the query embedding and all code snippet embeddings and returns the indices of the top-k most similar code snippets.
- What optimizations are suggested for dealing with large and complex codebases?
- Optimizations include finetuning embedding models on domain-specific data, using vector databases like Milvus or Faiss for efficient similarity search, and chunking the codebase for reduced memory consumption and parallel processing.
In today’s fast-paced development world, navigating through large codebases can be a daunting task. Wouldn’t it be great if you could search through your codebase using natural language queries? In this blog post, we’ll walk you through a basic Python script that does just that, albeit in a simple manner, leveraging the power of transformer models.
Laying the Groundwork
The core challenge we’re addressing is transforming raw code snippets from our codebase into meaningful vector representations, known as embeddings. These embeddings capture the essence of the code in a format that can be compared for similarity. By doing so, when we pose a natural language query, the system can sift through these embeddings, identify the most similar code snippets, and present them as relevant “answers” to our query.
Setup
Before you can run the script, you’ll need to set up your environment. Here’s a step-by-step guide:
-
Python Environment: Ensure you have Python 3.7 or newer installed. You can check your Python version with
python --version
. -
Install Required Libraries: You can install all the necessary libraries using pip:
pip install numpy torch transformers
-
Clone the Repository: The entire code, along with some sample codebases to test on, is available on GitHub. Clone the repository to your local machine:
git clone git@github.com:stephenc222/example-vectorize-codebase.git
-
Run the Script: Navigate to the directory containing the script and run:
python app.py
For a more detailed walkthrough, including potential customizations and optimizations, check out the companion GitHub repository.
The Building Blocks
Our script uses the following libraries:
os
andnumpy
: Basic Python libraries for file handling and numerical operations.torch
andtorch.nn.functional
: PyTorch libraries for tensor operations.transformers
: The Hugging Face library, which provides pre-trained transformer models.
import os
import numpy as np
from torch import Tensor
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
Loading the Codebase
The load_codebase
function recursively navigates through the specified directory, filtering out unwanted files and directories. It then reads the content of the allowed files and appends them to a list of code snippets.
CODEBASE_DIR = "./example-codebase"
IGNORED_DIRECTORIES = ["node_modules", "public/build"]
IGNORED_FILES = ["package-lock.json", "yarn.lock"]
ALLOWED_EXTENSIONS = [".ts", ".tsx"]
IMAGE_EXTENSIONS = [
".png",
".jpg",
".jpeg",
".gif",
".bmp",
".svg",
".ico",
]
def load_codebase(directory):
snippets = []
for filename in os.listdir(directory):
# Skip hidden files and directories
if filename.startswith('.'):
continue
filepath = os.path.join(directory, filename)
if os.path.isdir(filepath):
# If it's a directory, recursively load its contents
snippets.extend(load_codebase(filepath))
else:
if any(ignored in filepath for ignored in IGNORED_DIRECTORIES):
continue
if filename in IGNORED_FILES:
continue
if not any(filepath.endswith(ext) for ext in ALLOWED_EXTENSIONS):
continue
with open(filepath, 'r') as file:
content = file.read().strip()
if content: # Check if content is not empty
snippets.append(content)
return snippets
Generating Embeddings with Transformers
The heart of our script is the generate_embeddings
function. Here it is:
def generate_embeddings(snippets):
prefix = "query: " # Assuming all code snippets are queries
input_texts = [prefix + snippet for snippet in snippets]
tokenizer = AutoTokenizer.from_pretrained('thenlper/gte-base')
model = AutoModel.from_pretrained('thenlper/gte-base')
batch_dict = tokenizer(input_texts, max_length=512,
padding=True, truncation=True, return_tensors='pt')
outputs = model(**batch_dict)
embeddings = average_pool(
outputs.last_hidden_state, batch_dict['attention_mask'])
The average_pool
function it uses:
def average_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
And here is how it all works:
-
Tokenization: We prefix each code snippet with “query: ” and tokenize it using the
AutoTokenizer
from the Hugging Face library. This prepares our text for the transformer model. -
Model Inference: We use a pre-trained transformer model (
AutoModel
) to generate embeddings for our tokenized code snippets. The model returns the last hidden states for each token. -
Pooling: Since we want a single vector representation for each code snippet, we use the
average_pool
function to average out the token embeddings. This gives us a fixed-size vector for each code snippet. -
Normalization: Finally, we normalize the embeddings to ensure they have a magnitude of 1. This is crucial for calculating cosine similarities later on.
Finding the Nearest Neighbors
The find_k_nearest_neighbors
function calculates the cosine similarity between the query embedding and all code snippet embeddings. Since our embeddings are normalized, a simple dot product gives us the cosine similarity. The function then returns the indices of the top-k most similar code snippets.
def find_k_nearest_neighbors(query_embedding, embeddings, k=5):
# Using cosine similarity as embeddings are normalized
similarities = np.dot(embeddings, query_embedding.T)
sorted_indices = similarities.argsort(axis=0)[-k:][::-1]
return sorted_indices.squeeze()
Bringing It All Together
In the __main__
block, we:
- Load our codebase using
load_codebase
. - Generate embeddings for all code snippets using
generate_embeddings
. - Generate an embedding for our query.
- Find the nearest neighbors using
find_k_nearest_neighbors
.
Finally, we print out the top matches to see the most relevant pieces of code for our query.
if __name__ == "__main__":
snippets = load_codebase(CODEBASE_DIR)
embeddings = generate_embeddings(snippets)
# example query
query = "Where are the rules of sudoku defined?"
query_embedding = generate_embeddings([query])
nearest_neighbors = find_k_nearest_neighbors(query_embedding, embeddings)
top_matches = nearest_neighbors[:2]
print("Query:", query)
print("Top Matches:")
for index in top_matches:
# print the first 500 characters to illustrate the found match
print(f"- Matched Code:\n{snippets[index][:500]}...\n")
Next Steps
While our example provides a foundational understanding of code search with vector embeddings, it’s essential to recognize that we’ve worked with a relatively small codebase. In real-world scenarios, especially with extensive and complex codebases, there are additional considerations and optimizations to be made.
Finetuning Embedding Models
While pre-trained models offer a great starting point, they might not always capture the nuances of specific domains or applications. By finetuning these models on domain-specific data, we can achieve better performance. For instance, if you have a codebase primarily in a specific programming language or related to a particular domain (like web development or data science), finetuning your model on similar code snippets can enhance its understanding and, consequently, its search accuracy.
Vector Databases
As the size of the codebase grows, storing and searching through embeddings in memory becomes inefficient. This is where vector databases come into play. Tools like Milvus, Faiss, and others are designed to handle large-scale vector data and provide efficient similarity search capabilities. I’ve wrtten about how to also use sqlite to store vector embeddings. By integrating a vector database, you can scale your code search tool to handle much larger codebases without compromising on search speed.
Chunking the Codebase
Another challenge with large codebases is memory consumption during the embedding generation phase. One way to address this is by “chunking” the codebase. Instead of processing the entire codebase at once, you can divide it into smaller chunks and process each chunk separately. This ensures that only a part of the codebase is in memory at any given time, reducing memory overhead and making the embedding generation process more manageable.
By implementing chunking, you can sequentially generate embeddings for each chunk and then store them in the vector database. This approach not only conserves memory but also allows for parallel processing, where multiple chunks can be processed simultaneously on different cores or machines.
Conclusion
Through the ability of transformer models, we’ve built a simple tool to demonstrate search through a codebase using natural language queries. This approach can be a game-changer for large projects, making it easier for developers to find relevant code snippets and understand the codebase faster. The next time you’re lost in a sea of code, remember that transformers might just be the compass you need!
The entire codebase for this blog post is available on my GitHub repository. Feel free to fork, star, or open issues!