How to Add an ONNX Model
Use ONNX when you want an inference module inside the same communication-step world as mechanistic modules.
Step 1: export the model
PyTorch:
import torch
model = MyClassifier()
model.eval()
dummy_input = torch.randn(1, 10)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=["input"],
output_names=["probabilities"],
)Step 2: wrap it as a BioModule
import numpy as np
import biosimulant as biosim
class ONNXClassifier(biosim.BioModule):
def __init__(self, model_path: str = "data/assets/model.onnx"):
import onnxruntime as ort
self.session = ort.InferenceSession(model_path)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
self._input_data = None
self._prediction = None
self._emitted_at = 0.0
def inputs(self):
return {
"state_vector": biosim.SignalSpec.array(
dtype="float32",
shape=(10,),
)
}
def outputs(self):
return {
"classification": biosim.SignalSpec.scalar(dtype="int64"),
"probabilities": biosim.SignalSpec.array(dtype="float32", shape=(3,)),
}
def set_inputs(self, signals):
sig = signals.get("state_vector")
if sig is not None:
self._input_data = np.asarray(sig.value, dtype=np.float32).reshape(1, -1)
def advance_window(self, start: float, end: float) -> None:
self._emitted_at = end
if self._input_data is None:
self._prediction = None
return
result = self.session.run([self.output_name], {self.input_name: self._input_data})
self._prediction = np.asarray(result[0][0], dtype=np.float32)
def get_outputs(self):
if self._prediction is None:
return {}
class_idx = int(np.argmax(self._prediction))
return {
"classification": biosim.ScalarSignal(
source="classifier",
name="classification",
value=class_idx,
emitted_at=self._emitted_at,
spec=self.outputs()["classification"],
),
"probabilities": biosim.ArraySignal(
source="classifier",
name="probabilities",
value=self._prediction.tolist(),
emitted_at=self._emitted_at,
spec=self.outputs()["probabilities"],
),
}
def snapshot(self):
return {
"prediction": None if self._prediction is None else self._prediction.tolist(),
"emitted_at": self._emitted_at,
}
def restore(self, snapshot):
pred = snapshot.get("prediction")
self._prediction = None if pred is None else np.asarray(pred, dtype=np.float32)
self._emitted_at = float(snapshot.get("emitted_at", 0.0))Step 3: declare the manifest
schema_version: "2.0"
title: "Neuron State Classifier"
description: "ONNX classifier for neuron state vectors"
standard: onnx
biosim:
entrypoint: "src.onnx_classifier:ONNXClassifier"
init_kwargs:
model_path: "data/assets/model.onnx"
communication_step: 0.001
io:
inputs:
- name: state_vector
signal_type: array
dtype: float32
shape: [10]
outputs:
- name: classification
signal_type: scalar
dtype: int64
- name: probabilities
signal_type: array
dtype: float32
shape: [3]
onnx:
task: classification
model_file: data/assets/model.onnx
inputs:
- name: input
dtype: float32
shape: [-1, 10]
outputs:
- name: probabilities
dtype: float32
shape: [-1, 3]
runtime:
python_version: "3.12"
dependencies:
packages:
- onnxruntime>=1.16
- numpy>=1.24Step 4: compose it in a lab
models:
- package: biosimulant/neuro-mechanistic
version: 1.0.0
alias: neuron
- path: ../models/state-classifier
alias: classifier
wiring:
- from: neuron.state_vector
to:
- classifier.state_vector
runtime:
duration: 0.05
communication_step: 0.001The current kernel treats ONNX wrappers like any other module. They advance inside advance_window() and exchange signals only at communication boundaries.