pytest をつかって FastAPI アプリケーションのテストコードを実装する

プログラミング

はじめに

前回の記事では、FastAPISQLModel を使って簡単な ToDoアプリを作成しました。

この記事では、ToDoアプリのコードをテストするテストコードを pytest を使って実装していきます。

以下、使用している Python は Ver 3.10 で、必要なライブラリは次の通りです。

fastapi==0.95.1
sqlmodel==0.0.8
alembic==1.10.4
pytest==7.4.0

この記事に記載しているプログラム一式は次の GitHub で利用することができます。

GitHub - Joichiro433/Blog-fastapi-todo
Contribute to Joichiro433/Blog-fastapi-todo development by creating an account on GitHub.

フォルダ構成

ソースコードのフォルダ構成は次の通りです。前回から tests/ ディレクトリが新しく追加され、ここにアプリケーションのテストコードを配置していきます。

.
├── alembic.ini
├── cruds/
│   └── task.py
├── models/
│   └── task.py
├── routers/
│   └── task.py
├── migrations/
│   ├── env.py
│   └── script.py.mako
├── tests/
│   ├── conftest.py
│   └── test_task.py
├── db.py
├── main.py
└── settings.py

テストコードの実装

conftest.py

テストコードの前準備を行うコードを conftest.py に実装していきます。pytest では、conftest.py という名前のファイルをはじめに自動的にインポートします。そこで、複数のファイルをまたいで使用する共通のフィクスチャをここに定義することで、同じディレクトリ、またはそのサブディレクトリにある全てのテストファイルから利用することができます。

今回テストコードは test_task.py の一つのファイルに全て記載しますが、例えば複数にファイルにテストコードを分ける場合には共通する処理を conftest.py にまとめておけるので便利です。

import sys
from pathlib import Path

app_dir = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(app_dir))

import pytest
from fastapi.testclient import TestClient
from sqlmodel import Session, SQLModel, create_engine
from sqlmodel.pool import StaticPool

from db import get_session
from main import app


@pytest.fixture(name="session")
def session_fixture():
    engine = create_engine(
        url="sqlite://",
        connect_args={"check_same_thread": False}, 
        poolclass=StaticPool)
    SQLModel.metadata.create_all(engine)
    with Session(engine) as session:
        yield session


@pytest.fixture(name="client")
def client_fixture(session: Session):
    def get_session_override():
        return session
    
    app.dependency_overrides[get_session] = get_session_override
    client = TestClient(app)
    yield client
    app.dependency_overrides.clear()

コード冒頭では、モジュール検索パス(sys.path)に新たなパスを追加しています。モジュール検索パスとは、Pythonがモジュールをインポートする際に参照するディレクトリのリストのことです。

import sys
from pathlib import Path

app_dir = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(app_dir))
  • Path(__file__).resolve().parent.parent: これにより現在の2つ上のディレクトリの絶対パスを取得します。__file__ は現在のファイルのパスを表す特殊な変数で、.resolve() メソッドでこのパスの絶対パスを取得し、その後.parent.parentで2つ上のディレクトリを取得します。
  • sys.path.insert(0, str(app_dir)):ここで sys.path の先頭に app_dir のパスを追加しています。sys.path はPythonのモジュール検索パスを表すリストで、このリストのパスは順に探索されます。そのため、リストの先頭に追加したパスは最初に探索されることになります。

このフィクスチャは、SQLiteのメモリ上のデータベースに対してセッションを作成し、テストケースで使用します。

@pytest.fixture(name="session")
def session_fixture():
    engine = create_engine(
        url="sqlite://",
        connect_args={"check_same_thread": False}, 
        poolclass=StaticPool)
    SQLModel.metadata.create_all(engine)
    with Session(engine) as session:
        yield session
  1. create_engine() を使用してSQLiteのメモリ上のデータベースエンジンを作成します。接続URLにsqlite://を指定することでメモリ上のSQLiteデータベースを作成します。また、check_same_thread=Falseを指定しています。これはSQLiteがデフォルトで同一スレッドからしかアクセスを許可しないという制約を回避するためのものです。
  2. SQLModel.metadata.create_all(engine) でデータベースのテーブルを作成します。これはSQLModelのモデルクラスからテーブルの定義を読み取り、存在しないテーブルを作成します。
  3. 最後に with Session(engine) as session:でセッションを作成します。このセッションは、テストケースでデータベース操作を行うために使用します。yield session でセッションをテストケースに提供します。

このフィクスチャは他のテストケースやフィクスチャから session という名前で引用することができ、データベースへの接続を提供します。

このフィクスチャは、FastAPIのテストクライアントを生成します。引数 session により、このフィクスチャは前述の session_fixture を使用します。

@pytest.fixture(name="client")
def client_fixture(session: Session):
    def get_session_override():
        return session
    
    app.dependency_overrides[get_session] = get_session_override
    client = TestClient(app)
    yield client
    app.dependency_overrides.clear()
  1. app.dependency_overrides[get_session] = get_session_override により、get_session 関数が返すべきセッションを上書きしています。FastAPIの dependency_overrides はテストなどの目的で依存性を上書きするための仕組みです。
  2. client = TestClient(app) により、FastAPIのアプリケーションに対するテストクライントを作成します。TestClient は requestsライブラリと互換性のあるAPIを持ち、HTTPリクエストを送信するための機能を提供します。そして、yield clientにより、テストケースにテストクライントを提供します。
  3. app.dependency_overrides.clear() はクリーンアップの処理で、全ての依存性の上書きをクリアしています。これはテストケース間で依存性の上書きが影響を及ぼすことを防ぐために行っています。

test_task.py

test_task.py には ToDoアプリのテストコードを記載していきます。前回実装した API は以下の通りでした。

HTTPメソッドパス説明リクエストレスポンス
GET/api/tasksタスクを全て取得するタスクの一覧
GET/api/tasks/{id}特定idのタスクを取得するタスク
POST/api/tasksタスクを新しく登録するタスクの内容登録したタスク
PATCH/api/tasks/{id}タスクを更新するタスクの状態更新したタスク
DELETE/api/tasks/{id}タスクを削除する

これらの API のテストを満足するコードを実装します。

from fastapi.testclient import TestClient
from sqlmodel import Session

from models.task import Task


def test_create_task(client: TestClient):
    resp = client.post(
        url='/tasks/',
        json={'title': 'test task', 'done': False})
    data: dict = resp.json()

    assert resp.status_code == 200
    assert data == {'title': 'test task', 'done': False}


def test_read_tasks(session: Session, client: TestClient):
    task1 = Task(title='task1', done=False)
    task2 = Task(title='task2', done=True)
    session.add(task1)
    session.add(task2)
    session.commit()

    resp = client.get(url='/tasks/')
    data: list[dict] = resp.json()
    assert resp.status_code == 200
    assert len(data) == 2
    assert data[0] == {'title': 'task1', 'done': False, 'id': 1}
    assert data[1] == {'title': 'task2', 'done': True, 'id': 2}

    resp = client.get(
        url='/tasks/', 
        params={'done': True})
    data: list[dict] = resp.json()
    assert resp.status_code == 200
    assert len(data) == 1
    assert data[0] == {'title': 'task2', 'done': True, 'id': 2}

    resp = client.get(
        url='/tasks/', 
        params={'done': False})
    data: list[dict] = resp.json()
    assert resp.status_code == 200
    assert len(data) == 1
    assert data[0] == {'title': 'task1', 'done': False, 'id': 1}
    

def test_read_task(session: Session, client: TestClient):
    task1 = Task(title='task1', done=False)
    task2 = Task(title='task2', done=True)
    session.add(task1)
    session.add(task2)
    session.commit()

    resp = client.get(url='/tasks/1')
    data: dict = resp.json()
    assert resp.status_code == 200
    assert data == {'title': 'task1', 'done': False, 'id': 1}

    resp = client.get(url='/tasks/2')
    data: dict = resp.json()
    assert resp.status_code == 200
    assert data == {'title': 'task2', 'done': True, 'id': 2}

    resp = client.get(url='/tasks/3')
    assert resp.status_code == 404


def test_update_task(session: Session, client: TestClient):
    task1 = Task(title='task1', done=False)
    session.add(task1)
    session.commit()

    resp = client.patch(
        url='/tasks/1',
        json={'title': 'new task', 'done': True})
    data: dict = resp.json()
    assert resp.status_code == 200
    assert data == {'title': 'new task', 'done': True, 'id': 1}

    resp = client.patch(
        url='/tasks/2',
        json={'title': 'new task', 'done': True})
    assert resp.status_code == 404


def test_delete_task(session: Session, client: TestClient):
    task1 = Task(title='task1', done=False)
    session.add(task1)
    session.commit()

    resp = client.delete(url='/tasks/1')
    assert resp.status_code == 200

    resp = client.get(url='/tasks/')
    data: list[dict] = resp.json()
    assert len(data) == 0

以上で テストコード の実装は完了です。

テストを実行する

実装したテストコードを実行してみましょう。以下のコマンドを実行するだけでOKです。

$ pytest tests/
## 出力

======================================================== test session starts ========================================================
platform darwin -- Python 3.10.11, pytest-7.4.0, pluggy-1.0.0
rootdir: /xxxxx/
plugins: anyio-3.6.2
collected 5 items                                                                                                                   

tests/test_task.py .....                                                                                                      [100%]

========================================================= 5 passed in 0.05s =========================================================

全てのテストが Pass しました!

まとめ

  • FastAPI のテストコードを実装した
タイトルとURLをコピーしました