Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| """ | |
| JoyCaption - Advanced Image Captioning with LLaVA | |
| Uses fancyfeast/llama-joycaption-alpha-two-hf-llava model for high-quality image descriptions | |
| Free, open, and uncensored model for training Diffusion models | |
| """ | |
| import gradio as gr | |
| import torch | |
| import spaces | |
| from transformers import AutoProcessor, LlavaForConditionalGeneration | |
| from PIL import Image | |
| import tempfile | |
| import os | |
| from pathlib import Path | |
| # Initialize the JoyCaption model | |
| print("Loading JoyCaption model...") | |
| try: | |
| # Model configuration for optimal performance | |
| model_name = "fancyfeast/llama-joycaption-alpha-two-hf-llava" | |
| # Load processor and model with correct configuration | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| # Load model with bfloat16 (native dtype of Llama 3.1) | |
| llava_model = LlavaForConditionalGeneration.from_pretrained( | |
| model_name, | |
| torch_dtype="bfloat16", | |
| device_map="auto" if torch.cuda.is_available() else None | |
| ) | |
| llava_model.eval() | |
| print("JoyCaption model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # Create a fallback function for when model loading fails | |
| def process_image_with_caption(*args, **kwargs): | |
| return "Error: Model not loaded. Please check the model availability." | |
| def generate_image_caption(image_file, prompt_type="formal_detailed", custom_prompt=""): | |
| """ | |
| Generate high-quality image captions using JoyCaption model | |
| Args: | |
| image_file: Path to the image file or uploaded file | |
| prompt_type: Type of captioning (formal_detailed, creative, simple, custom) | |
| custom_prompt: Custom prompt for specialized captioning | |
| Returns: | |
| str: Generated image caption | |
| """ | |
| try: | |
| if not image_file: | |
| return "Please upload an image file." | |
| # Handle different types of image inputs | |
| if hasattr(image_file, 'name'): | |
| # Gradio file object | |
| image_path = image_file.name | |
| elif isinstance(image_file, str): | |
| # File path string | |
| image_path = image_file | |
| else: | |
| return "Invalid image file format." | |
| # Check if file exists | |
| if not os.path.exists(image_path): | |
| return "Image file not found." | |
| print(f"Processing image: {image_path}") | |
| # Load and preprocess image | |
| try: | |
| image = Image.open(image_path).convert('RGB') | |
| except Exception as e: | |
| return f"Error loading image: {str(e)}" | |
| # Define prompt templates based on type | |
| prompt_templates = { | |
| "formal_detailed": "Write a long descriptive caption for this image in a formal tone.", | |
| "creative": "Write a creative and artistic caption for this image, capturing its essence and mood.", | |
| "simple": "Write a simple, concise caption describing what you see in this image.", | |
| "technical": "Provide a detailed technical description of this image including composition, lighting, and visual elements.", | |
| "custom": custom_prompt if custom_prompt else "Write a descriptive caption for this image." | |
| } | |
| # Select appropriate prompt | |
| prompt = prompt_templates.get(prompt_type, prompt_templates["formal_detailed"]) | |
| # Build conversation following JoyCaption's recommended format | |
| convo = [ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful image captioner.", | |
| }, | |
| { | |
| "role": "user", | |
| "content": prompt, | |
| }, | |
| ] | |
| # Format the conversation using JoyCaption's specific method | |
| # WARNING: HF's handling of chat's on Llava models is very fragile | |
| convo_string = processor.apply_chat_template( | |
| convo, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| assert isinstance(convo_string, str) | |
| # Process the inputs with proper tensor handling | |
| inputs = processor( | |
| text=[convo_string], | |
| images=[image], | |
| return_tensors="pt" | |
| ).to('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Ensure pixel_values are in bfloat16 | |
| if 'pixel_values' in inputs: | |
| inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16) | |
| # Generate captions with JoyCaption's recommended parameters | |
| with torch.no_grad(): | |
| generate_ids = llava_model.generate( | |
| **inputs, | |
| max_new_tokens=300, | |
| do_sample=True, | |
| suppress_tokens=None, | |
| use_cache=True, | |
| temperature=0.6, | |
| top_k=None, | |
| top_p=0.9, | |
| repetition_penalty=1.1 | |
| )[0] | |
| # Trim off the prompt | |
| generate_ids = generate_ids[inputs['input_ids'].shape[1]:] | |
| # Decode the caption | |
| caption = processor.tokenizer.decode( | |
| generate_ids, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| ) | |
| caption = caption.strip() | |
| print(f"Caption generated successfully: {caption[:100]}...") | |
| return caption | |
| except Exception as e: | |
| error_msg = f"Error during caption generation: {str(e)}" | |
| print(error_msg) | |
| return error_msg | |
| def create_demo_image(): | |
| """Create a demo image for testing""" | |
| try: | |
| # Create a simple colored rectangle as demo | |
| from PIL import Image, ImageDraw | |
| # Create a 512x512 image with gradient | |
| width, height = 512, 512 | |
| image = Image.new('RGB', (width, height), color='white') | |
| draw = ImageDraw.Draw(image) | |
| # Draw a simple pattern | |
| for i in range(0, width, 50): | |
| for j in range(0, height, 50): | |
| color = (i % 255, j % 255, (i + j) % 255) | |
| draw.rectangle([i, j, i+25, j+25], fill=color) | |
| # Save demo image | |
| demo_file = "demo_image.png" | |
| image.save(demo_file) | |
| return demo_file | |
| except Exception as e: | |
| print(f"Error creating demo image: {e}") | |
| return None | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=generate_image_caption, | |
| inputs=[ | |
| gr.Image( | |
| label="Upload Image for Captioning", | |
| type="filepath", | |
| format="png" | |
| ), | |
| gr.Dropdown( | |
| choices=["formal_detailed", "creative", "simple", "technical", "custom"], | |
| value="formal_detailed", | |
| label="Caption Style", | |
| info="Choose the style of caption generation" | |
| ), | |
| gr.Textbox( | |
| label="Custom Prompt (Optional)", | |
| placeholder="Enter custom prompt for specialized captioning...", | |
| lines=3, | |
| visible=False | |
| ) | |
| ], | |
| outputs=[ | |
| gr.Textbox( | |
| label="Generated Caption", | |
| lines=8, | |
| placeholder="The generated caption will appear here..." | |
| ) | |
| ], | |
| title="π¨ JoyCaption - Advanced Image Captioning", | |
| description=""" | |
| This application uses the **JoyCaption** model to generate high-quality, detailed captions for images. | |
| **Key Features:** | |
| - π **Free & Open**: No restrictions, open weights, training scripts included | |
| - π **Uncensored**: Equal coverage of SFW and NSFW concepts | |
| - π **Diversity**: Supports digital art, photoreal, anime, furry, and all styles | |
| - π― **High Performance**: Near GPT4o-level captioning quality | |
| - π§ **Minimal Filtering**: Trained on diverse images for broad understanding | |
| **Supported image formats:** PNG, JPG, JPEG, WEBP | |
| **Caption Styles:** | |
| - **Formal Detailed**: Long descriptive captions in formal tone | |
| - **Creative**: Artistic and expressive descriptions | |
| - **Simple**: Concise, straightforward descriptions | |
| - **Technical**: Detailed technical analysis of composition and elements | |
| - **Custom**: User-defined prompts for specialized captioning | |
| **Model**: fancyfeast/llama-joycaption-alpha-two-hf-llava | |
| **Architecture**: LLaVA with Llama 3.1 base | |
| """, | |
| examples=[ | |
| ["Upload an image for formal detailed captioning"], | |
| ["Upload an image for creative captioning"], | |
| ["Upload an image with custom prompt"], | |
| ], | |
| theme=gr.themes.Soft( | |
| primary_hue="purple", | |
| secondary_hue="slate", | |
| neutral_hue="slate" | |
| ), | |
| css=""" | |
| .gradio-container {max-width: 900px !important; margin: auto !important;} | |
| .title {text-align: center; color: #7c3aed;} | |
| .description {text-align: center; font-size: 1.1em;} | |
| """, | |
| flagging_mode="never", | |
| submit_btn="π¨ Generate Caption", | |
| stop_btn="βΉοΈ Stop" | |
| ) | |
| if __name__ == "__main__": | |
| print("π Starting JoyCaption App...") | |
| print("π± Interface will be available at: http://localhost:7860") | |
| print("π¨ Using JoyCaption model by fancyfeast") | |
| print("π Free, Open, and Uncensored Image Captioning") | |
| # Launch the interface | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| debug=False, | |
| show_error=True | |
| ) |