궁금한게 많은 개발자 노트

[ FastAPI ] Middleware에서 request body 사용 불가 본문

Back End

[ FastAPI ] Middleware에서 request body 사용 불가

궁금한게 많은 개발자 2022. 10. 5. 10:13

현재 FastAPI를 사용하여 User API를 개발하고 있는 도중, Register시 들어오는 field들에 대한 validation기능 구현이 필요

schema에 지정한 UserCreate Model에서 @validator decorator를 사용하여 구현이 가능하지만, 여러 필드들에 대한 검증을 한번에 front-end로 전달할 수 없는 점이 있고 각 필드 validation의 우선순위를 임의로 지정할 수 없어 새로운 방안 필요

 

추가로, email의 경우 해당 schema의 vaildation을 거치기 전 의도하지 않은 fastapi의 RequestValidationError이 발생

해당 Error에 대해서만 exception handler를 구현하기도 애매한 상황

 

그래서 Custom Middleware를 구현하여, request가 들어올 때 한번에 각 필드들에 대한 validation을 진행하고,

front-end에 일괄적으로 전달해줄 수 있도록 구현

from collections import OrderedDict
import re
from typing import Any, Dict, Optional

import email_validator
from fastapi import Request, status
from starlette.middleware.base import (
    BaseHTTPMiddleware,
    RequestResponseEndpoint,
)
from starlette.responses import Response

from sbc_user_api.config import settings
from sbc_user_api.error import UserErrorResponse, UserException


class RegisterValidationMiddleware(BaseHTTPMiddleware):
    def validate_field(
        self, payload: Dict[str, Any]
    ) -> Optional[UserErrorResponse]:
        errors = OrderedDict()

        for field, value in payload.items():
            if isinstance(value, (str)) and not len(value):
                errors[field] = ["This field may not be blank."]
                continue

            if field == "username":
                if not re.compile(r"^[a-z]([a-z0-9]|[-_\.]){3,19}").fullmatch(
                    value
                ):
                    errors[field] = [
                        "The username consists of 4 or more and 20 or less "
                        "characters, starts with a lowercase letter, and must "
                        "consist of lowercase letters, numbers, and "
                        "special characters (-, _,.)."
                    ]
            elif field == "email":
                try:
                    email_validator.validate_email(value)
                except email_validator.EmailNotValidError:
                    errors[field] = ["Enter a valid email address."]
            elif field in ("first_name", "last_name"):
                if not re.compile(r"^[A-Za-z]+").fullmatch(value):
                    errors[field] = [f"{field} must consist only of letters."]

        if errors:
            return UserException(
                http_code=status.HTTP_400_BAD_REQUEST,
                message=errors,
            ).to_response()

    async def dispatch(
        self, request: Request, call_next: RequestResponseEndpoint
    ) -> Response:
        if (
            request.method == "POST"
            and "origin" in request.headers
            and request.headers["origin"] in settings.cors_allow_origins
            and request.url.path == f"{settings.version_prefix}/users/"
        ):
            payload: Dict[str, Any] = await request.json()
            response = self.validate_field(payload)
            if isinstance(response, UserErrorResponse):
                return response

        return await call_next(request)

해당 코드에서는 의도한 대로 적절한 필드들이 들어오지 않았을 경우의 처리는 잘 동작하였으나,

정상 동작이 진행되는 경우 다음 call_next를 호출하고 난 후 block이 걸리는 상황이 발생

 

 

구글링을 통해 알아본 결과, fastapi에서 미들웨어 구현 시 상속받아 사용하는 BaseHTTPMiddleware에서는 http만 제공하여 request headers에 대한 접근은 자유로운 반면 body에 대한 접근은 정상적으로 동작하지 않는 문제점을 발견

 

검색해본 결과로는 미리 request의 body를 load하여 접근할 경우, call_next로 들어간 이후에 내부에서 body를 사용할때 무한 wait이 걸리거나 에러가 발생 하게 되는 것 같습니다.

 

FastAPI내부에서는 message에서 body를 얻을 때, _stream_consumed라는 변수를 True로 변경하고, body를 멤버 변수인 _body에 저장하고, 이후 접근이 필요한 경우에는 _body에 있는 값을 사용하게 설계되어 있었습니다.

async def stream(self) -> typing.AsyncGenerator[bytes, None]:
        if hasattr(self, "_body"):
            yield self._body
            yield b""
            return

        if self._stream_consumed:
            raise RuntimeError("Stream consumed")

        self._stream_consumed = True
        while True:
            message = await self._receive()
            if message["type"] == "http.request":
                body = message.get("body", b"")
                if body:
                    yield body
                if not message.get("more_body", False):
                    break
            elif message["type"] == "http.disconnect":
                self._is_disconnected = True
                raise ClientDisconnect()
        yield b""

    async def body(self) -> bytes:
        if not hasattr(self, "_body"):
            chunks = []
            async for chunk in self.stream():
                chunks.append(chunk)
            self._body = b"".join(chunks)
        return self._body

async 함수인 body()를 호출하면, stream()함수 내부의 receive()를 호출하게 되고, 해당 함수는 아래 처럼 동작한다.

async def receive(self) -> "ASGIReceiveEvent":
        if self.waiting_for_100_continue and not self.transport.is_closing():
            event = h11.InformationalResponse(
                status_code=100, headers=[], reason="Continue"
            )
            output = self.conn.send(event)
            self.transport.write(output)
            self.waiting_for_100_continue = False

        if not self.disconnected and not self.response_complete:
            self.flow.resume_reading()
            await self.message_event.wait()
            self.message_event.clear()

        message: "Union[HTTPDisconnectEvent, HTTPRequestEvent]"
        if self.disconnected or self.response_complete:
            message = {"type": "http.disconnect"}
        else:
            message = {
                "type": "http.request",
                "body": self.body,
                "more_body": self.more_body,
            }
            self.body = b""

        return

uvicorn 웹서버의 receive 함수를 살펴보면 12줄에서 asyncio.Event() 객체인 message_event.wait() 로직은 message_event.set()이 되지 않으면 계속 block되는 함수입니다.

uvicorn은 client에서 더 이상 받을 data가 없는 경우, message_event.set()을 사용 할 수 없는 분기문으로 빠지고 이미 들어온 Task는 callback(response 후처리)을 통해서 비동기적으로 마무리 됩니다.

 

 

즉, 이미 생성한 커스텀 미들웨어에서 receive를 호출후 다음 ASGIapp 계층에서 receive를 다시 호출하게되면 message_event.set()이 더 이상 발생하지 않기 때문에 message_event.wait()에서 block되어 버립니다.

그래서 FastAPI(Starlette)의 Request 객체는 한번 body를 얻고나면 receive를 재호출하지 않고 저장한 _body 멤버를 재사용하게 설계하였습니다.

 

 

그렇다면 이후 ASGIapp 계층에서 request에 저장된 _body를 사용하지 않고 왜 receive를 재 호출하게 되는지 알아보자.

 

ASGIapp(ASGI 인터페이스를 따르는 객체)는 기본적으로 scope, receive, send 파라미터를 갖고 있습니다.

Fast API (Starlette)는 최초의 ASGIapp 객체이며 self.app = ASGIapp방식으로 하위에 ASGIapp 객체를 갖고 

계층(layer) 구조를 갖습니다.

 

scope : 전반적인 메타데이터를 가지고 있는 dict 구조체입니다. (header, url path, protocal type, asgi version, type ...etc...)

receive : 웹서버쪽에서 데이터(request body)를 받는 함수

send : 웹서버로 데이터(response)를 보내는 함수

 

 

 

아래에서 router에서 endpoint request에 대한 response를 반환할 때 새로운 Request객체를 할당하여 사용하는 것을 확인할 수 있습니다. 이를 보면 위에서 middlware단에서 request객체에 저장한 _body변수를  재사용 할 수 없는 이유가 설명됩니다.

def request_response(func: typing.Callable) -> ASGIApp:
    """
    Takes a function or coroutine `func(request) -> response`,
    and returns an ASGI application.
    """
    is_coroutine = is_async_callable(func)

    async def app(scope: Scope, receive: Receive, send: Send) -> None:
        request = Request(scope, receive=receive, send=send)
        if is_coroutine:
            response = await func(request)
        else:
            response = await run_in_threadpool(func, request)
        await response(scope, receive, send)

    return app

 

 

middleware단에서는 request의 body를 건드리지 않고, router단에서 validation을 확인할 수 있는 방안을 고안하려 합니다.

아래는 fastapi에서 middleware를 통과하여 request에 대한 response를 출력해주는 과정입니다.

FAST API app flow
 
 
uvicorn .run (app)
 
-> Fast API(Starlette).__call__(scope, receive, send)
-> middleware_stack(scope, receive, send)
-> app = Middleware(app=app) for app in middleware_list   # 최초 app은 APIRouter(Router)
-> app.__call__(scope, receive, send) # 미들웨어 하나씩 실행
	-> user_middleware(BaseHTTPMiddleware)가 있는 경우, 
	-> user_middleware.__call__(scope, receive, send)
	-> response = user_middleware(self).call_next(request)
	-> user_middleware의 call_next()는 send = queue.get을 통해 처리, 여기 미들웨어에서도 app을 계속 호출
 
-> 마지막은 미들웨어 스택의 app = APIRouter.__call__(scope, receive, send)
-> APIroute(route).handle(scope, receive, send)    for route in APIRouter(self).routes if route.match(path)
 
# 이건 Starlette.Route.handle
-> route(self).handle(scope, receive, send)-> route(self).app(scope, receive, send) # route(self).app = request_response(self.endpoint)
 
# Fast API의 APIRoute.handle
->  handle(scope, receive, send)-> self.app = request_response(self.get_route_handler())
get_route_handler에서 endpoint의 파라미터를 분석하여  app(request) 형태로 wrapping 함
즉 Fast API endpoint 함수의 파라미터들을 request하나만으로 call할수 있는 형태로 변경 Starlette와 호환되도록
 
 
-> request_response(func)
-> response = func(request)     # Fast API의 func = get_route_handler()
-> await response(scope, receive, send)
 
-> response.__call__(scope, receive, send)
-> 최종 웹서버로 보내는 send 실행 
	## user_middleware 사용한 경우
	-> user_middleware.call_next(request)
	-> user_middleware는 하위의 app(APIRoute)의 send는 Queue에 넣고
	-> queue.get으로 얻은 body(content)를 generator Iter로 response 생성
        -> return response = StreamingResponse(…, content=body_stream()) # call_next return!!
	-> await response(scope,receive, send)
        -> Send! (상위 미들웨어에서 물려받은 send)

'Back End' 카테고리의 다른 글

[ Spring ] JPA란?  (0) 2022.10.24
[ Spring ] Java Config, Configuration Annotation  (0) 2022.10.24
[ Python ] FastAPI - Depends  (0) 2022.06.03
SQLalchemy  (0) 2022.05.26
[ Python ] pydantic  (0) 2022.05.24
Comments