hf-docs / app.py
julien-c's picture
julien-c HF staff
Fixes
bdbadf6 verified
raw
history blame contribute delete
No virus
3.51 kB
import time
import os
from typing import Literal, Tuple
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer
import meilisearch
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-base-en-v1.5")
model = AutoModel.from_pretrained("BAAI/bge-base-en-v1.5")
model.eval()
cuda_available = torch.cuda.is_available()
print(f"CUDA available: {cuda_available}")
meilisearch_client = meilisearch.Client(
"https://edge.meilisearch.com", os.environ["MEILISEARCH_KEY"]
)
meilisearch_index_name = "docs-embed"
meilisearch_index = meilisearch_client.index(meilisearch_index_name)
output_options = ["RAG-friendly", "human-friendly"]
def search_embeddings(
query_text: str, output_option: Literal["RAG-friendly", "human-friendly"]
) -> Tuple[str, str]:
start_time_embedding = time.time()
query_prefix = "Represent this sentence for searching code documentation: "
query_tokens = tokenizer(
query_prefix + query_text,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512,
)
# step1: tokenizer the query
with torch.no_grad():
# Compute token embeddings
model_output = model(**query_tokens)
sentence_embeddings = model_output[0][:, 0]
# normalize embeddings
sentence_embeddings = torch.nn.functional.normalize(
sentence_embeddings, p=2, dim=1
)
sentence_embeddings_list = sentence_embeddings[0].tolist()
elapsed_time_embedding = time.time() - start_time_embedding
# step2: search meilisearch
start_time_meilisearch = time.time()
response = meilisearch_index.search(
"",
opt_params={
"vector": sentence_embeddings_list,
"hybrid": {"semanticRatio": 1.0},
"limit": 5,
"attributesToRetrieve": [
"text",
"source_page_url",
"source_page_title",
"library",
],
},
)
elapsed_time_meilisearch = time.time() - start_time_meilisearch
hits = response["hits"]
sources_md = [
f"[\"{hit['source_page_title']}\"]({hit['source_page_url']})" for hit in hits
]
sources_md = ", ".join(sources_md)
# step3: present the results in markdown
if output_option == "human-friendly":
md = f"Stats:\n\nembedding time: {elapsed_time_embedding:.2f}s\n\nmeilisearch time: {elapsed_time_meilisearch:.2f}s\n\n---\n\n"
for hit in hits:
text, source_page_url, source_page_title = (
hit["text"],
hit["source_page_url"],
hit["source_page_title"],
)
source = f'src: ["{source_page_title}"]({source_page_url})'
md += text + f"\n\n{source}\n\n---\n\n"
return md, sources_md
elif output_option == "RAG-friendly":
hit_texts = [hit["text"] for hit in hits]
hit_text_str = "\n------------\n".join(hit_texts)
return hit_text_str, sources_md
demo = gr.Interface(
fn=search_embeddings,
inputs=[
gr.Textbox(
label="enter your query", placeholder="Type Markdown here...", lines=10
),
gr.Radio(
label="Select an output option",
choices=output_options,
value="RAG-friendly",
),
],
outputs=[gr.Markdown(), gr.Markdown()],
title="HF Docs Embeddings Explorer",
allow_flagging="never",
)
if __name__ == "__main__":
demo.launch()