How ToLibraryAdd an ONNX Model

How to Add an ONNX Model

Use ONNX when you want an inference module inside the same communication-step world as mechanistic modules.

Step 1: export the model

PyTorch:

import torch
 
model = MyClassifier()
model.eval()
dummy_input = torch.randn(1, 10)
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["input"],
    output_names=["probabilities"],
)

Step 2: wrap it as a BioModule

import numpy as np
import biosimulant as biosim
 
 
class ONNXClassifier(biosim.BioModule):
    def __init__(self, model_path: str = "data/assets/model.onnx"):
        import onnxruntime as ort
 
        self.session = ort.InferenceSession(model_path)
        self.input_name = self.session.get_inputs()[0].name
        self.output_name = self.session.get_outputs()[0].name
        self._input_data = None
        self._prediction = None
        self._emitted_at = 0.0
 
    def inputs(self):
        return {
            "state_vector": biosim.SignalSpec.array(
                dtype="float32",
                shape=(10,),
            )
        }
 
    def outputs(self):
        return {
            "classification": biosim.SignalSpec.scalar(dtype="int64"),
            "probabilities": biosim.SignalSpec.array(dtype="float32", shape=(3,)),
        }
 
    def set_inputs(self, signals):
        sig = signals.get("state_vector")
        if sig is not None:
            self._input_data = np.asarray(sig.value, dtype=np.float32).reshape(1, -1)
 
    def advance_window(self, start: float, end: float) -> None:
        self._emitted_at = end
        if self._input_data is None:
            self._prediction = None
            return
        result = self.session.run([self.output_name], {self.input_name: self._input_data})
        self._prediction = np.asarray(result[0][0], dtype=np.float32)
 
    def get_outputs(self):
        if self._prediction is None:
            return {}
        class_idx = int(np.argmax(self._prediction))
        return {
            "classification": biosim.ScalarSignal(
                source="classifier",
                name="classification",
                value=class_idx,
                emitted_at=self._emitted_at,
                spec=self.outputs()["classification"],
            ),
            "probabilities": biosim.ArraySignal(
                source="classifier",
                name="probabilities",
                value=self._prediction.tolist(),
                emitted_at=self._emitted_at,
                spec=self.outputs()["probabilities"],
            ),
        }
 
    def snapshot(self):
        return {
            "prediction": None if self._prediction is None else self._prediction.tolist(),
            "emitted_at": self._emitted_at,
        }
 
    def restore(self, snapshot):
        pred = snapshot.get("prediction")
        self._prediction = None if pred is None else np.asarray(pred, dtype=np.float32)
        self._emitted_at = float(snapshot.get("emitted_at", 0.0))

Step 3: declare the manifest

schema_version: "2.0"
title: "Neuron State Classifier"
description: "ONNX classifier for neuron state vectors"
standard: onnx
 
biosim:
  entrypoint: "src.onnx_classifier:ONNXClassifier"
  init_kwargs:
    model_path: "data/assets/model.onnx"
  communication_step: 0.001
 
io:
  inputs:
    - name: state_vector
      signal_type: array
      dtype: float32
      shape: [10]
  outputs:
    - name: classification
      signal_type: scalar
      dtype: int64
    - name: probabilities
      signal_type: array
      dtype: float32
      shape: [3]
 
onnx:
  task: classification
  model_file: data/assets/model.onnx
  inputs:
    - name: input
      dtype: float32
      shape: [-1, 10]
  outputs:
    - name: probabilities
      dtype: float32
      shape: [-1, 3]
 
runtime:
  python_version: "3.12"
  dependencies:
    packages:
      - onnxruntime>=1.16
      - numpy>=1.24

Step 4: compose it in a lab

models:
  - package: biosimulant/neuro-mechanistic
    version: 1.0.0
    alias: neuron
 
  - path: ../models/state-classifier
    alias: classifier
 
wiring:
  - from: neuron.state_vector
    to:
      - classifier.state_vector
 
runtime:
  duration: 0.05
  communication_step: 0.001

The current kernel treats ONNX wrappers like any other module. They advance inside advance_window() and exchange signals only at communication boundaries.

Next steps