File size: 3,171 Bytes
33509ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os, asyncio, httpx, websockets
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.responses import StreamingResponse, PlainTextResponse

UPSTREAM = os.getenv("UPSTREAM", "https://huggingface.co/proxy/cjerzak-policyview.hf.space").rstrip("/")
TOKEN = os.getenv("HF_TOKEN")
if not TOKEN:
    raise RuntimeError("Set HF_TOKEN as a Secret in the Space settings.")

app = FastAPI()

@app.get("/healthz")
def healthz():
    return PlainTextResponse("ok")

async def stream_bytes(resp):
    async for chunk in resp.aiter_bytes():
        yield chunk

@app.api_route("/{path:path}", methods=["GET","POST","PUT","PATCH","DELETE","OPTIONS"])
async def proxy(request: Request, path: str):
    url = f"{UPSTREAM}/{path}"
    if request.url.query:
        url += f"?{request.url.query}"
    headers = {k: v for k, v in request.headers.items()
               if k.lower() not in ("host", "content-length", "authorization")}
    headers["Authorization"] = f"Bearer {TOKEN}"
    headers["x-forwarded-host"] = request.headers.get("host", "")
    headers["x-forwarded-proto"] = request.url.scheme
    async with httpx.AsyncClient(follow_redirects=False, timeout=httpx.Timeout(60.0, connect=60.0)) as client:
        upstream = await client.request(
            request.method, url, headers=headers, content=await request.body(), stream=True
        )
        # Strip hop-by-hop headers and rewrite redirects back through this proxy
        drop = {"content-length","transfer-encoding","connection","keep-alive",
                "proxy-authenticate","proxy-authorization","te","trailers","upgrade","set-cookie"}
        out_headers = {k: v for k, v in upstream.headers.items() if k.lower() not in drop}
        loc = upstream.headers.get("location")
        if loc and loc.startswith(UPSTREAM):
            out_headers["location"] = loc.replace(UPSTREAM, "")
        return StreamingResponse(stream_bytes(upstream), status_code=upstream.status_code, headers=out_headers)

@app.websocket("/{path:path}")
async def ws_proxy(ws: WebSocket, path: str):
    await ws.accept()
    target = f"{UPSTREAM}/{path}"
    if ws.url.query:
        target += f"?{ws.url.query}"
    ws_headers = [("Authorization", f"Bearer {TOKEN}")]
    # Upgrade to wss for https upstream
    target = target.replace("https://", "wss://").replace("http://", "ws://")
    try:
        async with websockets.connect(target, extra_headers=ws_headers, origin=UPSTREAM) as ups:
            async def client_to_up():
                while True:
                    msg = await ws.receive()
                    data = msg.get("text") if "text" in msg else msg.get("bytes")
                    if data is None: break
                    await ups.send(data)
            async def up_to_client():
                while True:
                    data = await ups.recv()
                    if isinstance(data, (bytes, bytearray)):
                        await ws.send_bytes(data)
                    else:
                        await ws.send_text(data)
            await asyncio.gather(client_to_up(), up_to_client())
    except WebSocketDisconnect:
        pass
    except Exception:
        await ws.close()