joycaption / joycaption_app.py
kazuhina's picture
Add JoyCaption - Advanced Image Captioning with LLaVA
23573b0
raw
history blame
9.48 kB
#!/usr/bin/env python3
"""
JoyCaption - Advanced Image Captioning with LLaVA
Uses fancyfeast/llama-joycaption-alpha-two-hf-llava model for high-quality image descriptions
Free, open, and uncensored model for training Diffusion models
"""
import gradio as gr
import torch
import spaces
from transformers import AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import tempfile
import os
from pathlib import Path
# Initialize the JoyCaption model
print("Loading JoyCaption model...")
try:
# Model configuration for optimal performance
model_name = "fancyfeast/llama-joycaption-alpha-two-hf-llava"
# Load processor and model with correct configuration
processor = AutoProcessor.from_pretrained(model_name)
# Load model with bfloat16 (native dtype of Llama 3.1)
llava_model = LlavaForConditionalGeneration.from_pretrained(
model_name,
torch_dtype="bfloat16",
device_map="auto" if torch.cuda.is_available() else None
)
llava_model.eval()
print("JoyCaption model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# Create a fallback function for when model loading fails
def process_image_with_caption(*args, **kwargs):
return "Error: Model not loaded. Please check the model availability."
@spaces.GPU
def generate_image_caption(image_file, prompt_type="formal_detailed", custom_prompt=""):
"""
Generate high-quality image captions using JoyCaption model
Args:
image_file: Path to the image file or uploaded file
prompt_type: Type of captioning (formal_detailed, creative, simple, custom)
custom_prompt: Custom prompt for specialized captioning
Returns:
str: Generated image caption
"""
try:
if not image_file:
return "Please upload an image file."
# Handle different types of image inputs
if hasattr(image_file, 'name'):
# Gradio file object
image_path = image_file.name
elif isinstance(image_file, str):
# File path string
image_path = image_file
else:
return "Invalid image file format."
# Check if file exists
if not os.path.exists(image_path):
return "Image file not found."
print(f"Processing image: {image_path}")
# Load and preprocess image
try:
image = Image.open(image_path).convert('RGB')
except Exception as e:
return f"Error loading image: {str(e)}"
# Define prompt templates based on type
prompt_templates = {
"formal_detailed": "Write a long descriptive caption for this image in a formal tone.",
"creative": "Write a creative and artistic caption for this image, capturing its essence and mood.",
"simple": "Write a simple, concise caption describing what you see in this image.",
"technical": "Provide a detailed technical description of this image including composition, lighting, and visual elements.",
"custom": custom_prompt if custom_prompt else "Write a descriptive caption for this image."
}
# Select appropriate prompt
prompt = prompt_templates.get(prompt_type, prompt_templates["formal_detailed"])
# Build conversation following JoyCaption's recommended format
convo = [
{
"role": "system",
"content": "You are a helpful image captioner.",
},
{
"role": "user",
"content": prompt,
},
]
# Format the conversation using JoyCaption's specific method
# WARNING: HF's handling of chat's on Llava models is very fragile
convo_string = processor.apply_chat_template(
convo,
tokenize=False,
add_generation_prompt=True
)
assert isinstance(convo_string, str)
# Process the inputs with proper tensor handling
inputs = processor(
text=[convo_string],
images=[image],
return_tensors="pt"
).to('cuda' if torch.cuda.is_available() else 'cpu')
# Ensure pixel_values are in bfloat16
if 'pixel_values' in inputs:
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
# Generate captions with JoyCaption's recommended parameters
with torch.no_grad():
generate_ids = llava_model.generate(
**inputs,
max_new_tokens=300,
do_sample=True,
suppress_tokens=None,
use_cache=True,
temperature=0.6,
top_k=None,
top_p=0.9,
repetition_penalty=1.1
)[0]
# Trim off the prompt
generate_ids = generate_ids[inputs['input_ids'].shape[1]:]
# Decode the caption
caption = processor.tokenizer.decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
caption = caption.strip()
print(f"Caption generated successfully: {caption[:100]}...")
return caption
except Exception as e:
error_msg = f"Error during caption generation: {str(e)}"
print(error_msg)
return error_msg
def create_demo_image():
"""Create a demo image for testing"""
try:
# Create a simple colored rectangle as demo
from PIL import Image, ImageDraw
# Create a 512x512 image with gradient
width, height = 512, 512
image = Image.new('RGB', (width, height), color='white')
draw = ImageDraw.Draw(image)
# Draw a simple pattern
for i in range(0, width, 50):
for j in range(0, height, 50):
color = (i % 255, j % 255, (i + j) % 255)
draw.rectangle([i, j, i+25, j+25], fill=color)
# Save demo image
demo_file = "demo_image.png"
image.save(demo_file)
return demo_file
except Exception as e:
print(f"Error creating demo image: {e}")
return None
# Create Gradio interface
demo = gr.Interface(
fn=generate_image_caption,
inputs=[
gr.Image(
label="Upload Image for Captioning",
type="filepath",
format="png"
),
gr.Dropdown(
choices=["formal_detailed", "creative", "simple", "technical", "custom"],
value="formal_detailed",
label="Caption Style",
info="Choose the style of caption generation"
),
gr.Textbox(
label="Custom Prompt (Optional)",
placeholder="Enter custom prompt for specialized captioning...",
lines=3,
visible=False
)
],
outputs=[
gr.Textbox(
label="Generated Caption",
lines=8,
placeholder="The generated caption will appear here..."
)
],
title="🎨 JoyCaption - Advanced Image Captioning",
description="""
This application uses the **JoyCaption** model to generate high-quality, detailed captions for images.
**Key Features:**
- πŸ†“ **Free & Open**: No restrictions, open weights, training scripts included
- πŸ”“ **Uncensored**: Equal coverage of SFW and NSFW concepts
- 🌈 **Diversity**: Supports digital art, photoreal, anime, furry, and all styles
- 🎯 **High Performance**: Near GPT4o-level captioning quality
- πŸ”§ **Minimal Filtering**: Trained on diverse images for broad understanding
**Supported image formats:** PNG, JPG, JPEG, WEBP
**Caption Styles:**
- **Formal Detailed**: Long descriptive captions in formal tone
- **Creative**: Artistic and expressive descriptions
- **Simple**: Concise, straightforward descriptions
- **Technical**: Detailed technical analysis of composition and elements
- **Custom**: User-defined prompts for specialized captioning
**Model**: fancyfeast/llama-joycaption-alpha-two-hf-llava
**Architecture**: LLaVA with Llama 3.1 base
""",
examples=[
["Upload an image for formal detailed captioning"],
["Upload an image for creative captioning"],
["Upload an image with custom prompt"],
],
theme=gr.themes.Soft(
primary_hue="purple",
secondary_hue="slate",
neutral_hue="slate"
),
css="""
.gradio-container {max-width: 900px !important; margin: auto !important;}
.title {text-align: center; color: #7c3aed;}
.description {text-align: center; font-size: 1.1em;}
""",
flagging_mode="never",
submit_btn="🎨 Generate Caption",
stop_btn="⏹️ Stop"
)
if __name__ == "__main__":
print("πŸš€ Starting JoyCaption App...")
print("πŸ“± Interface will be available at: http://localhost:7860")
print("🎨 Using JoyCaption model by fancyfeast")
print("πŸ”“ Free, Open, and Uncensored Image Captioning")
# Launch the interface
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=False,
show_error=True
)