9.4 Model Inferencing

For testing model inference, there is a separate Airflow DAG tailored specifically for conducting inference tests on a CNN SageMaker deployment. This DAG is responsible for various tasks, including verifying the status of the SageMaker endpoint, handling multiple sample images for inference, and retrieving predictions from the endpoint. The configuration of this DAG includes essential metadata and execution parameters.

9.4.1 Pipeline Workflow

Similar to the cnn_skin_cancer_workflow DAG, this code commences by importing the necessary modules and libraries required for the Airflow DAG. Furthermore, it configures specific parameters, including the skin_cancer_container_image, which represents the container image used for Kubernetes pods, and the SECRET_AWS_REGION, a secret housing AWS region information, later passed to the container as an environment variable.

import pendulum
from airflow.decorators import dag, task
from airflow.kubernetes.secret import Secret
from airflow.models import Variable

# SET PARAMETERS
skin_cancer_container_image = "seblum/cnn-skin-cancer-model:latest"  # base image for k8s pods

SECRET_AWS_REGION = Secret(
    deploy_type="env", deploy_target="AWS_REGION", secret="airflow-aws-account-information", key="AWS_REGION"
)

The Airflow DAG itself, named cnn_skin_cancer_sagemaker_inference_test, is then defined, complete with its metadata, scheduling details, and associated tasks. Within this DAG definition, there exists an inference task known as inference_call_op, which is established using the @task.kubernetes decorator. This task is responsible for conducting inference on a SageMaker endpoint, processing multiple images. It is configured with the previously defined secret and container image.

@dag(
    dag_id="cnn_skin_cancer_sagemaker_test_inference",
    default_args={
        "owner": "seblum",
        "depends_on_past": False,
        "start_date": pendulum.datetime(2021, 1, 1, tz="Europe/Amsterdam"),
        "tags": ["Inference test on CNN sagemaker deployment"],
    },
    schedule_interval=None,
    max_active_runs=1,
)
def cnn_skin_cancer_sagemaker_inference_test():
    """
    Apache Airflow DAG for testing inference on a CNN SageMaker deployment.
    """
    @task.kubernetes(
        image=skin_cancer_container_image,
        task_id="inference_call_op",
        namespace="airflow",
        in_cluster=True,
        get_logs=True,
        startup_timeout_seconds=300,
        service_account_name="airflow-sa",
        secrets=[
            SECRET_AWS_REGION,
        ],
    )
    def inference_call_op():
        """
        Perform inference on a SageMaker endpoint with multiple images.
        """
        import json

        from src.inference_to_sagemaker import (
            endpoint_status,
            get_image_directory,
            preprocess_image,
            query_endpoint,
            read_imagefile,
        )

Inside the inference_call_op task, a sequence of actions takes place, encompassing SageMaker endpoint status verification, image data preparation, image preprocessing, and the actual inference process. Since the Airflow workflow comprises only a single step, the function is called directly following its definition. Subsequently, the DAG is executed by invoking the cnn_skin_cancer_sagemaker_inference_test() function.


        sagemaker_endpoint_name = "test-cnn-skin-cancer"

        image_directoy = get_image_directory()
        print(f"Image directory: {image_directoy}")
        filenames = ["1.jpg", "10.jpg", "1003.jpg", "1005.jpg", "1007.jpg"]

        for file in filenames:
            filepath = f"{image_directoy}/{file}"
            print(f"[+] New Inference")
            print(f"[+] FilePath is {filepath}")

            # Check endpoint status
            print("[+] Endpoint Status")
            print(f"Application status is {endpoint_status(sagemaker_endpoint_name)}")

            image = read_imagefile(filepath)

            print("[+] Preprocess Data")
            np_image = preprocess_image(image)

            # Add instances fiels so np_image can be inferenced by MLflow model
            payload = json.dumps({"instances": np_image.tolist()})

            print("[+] Prediction")
            predictions = query_endpoint(app_name=sagemaker_endpoint_name, data=payload)
            print(f"Received response for {file}: {predictions}")

    inference_call_op()

cnn_skin_cancer_sagemaker_inference_test()

9.4.2 Inference Workflow Code

Collectively, these functions provide comprehensive support for testing and interacting with an Amazon SageMaker endpoint. Their functionality encompasses tasks such as image data preparation and processing, endpoint status verification, and querying the endpoint to obtain predictions or responses.

9.4.2.0.1 get_image_directory Function

This function is responsible for retrieving the absolute file path for the ‘inference_test_images’ directory, relative to the current script’s location.

def get_image_directory() -> str:
    """
    Get the file path for the 'inference_test_images' directory relative to the current script's location.

    Returns:
        str: The absolute file path to the 'inference_test_images' directory.
    """
    path = f"{Path(__file__).parent.parent}/inference_test_images"
    return path
9.4.2.0.2 read_imagefile Function

The read_imagefile function is designed to read an image file, which can be either from a file path or binary data, and return it as a PIL JpegImageFile object.

def read_imagefile(data: str) -> JpegImageFile:
    """
    Reads an image file and returns it as a PIL JpegImageFile object.

    Args:
        data (str): The file path or binary data representing the image.

    Returns:
        PIL.JpegImagePlugin.JpegImageFile: A PIL JpegImageFile object representing the image.

    Example:
        # Read an image file from a file path
        image_path = "example.jpg"
        image = read_imagefile(image_path)

        # Read an image file from binary data
        with open("example.jpg", "rb") as file:
            binary_data = file.read()
        image = read_imagefile(binary_data)
    """
    image = Image.open(data)
    return image
9.4.2.0.3 preprocess_image Function

The preprocess_image function plays a crucial role in preprocessing a JPEG image for deep learning models. It performs several operations, including converting the image to a NumPy array, scaling its values to fall within the 0 to 1 range, and reshaping it to match the expected input shape for the model.

def preprocess_image(image: JpegImageFile) -> np.array:
    """
    Preprocesses a JPEG image for deep learning models.

    Args:
        image (PIL.JpegImagePlugin.JpegImageFile): A PIL image object in JPEG format.

    Returns:
        np.ndarray: A NumPy array representing the preprocessed image.
                    The image is converted to a NumPy array with data type 'uint8',
                    scaled to values between 0 and 1, and reshaped to (1, 224, 224, 3).

    Example:
        # Load an image using PIL
        image = Image.open("example.jpg")

        # Preprocess the image
        preprocessed_image = preprocess_image(image)
    """
    np_image = np.array(image, dtype="uint8")
    np_image = np_image / 255.0
    np_image = np_image.reshape(1, 224, 224, 3)
    return np_image
9.4.2.0.4 endpoint_status Function

The endpoint_status function is responsible for checking the status of an Amazon SageMaker endpoint. It takes the app_name as input, which presumably represents the name or identifier of the endpoint.

def endpoint_status(app_name: str) -> str:
    """
    Checks the status of an Amazon SageMaker endpoint.

    Args:
        app_name (str): The name of the SageMaker endpoint to check.

    Returns:
        str: The current status of the SageMaker endpoint.

    Example:
        # Check the status of a SageMaker endpoint
        endpoint_name = "my-endpoint"
        status = endpoint_status(endpoint_name)
        print(f"Endpoint status: {status}")
    """
    AWS_REGION = os.getenv("AWS_REGION")
    sage_client = boto3.client("sagemaker", region_name=AWS_REGION)
    endpoint_description = sage_client.describe_endpoint(EndpointName=app_name)
    endpoint_status = endpoint_description["EndpointStatus"]
    return endpoint_status
9.4.2.0.5 query_endpoint Function

The query_endpoint function is responsible for querying an Amazon SageMaker endpoint using input data provided in JSON format. It then retrieves predictions or responses from the endpoint based on the provided input.

def query_endpoint(app_name: str, data: str) -> json:
    """
    Queries an Amazon SageMaker endpoint with input data and retrieves predictions.

    Args:
        app_name (str): The name of the SageMaker endpoint to query.
        data (str): Input data in JSON format to send to the endpoint.

    Returns:
        dict: The prediction or response obtained from the SageMaker endpoint.

    Example:
        # Query a SageMaker endpoint with JSON data
        endpoint_name = "my-endpoint"
        input_data = '{"feature1": 0.5, "feature2": 1.2}'
        prediction = query_endpoint(endpoint_name, input_data)
        print(f"Endpoint prediction: {prediction}")
    """
    AWS_REGION = os.getenv("AWS_REGION")
    client = boto3.session.Session().client("sagemaker-runtime", AWS_REGION)
    response = client.invoke_endpoint(
        EndpointName=app_name,
        Body=data,
        ContentType="application/json",
    )

    prediction = response["Body"].read().decode("ascii")
    prediction = json.loads(prediction)
    return prediction