File size: 9,482 Bytes
23573b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
#!/usr/bin/env python3
"""
JoyCaption - Advanced Image Captioning with LLaVA
Uses fancyfeast/llama-joycaption-alpha-two-hf-llava model for high-quality image descriptions
Free, open, and uncensored model for training Diffusion models
"""

import gradio as gr
import torch
import spaces
from transformers import AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import tempfile
import os
from pathlib import Path

# Initialize the JoyCaption model
print("Loading JoyCaption model...")
try:
    # Model configuration for optimal performance
    model_name = "fancyfeast/llama-joycaption-alpha-two-hf-llava"
    
    # Load processor and model with correct configuration
    processor = AutoProcessor.from_pretrained(model_name)
    
    # Load model with bfloat16 (native dtype of Llama 3.1)
    llava_model = LlavaForConditionalGeneration.from_pretrained(
        model_name, 
        torch_dtype="bfloat16", 
        device_map="auto" if torch.cuda.is_available() else None
    )
    llava_model.eval()
    
    print("JoyCaption model loaded successfully!")
    
except Exception as e:
    print(f"Error loading model: {e}")
    # Create a fallback function for when model loading fails
    def process_image_with_caption(*args, **kwargs):
        return "Error: Model not loaded. Please check the model availability."

@spaces.GPU
def generate_image_caption(image_file, prompt_type="formal_detailed", custom_prompt=""):
    """
    Generate high-quality image captions using JoyCaption model
    
    Args:
        image_file: Path to the image file or uploaded file
        prompt_type: Type of captioning (formal_detailed, creative, simple, custom)
        custom_prompt: Custom prompt for specialized captioning
        
    Returns:
        str: Generated image caption
    """
    try:
        if not image_file:
            return "Please upload an image file."
        
        # Handle different types of image inputs
        if hasattr(image_file, 'name'):
            # Gradio file object
            image_path = image_file.name
        elif isinstance(image_file, str):
            # File path string
            image_path = image_file
        else:
            return "Invalid image file format."
        
        # Check if file exists
        if not os.path.exists(image_path):
            return "Image file not found."
        
        print(f"Processing image: {image_path}")
        
        # Load and preprocess image
        try:
            image = Image.open(image_path).convert('RGB')
        except Exception as e:
            return f"Error loading image: {str(e)}"
        
        # Define prompt templates based on type
        prompt_templates = {
            "formal_detailed": "Write a long descriptive caption for this image in a formal tone.",
            "creative": "Write a creative and artistic caption for this image, capturing its essence and mood.",
            "simple": "Write a simple, concise caption describing what you see in this image.",
            "technical": "Provide a detailed technical description of this image including composition, lighting, and visual elements.",
            "custom": custom_prompt if custom_prompt else "Write a descriptive caption for this image."
        }
        
        # Select appropriate prompt
        prompt = prompt_templates.get(prompt_type, prompt_templates["formal_detailed"])
        
        # Build conversation following JoyCaption's recommended format
        convo = [
            {
                "role": "system",
                "content": "You are a helpful image captioner.",
            },
            {
                "role": "user",
                "content": prompt,
            },
        ]
        
        # Format the conversation using JoyCaption's specific method
        # WARNING: HF's handling of chat's on Llava models is very fragile
        convo_string = processor.apply_chat_template(
            convo, 
            tokenize=False, 
            add_generation_prompt=True
        )
        assert isinstance(convo_string, str)
        
        # Process the inputs with proper tensor handling
        inputs = processor(
            text=[convo_string], 
            images=[image], 
            return_tensors="pt"
        ).to('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Ensure pixel_values are in bfloat16
        if 'pixel_values' in inputs:
            inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
        
        # Generate captions with JoyCaption's recommended parameters
        with torch.no_grad():
            generate_ids = llava_model.generate(
                **inputs,
                max_new_tokens=300,
                do_sample=True,
                suppress_tokens=None,
                use_cache=True,
                temperature=0.6,
                top_k=None,
                top_p=0.9,
                repetition_penalty=1.1
            )[0]
            
            # Trim off the prompt
            generate_ids = generate_ids[inputs['input_ids'].shape[1]:]
            
            # Decode the caption
            caption = processor.tokenizer.decode(
                generate_ids, 
                skip_special_tokens=True, 
                clean_up_tokenization_spaces=False
            )
            caption = caption.strip()
        
        print(f"Caption generated successfully: {caption[:100]}...")
        return caption
        
    except Exception as e:
        error_msg = f"Error during caption generation: {str(e)}"
        print(error_msg)
        return error_msg

def create_demo_image():
    """Create a demo image for testing"""
    try:
        # Create a simple colored rectangle as demo
        from PIL import Image, ImageDraw
        
        # Create a 512x512 image with gradient
        width, height = 512, 512
        image = Image.new('RGB', (width, height), color='white')
        draw = ImageDraw.Draw(image)
        
        # Draw a simple pattern
        for i in range(0, width, 50):
            for j in range(0, height, 50):
                color = (i % 255, j % 255, (i + j) % 255)
                draw.rectangle([i, j, i+25, j+25], fill=color)
        
        # Save demo image
        demo_file = "demo_image.png"
        image.save(demo_file)
        return demo_file
        
    except Exception as e:
        print(f"Error creating demo image: {e}")
        return None

# Create Gradio interface
demo = gr.Interface(
    fn=generate_image_caption,
    inputs=[
        gr.Image(
            label="Upload Image for Captioning",
            type="filepath",
            format="png"
        ),
        gr.Dropdown(
            choices=["formal_detailed", "creative", "simple", "technical", "custom"],
            value="formal_detailed",
            label="Caption Style",
            info="Choose the style of caption generation"
        ),
        gr.Textbox(
            label="Custom Prompt (Optional)",
            placeholder="Enter custom prompt for specialized captioning...",
            lines=3,
            visible=False
        )
    ],
    outputs=[
        gr.Textbox(
            label="Generated Caption",
            lines=8,
            placeholder="The generated caption will appear here..."
        )
    ],
    title="🎨 JoyCaption - Advanced Image Captioning",
    description="""
    This application uses the **JoyCaption** model to generate high-quality, detailed captions for images.
    
    **Key Features:**
    - πŸ†“ **Free & Open**: No restrictions, open weights, training scripts included
    - πŸ”“ **Uncensored**: Equal coverage of SFW and NSFW concepts
    - 🌈 **Diversity**: Supports digital art, photoreal, anime, furry, and all styles
    - 🎯 **High Performance**: Near GPT4o-level captioning quality
    - πŸ”§ **Minimal Filtering**: Trained on diverse images for broad understanding
    
    **Supported image formats:** PNG, JPG, JPEG, WEBP
    
    **Caption Styles:**
    - **Formal Detailed**: Long descriptive captions in formal tone
    - **Creative**: Artistic and expressive descriptions
    - **Simple**: Concise, straightforward descriptions
    - **Technical**: Detailed technical analysis of composition and elements
    - **Custom**: User-defined prompts for specialized captioning
    
    **Model**: fancyfeast/llama-joycaption-alpha-two-hf-llava
    **Architecture**: LLaVA with Llama 3.1 base
    """,
    examples=[
        ["Upload an image for formal detailed captioning"],
        ["Upload an image for creative captioning"],
        ["Upload an image with custom prompt"],
    ],
    theme=gr.themes.Soft(
        primary_hue="purple",
        secondary_hue="slate",
        neutral_hue="slate"
    ),
    css="""
    .gradio-container {max-width: 900px !important; margin: auto !important;}
    .title {text-align: center; color: #7c3aed;}
    .description {text-align: center; font-size: 1.1em;}
    """,
    flagging_mode="never",
    submit_btn="🎨 Generate Caption",
    stop_btn="⏹️ Stop"
)

if __name__ == "__main__":
    print("πŸš€ Starting JoyCaption App...")
    print("πŸ“± Interface will be available at: http://localhost:7860")
    print("🎨 Using JoyCaption model by fancyfeast")
    print("πŸ”“ Free, Open, and Uncensored Image Captioning")
    
    # Launch the interface
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        debug=False,
        show_error=True
    )