//

Implementing Auth on a Websocket with FastAPI.

File this under 'technical issues that I was searching for a blog post to solve'.

I recently was tasked with adding a Websocket implementation to a FastAPI server. FastAPI supports websockets natively, so I was optimistic that this would be an easy task, but there are a number of caveats that make this tricky.

The first is authentication. In their infinite wisdom, the browser manufacturer s have deemed it imprudent to allow passing auth headers in a Websocket initialisation. This has led to a variety of workarounds, of varying levels of hackyness.

The solution that I went with seemed like the most pragmatic - it's the one that the Kubernetes client has gone with - to abuse the Sec-WebSocket-Protocol header, which is allowed in the initialization, to pass a base64 encoded bearer token.

Ie, initializing the Websocket with:

let ws = new WebSocket(
    "ws://" + location.host + "/ws",
    ["yourprotocol", f"base64.websocket.bearer." + B64_TOKEN + '"]'
);

(We'll return to yourprotocol later)

We need to implement the server-side handling of this, and in our API which already handles bearer tokens, the easiest way to do this is to add an ASGI middleware which rewrites this token out of the Sec-WebSocket-Protocol header into the authorization token:

import base64
from starlette.types import ASGIApp, Scope, Receive, Send

TOKEN_PREFIX = "base64.websocket.bearer."

class WebSocketProtocolBearerMiddleware:
    def __init__(
        self, 
        app: ASGIApp, 
        token_prefix: str = "Bearer", 
        protocol_prefix: str = TOKEN_PREFIX
    ):
        self.app = app
        self.token_prefix = token_prefix
        self.protocol_prefix = protocol_prefix

    async def __call__(self, scope: Scope, receive: Receive, send: Send):
        if scope["type"] == "websocket":
            headers = list(scope.get("headers", []))
            for k, v in headers:
                if k == b"sec-websocket-protocol":
                    protocols = [p.strip() for p in v.decode().split(",")]
                    for proto in protocols:
                        if proto.startswith(self.protocol_prefix):
                            # Extract and decode base64 token
                            b64token = proto[len(self.protocol_prefix):]
                            try:
                                token_bytes = base64.urlsafe_b64decode(b64token + '=' * (-len(b64token) % 4))
                                token = token_bytes.decode("utf-8")
                                headers.append((b"authorization", f"{self.token_prefix} {token}".encode()))
                            except Exception:
                                # Invalid base64, ignore
                                pass
                            break
                    break
            scope["headers"] = headers
        await self.app(scope, receive, send)

Note that yourprotocol from js earlier needs to be specified in your route handler, ie:

@router.websocket("/your-endpoint/ws")
def your_endpoint_handler(
    websocket: WebSocket,
    your_auth: Depends(get_your_auth),
):
    await websocket.accept(subprotocol='yourprotocol')
    ... do stuff ...

Assuming that this middleware is added before any auth Middleware, you should be golden.

Another common way of doing Auth in FastAPI is with a Dependency, in this case you need to make sure your dependencies handle either a Request or a Websocket being passed, which is easy if you default them to None, ie:

async def get_auth_credentials(
    request: Request = None,
    websocket: WebSocket = None
) -> YourAuthCredentials | None: # type: ignore
   ...

NB: We have to type: ignore this signature, as FastAPI dependency injection introspects the types, and gets confused by Request | None or Optional[Request].