9.4 Model Inferencing
For testing model inference, there is a separate Airflow DAG tailored specifically for conducting inference tests on a CNN SageMaker deployment. This DAG is responsible for various tasks, including verifying the status of the SageMaker endpoint, handling multiple sample images for inference, and retrieving predictions from the endpoint. The configuration of this DAG includes essential metadata and execution parameters.
9.4.1 Pipeline Workflow
Similar to the cnn_skin_cancer_workflow
DAG, this code commences by importing the necessary modules and libraries required for the Airflow DAG. Furthermore, it configures specific parameters, including the skin_cancer_container_image
, which represents the container image used for Kubernetes pods, and the SECRET_AWS_REGION
, a secret housing AWS region information, later passed to the container as an environment variable.
import pendulum
from airflow.decorators import dag, task
from airflow.kubernetes.secret import Secret
from airflow.models import Variable
# SET PARAMETERS
= "seblum/cnn-skin-cancer-model:latest" # base image for k8s pods
skin_cancer_container_image
= Secret(
SECRET_AWS_REGION ="env", deploy_target="AWS_REGION", secret="airflow-aws-account-information", key="AWS_REGION"
deploy_type )
The Airflow DAG itself, named cnn_skin_cancer_sagemaker_inference_test
, is then defined, complete with its metadata, scheduling details, and associated tasks. Within this DAG definition, there exists an inference task known as inference_call_op
, which is established using the @task.kubernetes
decorator. This task is responsible for conducting inference on a SageMaker endpoint, processing multiple images. It is configured with the previously defined secret and container image.
@dag(
="cnn_skin_cancer_sagemaker_test_inference",
dag_id={
default_args"owner": "seblum",
"depends_on_past": False,
"start_date": pendulum.datetime(2021, 1, 1, tz="Europe/Amsterdam"),
"tags": ["Inference test on CNN sagemaker deployment"],
},=None,
schedule_interval=1,
max_active_runs
)def cnn_skin_cancer_sagemaker_inference_test():
"""
Apache Airflow DAG for testing inference on a CNN SageMaker deployment.
"""
@task.kubernetes(
=skin_cancer_container_image,
image="inference_call_op",
task_id="airflow",
namespace=True,
in_cluster=True,
get_logs=300,
startup_timeout_seconds="airflow-sa",
service_account_name=[
secrets
SECRET_AWS_REGION,
],
)def inference_call_op():
"""
Perform inference on a SageMaker endpoint with multiple images.
"""
import json
from src.inference_to_sagemaker import (
endpoint_status,
get_image_directory,
preprocess_image,
query_endpoint,
read_imagefile, )
Inside the inference_call_op
task, a sequence of actions takes place, encompassing SageMaker endpoint status verification, image data preparation, image preprocessing, and the actual inference process. Since the Airflow workflow comprises only a single step, the function is called directly following its definition. Subsequently, the DAG is executed by invoking the cnn_skin_cancer_sagemaker_inference_test()
function.
= "test-cnn-skin-cancer"
sagemaker_endpoint_name
= get_image_directory()
image_directoy print(f"Image directory: {image_directoy}")
= ["1.jpg", "10.jpg", "1003.jpg", "1005.jpg", "1007.jpg"]
filenames
for file in filenames:
= f"{image_directoy}/{file}"
filepath print(f"[+] New Inference")
print(f"[+] FilePath is {filepath}")
# Check endpoint status
print("[+] Endpoint Status")
print(f"Application status is {endpoint_status(sagemaker_endpoint_name)}")
= read_imagefile(filepath)
image
print("[+] Preprocess Data")
= preprocess_image(image)
np_image
# Add instances fiels so np_image can be inferenced by MLflow model
= json.dumps({"instances": np_image.tolist()})
payload
print("[+] Prediction")
= query_endpoint(app_name=sagemaker_endpoint_name, data=payload)
predictions print(f"Received response for {file}: {predictions}")
inference_call_op()
cnn_skin_cancer_sagemaker_inference_test()
9.4.2 Inference Workflow Code
Collectively, these functions provide comprehensive support for testing and interacting with an Amazon SageMaker endpoint. Their functionality encompasses tasks such as image data preparation and processing, endpoint status verification, and querying the endpoint to obtain predictions or responses.
9.4.2.0.1 get_image_directory
Function
This function is responsible for retrieving the absolute file path for the ‘inference_test_images’ directory, relative to the current script’s location.
def get_image_directory() -> str:
"""
Get the file path for the 'inference_test_images' directory relative to the current script's location.
Returns:
str: The absolute file path to the 'inference_test_images' directory.
"""
= f"{Path(__file__).parent.parent}/inference_test_images"
path return path
9.4.2.0.2 read_imagefile
Function
The read_imagefile
function is designed to read an image file, which can be either from a file path or binary data, and return it as a PIL JpegImageFile
object.
def read_imagefile(data: str) -> JpegImageFile:
"""
Reads an image file and returns it as a PIL JpegImageFile object.
Args:
data (str): The file path or binary data representing the image.
Returns:
PIL.JpegImagePlugin.JpegImageFile: A PIL JpegImageFile object representing the image.
Example:
# Read an image file from a file path
image_path = "example.jpg"
image = read_imagefile(image_path)
# Read an image file from binary data
with open("example.jpg", "rb") as file:
binary_data = file.read()
image = read_imagefile(binary_data)
"""
= Image.open(data)
image return image
9.4.2.0.3 preprocess_image
Function
The preprocess_image
function plays a crucial role in preprocessing a JPEG image for deep learning models. It performs several operations, including converting the image to a NumPy array, scaling its values to fall within the 0 to 1 range, and reshaping it to match the expected input shape for the model.
def preprocess_image(image: JpegImageFile) -> np.array:
"""
Preprocesses a JPEG image for deep learning models.
Args:
image (PIL.JpegImagePlugin.JpegImageFile): A PIL image object in JPEG format.
Returns:
np.ndarray: A NumPy array representing the preprocessed image.
The image is converted to a NumPy array with data type 'uint8',
scaled to values between 0 and 1, and reshaped to (1, 224, 224, 3).
Example:
# Load an image using PIL
image = Image.open("example.jpg")
# Preprocess the image
preprocessed_image = preprocess_image(image)
"""
= np.array(image, dtype="uint8")
np_image = np_image / 255.0
np_image = np_image.reshape(1, 224, 224, 3)
np_image return np_image
9.4.2.0.4 endpoint_status
Function
The endpoint_status
function is responsible for checking the status of an Amazon SageMaker endpoint. It takes the app_name
as input, which presumably represents the name or identifier of the endpoint.
def endpoint_status(app_name: str) -> str:
"""
Checks the status of an Amazon SageMaker endpoint.
Args:
app_name (str): The name of the SageMaker endpoint to check.
Returns:
str: The current status of the SageMaker endpoint.
Example:
# Check the status of a SageMaker endpoint
endpoint_name = "my-endpoint"
status = endpoint_status(endpoint_name)
print(f"Endpoint status: {status}")
"""
= os.getenv("AWS_REGION")
AWS_REGION = boto3.client("sagemaker", region_name=AWS_REGION)
sage_client = sage_client.describe_endpoint(EndpointName=app_name)
endpoint_description = endpoint_description["EndpointStatus"]
endpoint_status return endpoint_status
9.4.2.0.5 query_endpoint
Function
The query_endpoint
function is responsible for querying an Amazon SageMaker endpoint using input data provided in JSON format. It then retrieves predictions or responses from the endpoint based on the provided input.
def query_endpoint(app_name: str, data: str) -> json:
"""
Queries an Amazon SageMaker endpoint with input data and retrieves predictions.
Args:
app_name (str): The name of the SageMaker endpoint to query.
data (str): Input data in JSON format to send to the endpoint.
Returns:
dict: The prediction or response obtained from the SageMaker endpoint.
Example:
# Query a SageMaker endpoint with JSON data
endpoint_name = "my-endpoint"
input_data = '{"feature1": 0.5, "feature2": 1.2}'
prediction = query_endpoint(endpoint_name, input_data)
print(f"Endpoint prediction: {prediction}")
"""
= os.getenv("AWS_REGION")
AWS_REGION = boto3.session.Session().client("sagemaker-runtime", AWS_REGION)
client = client.invoke_endpoint(
response =app_name,
EndpointName=data,
Body="application/json",
ContentType
)
= response["Body"].read().decode("ascii")
prediction = json.loads(prediction)
prediction return prediction