Implementing Auth on a Websocket with FastAPI.
File this under 'technical issues that I was searching for a blog post to solve'.
I recently was tasked with adding a Websocket implementation to a FastAPI server. FastAPI supports websockets natively, so I was optimistic that this would be an easy task, but there are a number of caveats that make this tricky.
The first is authentication. In their infinite wisdom, the browser manufacturer s have deemed it imprudent to allow passing auth headers in a Websocket initialisation. This has led to a variety of workarounds, of varying levels of hackyness.
The solution that I went with seemed like the most pragmatic - it's the one that
the Kubernetes client
has gone with - to abuse the Sec-WebSocket-Protocol
header, which is allowed
in the initialization, to pass a base64 encoded bearer token.
Ie, initializing the Websocket with:
let ws = new WebSocket(
"ws://" + location.host + "/ws",
["yourprotocol", f"base64.websocket.bearer." + B64_TOKEN + '"]'
);
(We'll return to yourprotocol
later)
We need to implement the server-side handling of this, and in our API which
already handles bearer tokens, the easiest way to do this is to add an ASGI
middleware which rewrites this token out of the Sec-WebSocket-Protocol
header
into the authorization token:
import base64
from starlette.types import ASGIApp, Scope, Receive, Send
TOKEN_PREFIX = "base64.websocket.bearer."
class WebSocketProtocolBearerMiddleware:
def __init__(
self,
app: ASGIApp,
token_prefix: str = "Bearer",
protocol_prefix: str = TOKEN_PREFIX
):
self.app = app
self.token_prefix = token_prefix
self.protocol_prefix = protocol_prefix
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] == "websocket":
headers = list(scope.get("headers", []))
for k, v in headers:
if k == b"sec-websocket-protocol":
protocols = [p.strip() for p in v.decode().split(",")]
for proto in protocols:
if proto.startswith(self.protocol_prefix):
# Extract and decode base64 token
b64token = proto[len(self.protocol_prefix):]
try:
token_bytes = base64.urlsafe_b64decode(b64token + '=' * (-len(b64token) % 4))
token = token_bytes.decode("utf-8")
headers.append((b"authorization", f"{self.token_prefix} {token}".encode()))
except Exception:
# Invalid base64, ignore
pass
break
break
scope["headers"] = headers
await self.app(scope, receive, send)
Note that yourprotocol
from js earlier needs to be specified in your route
handler, ie:
"/your-endpoint/ws")
def your_endpoint_handler(
websocket: WebSocket,
your_auth: Depends(get_your_auth),
):
await websocket.accept(subprotocol='yourprotocol')
... do stuff ...
(
Assuming that this middleware is added before any auth Middleware, you should be golden.
Another common way of doing Auth in FastAPI is with a Dependency, in this case you need to make sure your dependencies handle either a Request or a Websocket being passed, which is easy if you default them to None, ie:
async def get_auth_credentials(
request: Request = None,
websocket: WebSocket = None
) -> YourAuthCredentials | None: # type: ignore
...
NB: We have to type: ignore this signature, as FastAPI dependency injection
introspects the types, and gets confused by Request | None
or Optional[Request]
.