Spaces:
Running
Running
File size: 4,009 Bytes
ad248cf 7075532 86e4f57 ad248cf e1ecdae ad248cf 23eef01 ad248cf adf3303 ad248cf 9a84bf7 ad248cf 23eef01 ad248cf 3d7968b 0651cc5 ad248cf 23eef01 03ba0a4 e397e4f 4a3a4ab 7fda2c6 03ba0a4 ad248cf 03ba0a4 ad248cf 6ebf269 276de44 7d4ae0d 276de44 5edce3a 86e4f57 276de44 e397e4f 4a3a4ab 3ec0da9 7fda2c6 4a3a4ab 934d045 f0175ac 86e4f57 ad248cf 276de44 0651cc5 a8d3569 0651cc5 6ebf269 0651cc5 a8d3569 6ebf269 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import modal
from configs import (
vllm_image,
hf_cache_vol,
vllm_cache_vol,
MODEL_NAME,
MODEL_REVISION,
MINUTE,
N_GPU,
API_KEY,
VLLM_PORT,
flashinfer_cache_vol,
CHAT_TEMPLATE,
)
app = modal.App("vibe-shopping-llm")
@app.function(
image=vllm_image,
gpu=f"H100:{N_GPU}",
scaledown_window=(
5 * MINUTE
# how long should we stay up with no requests? Keep it low to minimize credit usage for now.
),
timeout=10 * MINUTE, # how long should we wait for container start?
volumes={
"/root/.cache/huggingface": hf_cache_vol,
"/root/.cache/vllm": vllm_cache_vol,
"/root/.cache/flashinfer": flashinfer_cache_vol,
},
secrets=[API_KEY],
)
@modal.concurrent(
max_inputs=50 # maximum number of concurrent requests per aut-scaling replica
)
@modal.web_server(port=VLLM_PORT, startup_timeout=10 * MINUTE)
def serve_llm():
import subprocess
import os
import torch
min_pixels = 128 * 28 * 28 # min 128 tokens
max_pixels = 500 * 28 * 28 # max 500 tokens (~640x640 image)
major, minor = torch.cuda.get_device_capability()
cmd = [
"env",
f"TORCH_CUDA_ARCH_LIST={major}.{minor}",
"vllm",
"serve",
MODEL_NAME,
"--revision",
MODEL_REVISION,
"--uvicorn-log-level=info",
"--tool-call-parser",
"hermes",
"--enable-auto-tool-choice",
"--limit-mm-per-prompt",
"image=100",
"--chat-template",
CHAT_TEMPLATE,
"--tensor-parallel-size",
str(N_GPU),
"--enforce-eager",
# Minimize token usage
"--mm-processor-kwargs",
f'{{"min_pixels": {min_pixels}, "max_pixels": {max_pixels}, "use_fast": true}}',
# Extend context length to 65536 tokens
# "--rope-scaling",
# '{"rope_type":"yarn","factor":2.0,"original_max_position_embeddings":32768}',
"--max-model-len",
"32768",
"--host",
"0.0.0.0",
"--port",
str(VLLM_PORT),
"--api-key",
os.environ["API_KEY"],
]
subprocess.Popen(cmd)
###### ------ FOR TESTING PURPOSES ONLY ------ ######
@app.local_entrypoint()
def test(test_timeout=25 * MINUTE, twice: bool = True):
import os
import json
import time
import urllib
import dotenv
dotenv.load_dotenv()
if "OPENAI_API_KEY" not in os.environ:
raise ValueError("OPENAI_API_KEY environment variable is not set.")
print(f"Running health check for server at {serve_llm.get_web_url()}")
up, start, delay = False, time.time(), 10
while not up:
try:
with urllib.request.urlopen(
serve_llm.get_web_url() + "/health"
) as response:
if response.getcode() == 200:
up = True
except Exception:
if time.time() - start > test_timeout:
break
time.sleep(delay)
assert up, f"Failed health check for server at {serve_llm.get_web_url()}"
print(f"Successful health check for server at {serve_llm.get_web_url()}")
messages = [{"role": "user", "content": "Testing! Is this thing on?"}]
print(f"Sending a sample message to {serve_llm.get_web_url()}", *messages, sep="\n")
headers = {
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",
"Content-Type": "application/json",
}
payload = json.dumps({"messages": messages, "model": MODEL_NAME})
req = urllib.request.Request(
serve_llm.get_web_url() + "/v1/chat/completions",
data=payload.encode("utf-8"),
headers=headers,
method="POST",
)
with urllib.request.urlopen(req) as response:
print(json.loads(response.read().decode()))
if twice:
print("Sending the same message again to test caching.")
with urllib.request.urlopen(req) as response:
print(json.loads(response.read().decode()))
|