自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

本文讲述了如何处理在使用自定义BERT模型进行ONNXRuntime推理时遇到的TypeError,关键在于将PyTorchTensor转换为numpy数组以适应运行要求。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

自定义 bert 在 onnxruntime 推理错误:TypeError: run(): incompatible function arguments

推理代码

    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)

    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)

错误提示

Traceback (most recent call last):
  File "/xx/workspace/model/test_onnx.py", line 90, in <module>
    res = inferencer.inference(text, img_path)
  File "/xx/workspace/model/test_onnx.py", line 58, in inference
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=toks)
  File "/xx/miniconda3/envs/py39/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

Invoked with: <onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession object at 0x7f975ded1570>, ['output'], {'input_ids': array([[ 101, 3899,  102]]), 'token_type_ids': array([[0, 0, 0]]), 'attention_mask': array([[1, 1, 1]])}, None

核心错误

TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

解决方法

核对参数

arg0: List[str]
arg1: Dict[str, object]

对应的参数

output_names=['output'], input_feed=toks

arg0=[‘output’] 参数类型正确
arg1=toks 表面看参数也正常,打印看看toks的每个值的类型

type(toks[‘input_ids’]) 输出为 <class ‘torch.Tensor’>, 实际需要输入类型为 <class ‘numpy.ndarray’>

修改代码

    # text embedding
    toks = self.tokenizer([text])
    if self.debug:
        print('toks', toks)
    
    text_input = {}
    text_input['input_ids'] = toks['input_ids'].numpy()
    text_input['token_type_ids'] = toks['token_type_ids'].numpy()
    text_input['attention_mask'] = toks['attention_mask'].numpy()
    text_embed = self.text_model_session.run(output_names=['output'], input_feed=text_input)

再次执行代码,正常运行,无报错!!

<think>好的,我现在需要回答用户关于如何正确为`onnxruntime.InferenceSession(path_onnx)`提供参数的问题。用户提到了之前关于`self.session.get_inputs()[0].name`的讨论,以及一些引用中的错误信息,比如类型不匹配和模型加载失败的问题。我需要结合这些信息,详细解释`InferenceSession`的参数使用方法,并确保回答结构清晰、易于理解。 首先,我应该回顾ONNX Runtime的官方文档,确认`InferenceSession`的参数选项。用户主要关心的是括号中的参数输入方式,也就是初始化`InferenceSession`时需要传递哪些参数,特别是`path_onnx`之外的参数。 根据之前的讨论,用户在使用自定义BERT模型时遇到了`TypeError: run(): incompatible function arguments`错误,这可能与输入节点的名称或数据格式不匹配有关。另一个引用提到导出ONNX模型时指定opset版本解决了问题,这说明模型版本兼容性也很重要。这些错误可能涉及到`InferenceSession`的参数设置,比如providers、session_options等,这些参数可能影响模型的加载和执行。 接下来,我需要拆解`InferenceSession`的参数,逐一解释每个参数的作用和常见用法。例如,`providers`参数用于指定执行计算的硬件后端(如CPU、CUDA),这会影响性能。另外,用户可能在加载模型时需要设置会话选项,比如线程数、优化级别等,这些都属于`SessionOptions`的配置。 同时,用户提到的导出模型时指定opset版本的问题,可能需要在使用`InferenceSession`时确保运行时支持的opset与模型兼容。虽然这主要是在模型导出时处理,但了解`InferenceSession`的参数如何影响模型加载也很重要。 还要考虑用户可能遇到的常见错误,比如路径错误、模型格式不正确、缺少必要的providers等,并提供相应的解决方法。例如,如果用户没有安装CUDA,却试图使用CUDA作为provider,会导致错误,这时候需要提示用户检查安装或回退到CPU。 最后,需要结合代码示例,展示不同参数的使用场景,比如设置providers、使用SessionOptions配置线程数等。同时,强调如何通过`get_inputs()`和`get_outputs()`方法动态获取输入输出信息,确保输入数据的正确性,避免之前的类型错误。 总结来说,回答需要覆盖以下内容: 1. `InferenceSession`的基本用法和必需参数。 2. 可选参数详解,如providers、session_options。 3. 常见错误及解决方法,结合用户之前的错误案例。 4. 代码示例展示不同参数配置。 5. 动态获取输入输出节点的重要性,避免硬编码名称。 现在需要确保信息准确,结构清晰,语言通俗易懂,符合用户要求的详细步骤和可靠解答。</think>### 一、`onnxruntime.InferenceSession` 的核心作用 **简单来说,这是加载 ONNX 模型并创建推理会话的核心接口**。类比为“启动一台预先设计好的机器(模型)”,而参数则是告诉机器如何启动(硬件加速、线程配置等)。 --- ### 二、参数详解(括号内参数输入方式) #### 1. **必需参数:模型路径** ```python import onnxruntime as ort # 基本用法:直接传入 ONNX 模型路径 session = ort.InferenceSession("model.onnx") ``` - **作用**:从本地文件加载 ONNX 模型。 - **注意**: - 路径必须正确(如文件不存在会报 `FileNotFoundError`)。 - 模型需符合 ONNX 格式标准(若导出错误,参考引用[^2]中提到的 `opset` 版本问题)。 #### 2. **可选参数:`providers`(硬件加速配置)** ```python # 明确指定使用 CPU 或 CUDA 加速 session = ort.InferenceSession( "model.onnx", providers=["CUDAExecutionProvider", "CPUExecutionProvider"] ) ``` - **功能**:指定模型运行的硬件后端,优先级按列表顺序。 - **常见值**: - `"CUDAExecutionProvider"`:NVIDIA GPU 加速(需安装 CUDA 版 ONNX Runtime)。 - `"CPUExecutionProvider"`:默认使用 CPU。 - **典型问题**: - 未安装 GPU 版却指定 CUDA → 报错 `Fail to initialize [CUDAExecutionProvider]`。 - **解决方法**:检查安装版本,或回退到 CPU。 #### 3. **可选参数:`session_options`(会话配置)** ```python # 配置会话选项(如线程数、优化级别) options = ort.SessionOptions() options.intra_op_num_threads = 4 # 设置计算线程数 options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL # 启用优化 session = ort.InferenceSession("model.onnx", session_options=options) ``` - **关键配置项**: - `intra_op_num_threads`:并行计算线程数(提升 CPU 性能)。 - `graph_optimization_level`:模型图优化级别(`ORT_ENABLE_BASIC`/`ORT_ENABLE_EXTENDED`/`ORT_ENABLE_ALL`)。 #### 4. **其他参数:`sess_config`(已弃用)** - **注意**:ONNX Runtime 新版本中已弃用 `sess_config`,改用 `providers` 和 `session_options`。 --- ### 三、实际应用场景与代码示例 #### 场景 1:基础用法(仅加载模型) ```python session = ort.InferenceSession("model.onnx") input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name ``` - **作用**:快速加载模型,获取输入/输出节点名称。 #### 场景 2:多硬件环境适配 ```python # 优先尝试 CUDA,失败后回退到 CPU try: session = ort.InferenceSession( "model.onnx", providers=["CUDAExecutionProvider", "CPUExecutionProvider"] ) except ort.capi.onnxruntime_pybind11_state.EPFail: session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"]) ``` - **意义**:增强代码鲁棒性,适应不同部署环境。 #### 场景 3:性能调优 ```python options = ort.SessionOptions() options.enable_profiling = True # 启用性能分析 options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL # 顺序执行模式 session = ort.InferenceSession("model.onnx", session_options=options) ``` - **适用场景**:需要分析模型推理耗时或优化执行流程时。 --- ### 四、常见错误与解决方法 #### 错误 1:`TypeError: run(): incompatible function arguments`(引用[^1]) - **原因**:输入数据格式或名称不匹配。 - **解决步骤**: 1. **动态获取输入名称**: ```python input_name = session.get_inputs()[0].name # 而非硬编码名称 ``` 2. **检查数据形状和类型**: ```python print(session.get_inputs()[0].shape) # 例如 [1, 3, 224, 224] print(session.get_inputs()[0].type) # 例如 tensor(float) ``` #### 错误 2:`ONNXRuntimeError: FAIL : Load model from ...`(引用[^2]) - **原因**:模型文件损坏或版本不兼容。 - **解决步骤**: 1. 重新导出模型并指定 `opset` 版本: ```bash python -m tf2onnx.convert --opset=13 --saved-model ./dir --output model.onnx ``` 2. 检查 ONNX 文件完整性: ```python import onnx onnx.checker.check_model("model.onnx") ``` --- ### 五、最佳实践总结 1. **始终动态获取输入/输出名称**: ```python input_names = [input.name for input in session.get_inputs()] output_names = [output.name for output in session.get_outputs()] ``` 2. **优先使用 GPU 加速**: ```python providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] ``` 3. **验证输入数据格式**: - 形状、数据类型需与 `session.get_inputs()[0].shape` 和 `type` 一致。 --- ### 相关问题 1. 如何检查 ONNX 模型的输入/输出节点信息? 2. ONNX Runtime 支持哪些硬件加速后端? 3. 模型推理时如何实现批量数据处理? [^1]: ONNX Runtime 输入数据格式与节点名称匹配问题 : ONNX 模型导出与版本兼容性注意事项
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

szZack

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值