Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,482 Bytes
23573b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 |
#!/usr/bin/env python3
"""
JoyCaption - Advanced Image Captioning with LLaVA
Uses fancyfeast/llama-joycaption-alpha-two-hf-llava model for high-quality image descriptions
Free, open, and uncensored model for training Diffusion models
"""
import gradio as gr
import torch
import spaces
from transformers import AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import tempfile
import os
from pathlib import Path
# Initialize the JoyCaption model
print("Loading JoyCaption model...")
try:
# Model configuration for optimal performance
model_name = "fancyfeast/llama-joycaption-alpha-two-hf-llava"
# Load processor and model with correct configuration
processor = AutoProcessor.from_pretrained(model_name)
# Load model with bfloat16 (native dtype of Llama 3.1)
llava_model = LlavaForConditionalGeneration.from_pretrained(
model_name,
torch_dtype="bfloat16",
device_map="auto" if torch.cuda.is_available() else None
)
llava_model.eval()
print("JoyCaption model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# Create a fallback function for when model loading fails
def process_image_with_caption(*args, **kwargs):
return "Error: Model not loaded. Please check the model availability."
@spaces.GPU
def generate_image_caption(image_file, prompt_type="formal_detailed", custom_prompt=""):
"""
Generate high-quality image captions using JoyCaption model
Args:
image_file: Path to the image file or uploaded file
prompt_type: Type of captioning (formal_detailed, creative, simple, custom)
custom_prompt: Custom prompt for specialized captioning
Returns:
str: Generated image caption
"""
try:
if not image_file:
return "Please upload an image file."
# Handle different types of image inputs
if hasattr(image_file, 'name'):
# Gradio file object
image_path = image_file.name
elif isinstance(image_file, str):
# File path string
image_path = image_file
else:
return "Invalid image file format."
# Check if file exists
if not os.path.exists(image_path):
return "Image file not found."
print(f"Processing image: {image_path}")
# Load and preprocess image
try:
image = Image.open(image_path).convert('RGB')
except Exception as e:
return f"Error loading image: {str(e)}"
# Define prompt templates based on type
prompt_templates = {
"formal_detailed": "Write a long descriptive caption for this image in a formal tone.",
"creative": "Write a creative and artistic caption for this image, capturing its essence and mood.",
"simple": "Write a simple, concise caption describing what you see in this image.",
"technical": "Provide a detailed technical description of this image including composition, lighting, and visual elements.",
"custom": custom_prompt if custom_prompt else "Write a descriptive caption for this image."
}
# Select appropriate prompt
prompt = prompt_templates.get(prompt_type, prompt_templates["formal_detailed"])
# Build conversation following JoyCaption's recommended format
convo = [
{
"role": "system",
"content": "You are a helpful image captioner.",
},
{
"role": "user",
"content": prompt,
},
]
# Format the conversation using JoyCaption's specific method
# WARNING: HF's handling of chat's on Llava models is very fragile
convo_string = processor.apply_chat_template(
convo,
tokenize=False,
add_generation_prompt=True
)
assert isinstance(convo_string, str)
# Process the inputs with proper tensor handling
inputs = processor(
text=[convo_string],
images=[image],
return_tensors="pt"
).to('cuda' if torch.cuda.is_available() else 'cpu')
# Ensure pixel_values are in bfloat16
if 'pixel_values' in inputs:
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
# Generate captions with JoyCaption's recommended parameters
with torch.no_grad():
generate_ids = llava_model.generate(
**inputs,
max_new_tokens=300,
do_sample=True,
suppress_tokens=None,
use_cache=True,
temperature=0.6,
top_k=None,
top_p=0.9,
repetition_penalty=1.1
)[0]
# Trim off the prompt
generate_ids = generate_ids[inputs['input_ids'].shape[1]:]
# Decode the caption
caption = processor.tokenizer.decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
caption = caption.strip()
print(f"Caption generated successfully: {caption[:100]}...")
return caption
except Exception as e:
error_msg = f"Error during caption generation: {str(e)}"
print(error_msg)
return error_msg
def create_demo_image():
"""Create a demo image for testing"""
try:
# Create a simple colored rectangle as demo
from PIL import Image, ImageDraw
# Create a 512x512 image with gradient
width, height = 512, 512
image = Image.new('RGB', (width, height), color='white')
draw = ImageDraw.Draw(image)
# Draw a simple pattern
for i in range(0, width, 50):
for j in range(0, height, 50):
color = (i % 255, j % 255, (i + j) % 255)
draw.rectangle([i, j, i+25, j+25], fill=color)
# Save demo image
demo_file = "demo_image.png"
image.save(demo_file)
return demo_file
except Exception as e:
print(f"Error creating demo image: {e}")
return None
# Create Gradio interface
demo = gr.Interface(
fn=generate_image_caption,
inputs=[
gr.Image(
label="Upload Image for Captioning",
type="filepath",
format="png"
),
gr.Dropdown(
choices=["formal_detailed", "creative", "simple", "technical", "custom"],
value="formal_detailed",
label="Caption Style",
info="Choose the style of caption generation"
),
gr.Textbox(
label="Custom Prompt (Optional)",
placeholder="Enter custom prompt for specialized captioning...",
lines=3,
visible=False
)
],
outputs=[
gr.Textbox(
label="Generated Caption",
lines=8,
placeholder="The generated caption will appear here..."
)
],
title="π¨ JoyCaption - Advanced Image Captioning",
description="""
This application uses the **JoyCaption** model to generate high-quality, detailed captions for images.
**Key Features:**
- π **Free & Open**: No restrictions, open weights, training scripts included
- π **Uncensored**: Equal coverage of SFW and NSFW concepts
- π **Diversity**: Supports digital art, photoreal, anime, furry, and all styles
- π― **High Performance**: Near GPT4o-level captioning quality
- π§ **Minimal Filtering**: Trained on diverse images for broad understanding
**Supported image formats:** PNG, JPG, JPEG, WEBP
**Caption Styles:**
- **Formal Detailed**: Long descriptive captions in formal tone
- **Creative**: Artistic and expressive descriptions
- **Simple**: Concise, straightforward descriptions
- **Technical**: Detailed technical analysis of composition and elements
- **Custom**: User-defined prompts for specialized captioning
**Model**: fancyfeast/llama-joycaption-alpha-two-hf-llava
**Architecture**: LLaVA with Llama 3.1 base
""",
examples=[
["Upload an image for formal detailed captioning"],
["Upload an image for creative captioning"],
["Upload an image with custom prompt"],
],
theme=gr.themes.Soft(
primary_hue="purple",
secondary_hue="slate",
neutral_hue="slate"
),
css="""
.gradio-container {max-width: 900px !important; margin: auto !important;}
.title {text-align: center; color: #7c3aed;}
.description {text-align: center; font-size: 1.1em;}
""",
flagging_mode="never",
submit_btn="π¨ Generate Caption",
stop_btn="βΉοΈ Stop"
)
if __name__ == "__main__":
print("π Starting JoyCaption App...")
print("π± Interface will be available at: http://localhost:7860")
print("π¨ Using JoyCaption model by fancyfeast")
print("π Free, Open, and Uncensored Image Captioning")
# Launch the interface
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=False,
show_error=True
) |