File size: 4,009 Bytes
ad248cf
 
 
 
 
 
 
 
 
 
 
7075532
86e4f57
ad248cf
 
 
e1ecdae
ad248cf
23eef01
ad248cf
 
 
 
adf3303
ad248cf
 
 
 
 
 
9a84bf7
ad248cf
23eef01
ad248cf
 
 
 
3d7968b
0651cc5
ad248cf
23eef01
03ba0a4
e397e4f
4a3a4ab
7fda2c6
03ba0a4
 
ad248cf
03ba0a4
 
ad248cf
 
 
 
 
6ebf269
276de44
7d4ae0d
276de44
 
5edce3a
86e4f57
 
276de44
 
e397e4f
4a3a4ab
3ec0da9
7fda2c6
4a3a4ab
934d045
 
f0175ac
 
86e4f57
 
 
 
 
 
ad248cf
 
276de44
0651cc5
 
 
 
a8d3569
0651cc5
 
 
 
 
 
 
 
 
 
 
 
 
 
6ebf269
 
 
0651cc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8d3569
 
 
 
 
6ebf269
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import modal
from configs import (
    vllm_image,
    hf_cache_vol,
    vllm_cache_vol,
    MODEL_NAME,
    MODEL_REVISION,
    MINUTE,
    N_GPU,
    API_KEY,
    VLLM_PORT,
    flashinfer_cache_vol,
    CHAT_TEMPLATE,
)


app = modal.App("vibe-shopping-llm")


@app.function(
    image=vllm_image,
    gpu=f"H100:{N_GPU}",
    scaledown_window=(
        5 * MINUTE
        # how long should we stay up with no requests? Keep it low to minimize credit usage for now.
    ),
    timeout=10 * MINUTE,  # how long should we wait for container start?
    volumes={
        "/root/.cache/huggingface": hf_cache_vol,
        "/root/.cache/vllm": vllm_cache_vol,
        "/root/.cache/flashinfer": flashinfer_cache_vol,
    },
    secrets=[API_KEY],
)
@modal.concurrent(
    max_inputs=50  # maximum number of concurrent requests per aut-scaling replica
)
@modal.web_server(port=VLLM_PORT, startup_timeout=10 * MINUTE)
def serve_llm():
    import subprocess
    import os
    import torch

    min_pixels = 128 * 28 * 28  # min 128 tokens
    max_pixels = 500 * 28 * 28  # max 500 tokens (~640x640 image)
    
    major, minor = torch.cuda.get_device_capability()
    cmd = [
        "env",
        f"TORCH_CUDA_ARCH_LIST={major}.{minor}",
        "vllm",
        "serve",
        MODEL_NAME,
        "--revision",
        MODEL_REVISION,
        "--uvicorn-log-level=info",
        "--tool-call-parser",
        "hermes",
        "--enable-auto-tool-choice",
        "--limit-mm-per-prompt",
        "image=100",
        "--chat-template",
        CHAT_TEMPLATE,
        "--tensor-parallel-size",
        str(N_GPU),
        "--enforce-eager",
        # Minimize token usage
        "--mm-processor-kwargs",
        f'{{"min_pixels": {min_pixels}, "max_pixels": {max_pixels}, "use_fast": true}}',
        # Extend context length to 65536 tokens
        # "--rope-scaling",
        # '{"rope_type":"yarn","factor":2.0,"original_max_position_embeddings":32768}',
        "--max-model-len",
        "32768",
        "--host",
        "0.0.0.0",
        "--port",
        str(VLLM_PORT),
        "--api-key",
        os.environ["API_KEY"],
    ]

    subprocess.Popen(cmd)


###### ------ FOR TESTING PURPOSES ONLY ------ ######
@app.local_entrypoint()
def test(test_timeout=25 * MINUTE, twice: bool = True):
    import os
    import json
    import time
    import urllib
    import dotenv

    dotenv.load_dotenv()
    if "OPENAI_API_KEY" not in os.environ:
        raise ValueError("OPENAI_API_KEY environment variable is not set.")

    print(f"Running health check for server at {serve_llm.get_web_url()}")
    up, start, delay = False, time.time(), 10
    while not up:
        try:
            with urllib.request.urlopen(
                serve_llm.get_web_url() + "/health"
            ) as response:
                if response.getcode() == 200:
                    up = True
        except Exception:
            if time.time() - start > test_timeout:
                break
            time.sleep(delay)

    assert up, f"Failed health check for server at {serve_llm.get_web_url()}"

    print(f"Successful health check for server at {serve_llm.get_web_url()}")

    messages = [{"role": "user", "content": "Testing! Is this thing on?"}]
    print(f"Sending a sample message to {serve_llm.get_web_url()}", *messages, sep="\n")

    headers = {
        "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",
        "Content-Type": "application/json",
    }
    payload = json.dumps({"messages": messages, "model": MODEL_NAME})
    req = urllib.request.Request(
        serve_llm.get_web_url() + "/v1/chat/completions",
        data=payload.encode("utf-8"),
        headers=headers,
        method="POST",
    )
    with urllib.request.urlopen(req) as response:
        print(json.loads(response.read().decode()))

    if twice:
        print("Sending the same message again to test caching.")
        with urllib.request.urlopen(req) as response:
            print(json.loads(response.read().decode()))