Open In App

How to save a decision tree in ONNX format for deployment?

Last Updated : 25 Nov, 2024
Comments
Improve
Suggest changes
Like Article
Like
Report

To save a decision tree in ONNX format for deployment, you can use the skl2onnx library, which converts scikit-learn models to the ONNX format. ONNX (Open Neural Network Exchange) allows models to be deployed across different platforms and is compatible with various programming languages. Let's save a decision tree model in ONNX format with step-by-step guide:

Saving a Decision Tree in ONNX Format

The skl2onnx package provides utilities to convert scikit-learn models into the ONNX format. The conversion process involves specifying the input data types and shapes that the model expects during inference. Steps to Save a Decision Tree Model as ONNX:

Step 1: Install Required Libraries

Ensure that you have the necessary libraries installed:

pip install scikit-learn skl2onnx onnx

Step 2: Train the Decision Tree Model in scikit-learn

For this example, we’ll train a simple decision tree classifier.

Python
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Load dataset and split into training and testing sets
data = load_iris()
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, random_state=42)

# Train a decision tree model
model = DecisionTreeClassifier()
model.fit(X_train, y_train)

Step 3: Convert the Model to ONNX Format

Use skl2onnx to convert the trained model to ONNX format. We specify the initial_types parameter, which defines the input type for the model, based on the training data's shape and type.

Python
import skl2onnx
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

# Define the initial types based on the input shape
initial_type = [('input', FloatTensorType([None, X_train.shape[1]]))]

# Convert the model
onnx_model = convert_sklearn(model, initial_types=initial_type)

Step 4: Save the Model as an ONNX File

Save the ONNX model to a file, making it ready for deployment.

Python
import onnx

# Save the model
onnx.save_model(onnx_model, "decision_tree_model.onnx")

Step 5: Verify the ONNX Model

You can load and check the model to verify it was saved correctly.

Python
# Load the saved ONNX model
onnx_model = onnx.load("decision_tree_model.onnx")

# Check the model for errors
onnx.checker.check_model(onnx_model)
print("ONNX model is valid and ready for deployment.")

Output:

ONNX model is valid and ready for deployment.

Let's implement the entire code together.

Python
# Step 1: Train the Decision Tree Model in scikit-learn
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

data = load_iris()
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, random_state=42)

# Train a decision tree model
model = DecisionTreeClassifier()
model.fit(X_train, y_train)

# Step 2: Convert the Model to ONNX Format
import skl2onnx
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

# Define the initial types based on the input shape
initial_type = [('input', FloatTensorType([None, X_train.shape[1]]))]

# Convert the model
onnx_model = convert_sklearn(model, initial_types=initial_type)

# Step 3: Save the Model as an ONNX File
import onnx

# Save the model
onnx.save_model(onnx_model, "decision_tree_model.onnx")

# Step 4: Verify the ONNX Model
# Load the saved ONNX model
onnx_model = onnx.load("decision_tree_model.onnx")

# Check the model for errors
onnx.checker.check_model(onnx_model)
print("ONNX model is valid and ready for deployment.")

Saving a decision tree or any other machine learning model in ONNX format is straightforward with tools like skl2onnx. This process enhances portability and ensures that your trained models can be deployed across different platforms efficiently


Next Article

Similar Reads