Airflowのtaskの実行状況をDAG横断の時系列で見たいと思ったので、
mermaidのガントチャートで確認するスクリプト書いたので、その記録を残しておきます。

動作確認したAirflowのversionは2.10.5です。

作成したスクリプト

スクリプト:

import datetime
import json
import os
import pickle
import sys

import pytz
import requests
from requests.auth import HTTPBasicAuth


class AuthedSession(object):
    def __init__(self, auth):
        self.auth = auth

    def request(self, method, url, params):
        return requests.request(method, url, params=params, auth=self.auth)


def basic_authed_airflow_client(endpoint_url, user, password):
    return AirflowApiClient(endpoint_url, authed_session=AuthedSession(HTTPBasicAuth(user, password)))


class AirflowApiClient(object):
    LOOP_COUNT_MAX = 100

    def __init__(self, endpoint_url, authed_session=None):
        self.endpoint_url = endpoint_url
        self.authed_session = authed_session
        self.cache_db = {}

    def _call_api(self, method, endpoint, params=None, cache=True):
        request_url = f"{self.endpoint_url}/{endpoint}"
        params = {} if params is None else params
        key = "&".join([request_url, *[f"{k}={v}" for k, v in params.items()]])
        resp_obj = self.cache_db.get(key, None)
        if method != "GET" or not cache or not resp_obj:
            resp = self.authed_session.request(method=method, url=request_url, params=params)
            if resp.status_code != 200:
                print(resp.text, file=sys.stderr)
                resp.raise_for_status()
            resp_obj = json.loads(resp.text)
            if method == "GET":
                self.cache_db[key] = resp_obj
        return resp_obj

    def _get_all(self, entry_name, fn, url_vars, params, cache):
        params = {} if params is None else params
        entries = []
        total_entries = -1
        for _ in range(self.LOOP_COUNT_MAX):
            resp = fn(
                **{
                    "params": {"offset": len(entries), **params},
                    "cache": cache,
                    **url_vars,
                }
            )
            total_entries = resp.get("total_entries", 0)
            entries.extend(resp.get(entry_name))
            if len(entries) >= total_entries:
                break
        if len(entries) != total_entries:
            raise Exception("too many entries")
        return entries

    def get_version(self, cache=True):
        return self._call_api("GET", "api/v1/version", cache=cache)

    def get_dags(self, params=None, cache=True):
        return self._call_api("GET", "api/v1/dags", params, cache)

    def get_dag_runs(self, dag_id, params=None, cache=True):
        return self._call_api("GET", f"api/v1/dags/{dag_id}/dagRuns", params, cache)

    def get_task_instance(self, dag_id, dag_run_id, params, cache=True):
        return self._call_api(
            "GET",
            f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances",
            params,
            cache,
        )

    def get_dags_all(self, params=None, cache=True):
        return self._get_all("dags", self.get_dags, {}, params, cache)

    def get_dag_runs_all(self, dag_id, params=None, cache=True):
        return self._get_all("dag_runs", self.get_dag_runs, {"dag_id": dag_id}, params, cache)

    def get_task_instance_all(self, dag_id, dag_run_id, params=None, cache=True):
        return self._get_all(
            "task_instances",
            self.get_task_instance,
            {"dag_id": dag_id, "dag_run_id": dag_run_id},
            params,
            cache,
        )

    def load_cache(self, filename):
        if os.path.exists(filename):
            with open(filename, "rb") as fp:
                self.cache_db = pickle.load(fp)

    def save_cache(self, filename):
        with open(filename, "wb") as fp:
            pickle.dump(self.cache_db, fp)


def get_dag_ti_list(cli, range_start, range_end, skip_dags=None, filter_pools=None, cache=True):
    format_ymdhmsz = "%Y-%m-%dT%H:%M:%S%z"
    format_ymdhmsfz = "%Y-%m-%dT%H:%M:%S.%f%z"
    datetime_range = {
        "start": range_start.astimezone(pytz.timezone("UTC")),
        "end": range_end.astimezone(pytz.timezone("UTC")),
    }
    skip_dags = skip_dags if skip_dags else []
    dag_run_range = {
        "start_date_gte": (datetime_range["start"] - datetime.timedelta(hours=6)).strftime("%Y-%m-%dT%H:%M:%S+00:00"),
        "start_date_lte": (datetime_range["end"]).strftime("%Y-%m-%dT%H:%M:%S+00:00"),
        # "end_date_gte": (datetime_range["start"]).strftime("%Y-%m-%dT%H:%M:%S+00:00"),
        # "end_date_lte": (datetime_range["end"] + datetime.timedelta(hours=6)).strftime("%Y-%m-%dT%H:%M:%S+00:00"),
    }

    dag_list = {}
    try:
        if cache:
            cli.load_cache("api_cache_db.pickle")
        for dag_id in [e["dag_id"] for e in cli.get_dags_all()]:
            if dag_id in skip_dags:
                continue
            dag_run_list = {}
            for dag_run in cli.get_dag_runs_all(dag_id, dag_run_range):
                ti_list = []
                for ti in cli.get_task_instance_all(dag_id, dag_run["dag_run_id"], {}):
                    if not ti["start_date"]:
                        continue
                    s = datetime.datetime.strptime(ti["start_date"], format_ymdhmsfz) if ti["start_date"] else None
                    e = datetime.datetime.strptime(ti["end_date"], format_ymdhmsfz) if ti["end_date"] else None
                    if e is not None:
                        d = (e - s).total_seconds()
                    else:
                        d = 0
                    if filter_pools and ti["pool"] not in filter_pools:
                        continue
                    if s < datetime_range["end"] and (e > datetime_range["start"] or e is None):
                        ti_list.append(
                            {
                                "name": ti["task_id"],
                                "start": s.astimezone(pytz.timezone("Asia/Tokyo")),
                                "duration": d,
                                "state": ti["state"],
                            }
                        )
                if len(ti_list) > 0:
                    if dag_run["run_type"] == "manual":
                        e = (
                            datetime.datetime.strptime(dag_run["data_interval_end"], format_ymdhmsfz)
                            .astimezone(pytz.timezone("Asia/Tokyo"))
                            .strftime("%H:%M")
                        )
                        dag_run_key = f"{e} (manual)"
                    else:
                        e = (
                            datetime.datetime.strptime(dag_run["data_interval_end"], format_ymdhmsz)
                            .astimezone(pytz.timezone("Asia/Tokyo"))
                            .strftime("%H:%M")
                        )
                        dag_run_key = f"{e}"
                    dag_run_list[dag_run_key] = sorted(ti_list, key=lambda e: e["start"])
            if len(dag_run_list) > 0:
                dag_list[dag_id] = {e[0]: e[1] for e in sorted(dag_run_list.items(), key=lambda e: e[1][0]["start"])}

        dag_list = {e[0]: e[1] for e in sorted(dag_list.items(), key=lambda e: list(e[1].values())[0][0]["start"])}
        dag_ti_list = {
            f"{dag_id}@{dag_run_id}": ti_list
            for dag_id, dag_runs in dag_list.items()
            for dag_run_id, ti_list in dag_runs.items()
        }
    finally:
        cli.save_cache("api_cache_db.pickle")
    return dag_ti_list


def convert_dag_ti_list_to_mermaid_gantt(dag_ti_list):
    mermaid_str = []
    mermaid_str.append("```mermaid")
    mermaid_str.append("gantt")
    mermaid_str.append("  title Task Instances")
    mermaid_str.append("  dateFormat HH:mm:ss")
    mermaid_str.append("  axisFormat %H:%M")
    for section in dag_ti_list:
        mermaid_str.append(f"  section {section.replace('@', '<br/>@').replace(':', '#colon;')}")
        for task in dag_ti_list[section]:
            short_name = task["name"].split(".")[-1]
            state = f" ({task['state']})" if task["state"] != "success" else ""
            tag = "crit" if task["state"] != "success" else "active"
            start = task["start"].strftime("%H:%M:%S")
            mermaid_str.append(f"    {short_name}{state}: {tag}, {start}, {task['duration']}s")
    mermaid_str.append("```")
    return "\n".join(mermaid_str)


if __name__ == "__main__":
    range_start = datetime.datetime.strptime("2025-05-12T14:00:00+09:00", "%Y-%m-%dT%H:%M:%S%z")
    range_end = datetime.datetime.strptime("2025-05-12T14:30:00+09:00", "%Y-%m-%dT%H:%M:%S%z")

    cli = basic_authed_airflow_client(
        "http://localhost:8080",
        "<username here>",
        "<password here>",
    )
    dag_ti_list = get_dag_ti_list(
        cli,
        range_start,
        range_end,
        cache=False,
    )
    with open("task_instances.md", "w") as fp:
        fp.write(convert_dag_ti_list_to_mermaid_gantt(dag_ti_list))

生成したガントチャートのイメージ:

以下のようなガントチャートが生成でき、
時間帯に応じた処理の混雑状況が把握できるかと思います。

airflow_gantt01.jpg

簡単な説明

大まかな流れは次の通りです。

  1. AirflowのRESTAPIを使って、対象期間のTaskInstanceを取得 (get_dag_ti_list関数)
  2. TaskInstanceをmermaidの書式で出力 (convert_dag_ti_list_to_mermaid_gantt関数)

AirflowのRESTAPIを使って、対象期間のTaskInstanceを取得

AirflowのRESTAPIは、次のドキュメントの通りです。

Airflow API (Stable) (2.10.5)
https://airflow.apache.org/docs/apache-airflow/2.10.5/stable-rest-api-ref.html

認証については、基本認証を使っています。
APIの認証方式は、次の設定で変更します。

auth_backends | Configuration | Apache Airflow
https://airflow.apache.org/docs/apache-airflow/2.10.5/configurations-ref.html#auth-backends

具体的にはairflow.cfgを次のように書き換えています。

[api]
# auth_backends = airflow.api.auth.backend.session
auth_backends = airflow.api.auth.backend.basic_auth

# Airflow3.0.0の認証は、以下を参照してください。
# Public API | Apache Airflow
# https://airflow.apache.org/docs/apache-airflow/3.0.0/security/api.html

# Google CloudComposerの認証は、以下が参考になります。
# Airflow REST API にアクセスする | CloudComposer | Google Cloud
# https://cloud.google.com/composer/docs/composer-3/access-airflow-api?hl=ja

TaskInstanceをmermaidの書式で出力

mermaidのガントチャートの初期は、次のドキュメントの通りです。

Gantt diagrams | Mermaid
https://mermaid.js.org/syntax/gantt.html

以上。