import os, asyncio, httpx, websockets from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect from fastapi.responses import StreamingResponse, PlainTextResponse UPSTREAM = os.getenv("UPSTREAM", "https://cjerzak-policyview.hf.space").rstrip("/") TOKEN = os.getenv("HF_TOKEN") if not TOKEN: raise RuntimeError("Set HF_TOKEN as a Secret in the Space settings.") app = FastAPI() @app.get("/healthz") def healthz(): return PlainTextResponse("ok") async def stream_bytes(resp): async for chunk in resp.aiter_bytes(): yield chunk @app.api_route("/{path:path}", methods=["GET","POST","PUT","PATCH","DELETE","OPTIONS"]) async def proxy(request: Request, path: str): url = f"{UPSTREAM}/{path}" if request.url.query: url += f"?{request.url.query}" headers = {k: v for k, v in request.headers.items() if k.lower() not in ("host", "content-length", "authorization")} headers["Authorization"] = f"Bearer {TOKEN}" headers["x-forwarded-host"] = request.headers.get("host", "") headers["x-forwarded-proto"] = request.url.scheme async with httpx.AsyncClient(follow_redirects=False, timeout=httpx.Timeout(60.0, connect=60.0)) as client: upstream = await client.request( request.method, url, headers=headers, content=await request.body(), stream=True ) # Strip hop-by-hop headers and rewrite redirects back through this proxy drop = {"content-length","transfer-encoding","connection","keep-alive", "proxy-authenticate","proxy-authorization","te","trailers","upgrade","set-cookie"} out_headers = {k: v for k, v in upstream.headers.items() if k.lower() not in drop} loc = upstream.headers.get("location") if loc and loc.startswith(UPSTREAM): out_headers["location"] = loc.replace(UPSTREAM, "") return StreamingResponse(stream_bytes(upstream), status_code=upstream.status_code, headers=out_headers) @app.websocket("/{path:path}") async def ws_proxy(ws: WebSocket, path: str): await ws.accept() target = f"{UPSTREAM}/{path}" if ws.url.query: target += f"?{ws.url.query}" ws_headers = [("Authorization", f"Bearer {TOKEN}")] # Upgrade to wss for https upstream target = target.replace("https://", "wss://").replace("http://", "ws://") try: async with websockets.connect(target, extra_headers=ws_headers, origin=UPSTREAM) as ups: async def client_to_up(): while True: msg = await ws.receive() data = msg.get("text") if "text" in msg else msg.get("bytes") if data is None: break await ups.send(data) async def up_to_client(): while True: data = await ups.recv() if isinstance(data, (bytes, bytearray)): await ws.send_bytes(data) else: await ws.send_text(data) await asyncio.gather(client_to_up(), up_to_client()) except WebSocketDisconnect: pass except Exception: await ws.close()