AWS Step Functionsの単体テストをローカル環境で実施したい場合、
AWSのドキュメントに「AWS Step Functions Local」を使った方法が記載されており、
AWS Step Functionsから呼び出す処理(Lamda,SQSなど)をモックする方法もあります。

ステートマシンのローカルテスト | docs.aws.amazon.com
https://docs.aws.amazon.com/ja_jp/step-functions/latest/dg/sfn-local.html

モックサービス統合の使用 | ステートマシンのローカルテスト | docs.aws.amazon.com
https://docs.aws.amazon.com/ja_jp/step-functions/latest/dg/sfn-local-test-sm-exec.html

上記のドキュメントの手順に従うと、実施の度に手動でテストを動かすことになるので、
このエントリでは、繰り返しの実施をしやすいようにpytestに組み込むことを考えます。

テスト実施環境の構成

モックサービス統合では、Mock configuration fileに、
テストケース毎にモックの挙動を記載してテストを実施する事ができます。
これらのテストケースをpytestのテストとして実施することを考えます。

ディレクトリ構造

このエントリで作成するテスト実施環境は、次のディレクトリ構成をとります。

  • functions
    • (statemachine).json … テスト対象 (State machine definition)
  • tests
    • test_unittests.py … テスト実行コード
    • testcases.json … テストケース (Mock configuration fileを拡張したもの)

テスト対象「(statemachine).json」は、
「モックサービス統合の使用」に説明がある
「State machine definition」の内容を記載したjsonファイルです。

テストケース「testcases.json」は、
「モックサービス統合の使用」に説明がある
「Mock configuration file」にAssertionの情報を追記して拡張したjsonファイルです。

テスト実行コード「test_unittests.py」は、
テストケース「testcases.json」をpytestのテストとして実行するコードです。

Mock configuration fileの拡張

テストケース「testcases.json」は、
「Mock configuration file」を拡張したものです。

「Mock configuration file」ではテストケース毎にMockの挙動を指定しますが、
ここでは、テスト実行時のAssertionをおこなうための情報を追記します。
具体的には、 階層 StateMachines/(statemachine)/TestCases/(testcase)/(state)/_ex
の配下に、次の項目を指定します。

  • input … StateMachineへの入力
  • status … StateMachineの処理ステータス SUCCEEDED,FAILED
  • output … StateMachineの処理結果

テストコードでは、
各テストケース実行時にinputの入力を与えてStateMachineを開始させ、
実施終了後にstatus,outputが期待値と一致しているかを検証します。

tests/testcases.jsonの例

{
    "StateMachines": {
        "test": {
            "TestCases": {
                "case1": {
                    "LambdaState": "MockedLambdaOK",
                    "_ex": {
                        "input": {},
                        "status": "SUCCEEDED",
                        "output": {"StatusCode": 200, "body": "Hello from Lambda!"}
                    }
                },
                "case2": {
                    "LambdaState": "MockedLambdaNG",
                    "_ex": {
                        "input": {},
                        "status": "FAILED",
                        "output": null
                    }
                }
            }
        }
    },
    "MockedResponses": {
        "MockedLambdaOK": {
            "0": {
                "Return": {
                    "StatusCode": 200,
                    "Payload": {
                        "StatusCode": 200,
                        "body": "Hello from Lambda!"
                    }
                }
            }
        },
        "MockedLambdaNG": {
            "0": {
                "Throw":{
                    "Error":"Lambda.ResourceNotReadyException",
                    "Cause":"Lambda resource is not ready."
                }
            }
        }
    }
}

テスト実施環境の作成

Java/Pythonの実行環境が構築済みで、パスが通っている状態を前提とします。
参考に、実行を確認したversionを記載しておきます。

  • Java 1.8.0_402
  • Python 3.11.6

以下のコマンドのように、
作業ディレクトリの作成、Python仮想環境と依存ライブラリのインストールを行います。

mkdir stepfunctionslocal_unittests && cd $_
python -m venv venv
. venv/bin/activate
pip install pytest boto3 requests
mkdir tests functions

以下の内容のテスト実行コードを配置します。

tests/test_unittests.py

import os
import requests
import shutil
import subprocess
import boto3
import time
import pytest
import json
import tempfile

with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "testcases.json"))) as fp:
    test_cases = json.loads(fp.read())


class StepFunctionsLocalService():
    def __init__(self, mock_file=None):
        self.cache_dir = os.path.join(os.environ.get("HOME"), ".cache", "aws_local", "stepfunctionlocal")
        self.jars_dir = os.path.join(self.cache_dir, "jars")
        self.args = [
            os.path.join(os.environ.get("JAVA_HOME"), "bin", "java") if os.environ.get("JAVA_HOME") else "java",
            "-jar", os.path.join(self.jars_dir, "StepFunctionsLocal.jar")
        ]
        self.client = boto3.client(
            'stepfunctions',
            endpoint_url="http://127.0.0.1:8083/",
            aws_access_key_id="dummy",
            aws_secret_access_key="dummy",
        )
        self.mock_file = mock_file

    def up(self):
        self.download()
        envs = {"SFN_MOCK_CONFIG": self.mock_file} if self.mock_file else {}
        self.proc = subprocess.Popen(self.args, env=envs)

    def down(self):
        if self.proc is None:
            return
        try:
            self.proc.terminate()
        except:  # noqa
            pass

    def wait_for_healthy(self):
        healthy = False
        for trial in range(5):
            if self.is_healthy():
                healthy = True
                break
            time.sleep(1)
        if not healthy:
            raise Exception(f"Step Functions Local service health check failed")

    def is_healthy(self):
        resp = self.client.list_state_machines()
        if resp:
            return True
        return False

    def download(self):
        url = 'https://s3.amazonaws.com/stepfunctionslocal/StepFunctionsLocal.zip'
        filename = 'StepFunctionsLocal.zip'
        os.makedirs(self.cache_dir, exist_ok=True)
        zip_file = os.path.join(self.cache_dir, filename)
        jars_dir = os.path.join(self.cache_dir, "jars")
        if not os.path.isfile(zip_file):
            zip_data = requests.get(url).content
            with open(zip_file, mode='wb') as fp:
                fp.write(zip_data)
        if not os.path.isdir(jars_dir):
            shutil.unpack_archive(zip_file, jars_dir)


@pytest.fixture(scope="session")
def svc():
    with tempfile.TemporaryDirectory() as tmp_dir:
        mock_file = os.path.join(tmp_dir, "mock.json")
        conf = test_cases.copy()
        for k in [
            (state_machine, test_case, key)
            for state_machine in conf["StateMachines"]
            for test_case in conf["StateMachines"][state_machine]["TestCases"]
            for key in conf["StateMachines"][state_machine]["TestCases"][test_case]
            if key.startswith("_")
        ]:
            del conf["StateMachines"][k[0]]["TestCases"][k[1]][k[2]]
        with open(mock_file, "w") as fp:
            fp.write(json.dumps(conf))
        svc = StepFunctionsLocalService(mock_file=mock_file)
        try:
            svc.up()
            svc.wait_for_healthy()
            yield svc
        finally:
            svc.down()


@pytest.mark.parametrize(
    "state_machine,test_case,input,expect_status,expect_output",
    [
        pytest.param(
            state_machine,
            test_case,
            test_cases["StateMachines"][state_machine]["TestCases"][test_case]["_ex"]["input"],
            test_cases["StateMachines"][state_machine]["TestCases"][test_case]["_ex"]["status"],
            test_cases["StateMachines"][state_machine]["TestCases"][test_case]["_ex"]["output"],
            id=f"{state_machine}_{test_case}"
        )
        for state_machine in test_cases["StateMachines"]
        for test_case in test_cases["StateMachines"][state_machine]["TestCases"]
    ]
)
def test_functions(svc,state_machine,test_case,input,expect_status,expect_output):
    cli = svc.client

    dirname = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "functions"))
    with open(os.path.join(dirname, f"{state_machine}.json")) as fp:
        func_body = fp.read()
    resp = cli.create_state_machine(
        name=state_machine,
        definition=func_body,
        roleArn="arn:aws:iam::000000000000:role/stepfunctions-role"
    )
    state_machine_arn = resp["stateMachineArn"]

    resp=cli.start_execution(
        stateMachineArn=state_machine_arn + "#" + test_case,
        input=json.dumps(input)
    )
    execution_arn = resp["executionArn"]

    status = 'RUNNING'
    for trial in range(10):
        resp = cli.describe_execution(executionArn=execution_arn)
        status = resp["status"]
        if status != 'RUNNING':
            break
        time.sleep(0.5)

    # dump history
    resp = cli.get_execution_history(executionArn=execution_arn)
    for e in resp["events"]:
        print("----")
        for k, v in e.items():
            print(k, v)

    resp = cli.describe_execution(executionArn=execution_arn)
    assert resp["status"] == expect_status
    assert json.loads(resp.get("output", "{}")).get("Payload") == expect_output

次の内容の、サンプルのテスト対象・テストケースを配置します。

tests/testcases.json

{
    "StateMachines": {
        "test": {
            "TestCases": {
                "case1": {
                    "LambdaState": "MockedLambdaOK",
                    "_ex": {
                        "input": {},
                        "status": "SUCCEEDED",
                        "output": {"StatusCode": 200, "body": "Hello from Lambda!"}
                    }
                },
                "case2": {
                    "LambdaState": "MockedLambdaNG",
                    "_ex": {
                        "input": {},
                        "status": "FAILED",
                        "output": null
                    }
                }
            }
        }
    },
    "MockedResponses": {
        "MockedLambdaOK": {
            "0": {
                "Return": {
                    "StatusCode": 200,
                    "Payload": {
                        "StatusCode": 200,
                        "body": "Hello from Lambda!"
                    }
                }
            }
        },
        "MockedLambdaNG": {
            "0": {
                "Throw":{
                    "Error":"Lambda.ResourceNotReadyException",
                    "Cause":"Lambda resource is not ready."
                }
            }
        }
    }
}

functions/test.json

{
  "Comment": "DummyFunction",
  "StartAt": "LambdaState",
  "States": {
    "LambdaState": {
      "Type": "Task",
      "Resource": "arn:aws:states:::lambda:invoke",
      "Parameters": {
        "Payload.$": "$",
        "FunctionName": "HelloWorldFunction"
      },
      "End": true
    }
  }
}

以下のコマンドで、テストを実行します。

pytest tests/

以上。