Airflowのtaskの実行状況をmermaidのガントチャートで確認する
Airflowのtaskの実行状況をDAG横断の時系列で見たいと思ったので、
mermaidのガントチャートで確認するスクリプト書いたので、その記録を残しておきます。
動作確認したAirflowのversionは2.10.5です。
作成したスクリプト
スクリプト:
import datetime
import json
import os
import pickle
import sys
import pytz
import requests
from requests.auth import HTTPBasicAuth
class AuthedSession(object):
def __init__(self, auth):
self.auth = auth
def request(self, method, url, params):
return requests.request(method, url, params=params, auth=self.auth)
def basic_authed_airflow_client(endpoint_url, user, password):
return AirflowApiClient(endpoint_url, authed_session=AuthedSession(HTTPBasicAuth(user, password)))
class AirflowApiClient(object):
LOOP_COUNT_MAX = 100
def __init__(self, endpoint_url, authed_session=None):
self.endpoint_url = endpoint_url
self.authed_session = authed_session
self.cache_db = {}
def _call_api(self, method, endpoint, params=None, cache=True):
request_url = f"{self.endpoint_url}/{endpoint}"
params = {} if params is None else params
key = "&".join([request_url, *[f"{k}={v}" for k, v in params.items()]])
resp_obj = self.cache_db.get(key, None)
if method != "GET" or not cache or not resp_obj:
resp = self.authed_session.request(method=method, url=request_url, params=params)
if resp.status_code != 200:
print(resp.text, file=sys.stderr)
resp.raise_for_status()
resp_obj = json.loads(resp.text)
if method == "GET":
self.cache_db[key] = resp_obj
return resp_obj
def _get_all(self, entry_name, fn, url_vars, params, cache):
params = {} if params is None else params
entries = []
total_entries = -1
for _ in range(self.LOOP_COUNT_MAX):
resp = fn(
**{
"params": {"offset": len(entries), **params},
"cache": cache,
**url_vars,
}
)
total_entries = resp.get("total_entries", 0)
entries.extend(resp.get(entry_name))
if len(entries) >= total_entries:
break
if len(entries) != total_entries:
raise Exception("too many entries")
return entries
def get_version(self, cache=True):
return self._call_api("GET", "api/v1/version", cache=cache)
def get_dags(self, params=None, cache=True):
return self._call_api("GET", "api/v1/dags", params, cache)
def get_dag_runs(self, dag_id, params=None, cache=True):
return self._call_api("GET", f"api/v1/dags/{dag_id}/dagRuns", params, cache)
def get_task_instance(self, dag_id, dag_run_id, params, cache=True):
return self._call_api(
"GET",
f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances",
params,
cache,
)
def get_dags_all(self, params=None, cache=True):
return self._get_all("dags", self.get_dags, {}, params, cache)
def get_dag_runs_all(self, dag_id, params=None, cache=True):
return self._get_all("dag_runs", self.get_dag_runs, {"dag_id": dag_id}, params, cache)
def get_task_instance_all(self, dag_id, dag_run_id, params=None, cache=True):
return self._get_all(
"task_instances",
self.get_task_instance,
{"dag_id": dag_id, "dag_run_id": dag_run_id},
params,
cache,
)
def load_cache(self, filename):
if os.path.exists(filename):
with open(filename, "rb") as fp:
self.cache_db = pickle.load(fp)
def save_cache(self, filename):
with open(filename, "wb") as fp:
pickle.dump(self.cache_db, fp)
def get_dag_ti_list(cli, range_start, range_end, skip_dags=None, filter_pools=None, cache=True):
format_ymdhmsz = "%Y-%m-%dT%H:%M:%S%z"
format_ymdhmsfz = "%Y-%m-%dT%H:%M:%S.%f%z"
datetime_range = {
"start": range_start.astimezone(pytz.timezone("UTC")),
"end": range_end.astimezone(pytz.timezone("UTC")),
}
skip_dags = skip_dags if skip_dags else []
dag_run_range = {
"start_date_gte": (datetime_range["start"] - datetime.timedelta(hours=6)).strftime("%Y-%m-%dT%H:%M:%S+00:00"),
"start_date_lte": (datetime_range["end"]).strftime("%Y-%m-%dT%H:%M:%S+00:00"),
# "end_date_gte": (datetime_range["start"]).strftime("%Y-%m-%dT%H:%M:%S+00:00"),
# "end_date_lte": (datetime_range["end"] + datetime.timedelta(hours=6)).strftime("%Y-%m-%dT%H:%M:%S+00:00"),
}
dag_list = {}
try:
if cache:
cli.load_cache("api_cache_db.pickle")
for dag_id in [e["dag_id"] for e in cli.get_dags_all()]:
if dag_id in skip_dags:
continue
dag_run_list = {}
for dag_run in cli.get_dag_runs_all(dag_id, dag_run_range):
ti_list = []
for ti in cli.get_task_instance_all(dag_id, dag_run["dag_run_id"], {}):
if not ti["start_date"]:
continue
s = datetime.datetime.strptime(ti["start_date"], format_ymdhmsfz) if ti["start_date"] else None
e = datetime.datetime.strptime(ti["end_date"], format_ymdhmsfz) if ti["end_date"] else None
if e is not None:
d = (e - s).total_seconds()
else:
d = 0
if filter_pools and ti["pool"] not in filter_pools:
continue
if s < datetime_range["end"] and (e > datetime_range["start"] or e is None):
ti_list.append(
{
"name": ti["task_id"],
"start": s.astimezone(pytz.timezone("Asia/Tokyo")),
"duration": d,
"state": ti["state"],
}
)
if len(ti_list) > 0:
if dag_run["run_type"] == "manual":
e = (
datetime.datetime.strptime(dag_run["data_interval_end"], format_ymdhmsfz)
.astimezone(pytz.timezone("Asia/Tokyo"))
.strftime("%H:%M")
)
dag_run_key = f"{e} (manual)"
else:
e = (
datetime.datetime.strptime(dag_run["data_interval_end"], format_ymdhmsz)
.astimezone(pytz.timezone("Asia/Tokyo"))
.strftime("%H:%M")
)
dag_run_key = f"{e}"
dag_run_list[dag_run_key] = sorted(ti_list, key=lambda e: e["start"])
if len(dag_run_list) > 0:
dag_list[dag_id] = {e[0]: e[1] for e in sorted(dag_run_list.items(), key=lambda e: e[1][0]["start"])}
dag_list = {e[0]: e[1] for e in sorted(dag_list.items(), key=lambda e: list(e[1].values())[0][0]["start"])}
dag_ti_list = {
f"{dag_id}@{dag_run_id}": ti_list
for dag_id, dag_runs in dag_list.items()
for dag_run_id, ti_list in dag_runs.items()
}
finally:
cli.save_cache("api_cache_db.pickle")
return dag_ti_list
def convert_dag_ti_list_to_mermaid_gantt(dag_ti_list):
mermaid_str = []
mermaid_str.append("```mermaid")
mermaid_str.append("gantt")
mermaid_str.append(" title Task Instances")
mermaid_str.append(" dateFormat HH:mm:ss")
mermaid_str.append(" axisFormat %H:%M")
for section in dag_ti_list:
mermaid_str.append(f" section {section.replace('@', '<br/>@').replace(':', '#colon;')}")
for task in dag_ti_list[section]:
short_name = task["name"].split(".")[-1]
state = f" ({task['state']})" if task["state"] != "success" else ""
tag = "crit" if task["state"] != "success" else "active"
start = task["start"].strftime("%H:%M:%S")
mermaid_str.append(f" {short_name}{state}: {tag}, {start}, {task['duration']}s")
mermaid_str.append("```")
return "\n".join(mermaid_str)
if __name__ == "__main__":
range_start = datetime.datetime.strptime("2025-05-12T14:00:00+09:00", "%Y-%m-%dT%H:%M:%S%z")
range_end = datetime.datetime.strptime("2025-05-12T14:30:00+09:00", "%Y-%m-%dT%H:%M:%S%z")
cli = basic_authed_airflow_client(
"http://localhost:8080",
"<username here>",
"<password here>",
)
dag_ti_list = get_dag_ti_list(
cli,
range_start,
range_end,
cache=False,
)
with open("task_instances.md", "w") as fp:
fp.write(convert_dag_ti_list_to_mermaid_gantt(dag_ti_list))
生成したガントチャートのイメージ:
以下のようなガントチャートが生成でき、
時間帯に応じた処理の混雑状況が把握できるかと思います。
簡単な説明
大まかな流れは次の通りです。
- AirflowのRESTAPIを使って、対象期間のTaskInstanceを取得 (get_dag_ti_list関数)
- TaskInstanceをmermaidの書式で出力 (convert_dag_ti_list_to_mermaid_gantt関数)
AirflowのRESTAPIを使って、対象期間のTaskInstanceを取得
AirflowのRESTAPIは、次のドキュメントの通りです。
Airflow API (Stable) (2.10.5)
https://airflow.apache.org/docs/apache-airflow/2.10.5/stable-rest-api-ref.html
認証については、基本認証を使っています。
APIの認証方式は、次の設定で変更します。
auth_backends | Configuration | Apache Airflow
https://airflow.apache.org/docs/apache-airflow/2.10.5/configurations-ref.html#auth-backends
具体的にはairflow.cfg
を次のように書き換えています。
[api]
# auth_backends = airflow.api.auth.backend.session
auth_backends = airflow.api.auth.backend.basic_auth
# Airflow3.0.0の認証は、以下を参照してください。
# Public API | Apache Airflow
# https://airflow.apache.org/docs/apache-airflow/3.0.0/security/api.html
# Google CloudComposerの認証は、以下が参考になります。
# Airflow REST API にアクセスする | CloudComposer | Google Cloud
# https://cloud.google.com/composer/docs/composer-3/access-airflow-api?hl=ja
TaskInstanceをmermaidの書式で出力
mermaidのガントチャートの初期は、次のドキュメントの通りです。
Gantt diagrams | Mermaid
https://mermaid.js.org/syntax/gantt.html
以上。