Spaces:
Running
on
Zero
Running
on
Zero
| # Importing the requirements | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import os | |
| import base64 | |
| import subprocess | |
| from io import BytesIO | |
| from tqdm import tqdm | |
| from pdf2image import convert_from_path | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from transformers.utils.import_utils import is_flash_attn_2_available | |
| from colpali_engine.models import ColModernVBert, ColModernVBertProcessor | |
| from openai import OpenAI | |
| import spaces | |
| import gradio as gr | |
| # Enable flash attention | |
| # subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
| # Load the visual document retrieval model | |
| model = ColModernVBert.from_pretrained( | |
| "ModernVBERT/colmodernvbert", | |
| torch_dtype=torch.float32, | |
| device_map="cuda:0", | |
| attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, | |
| ).eval() | |
| processor = ColModernVBertProcessor.from_pretrained("ModernVBERT/colmodernvbert") | |
| ################################################ | |
| # Helper functions | |
| ################################################ | |
| def encode_image_to_base64(image): | |
| """Encodes a PIL image to a base64 string.""" | |
| buffered = BytesIO() | |
| image.save(buffered, format="JPEG") | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| def convert_files(files): | |
| """Converts a list of PDF files to a list of images.""" | |
| images = [] | |
| for f in files: | |
| images.extend(convert_from_path(f, thread_count=4)) | |
| # Check if the number of images is greater than 150 | |
| if len(images) >= 150: | |
| raise gr.Error("The number of images in the dataset should be less than 150.") | |
| return images | |
| ################################################ | |
| # Model Inference with ModernVBERT and Qwen | |
| ################################################ | |
| def index_gpu(images, ds): | |
| """Runs inference on the GPU for the given images with the visual document retrieval model.""" | |
| # Specify the device | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| if device != model.device: | |
| model.to(device) | |
| # Create a DataLoader for the images | |
| dataloader = DataLoader( | |
| images, | |
| batch_size=4, | |
| # num_workers=4, | |
| shuffle=False, | |
| collate_fn=lambda x: processor.process_images(x).to(model.device), | |
| ) | |
| # Store the document embeddings | |
| for batch_doc in tqdm(dataloader): | |
| with torch.no_grad(): | |
| batch_doc = {k: v.to(device) for k, v in batch_doc.items()} | |
| embeddings_doc = model(**batch_doc) | |
| ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) | |
| return f"Uploaded and converted {len(images)} pages", ds, images | |
| def query_qwen(query, images, api_key): | |
| """Calls Qwen model with the query and image data.""" | |
| if api_key: | |
| try: | |
| # Convert images to base64 strings | |
| base64_images = [encode_image_to_base64(image[0]) for image in images] | |
| # Initialize the OpenAI client with the Hugging Face token | |
| client = OpenAI( | |
| api_key=api_key.strip(), | |
| base_url="https://router.huggingface.co/v1", | |
| ) | |
| PROMPT = """ | |
| You are a smart assistant designed to answer questions about a PDF document. | |
| You are given relevant information in the form of PDF pages. Use them to construct a short response to the question, and cite your sources (page numbers, etc). | |
| If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents. | |
| Give detailed and extensive answers, only containing info in the pages you are given. | |
| You can answer using information contained in plots and figures if necessary. | |
| Answer in the same language as the query. | |
| Query: {query} | |
| PDF pages: | |
| """ | |
| # Get the response from the Qwen inference API | |
| response = client.chat.completions.create( | |
| model="Qwen/Qwen3-VL-30B-A3B-Instruct", | |
| reasoning_effort="none", | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": PROMPT.format(query=query)} | |
| ] | |
| + [ | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/jpeg;base64,{im}"}, | |
| } | |
| for im in base64_images | |
| ], | |
| } | |
| ], | |
| max_tokens=500, | |
| ) | |
| # Return the content of the response | |
| return response.choices[0].message.content | |
| # Handle errors from the API | |
| except Exception as e: | |
| return "API connection error! Please check your API token and try again." | |
| # If no API token is provided, return a message indicating that the user should enter their token | |
| return "Enter your Hugging Face token to get a custom response." | |
| ################################################ | |
| # Document Indexing and Search | |
| ################################################ | |
| def index(files, ds): | |
| """Convert files to images and index them.""" | |
| images = convert_files(files) | |
| return index_gpu(images, ds) | |
| def search(query: str, ds, images, k, api_key): | |
| """Search for the most relevant pages based on the query.""" | |
| k = min(k, len(ds)) | |
| # Specify the device | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| if device != model.device: | |
| model.to(device) | |
| # Store the query embeddings | |
| qs = [] | |
| with torch.no_grad(): | |
| batch_query = processor.process_queries([query]).to(model.device) | |
| embeddings_query = model(**batch_query) | |
| qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) | |
| # Compute scores | |
| scores = processor.score(qs, ds, device=device) | |
| top_k_indices = scores[0].topk(k).indices.tolist() | |
| # Get the top k images | |
| results = [] | |
| for idx in top_k_indices: | |
| img = images[idx] | |
| img_copy = img.copy() | |
| results.append((img_copy, f"Page {idx}")) | |
| # Generate response | |
| ai_response = query_qwen(query, results, api_key) | |
| return results, ai_response | |
| ################################################ | |
| # Gradio UI | |
| ################################################ | |
| with gr.Blocks(theme=gr.themes.Ocean()) as demo: | |
| gr.Markdown( | |
| "# Multimodal RAG with ModernVBERT & Qwen 📚" | |
| ) | |
| gr.Markdown( | |
| """Demo to test ColModernVBERT (ModernVBERT) on PDF documents. | |
| ModernVBERT is a model implemented from the paper [ModernVBERT: Towards Smaller Visual Document Retrievers](https://arxiv.org/abs/2510.01149). | |
| This demo allows you to upload PDF files and search for the most relevant pages based on your query. | |
| Refresh the page if you change documents! | |
| ⚠️ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing English text. Performance is expected to drop for other page formats and languages. | |
| Other models will be released with better robustness towards different languages and document formats! | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("## 1️⃣ Upload PDFs") | |
| file = gr.File( | |
| file_types=[".pdf"], file_count="multiple", label="Upload PDFs" | |
| ) | |
| gr.Markdown("## 2️⃣ Index the PDFs") | |
| message = gr.Textbox("Files not yet uploaded", label="Status") | |
| convert_button = gr.Button("🔄 Index documents") | |
| embeds = gr.State(value=[]) | |
| imgs = gr.State(value=[]) | |
| with gr.Column(scale=3): | |
| gr.Markdown("## 3️⃣ Search") | |
| api_key = gr.Textbox( | |
| placeholder="Enter your Hugging Face token here (must be valid)", | |
| label="API token", | |
| ) | |
| query = gr.Textbox(placeholder="Enter your query here", label="Query") | |
| k = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| step=1, | |
| label="Number of results", | |
| value=3, | |
| info="Number of pages to retrieve", | |
| ) | |
| search_button = gr.Button("🔍 Search", variant="primary") | |
| # Define the output components | |
| gr.Markdown("## 4️⃣ Retrieved Image") | |
| output_gallery = gr.Gallery( | |
| label="Retrieved Documents", height=600, show_label=True | |
| ) | |
| gr.Markdown("## 5️⃣ Qwen Response") | |
| output_text = gr.Textbox( | |
| label="AI Response", | |
| placeholder="Generated response based on retrieved documents", | |
| show_copy_button=True, | |
| ) | |
| # Define the button actions | |
| convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs]) | |
| search_button.click( | |
| search, | |
| inputs=[query, embeds, imgs, k, api_key], | |
| outputs=[output_gallery, output_text], | |
| ) | |
| # Launch the gradio app | |
| if __name__ == "__main__": | |
| demo.queue(max_size=10).launch() | |