Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoProcessor, VisionEncoderDecoderModel, TrOCRProcessor | |
| from vllm import LLM, SamplingParams | |
| from PIL import Image | |
| # Load the language model and tokenizer from Hugging Face | |
| model_name = "facebook/opt-125m" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Initialize vLLM with CPU configuration | |
| vllm_model = LLM(model=model_name, tensor_parallel_size=1, device="cpu") | |
| # Load the OCR model and processor | |
| ocr_model_name = "microsoft/trocr-small-handwritten" | |
| ocr_model = VisionEncoderDecoderModel.from_pretrained(ocr_model_name) | |
| ocr_processor = TrOCRProcessor.from_pretrained(ocr_model_name) | |
| #ocr_processor = AutoProcessor.from_pretrained(ocr_model_name) | |
| def generate_response(prompt, max_tokens, temperature, top_p): | |
| # Define sampling parameters | |
| sampling_params = SamplingParams( | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| # Generate text using vLLM (input is the raw string `prompt`) | |
| output = vllm_model.generate(prompt, sampling_params) | |
| # Extract and decode the generated tokens | |
| generated_text = output[0].outputs[0].text | |
| return generated_text | |
| def ocr_image(image_path): | |
| # Open the image from the file path | |
| image = Image.open(image_path).convert("RGB") | |
| # Preprocess the image for the OCR model | |
| pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values | |
| # Perform OCR inference | |
| outputs = ocr_model.generate(pixel_values) | |
| # Decode the generated tokens into text | |
| text = ocr_processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
| return text | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π Hugging Face Integration with vLLM and OCR (CPU)") | |
| gr.Markdown("Upload an image to extract text using OCR or generate text using the vLLM integration.") | |
| with gr.Tab("Text Generation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your prompt here...", | |
| lines=3, | |
| ) | |
| max_tokens = gr.Slider( | |
| label="Max Tokens", | |
| minimum=10, | |
| maximum=500, | |
| value=100, | |
| step=10, | |
| ) | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.1, | |
| ) | |
| top_p = gr.Slider( | |
| label="Top P", | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.1, | |
| ) | |
| submit_button = gr.Button("Generate") | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Generated Text", | |
| lines=10, | |
| interactive=False, | |
| ) | |
| submit_button.click( | |
| generate_response, | |
| inputs=[prompt_input, max_tokens, temperature, top_p], | |
| outputs=output_text, | |
| ) | |
| with gr.Tab("OCR"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image( | |
| label="Upload Image", | |
| type="filepath", # Corrected type | |
| image_mode="RGB", | |
| ) | |
| ocr_submit_button = gr.Button("Extract Text") | |
| with gr.Column(): | |
| ocr_output = gr.Textbox( | |
| label="Extracted Text", | |
| lines=10, | |
| interactive=False, | |
| ) | |
| ocr_submit_button.click( | |
| ocr_image, | |
| inputs=[image_input], | |
| outputs=ocr_output, | |
| ) | |
| # Launch the app | |
| demo.launch() |