keras_tqdm安装失败
时间: 2024-06-09 17:03:37 浏览: 163
`keras-tqdm`是一个用于Keras训练过程中的进度条的第三方库,它结合了Keras和tqdm库,提供了一个直观的训练进度跟踪。如果在安装过程中遇到问题,可能有以下几个原因:
1. **依赖缺失**:确保你已经正确安装了Python的必要库,如`numpy`, `tensorflow`(或`keras`),以及`tqdm`本身。如果缺少其中任何一个,安装`keras-tqdm`时会失败。
2. **版本冲突**:Keras和`tqdm`的版本不兼容也可能会导致问题。检查你的`pip`包列表中这两个库的版本,确保它们可以协同工作。
3. **权限问题**:如果你没有足够的权限,可能会无法写入安装文件夹。尝试使用`sudo`命令以管理员权限安装。
4. **网络问题**:在下载过程中网络连接不稳定可能导致安装失败,尝试切换到稳定的网络环境重新安装。
5. **已安装的残留**:如果之前尝试过但失败,可能会留下一些残余文件。先尝试用`pip uninstall keras-tqdm`移除已安装的版本,然后再重新安装。
6. **包源问题**:如果使用的Python包管理器的源有问题,可能需要更换一个可靠的源,比如Anaconda的`conda-forge`或PyPI的源。
相关问题
import tensorflow as tf import numpy as np import cv2 import os import json from tqdm import tqdm class ObjectRecognitionDeployer: def __init__(self, model_path, class_labels): """ 初始化部署器 :param model_path: 模型文件路径 (Keras或TFLite) :param class_labels: 类别标签列表 """ self.class_labels = class_labels self.model_path = model_path self.interpreter = None self.input_details = None self.output_details = None # 根据模型类型加载 if model_path.endswith('.tflite'): self.load_tflite_model(model_path) else: self.model = tf.keras.models.load_model(model_path) self.input_shape = self.model.input_shape[1:3] def load_tflite_model(self, model_path): """加载并配置TFLite模型""" # 加载模型 self.interpreter = tf.lite.Interpreter(model_path=model_path) self.interpreter.allocate_tensors() # 获取输入输出详细信息 self.input_details = self.interpreter.get_input_details() self.output_details = self.interpreter.get_output_details() # 保存输入形状 self.input_shape = tuple(self.input_details[0]['shape'][1:3]) # 安全地打印模型元数据 self.print_model_metadata(model_path) def print_model_metadata(self, model_path): """安全地打印TFLite模型元数据""" try: from tflite_support import metadata displayer = metadata.MetadataDisplayer.with_model_file(model_path) print("--- 模型元数据 ---") print(displayer.get_metadata_json()) print("--- 关联文件 ---") print(displayer.get_packed_associated_file_list()) except (ImportError, ValueError) as e: print(f"警告: 无法获取模型元数据 - {str(e)}") print("使用输入/输出详细信息代替:") print(f"输入: {self.input_details}") print(f"输出: {self.output_details}") def preprocess_image(self, image, input_size, input_dtype=np.float32): """ 预处理图像 :param image: 输入图像 (numpy数组或文件路径) :param input_size: 模型输入尺寸 (height, width) :param input_dtype: 期望的输入数据类型 :return: 预处理后的图像张量 """ if isinstance(image, str): if not os.path.exists(image): raise FileNotFoundError(f"图像文件不存在: {image}") img = cv2.imread(image) if img is None: raise ValueError(f"无法读取图像: {image}") else: img = image # 调整尺寸和颜色空间 img = cv2.resize(img, (input_size[1], input_size[0])) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 根据数据类型进行归一化 if input_dtype == np.uint8: img = img.astype(np.uint8) # 量化模型使用uint8 else: # 浮点模型使用float32 img = img.astype(np.float32) / 255.0 # 添加批次维度 img = np.expand_dims(img, axis=0) return img def predict(self, image): """ 执行预测 :param image: 输入图像 (numpy数组或文件路径) :return: 预测结果 (类别名称, 置信度) """ if self.interpreter is not None: # TFLite模型推理 return self.predict_tflite(image) else: # Keras模型推理 return self.predict_keras(image) def predict_keras(self, image): """使用Keras模型预测""" # 预处理 img = self.preprocess_image(image, self.input_shape, np.float32) # 预测 predictions = self.model.predict(img, verbose=0)[0] class_idx = np.argmax(predictions) confidence = predictions[class_idx] class_name = self.class_labels[class_idx] return class_name, confidence def predict_tflite(self, image): """使用TFLite模型预测""" # 获取输入数据类型 input_dtype = self.input_details[0]['dtype'] # 预处理 img = self.preprocess_image(image, self.input_shape, input_dtype) # 设置输入张量 self.interpreter.set_tensor(self.input_details[0]['index'], img) # 执行推理 self.interpreter.invoke() # 获取输出 output_data = self.interpreter.get_tensor(self.output_details[0]['index']) predictions = output_data[0] # 解析结果 class_idx = np.argmax(predictions) confidence = predictions[class_idx] # 如果输出是量化数据,需要反量化 if self.output_details[0]['dtype'] == np.uint8: # 反量化输出 scale, zero_point = self.output_details[0]['quantization'] confidence = scale * (confidence - zero_point) class_name = self.class_labels[class_idx] return class_name, confidence def benchmark(self, image, runs=100): """ 模型性能基准测试 :param image: 测试图像 :param runs: 运行次数 :return: 平均推理时间(ms), 内存占用(MB) """ # 预热运行 self.predict(image) # 计时测试 start_time = tf.timestamp() for _ in range(runs): self.predict(image) end_time = tf.timestamp() avg_time_ms = (end_time - start_time).numpy() * 1000 / runs # 内存占用 if self.interpreter: # 计算输入张量内存占用 input_size = self.input_details[0]['shape'] dtype_size = np.dtype(self.input_details[0]['dtype']).itemsize mem_usage = np.prod(input_size) * dtype_size / (1024 * 1024) else: # 估算Keras模型内存 mem_usage = self.model.count_params() * 4 / (1024 * 1024) # 假设32位浮点数 return avg_time_ms, mem_usage def create_metadata(self, output_path): """ 创建并保存模型元数据文件 :param output_path: 元数据文件输出路径 """ metadata = { "model_type": "tflite" if self.model_path.endswith('.tflite') else "keras", "class_labels": self.class_labels, "input_size": self.input_shape, "input_dtype": str(self.input_details[0]['dtype']) if self.interpreter else "float32", "quantization": None } if self.interpreter and self.input_details[0]['dtype'] == np.uint8: metadata["quantization"] = { "input_scale": float(self.input_details[0]['quantization'][0]), "input_zero_point": int(self.input_details[0]['quantization'][1]), "output_scale": float(self.output_details[0]['quantization'][0]), "output_zero_point": int(self.output_details[0]['quantization'][1]) } with open(output_path, 'w') as f: json.dump(metadata, f, indent=4) return metadata def convert_to_tflite_with_metadata(self, output_path, quantize=False, representative_data_dir=None): """ 将Keras模型转换为TFLite格式并添加元数据 :param output_path: 输出TFLite文件路径 :param quantize: 是否进行量化 :param representative_data_dir: 代表性数据集目录 """ if not self.model_path.endswith(('.keras', '.h5')): raise ValueError("需要Keras模型格式进行转换") # 加载Keras模型 keras_model = tf.keras.models.load_model(self.model_path) # 创建转换器 converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) if quantize: # 量化配置 converter.optimizations = [tf.lite.Optimize.DEFAULT] # 设置代表性数据集生成器 converter.representative_dataset = lambda: self.representative_dataset( representative_data_dir, input_size=self.input_shape ) # 设置输入输出类型 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.uint8 converter.inference_output_type = tf.uint8 # 转换模型 tflite_model = converter.convert() # 保存模型 with open(output_path, 'wb') as f: f.write(tflite_model) print(f"TFLite模型已保存到: {output_path}") # 添加元数据 self.add_tflite_metadata(output_path) return output_path def representative_dataset(self, data_dir=None, input_size=(224, 224), num_samples=100): """ 生成代表性数据集用于量化 :param data_dir: 真实数据目录 :param input_size: 输入尺寸 (height, width) :param num_samples: 样本数量 """ # 优先使用真实数据 if data_dir and os.path.exists(data_dir): image_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] # 限制样本数量 image_files = image_files[:min(len(image_files), num_samples)] print(f"使用 {len(image_files)} 张真实图像进行量化校准") for img_path in tqdm(image_files, desc="量化校准"): try: # 读取并预处理图像 img = cv2.imread(img_path) if img is None: continue img = cv2.resize(img, (input_size[1], input_size[0])) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = img.astype(np.float32) / 255.0 # 转换为float32并归一化 img = np.expand_dims(img, axis=0) yield [img] except Exception as e: print(f"处理图像 {img_path} 时出错: {str(e)}") else: # 使用随机数据作为备选 print(f"使用随机数据生成 {num_samples} 个样本进行量化校准") for _ in range(num_samples): # 生成随机图像,归一化到[0,1]范围,使用float32类型 data = np.random.rand(1, input_size[0], input_size[1], 3).astype(np.float32) yield [data] def add_tflite_metadata(self, model_path): """为TFLite模型添加元数据""" # 创建标签文件 labels_path = os.path.join(os.path.dirname(model_path), "labels.txt") with open(labels_path, 'w') as f: for label in self.class_labels: f.write(f"{label}\n") # 创建元数据 metadata_path = os.path.join(os.path.dirname(model_path), "metadata.json") self.create_metadata(metadata_path) print(f"元数据已创建: {metadata_path}") print(f"标签文件已创建: {labels_path}") # 使用示例 if __name__ == "__main__": # 类别标签 CLASS_LABELS = ['book', 'cup', 'glasses', 'phone', 'shoe'] # 初始化部署器 deployer = ObjectRecognitionDeployer( model_path='optimized_model.keras', class_labels=CLASS_LABELS ) # 转换为带元数据的TFLite格式 tflite_path = 'model_quantized.tflite' # 使用真实数据目录进行量化校准 REPRESENTATIVE_DATA_DIR = 'path/to/representative_dataset' # 替换为实际路径 deployer.convert_to_tflite_with_metadata( tflite_path, quantize=True, representative_data_dir=REPRESENTATIVE_DATA_DIR ) # 重新加载带元数据的模型 tflite_deployer = ObjectRecognitionDeployer( model_path=tflite_path, class_labels=CLASS_LABELS ) # 测试预测 test_image = 'test_image.jpg' class_name, confidence = tflite_deployer.predict(test_image) print(f"预测结果: {class_name}, 置信度: {confidence:.2f}") # 性能测试 avg_time, mem_usage = tflite_deployer.benchmark(test_image) print(f"平均推理时间: {avg_time:.2f} ms") print(f"内存占用: {mem_usage:.2f} MB") # 创建元数据文件 metadata = deployer.create_metadata('model_metadata.json') print("模型元数据:", json.dumps(metadata, indent=4)) import tensorflow as tf import numpy as np import cv2 import os import json from tqdm import tqdm class ObjectRecognitionDeployer: def __init__(self, model_path, class_labels): """ 初始化部署器 :param model_path: 模型文件路径 (Keras或TFLite) :param class_labels: 类别标签列表 """ self.class_labels = class_labels self.model_path = model_path self.interpreter = None self.input_details = None self.output_details = None # 根据模型类型加载 if model_path.endswith('.tflite'): self.load_tflite_model(model_path) else: self.model = tf.keras.models.load_model(model_path) self.input_shape = self.model.input_shape[1:3] def load_tflite_model(self, model_path): """加载并配置TFLite模型""" # 加载模型 self.interpreter = tf.lite.Interpreter(model_path=model_path) self.interpreter.allocate_tensors() # 获取输入输出详细信息 self.input_details = self.interpreter.get_input_details() self.output_details = self.interpreter.get_output_details() # 保存输入形状 self.input_shape = tuple(self.input_details[0]['shape'][1:3]) # 安全地打印模型元数据 self.print_model_metadata(model_path) def print_model_metadata(self, model_path): """安全地打印TFLite模型元数据""" try: from tflite_support import metadata displayer = metadata.MetadataDisplayer.with_model_file(model_path) print("--- 模型元数据 ---") print(displayer.get_metadata_json()) print("--- 关联文件 ---") print(displayer.get_packed_associated_file_list()) except (ImportError, ValueError) as e: print(f"警告: 无法获取模型元数据 - {str(e)}") print("使用输入/输出详细信息代替:") print(f"输入: {self.input_details}") print(f"输出: {self.output_details}") def preprocess_image(self, image, input_size, input_dtype=np.float32): """ 预处理图像 :param image: 输入图像 (numpy数组或文件路径) :param input_size: 模型输入尺寸 (height, width) :param input_dtype: 期望的输入数据类型 :return: 预处理后的图像张量 """ if isinstance(image, str): if not os.path.exists(image): raise FileNotFoundError(f"图像文件不存在: {image}") img = cv2.imread(image) if img is None: raise ValueError(f"无法读取图像: {image}") else: img = image # 调整尺寸和颜色空间 img = cv2.resize(img, (input_size[1], input_size[0])) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 根据数据类型进行归一化 if input_dtype == np.uint8: img = img.astype(np.uint8) # 量化模型使用uint8 else: # 浮点模型使用float32 img = img.astype(np.float32) / 255.0 # 添加批次维度 img = np.expand_dims(img, axis=0) return img def predict(self, image): """ 执行预测 :param image: 输入图像 (numpy数组或文件路径) :return: 预测结果 (类别名称, 置信度) """ if self.interpreter is not None: # TFLite模型推理 return self.predict_tflite(image) else: # Keras模型推理 return self.predict_keras(image) def predict_keras(self, image): """使用Keras模型预测""" # 预处理 img = self.preprocess_image(image, self.input_shape, np.float32) # 预测 predictions = self.model.predict(img, verbose=0)[0] class_idx = np.argmax(predictions) confidence = predictions[class_idx] class_name = self.class_labels[class_idx] return class_name, confidence def predict_tflite(self, image): """使用TFLite模型预测""" # 获取输入数据类型 input_dtype = self.input_details[0]['dtype'] # 预处理 img = self.preprocess_image(image, self.input_shape, input_dtype) # 设置输入张量 self.interpreter.set_tensor(self.input_details[0]['index'], img) # 执行推理 self.interpreter.invoke() # 获取输出 output_data = self.interpreter.get_tensor(self.output_details[0]['index']) predictions = output_data[0] # 解析结果 class_idx = np.argmax(predictions) confidence = predictions[class_idx] # 如果输出是量化数据,需要反量化 if self.output_details[0]['dtype'] == np.uint8: # 反量化输出 scale, zero_point = self.output_details[0]['quantization'] confidence = scale * (confidence - zero_point) class_name = self.class_labels[class_idx] return class_name, confidence def benchmark(self, image, runs=100): """ 模型性能基准测试 :param image: 测试图像 :param runs: 运行次数 :return: 平均推理时间(ms), 内存占用(MB) """ # 预热运行 self.predict(image) # 计时测试 start_time = tf.timestamp() for _ in range(runs): self.predict(image) end_time = tf.timestamp() avg_time_ms = (end_time - start_time).numpy() * 1000 / runs # 内存占用 if self.interpreter: # 计算输入张量内存占用 input_size = self.input_details[0]['shape'] dtype_size = np.dtype(self.input_details[0]['dtype']).itemsize mem_usage = np.prod(input_size) * dtype_size / (1024 * 1024) else: # 估算Keras模型内存 mem_usage = self.model.count_params() * 4 / (1024 * 1024) # 假设32位浮点数 return avg_time_ms, mem_usage def create_metadata(self, output_path): """ 创建并保存模型元数据文件 :param output_path: 元数据文件输出路径 """ metadata = { "model_type": "tflite" if self.model_path.endswith('.tflite') else "keras", "class_labels": self.class_labels, "input_size": self.input_shape, "input_dtype": str(self.input_details[0]['dtype']) if self.interpreter else "float32", "quantization": None } if self.interpreter and self.input_details[0]['dtype'] == np.uint8: metadata["quantization"] = { "input_scale": float(self.input_details[0]['quantization'][0]), "input_zero_point": int(self.input_details[0]['quantization'][1]), "output_scale": float(self.output_details[0]['quantization'][0]), "output_zero_point": int(self.output_details[0]['quantization'][1]) } with open(output_path, 'w') as f: json.dump(metadata, f, indent=4) return metadata def convert_to_tflite_with_metadata(self, output_path, quantize=False, representative_data_dir=None): """ 将Keras模型转换为TFLite格式并添加元数据 :param output_path: 输出TFLite文件路径 :param quantize: 是否进行量化 :param representative_data_dir: 代表性数据集目录 """ if not self.model_path.endswith(('.keras', '.h5')): raise ValueError("需要Keras模型格式进行转换") # 加载Keras模型 keras_model = tf.keras.models.load_model(self.model_path) # 创建转换器 converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) if quantize: # 量化配置 converter.optimizations = [tf.lite.Optimize.DEFAULT] # 设置代表性数据集生成器 converter.representative_dataset = lambda: self.representative_dataset( representative_data_dir, input_size=self.input_shape ) # 设置输入输出类型 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.uint8 converter.inference_output_type = tf.uint8 # 转换模型 tflite_model = converter.convert() # 保存模型 with open(output_path, 'wb') as f: f.write(tflite_model) print(f"TFLite模型已保存到: {output_path}") # 添加元数据 self.add_tflite_metadata(output_path) return output_path def representative_dataset(self, data_dir=None, input_size=(224, 224), num_samples=100): """ 生成代表性数据集用于量化 :param data_dir: 真实数据目录 :param input_size: 输入尺寸 (height, width) :param num_samples: 样本数量 """ # 优先使用真实数据 if data_dir and os.path.exists(data_dir): image_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] # 限制样本数量 image_files = image_files[:min(len(image_files), num_samples)] print(f"使用 {len(image_files)} 张真实图像进行量化校准") for img_path in tqdm(image_files, desc="量化校准"): try: # 读取并预处理图像 img = cv2.imread(img_path) if img is None: continue img = cv2.resize(img, (input_size[1], input_size[0])) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = img.astype(np.float32) / 255.0 # 转换为float32并归一化 img = np.expand_dims(img, axis=0) yield [img] except Exception as e: print(f"处理图像 {img_path} 时出错: {str(e)}") else: # 使用随机数据作为备选 print(f"使用随机数据生成 {num_samples} 个样本进行量化校准") for _ in range(num_samples): # 生成随机图像,归一化到[0,1]范围,使用float32类型 data = np.random.rand(1, input_size[0], input_size[1], 3).astype(np.float32) yield [data] def add_tflite_metadata(self, model_path): """为TFLite模型添加元数据""" # 创建标签文件 labels_path = os.path.join(os.path.dirname(model_path), "labels.txt") with open(labels_path, 'w') as f: for label in self.class_labels: f.write(f"{label}\n") # 创建元数据 metadata_path = os.path.join(os.path.dirname(model_path), "metadata.json") self.create_metadata(metadata_path) print(f"元数据已创建: {metadata_path}") print(f"标签文件已创建: {labels_path}") # 使用示例 if __name__ == "__main__": # 类别标签 CLASS_LABELS = ['book', 'cup', 'glasses', 'phone', 'shoe'] # 初始化部署器 deployer = ObjectRecognitionDeployer( model_path='optimized_model.keras', class_labels=CLASS_LABELS ) # 转换为带元数据的TFLite格式 tflite_path = 'model_quantized.tflite' # 使用真实数据目录进行量化校准 REPRESENTATIVE_DATA_DIR = 'path/to/representative_dataset' # 替换为实际路径 deployer.convert_to_tflite_with_metadata( tflite_path, quantize=True, representative_data_dir=REPRESENTATIVE_DATA_DIR ) # 重新加载带元数据的模型 tflite_deployer = ObjectRecognitionDeployer( model_path=tflite_path, class_labels=CLASS_LABELS ) # 测试预测 test_image = 'test_image.jpg' class_name, confidence = tflite_deployer.predict(test_image) print(f"预测结果: {class_name}, 置信度: {confidence:.2f}") # 性能测试 avg_time, mem_usage = tflite_deployer.benchmark(test_image) print(f"平均推理时间: {avg_time:.2f} ms") print(f"内存占用: {mem_usage:.2f} MB") # 创建元数据文件 metadata = deployer.create_metadata('model_metadata.json') print("模型元数据:", json.dumps(metadata, indent=4)) 上述代码我已经成功执行,并且我的ObjectRecognitionDeployer类路径导入代码是from 计算机视觉.test2 import ObjectRecognitionDeployer
<think>我们注意到用户已经成功执行了代码,并且ObjectRecognitionDeployer类的导入路径是`from计算机视觉.test2importObjectRecognitionDeployer`。现在,用户希望将之前的TkinterGUI应用转换为PyQt5,并且已经提供了模型部署器的实现。在之前的PyQt5GUI代码中,我们假设用户已经有一个ObjectRecognitionDeployer类。现在用户已经确认了导入路径。因此,我们需要将PyQt5GUI代码中的模型部署器导入路径修改为用户提供的路径,并确保代码能够正确运行。另外,用户还希望模型路径和类别标签的配置能够自动适应当前文件夹。我们已经在前面的回答中提供了动态配置模型路径的方法。现在,我们将整合所有信息,提供一个完整的PyQt5GUI应用代码,该代码使用用户指定的ObjectRecognitionDeployer类,并动态配置模型路径。注意:由于用户已经确认了ObjectRecognitionDeployer的导入路径,我们将在GUI代码中正确导入。以下是完整的PyQt5GUI应用代码,其中:1.导入ObjectRecognitionDeployer的路径为:`from计算机视觉.test2importObjectRecognitionDeployer`2.动态配置模型路径(在当前文件夹中查找模型文件)3.提供详细的错误处理和用户反馈我们假设模型文件名为"model_quantized.tflite",但代码中会尝试自动查找其他常见格式的模型文件。代码将分为两个文件(可选):-主文件:包含GUI应用-模型部署器文件:即用户已经实现的`test2.py`(位于`计算机视觉`包中)但为了方便,我们可以将所有代码放在一个文件中(如果用户愿意),但考虑到用户已经有一个部署器类,我们将主GUI代码单独写。因此,我们提供以下代码:注意:确保你的工作目录下有`计算机视觉`包,且包中有`test2.py`文件,其中定义了`ObjectRecognitionDeployer`类。如果目录结构不同,请调整导入路径。以下是完整的PyQt5GUI代码:</think># 完整的PyQt5物体识别系统集成方案
基于您提供的`ObjectRecognitionDeployer`类,我将实现一个完整的PyQt5物体识别系统,包含动态模型加载、性能监控和用户友好的界面。
```python
import sys
import os
import cv2
import numpy as np
import time
import psutil
import json
from datetime import datetime
from PyQt5.QtWidgets import (
QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QLabel,
QPushButton, QFileDialog, QFrame, QProgressBar, QMessageBox, QGroupBox,
QComboBox, QTabWidget, QTableWidget, QTableWidgetItem, QHeaderView,
QListWidget, QListWidgetItem, QSplitter, QStatusBar
)
from PyQt5.QtGui import QPixmap, QImage, QFont, QPalette, QColor, QIcon
from PyQt5.QtCore import Qt, QSize, QTimer, QThread, pyqtSignal
# 导入您的模型部署器
from 计算机视觉.test2 import ObjectRecognitionDeployer
class PerformanceMonitor(QThread):
"""后台线程监控系统性能"""
update_signal = pyqtSignal(float, float, float) # CPU, RAM, GPU
def run(self):
"""持续监控系统性能"""
while True:
try:
# CPU使用率
cpu_percent = psutil.cpu_percent(interval=1)
# 内存使用率
mem = psutil.virtual_memory()
ram_percent = mem.percent
ram_used = mem.used / (1024 ** 3) # GB
# GPU监控 (需要额外库如pynvml)
gpu_percent = 0.0
try:
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
gpu_percent = util.gpu
pynvml.nvmlShutdown()
except ImportError:
pass
except pynvml.NVMLError:
pass
# 发送信号
self.update_signal.emit(cpu_percent, ram_percent, gpu_percent)
# 睡眠以避免过高频率
time.sleep(1)
except Exception as e:
print(f"性能监控错误: {str(e)}")
time.sleep(5)
class ObjectRecognitionGUI(QMainWindow):
def __init__(self):
"""
初始化物体识别GUI应用
"""
super().__init__()
# 窗口设置
self.setWindowTitle("智能物体识别系统")
self.setGeometry(100, 100, 1400, 900)
self.setWindowIcon(QIcon("icon.png")) # 替换为实际图标路径
# 创建状态栏
self.statusBar().setFont(QFont("Arial", 9))
self.statusBar().showMessage("就绪")
# 配置模型路径和类别标签
self.configure_model_paths()
# 初始化UI
self.init_ui()
# 初始化模型
self.init_model()
# 当前选择的图像
self.current_image = None
self.image_path = None
# 性能统计
self.inference_times = []
self.performance_history = []
# 更新内存使用定时器
self.memory_timer = QTimer(self)
self.memory_timer.timeout.connect(self.update_memory_usage)
self.memory_timer.start(1000) # 每秒更新一次
# 启动性能监控线程
self.performance_monitor = PerformanceMonitor()
self.performance_monitor.update_signal.connect(self.update_performance_stats)
self.performance_monitor.start()
def configure_model_paths(self):
"""
配置模型路径和类别标签
使用当前文件夹作为基础路径
"""
# 获取当前脚本所在目录
base_dir = os.path.dirname(os.path.abspath(__file__))
# 模型路径 - 假设模型文件名为 "model_quantized.tflite"
self.model_path = os.path.join(base_dir, "model_quantized.tflite")
# 类别标签 - 从文件加载或使用默认值
self.class_labels = self.load_class_labels(base_dir)
# 检查模型文件是否存在
self.model_exists = os.path.exists(self.model_path)
if not self.model_exists:
# 尝试其他可能的模型文件名
possible_names = ["model.tflite", "model.h5", "model.pb", "model.onnx"]
for name in possible_names:
test_path = os.path.join(base_dir, name)
if os.path.exists(test_path):
self.model_path = test_path
self.model_exists = True
break
def load_class_labels(self, base_dir):
"""从文件加载类别标签"""
# 尝试从labels.txt加载
labels_path = os.path.join(base_dir, "labels.txt")
if os.path.exists(labels_path):
try:
with open(labels_path, 'r') as f:
return [line.strip() for line in f.readlines()]
except Exception as e:
print(f"加载标签文件失败: {str(e)}")
# 尝试从metadata.json加载
metadata_path = os.path.join(base_dir, "metadata.json")
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r') as f:
metadata = json.load(f)
return metadata.get("class_labels", [])
except Exception as e:
print(f"加载元数据失败: {str(e)}")
# 默认类别标签
return ['book', 'cup', 'glasses', 'phone', 'shoe']
def init_ui(self):
"""初始化用户界面"""
# 创建中央部件
central_widget = QWidget()
self.setCentralWidget(central_widget)
# 主布局
main_layout = QHBoxLayout(central_widget)
main_layout.setSpacing(15)
main_layout.setContentsMargins(15, 15, 15, 15)
# 使用分割器
splitter = QSplitter(Qt.Horizontal)
# 左侧面板 - 图像显示
left_frame = QFrame()
left_frame.setFrameShape(QFrame.StyledPanel)
left_layout = QVBoxLayout(left_frame)
# 图像标题
image_title = QLabel("图像预览")
image_title.setFont(QFont("Arial", 12, QFont.Bold))
left_layout.addWidget(image_title)
# 图像显示区域
self.image_label = QLabel()
self.image_label.setAlignment(Qt.AlignCenter)
self.image_label.setMinimumSize(600, 400)
self.image_label.setStyleSheet("""
background-color: #f0f0f0;
border: 1px solid #cccccc;
border-radius: 5px;
""")
left_layout.addWidget(self.image_label, 1) # 设置拉伸因子
# 图像路径显示
self.image_path_label = QLabel("未选择图像")
self.image_path_label.setStyleSheet("color: #666666; font-style: italic;")
left_layout.addWidget(self.image_path_label)
splitter.addWidget(left_frame)
# 右侧面板 - 控制和结果
right_frame = QFrame()
right_layout = QVBoxLayout(right_frame)
right_layout.setSpacing(15)
# 标签页控件
self.tab_widget = QTabWidget()
self.tab_widget.setFont(QFont("Arial", 10))
# 识别标签页
recognition_tab = QWidget()
self.init_recognition_tab(recognition_tab)
self.tab_widget.addTab(recognition_tab, "物体识别")
# 性能标签页
performance_tab = QWidget()
self.init_performance_tab(performance_tab)
self.tab_widget.addTab(performance_tab, "性能监控")
# 模型标签页
model_tab = QWidget()
self.init_model_tab(model_tab)
self.tab_widget.addTab(model_tab, "模型管理")
right_layout.addWidget(self.tab_widget)
splitter.addWidget(right_frame)
# 设置分割比例
splitter.setSizes([700, 500])
main_layout.addWidget(splitter)
# 添加状态栏信息
model_status = "已加载" if self.model_exists else "未找到"
self.statusBar().addPermanentWidget(QLabel(f"模型状态: {model_status}"))
self.statusBar().addPermanentWidget(QLabel(f"类别数: {len(self.class_labels)}"))
# 设置初始标签页
self.tab_widget.setCurrentIndex(0)
def init_recognition_tab(self, tab):
"""初始化识别标签页"""
layout = QVBoxLayout(tab)
layout.setSpacing(10)
# 控制面板
control_group = QGroupBox("控制面板")
control_layout = QVBoxLayout(control_group)
# 模型选择
model_layout = QHBoxLayout()
model_layout.addWidget(QLabel("当前模型:"))
self.model_label = QLabel(os.path.basename(self.model_path))
self.model_label.setStyleSheet("color: #3d85c6;")
model_layout.addWidget(self.model_label)
model_layout.addStretch()
# 重新加载模型按钮
self.btn_reload = QPushButton("重新加载模型")
self.btn_reload.setFont(QFont("Arial", 9))
self.btn_reload.setStyleSheet("background-color: #e0e0e0; padding: 5px;")
self.btn_reload.clicked.connect(self.reload_model)
model_layout.addWidget(self.btn_reload)
control_layout.addLayout(model_layout)
# 按钮行
button_layout = QHBoxLayout()
# 选择图像按钮
self.btn_select = QPushButton("选择图像")
self.btn_select.setFont(QFont("Arial", 10))
self.btn_select.setStyleSheet("""
QPushButton {
background-color: #4a86e8;
color: white;
border-radius: 5px;
padding: 8px;
}
QPushButton:hover {
background-color: #3a76d8;
}
""")
self.btn_select.clicked.connect(self.select_image)
button_layout.addWidget(self.btn_select)
# 预测按钮
self.btn_predict = QPushButton("运行预测")
self.btn_predict.setFont(QFont("Arial", 10))
self.btn_predict.setStyleSheet("""
QPushButton {
background-color: #6aa84f;
color: white;
border-radius: 5px;
padding: 8px;
}
QPushButton:hover {
background-color: #5a983f;
}
""")
self.btn_predict.clicked.connect(self.run_prediction)
button_layout.addWidget(self.btn_predict)
# 性能测试按钮
self.btn_benchmark = QPushButton("性能测试")
self.btn_benchmark.setFont(QFont("Arial", 10))
self.btn_benchmark.setStyleSheet("""
QPushButton {
background-color: #e69138;
color: white;
border-radius: 5px;
padding: 8px;
}
QPushButton:hover {
background-color: #d68128;
}
""")
self.btn_benchmark.clicked.connect(self.run_benchmark)
button_layout.addWidget(self.btn_benchmark)
control_layout.addLayout(button_layout)
layout.addWidget(control_group)
# 结果面板
result_group = QGroupBox("预测结果")
result_layout = QVBoxLayout(result_group)
# 类别标签
class_layout = QHBoxLayout()
class_layout.addWidget(QLabel("识别类别:"))
self.class_label = QLabel("")
self.class_label.setFont(QFont("Arial", 12, QFont.Bold))
self.class_label.setStyleSheet("color: #3d85c6;")
class_layout.addWidget(self.class_label)
class_layout.addStretch()
result_layout.addLayout(class_layout)
# 置信度
conf_layout = QHBoxLayout()
conf_layout.addWidget(QLabel("置信度:"))
self.confidence_label = QLabel("")
self.confidence_label.setStyleSheet("color: #6aa84f;")
conf_layout.addWidget(self.confidence_label)
conf_layout.addStretch()
result_layout.addLayout(conf_layout)
# 类别概率分布
prob_group = QGroupBox("类别概率分布")
prob_layout = QVBoxLayout(prob_group)
# 进度条容器
self.progress_bars = {}
for label in self.class_labels:
label_layout = QHBoxLayout()
# 标签
lbl_widget = QLabel(label)
lbl_widget.setFixedWidth(100)
label_layout.addWidget(lbl_widget)
# 进度条
pb = QProgressBar()
pb.setRange(0, 100)
pb.setValue(0)
pb.setFormat("%p%")
pb.setStyleSheet(self.get_progressbar_style(0))
pb.setFixedHeight(25)
label_layout.addWidget(pb, 1) # 设置拉伸因子为1
# 百分比标签
percent_label = QLabel("0%")
percent_label.setFixedWidth(50)
percent_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
label_layout.addWidget(percent_label)
# 存储引用
self.progress_bars[label] = {
'progress': pb,
'percent': percent_label
}
prob_layout.addLayout(label_layout)
result_layout.addWidget(prob_group)
layout.addWidget(result_group, 1) # 设置拉伸因子
def init_performance_tab(self, tab):
"""初始化性能监控标签页"""
layout = QVBoxLayout(tab)
# 实时性能指标
perf_group = QGroupBox("实时性能")
perf_layout = QGridLayout(perf_group)
# CPU使用率
cpu_layout = QVBoxLayout()
cpu_layout.addWidget(QLabel("CPU使用率"))
self.cpu_progress = QProgressBar()
self.cpu_progress.setRange(0, 100)
self.cpu_progress.setValue(0)
self.cpu_progress.setFormat("%p%")
self.cpu_progress.setStyleSheet("""
QProgressBar::chunk {
background-color: #4caf50;
}
""")
cpu_layout.addWidget(self.cpu_progress)
perf_layout.addLayout(cpu_layout, 0, 0)
# 内存使用率
ram_layout = QVBoxLayout()
ram_layout.addWidget(QLabel("内存使用率"))
self.ram_progress = QProgressBar()
self.ram_progress.setRange(0, 100)
self.ram_progress.setValue(0)
self.ram_progress.setFormat("%p%")
self.ram_progress.setStyleSheet("""
QProgressBar::chunk {
background-color: #2196f3;
}
""")
ram_layout.addWidget(self.ram_progress)
perf_layout.addLayout(ram_layout, 0, 1)
# GPU使用率
gpu_layout = QVBoxLayout()
gpu_layout.addWidget(QLabel("GPU使用率"))
self.gpu_progress = QProgressBar()
self.gpu_progress.setRange(0, 100)
self.gpu_progress.setValue(0)
self.gpu_progress.setFormat("%p%")
self.gpu_progress.setStyleSheet("""
QProgressBar::chunk {
background-color: #ff9800;
}
""")
gpu_layout.addWidget(self.gpu_progress)
perf_layout.addLayout(gpu_layout, 0, 2)
# 模型性能指标
model_perf_group = QGroupBox("模型性能指标")
model_perf_layout = QVBoxLayout(model_perf_group)
# 创建表格
self.performance_table = QTableWidget()
self.performance_table.setColumnCount(4)
self.performance_table.setHorizontalHeaderLabels(["时间", "推理时间(ms)", "内存占用(MB)", "置信度"])
self.performance_table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
self.performance_table.setEditTriggers(QTableWidget.NoEditTriggers)
model_perf_layout.addWidget(self.performance_table)
# 添加布局
layout.addWidget(perf_group)
layout.addWidget(model_perf_group, 1) # 设置拉伸因子
def init_model_tab(self, tab):
"""初始化模型管理标签页"""
layout = QVBoxLayout(tab)
# 当前模型信息
model_info_group = QGroupBox("当前模型信息")
model_info_layout = QVBoxLayout(model_info_group)
# 模型路径
path_layout = QHBoxLayout()
path_layout.addWidget(QLabel("模型路径:"))
self.model_path_label = QLabel(self.model_path)
self.model_path_label.setStyleSheet("color: #666666;")
path_layout.addWidget(self.model_path_label)
model_info_layout.addLayout(path_layout)
# 模型类型
type_layout = QHBoxLayout()
type_layout.addWidget(QLabel("模型类型:"))
self.model_type_label = QLabel("TFLite" if self.model_path.endswith('.tflite') else "Keras")
type_layout.addWidget(self.model_type_label)
model_info_layout.addLayout(type_layout)
# 输入尺寸
input_layout = QHBoxLayout()
input_layout.addWidget(QLabel("输入尺寸:"))
self.input_size_label = QLabel("N/A")
input_layout.addWidget(self.input_size_label)
model_info_layout.addLayout(input_layout)
# 类别标签
label_layout = QHBoxLayout()
label_layout.addWidget(QLabel("类别标签:"))
self.label_list = QListWidget()
self.label_list.addItems(self.class_labels)
label_layout.addWidget(self.label_list)
model_info_layout.addLayout(label_layout)
layout.addWidget(model_info_group)
# 模型操作
model_ops_group = QGroupBox("模型操作")
model_ops_layout = QVBoxLayout(model_ops_group)
# 加载新模型按钮
self.btn_load_new = QPushButton("加载新模型")
self.btn_load_new.setFont(QFont("Arial", 10))
self.btn_load_new.setStyleSheet("background-color: #4a86e8; color: white; padding: 8px;")
self.btn_load_new.clicked.connect(self.load_new_model)
model_ops_layout.addWidget(self.btn_load_new)
# 转换模型按钮
self.btn_convert = QPushButton("转换为TFLite格式")
self.btn_convert.setFont(QFont("Arial", 10))
self.btn_convert.setStyleSheet("background-color: #6aa84f; color: white; padding: 8px;")
self.btn_convert.clicked.connect(self.convert_model)
model_ops_layout.addWidget(self.btn_convert)
# 导出元数据按钮
self.btn_export_meta = QPushButton("导出元数据")
self.btn_export_meta.setFont(QFont("Arial", 10))
self.btn_export_meta.setStyleSheet("background-color: #e69138; color: white; padding: 8px;")
self.btn_export_meta.clicked.connect(self.export_metadata)
model_ops_layout.addWidget(self.btn_export_meta)
layout.addWidget(model_ops_group)
def init_model(self):
"""初始化模型"""
try:
# 检查模型文件是否存在
if not self.model_exists:
raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
# 加载模型
self.deployer = ObjectRecognitionDeployer(self.model_path, self.class_labels)
# 更新UI显示
self.input_size_label.setText(f"{self.deployer.input_shape[0]}×{self.deployer.input_shape[1]}")
# 显示状态信息
self.statusBar().showMessage(f"模型加载成功: {os.path.basename(self.model_path)}", 5000)
except Exception as e:
# 提供更详细的错误信息
error_msg = (
f"无法加载模型: {str(e)}\n\n"
f"模型路径: {self.model_path}\n"
f"当前目录: {os.path.dirname(os.path.abspath(__file__))}\n\n"
"请确保:\n"
"1. 模型文件存在于当前目录\n"
"2. 模型文件名正确\n"
"3. 模型格式兼容"
)
QMessageBox.critical(self, "模型加载错误", error_msg)
# 创建空部署器以避免崩溃
self.deployer = None
def reload_model(self):
"""重新加载模型"""
# 重新配置路径
self.configure_model_paths()
# 重新初始化模型
self.init_model()
# 更新UI
self.model_label.setText(os.path.basename(self.model_path))
self.model_path_label.setText(self.model_path)
self.model_type_label.setText("TFLite" if self.model_path.endswith('.tflite') else "Keras")
# 更新标签列表
self.label_list.clear()
self.label_list.addItems(self.class_labels)
# 更新进度条
self.progress_bars = {}
# 这里需要重新初始化进度条UI,为简洁起见省略具体实现
# 以下方法保持不变(select_image, display_image, run_prediction, update_probability_bars,
# get_progressbar_style, run_benchmark, update_memory_usage)...
def update_performance_stats(self, cpu_percent, ram_percent, gpu_percent):
"""更新性能统计"""
self.cpu_progress.setValue(int(cpu_percent))
self.ram_progress.setValue(int(ram_percent))
self.gpu_progress.setValue(int(gpu_percent))
# 记录性能历史
self.performance_history.append({
"time": datetime.now().strftime("%H:%M:%S"),
"cpu": cpu_percent,
"ram": ram_percent,
"gpu": gpu_percent
})
# 保留最近100条记录
if len(self.performance_history) > 100:
self.performance_history.pop(0)
def add_performance_record(self, inference_time, mem_usage, confidence):
"""添加性能记录到表格"""
row_position = self.performance_table.rowCount()
self.performance_table.insertRow(row_position)
# 添加时间
time_item = QTableWidgetItem(datetime.now().strftime("%H:%M:%S"))
self.performance_table.setItem(row_position, 0, time_item)
# 添加推理时间
time_item = QTableWidgetItem(f"{inference_time:.2f}")
time_item.setTextAlignment(Qt.AlignCenter)
self.performance_table.setItem(row_position, 1, time_item)
# 添加内存占用
mem_item = QTableWidgetItem(f"{mem_usage:.2f}")
mem_item.setTextAlignment(Qt.AlignCenter)
self.performance_table.setItem(row_position, 2, mem_item)
# 添加置信度
conf_item = QTableWidgetItem(f"{confidence:.2%}")
conf_item.setTextAlignment(Qt.AlignCenter)
self.performance_table.setItem(row_position, 3, conf_item)
# 滚动到最后一行
self.performance_table.scrollToBottom()
def load_new_model(self):
"""加载新模型"""
file_path, _ = QFileDialog.getOpenFileName(
self,
"选择模型文件",
"",
"模型文件 (*.tflite *.h5 *.keras *.pb *.onnx)"
)
if file_path:
# 更新模型路径
self.model_path = file_path
# 重新加载模型
self.reload_model()
def convert_model(self):
"""将模型转换为TFLite格式"""
if not self.deployer:
QMessageBox.warning(self, "警告", "请先加载有效的模型")
return
if not hasattr(self.deployer, 'convert_to_tflite_with_metadata'):
QMessageBox.warning(self, "警告", "当前模型不支持转换")
return
# 选择保存路径
save_path, _ = QFileDialog.getSaveFileName(
self,
"保存TFLite模型",
"",
"TFLite模型 (*.tflite)"
)
if not save_path:
return
# 选择代表性数据集目录
data_dir = QFileDialog.getExistingDirectory(
self,
"选择代表性数据集目录",
""
)
# 显示进度对话框
progress_dialog = QMessageBox(self)
progress_dialog.setWindowTitle("模型转换中")
progress_dialog.setText("正在转换模型,请稍候...")
progress_dialog.setStandardButtons(QMessageBox.NoButton)
progress_dialog.show()
QApplication.processEvents() # 确保UI更新
try:
# 执行转换
self.deployer.convert_to_tflite_with_metadata(
save_path,
quantize=True,
representative_data_dir=data_dir if data_dir else None
)
# 更新模型路径
self.model_path = save_path
self.reload_model()
QMessageBox.information(self, "转换成功", f"模型已成功转换为TFLite格式并保存到:\n{save_path}")
except Exception as e:
QMessageBox.critical(self, "转换失败", f"模型转换失败: {str(e)}")
finally:
progress_dialog.done(0)
def export_metadata(self):
"""导出模型元数据"""
if not self.deployer:
QMessageBox.warning(self, "警告", "请先加载有效的模型")
return
# 选择保存路径
save_path, _ = QFileDialog.getSaveFileName(
self,
"保存元数据文件",
"",
"JSON文件 (*.json)"
)
if not save_path:
return
try:
# 创建元数据
metadata = self.deployer.create_metadata(save_path)
QMessageBox.information(self, "导出成功", f"模型元数据已成功导出到:\n{save_path}")
except Exception as e:
QMessageBox.critical(self, "导出失败", f"元数据导出失败: {str(e)}")
def closeEvent(self, event):
"""窗口关闭事件处理"""
# 停止定时器和线程
self.memory_timer.stop()
self.performance_monitor.quit()
self.performance_monitor.wait()
event.accept()
# 主程序入口
if __name__ == "__main__":
# 创建应用
app = QApplication(sys.argv)
# 设置应用样式
app.setStyle("Fusion")
# 创建主窗口
window = ObjectRecognitionGUI()
window.show()
# 运行主循环
sys.exit(app.exec_())
```
## 系统功能详解
### 1. 多标签页界面设计
- **物体识别标签页**:核心功能,包括图像选择、模型预测和结果展示
- **性能监控标签页**:实时显示系统资源使用情况和模型性能指标
- **模型管理标签页**:提供模型加载、转换和元数据导出功能
### 2. 动态模型管理
- **自动检测模型**:自动查找当前目录下的模型文件
- **模型重新加载**:支持不重启应用重新加载模型
- **模型转换**:将Keras模型转换为TFLite格式
- **元数据导出**:导出模型的元数据信息
### 3. 实时性能监控
- **系统资源监控**:实时显示CPU、内存和GPU使用率
- **模型性能记录**:记录每次推理的时间、内存占用和置信度
- **后台线程监控**:使用独立线程监控系统性能
### 4. 用户友好界面
- **状态栏信息**:显示模型状态和类别数
- **进度条颜色编码**:根据概率值动态改变进度条颜色
- **详细错误处理**:提供清晰的错误信息和解决建议
- **响应式布局**:使用分割器实现灵活的界面布局
## 使用说明
### 1. 系统依赖
```bash
pip install pyqt5 opencv-python numpy pillow psutil
# 可选:GPU监控支持
pip install pynvml
```
### 2. 文件结构
```
物体识别系统/
├── main.py # 主程序文件
├── model_quantized.tflite # 模型文件
├── labels.txt # 类别标签文件
├── metadata.json # 模型元数据
├── icon.png # 应用图标
└── 计算机视觉/ # 模型部署器包
└── test2.py # ObjectRecognitionDeployer实现
```
### 3. 启动应用
```bash
python main.py
```
### 4. 主要操作流程
1. **选择图像**:点击"选择图像"按钮加载测试图像
2. **运行预测**:点击"运行预测"进行物体识别
3. **性能测试**:点击"性能测试"评估模型性能
4. **模型管理**:
- 在"模型管理"标签页加载新模型
- 转换模型格式为TFLite
- 导出模型元数据
## 高级功能实现
### 1. 后台性能监控
使用`QThread`实现后台性能监控,避免阻塞主界面:
```python
class PerformanceMonitor(QThread):
update_signal = pyqtSignal(float, float, float) # CPU, RAM, GPU
def run(self):
while True:
# 获取系统性能数据
cpu_percent = psutil.cpu_percent()
# ...其他指标...
self.update_signal.emit(cpu_percent, ram_percent, gpu_percent)
time.sleep(1)
```
### 2. 模型热更新
不重启应用即可加载新模型:
```python
def reload_model(self):
"""重新加载模型"""
self.configure_model_paths() # 重新配置路径
self.init_model() # 重新初始化模型
# 更新UI...
```
### 3. 模型转换功能
集成`ObjectRecognitionDeployer`的转换方法:
```python
def convert_model(self):
"""将模型转换为TFLite格式"""
# 获取保存路径和数据目录
save_path = ...
data_dir = ...
# 执行转换
self.deployer.convert_to_tflite_with_metadata(
save_path,
quantize=True,
representative_data_dir=data_dir
)
```
## 性能优化建议
1. **异步处理**:对于耗时的模型推理操作,使用`QThread`避免界面冻结
2. **图像缓存**:缓存预处理后的图像,减少重复计算
3. **模型预热**:在应用启动时执行一次推理,提前初始化模型
4. **资源限制**:对于大模型,限制同时进行的推理数量
TypeError: in user code: TypeError: outer_factory.<locals>.inner_factory.<locals>.tf__combine_features() takes 1 positional argument but 2 were given,下面是这个出现这个问题的原代码,你帮我修改一下import os import re import glob import tensorflow as tf import numpy as np from tqdm import tqdm import matplotlib.pyplot as plt import matplotlib as mpl from sklearn.model_selection import train_test_split import imageio import sys from skimage.transform import resize import traceback from tensorflow.keras import layers, models, Input, Model from tensorflow.keras.optimizers import Adam from pathlib import Path from tensorflow.keras.losses import MeanSquaredError from tensorflow.keras.metrics import MeanAbsoluteError from skimage import measure, morphology, filters # =============== 配置参数===================================== BASE_DIR = "F:/2025.7.2wavelengthtiff" # 根目录路径 START_WAVELENGTH = 788.55500 # 起始波长 END_WAVELENGTH = 788.55600 # 结束波长 STEP = 0.00005 # 波长步长 BATCH_SIZE = 8 # 批处理大小 IMAGE_SIZE = (256, 256) # 图像尺寸 TARGET_CHANNELS = 1 # 目标通道数 - 使用灰度图像 TEST_SIZE = 0.2 # 测试集比例 RANDOM_SEED = 42 # 随机种子 MODEL_SAVE_PATH = Path.home() / "Documents" / "wavelength_model.keras" # 修改为.keras格式 # ================================================================ # 设置中文字体支持 try: mpl.rcParams['font.sans-serif'] = ['SimHei'] # 使用黑体 mpl.rcParams['axes.unicode_minus'] = False # 解决负号显示问题 print("已设置中文字体支持") except: print("警告:无法设置中文字体,图表可能无法正确显示中文") def generate_folder_names(start, end, step): """生成波长文件夹名称列表""" num_folders = int((end - start) / step) + 1 folder_names = [] for i in range(num_folders): wavelength = start + i * step folder_name = f"{wavelength:.5f}" folder_names.append(folder_name) return folder_names def find_best_match_file(folder_path, target_wavelength): """在文件夹中找到波长最接近目标值的TIFF文件""" tiff_files = glob.glob(os.path.join(folder_path, "*.tiff")) + glob.glob(os.path.join(folder_path, "*.tif")) if not tiff_files: return None best_match = None min_diff = float('inf') for file_path in tiff_files: filename = os.path.basename(file_path) match = re.search(r'\s*([\d.]+)_', filename) if not match: continue try: file_wavelength = float(match.group(1)) diff = abs(file_wavelength - target_wavelength) if diff < min_diff: min_diff = diff best_match = file_path except ValueError: continue return best_match def extract_shape_features(binary_image): """从二值化图像中提取形状和边界特征""" features = np.zeros(6, dtype=np.float32) # 初始化特征向量 try: contours = measure.find_contours(binary_image, 0.5) if not contours: return features main_contour = max(contours, key=len) contour_length = len(main_contour) label_image = morphology.label(binary_image) region = measure.regionprops(label_image)[0] contour_area = region.area hull = morphology.convex_hull_image(label_image) hull_area = np.sum(hull) solidity = region.solidity eccentricity = region.eccentricity orientation = region.orientation features[0] = contour_length / 1000 # 归一化 features[1] = contour_area / 1000 # 归一化 features[2] = solidity features[3] = eccentricity features[4] = orientation features[5] = hull_area / 1000 # 凸包面积 except Exception as e: print(f"形状特征提取错误: {e}") traceback.print_exc() return features def load_and_preprocess_image(file_path): """加载并预处理TIFF图像""" try: image = imageio.imread(file_path) if image.dtype == np.uint16: image = image.astype(np.float32) / 65535.0 elif image.dtype == np.uint8: image = image.astype(np.float32) / 255.0 else: image = image.astype(np.float32) if np.max(image) > 1.0: image = image / np.max(image) if len(image.shape) == 3 and image.shape[2] > 1: image = 0.299 * image[:, :, 0] + 0.587 * image[:, :, 1] + 0.114 * image[:, :, 2] image = np.expand_dims(image, axis=-1) image = resize(image, (IMAGE_SIZE[0], IMAGE_SIZE[1]), anti_aliasing=True) blurred = filters.gaussian(image[..., 0], sigma=1.0) thresh = filters.threshold_otsu(blurred) binary = blurred > thresh * 0.8 return image, binary except Exception as e: print(f"图像加载失败: {e}, 使用空白图像代替") return np.zeros((IMAGE_SIZE[0], IMAGE_SIZE[1], 1), dtype=np.float32), np.zeros((IMAGE_SIZE[0], IMAGE_SIZE[1]), dtype=np.bool) def create_tiff_dataset(file_paths): """从文件路径列表创建TensorFlow数据集""" dataset = tf.data.Dataset.from_tensor_slices(file_paths) def load_wrapper(file_path): file_path_str = file_path.numpy().decode('utf-8') image, binary = load_and_preprocess_image(file_path_str) features = extract_shape_features(binary) return image, features def tf_load_wrapper(file_path): result = tf.py_function( func=load_wrapper, inp=[file_path], Tout=(tf.float32, tf.float32) ) image = result[0] features = result[1] image.set_shape((IMAGE_SIZE[0], IMAGE_SIZE[1], 1)) # 单通道 features.set_shape((6,)) # 6个形状特征 return image, features dataset = dataset.map(tf_load_wrapper, num_parallel_calls=tf.data.experimental.AUTOTUNE) dataset = dataset.batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE) return dataset def load_and_prepare_data(): """加载所有数据并准备训练/测试集""" folder_names = generate_folder_names(START_WAVELENGTH, END_WAVELENGTH, STEP) print(f"\n生成的文件夹数量: {len(folder_names)}") print(f"起始文件夹: {folder_names[0]}") print(f"结束文件夹: {folder_names[-1]}") valid_files = [] wavelengths = [] print("\n扫描文件夹并匹配文件...") for folder_name in tqdm(folder_names, desc="处理文件夹"): folder_path = os.path.join(BASE_DIR, folder_name) if not os.path.isdir(folder_path): continue try: target_wavelength = float(folder_name) file_path = find_best_match_file(folder_path, target_wavelength) if file_path: valid_files.append(file_path) wavelengths.append(target_wavelength) except ValueError: continue print(f"\n找到的有效文件: {len(valid_files)}/{len(folder_names)}") if not valid_files: raise ValueError("未找到任何有效文件,请检查路径和文件夹名称") wavelengths = np.array(wavelengths) min_wavelength = np.min(wavelengths) max_wavelength = np.max(wavelengths) wavelength_range = max_wavelength - min_wavelength wavelengths_normalized = (wavelengths - min_wavelength) / wavelength_range print(f"波长范围: {min_wavelength:.6f} 到 {max_wavelength:.6f}, 范围大小: {wavelength_range:.6f}") train_files, test_files, train_wavelengths, test_wavelengths = train_test_split( valid_files, wavelengths_normalized, test_size=TEST_SIZE, random_state=RANDOM_SEED ) print(f"训练集大小: {len(train_files)}") print(f"测试集大小: {len(test_files)}") train_dataset = create_tiff_dataset(train_files) test_dataset = create_tiff_dataset(test_files) train_labels = tf.data.Dataset.from_tensor_slices(train_wavelengths) test_labels = tf.data.Dataset.from_tensor_slices(test_wavelengths) # 修复后的 combine_features 函数 def combine_features(data): """将图像特征和标签组合成模型需要的格式""" image_features, label = data image, shape_features = image_features return (image, shape_features), label train_dataset_unet = tf.data.Dataset.zip((train_dataset, train_labels)).map( combine_features, num_parallel_calls=tf.data.experimental.AUTOTUNE ) test_dataset_unet = tf.data.Dataset.zip((test_dataset, test_labels)).map( combine_features, num_parallel_calls=tf.data.experimental.AUTOTUNE ) train_dataset_cnn_dense = tf.data.Dataset.zip((train_dataset.map(lambda x: x[0]), train_labels)) test_dataset_cnn_dense = tf.data.Dataset.zip((test_dataset.map(lambda x: x[0]), test_labels)) return train_dataset_unet, test_dataset_unet, train_dataset_cnn_dense, test_dataset_cnn_dense, valid_files, min_wavelength, wavelength_range def build_unet_model(input_shape, shape_feature_size): """构建 U-Net 模型,同时接受图像输入和形状特征输入""" # 图像输入 image_input = Input(shape=input_shape, name='image_input') # Encoder conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(image_input) conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv1) pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1) conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(pool1) conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv2) pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2) conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(pool2) conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv3) pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3) # Bottleneck conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(pool3) conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv4) drop4 = layers.Dropout(0.5)(conv4) # Decoder up5 = layers.Conv2D(128, (2, 2), activation='relu', padding='same')(layers.UpSampling2D(size=(2, 2))(drop4)) merge5 = layers.Concatenate()([conv3, up5]) conv5 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(merge5) conv5 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv5) up6 = layers.Conv2D(64, (2, 2), activation='relu', padding='same')(layers.UpSampling2D(size=(2, 2))(conv5)) merge6 = layers.Concatenate()([conv2, up6]) conv6 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(merge6) conv6 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv6) up7 = layers.Conv2D(32, (2, 2), activation='relu', padding='same')(layers.UpSampling2D(size=(2, 2))(conv6)) merge7 = layers.Concatenate()([conv1, up7]) conv7 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(merge7) conv7 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv7) # 形状特征输入 shape_input = Input(shape=(shape_feature_size,), name='shape_input') shape_dense = layers.Dense(128, activation='relu')(shape_input) shape_dense = layers.Dense(64, activation='relu')(shape_dense) # 合并图像特征和形状特征 flat = layers.GlobalAveragePooling2D()(conv7) combined = layers.Concatenate()([flat, shape_dense]) # 输出层 outputs = layers.Dense(1, activation='linear')(combined) model = Model(inputs=[image_input, shape_input], outputs=outputs) model.compile(optimizer=Adam(learning_rate=1e-4), loss='mean_squared_error', metrics=['mae']) return model def build_dense_model(input_shape): """构建简单的全连接网络""" inputs = Input(shape=input_shape) x = layers.Flatten()(inputs) x = layers.Dense(128, activation='relu')(x) x = layers.Dense(64, activation='relu')(x) outputs = layers.Dense(1, activation='linear')(x) model = Model(inputs=[inputs], outputs=[outputs]) model.compile(optimizer=Adam(learning_rate=1e-4), loss='mean_squared_error', metrics=['mae']) return model def build_cnn_model(input_shape): """构建传统的 CNN 模型""" inputs = Input(shape=input_shape) x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs) x = layers.MaxPooling2D((2, 2))(x) x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x) x = layers.MaxPooling2D((2, 2))(x) x = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x) x = layers.MaxPooling2D((2, 2))(x) x = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x) x = layers.GlobalAveragePooling2D()(x) x = layers.Dense(256, activation='relu')(x) outputs = layers.Dense(1, activation='linear')(x) model = Model(inputs=[inputs], outputs=[outputs]) model.compile(optimizer=Adam(learning_rate=1e-4), loss='mean_squared_error', metrics=['mae']) return model def train_models(train_dataset_unet, test_dataset_unet, train_dataset_cnn_dense, test_dataset_cnn_dense, input_shape, shape_feature_size, wavelength_range): """训练多个模型""" unet_model = build_unet_model(input_shape, shape_feature_size) dense_model = build_dense_model(input_shape) cnn_model = build_cnn_model(input_shape) unet_model.summary() dense_model.summary() cnn_model.summary() callbacks = [ tf.keras.callbacks.EarlyStopping(patience=30, restore_best_weights=True, monitor='val_loss', min_delta=1e-6), tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=10, min_lr=1e-7) ] unet_history = unet_model.fit(train_dataset_unet, epochs=300, validation_data=test_dataset_unet, callbacks=callbacks, verbose=2) dense_history = dense_model.fit(train_dataset_cnn_dense, epochs=300, validation_data=test_dataset_cnn_dense, callbacks=callbacks, verbose=2) cnn_history = cnn_model.fit(train_dataset_cnn_dense, epochs=300, validation_data=test_dataset_cnn_dense, callbacks=callbacks, verbose=2) return unet_model, dense_model, cnn_model def predict_with_voting(models, test_image_path, min_wavelength, wavelength_range): """使用多个模型进行预测,并通过投票机制决定最终结果""" image, binary = load_and_preprocess_image(test_image_path) image = np.expand_dims(image, axis=0) shape_features = extract_shape_features(binary) shape_features = np.expand_dims(shape_features, axis=0) predictions = [] for model in models: if model.name == 'unet_model': predicted_normalized = model.predict([image, shape_features], verbose=0)[0][0] else: predicted_normalized = model.predict(image, verbose=0)[0][0] predicted_wavelength = predicted_normalized * wavelength_range + min_wavelength predictions.append(predicted_wavelength) # 使用投票机制(例如取平均值) final_prediction = np.mean(predictions) print(f"最终预测波长: {final_prediction:.8f} 纳米") return final_prediction def main(): """主函数""" print(f"TensorFlow 版本: {tf.__version__}") try: train_dataset_unet, test_dataset_unet, train_dataset_cnn_dense, test_dataset_cnn_dense, all_files, min_wavelength, wavelength_range = load_and_prepare_data() print(f"最小波长: {min_wavelength:.6f}, 波长范围: {wavelength_range:.6f}") except Exception as e: print(f"数据加载失败: {str(e)}") traceback.print_exc() return print("\n开始训练模型...") try: unet_model, dense_model, cnn_model = train_models(train_dataset_unet, test_dataset_unet, train_dataset_cnn_dense, test_dataset_cnn_dense, (IMAGE_SIZE[0], IMAGE_SIZE[1], 1), 6, wavelength_range) except Exception as e: print(f"模型训练失败: {str(e)}") traceback.print_exc() return print("\n从测试集中随机选择一张图片进行预测...") try: for (images, features), labels in test_dataset_unet.take(1): if images.shape[0] > 0: test_image = images[0].numpy() test_features = features[0].numpy() labels_np = labels.numpy() if labels_np.ndim == 0: true_wavelength_normalized = labels_np.item() else: true_wavelength_normalized = labels_np[0] true_wavelength = true_wavelength_normalized * wavelength_range + min_wavelength test_image_path = "f:/phD/代码/test_image.tiff" imageio.imwrite(test_image_path, (test_image[..., 0] * 255).astype(np.uint8)) predicted_wavelength = predict_with_voting([unet_model, dense_model, cnn_model], test_image_path, min_wavelength, wavelength_range) print(f"真实波长: {true_wavelength:.6f} 纳米") print(f"预测波长: {predicted_wavelength:.6f} 纳米") print(f"绝对误差: {abs(predicted_wavelength - true_wavelength):.8f} 纳米") print(f"相对误差: {abs(predicted_wavelength - true_wavelength) / wavelength_range * 100:.4f}%") else: print("错误:测试批次中没有样本") except Exception as e: print(f"测试失败: {str(e)}") traceback.print_exc() print("\n您可以使用自己的图片进行测试:") model = tf.keras.models.load_model(MODEL_SAVE_PATH) image_path = input("请输入您要测试的图片路径(例如:'test_image.tiff'): ") predicted = predict_with_voting([unet_model, dense_model, cnn_model], image_path, min_wavelength, wavelength_range) print(f"预测波长: {predicted:.6f} 纳米") print("\n程序执行完成。") if __name__ == "__main__": os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' try: from skimage import filters, measure, morphology except ImportError: print("安装必要的库...") import subprocess subprocess.run([sys.executable, "-m", "pip", "install", "scikit-image", "imageio"]) from skimage import filters, measure, morphology main()
<think>我们正在处理一个TensorFlow中的TypeError,错误信息指出:`outer_factory.<locals>.inner_factory.<locals>.tf__combine_features() takes 1 positional argument but 2 were given`。这意味着我们在调用函数`combine_features`时传递了两个位置参数,但该函数只接受一个位置参数。
在TensorFlow中,尤其是在使用`tf.function`时,函数参数的处理可能会有些特殊。通常,当我们使用`tf.function`装饰一个函数时,该函数的参数应该被设计为接受一个或多个张量作为输入,并且我们需要确保在调用时传递的参数数量与定义一致。
错误信息中提到的`outer_factory`和`inner_factory`表明这可能是一个嵌套的函数装饰(可能是由`tf.function`或其他装饰器引起的)。具体来说,错误发生在`tf__combine_features`函数上,这是由`tf.function`装饰器生成的。
可能的原因和解决方案:
1. **函数定义与调用参数不匹配**:检查`combine_features`函数的定义,看它是否只定义了一个参数,但在调用时却传递了两个参数。例如:
错误定义:
```python
@tf.function
def combine_features(x):
...
```
错误调用:
```python
combine_features(feature1, feature2) # 传递了两个参数
```
修正方法:修改函数定义以接受两个参数,或者修改调用方式以传递一个包含两个元素的元组或列表。
2. **使用`tf.function`时参数传递方式问题**:如果函数定义时使用了多个参数,但在`tf.function`中调用时,由于TensorFlow的图执行特性,可能会要求我们以更结构化的方式传递参数。确保在调用被`tf.function`装饰的函数时,传递的参数与函数签名一致。
3. **嵌套装饰器问题**:如果函数被多个装饰器装饰,可能会改变函数的签名。检查装饰器的顺序以及它们如何影响函数。
4. **使用`tf.function`的`input_signature`参数**:如果希望函数接受多个输入,可以使用`input_signature`来明确指定函数的输入签名,这样TensorFlow可以正确地构建计算图。
示例:
假设我们有一个函数`combine_features`,它原本设计为接受两个张量:
```python
def combine_features(a, b):
return a + b
```
当我们用`@tf.function`装饰它时,调用这个函数需要两个参数。但是,如果我们错误地传递了一个参数(比如一个元组)或者两个单独的参数,可能会出现问题。然而,错误信息是“takes 1 positional argument but 2 were given”,这表明函数被设计为只接受一个位置参数,但我们却传递了两个。
因此,更可能的情况是函数被定义为只接受一个参数,但我们却传递了两个。例如:
```python
@tf.function
def combine_features(inputs):
a, b = inputs
return a + b
```
然后这样调用:`combine_features(a, b)`,这会导致错误,因为函数只接受一个参数(一个包含两个元素的元组或列表),而我们传递了两个单独的参数。
修正方法1:修改调用方式,将两个参数打包成一个元组:
```python
combine_features((a, b))
```
修正方法2:修改函数定义,使其接受两个参数:
```python
@tf.function
def combine_features(a, b):
return a + b
```
然后直接调用:`combine_features(a, b)`
5. **类方法中的self参数**:如果`combine_features`是一个类的方法,那么第一个参数是`self`,调用时它会自动传递。因此,当我们调用`instance.combine_features(x)`时,实际上传递了两个参数:`self`和`x`。但是,如果我们在定义时没有正确使用`self`,或者错误地使用了`tf.function`装饰器,可能会导致问题。
在类的方法中使用`tf.function`时,通常建议将方法装饰为`@tf.function`,并且确保第一个参数是`self`,第二个参数开始是输入。例如:
```python
class MyModel(tf.keras.Model):
@tf.function
def call(self, inputs):
...
```
但是,如果我们在类中定义了一个方法,它除了`self`外只接受一个参数,但我们调用时传递了两个参数(不包括自动传递的`self`),那么就会出错。
例如,定义:
```python
class MyClass:
@tf.function
def combine_features(self, inputs):
...
```
调用:`obj.combine_features(a, b)` # 这里a和b会被视为两个参数,加上自动传递的self,函数总共收到三个参数,但定义只接受两个(self和inputs)。所以这会导致错误。
因此,请检查函数定义和调用方式,确保参数数量匹配。
总结步骤:
1. 检查`combine_features`函数的定义,确认它有几个位置参数(注意:类方法中第一个参数是`self`,不计入用户传递的参数)。
2. 检查调用该函数的地方,确认传递了多少个位置参数。
3. 调整函数定义或调用,使参数数量一致。
由于没有提供具体的代码,以上是常见的可能原因和解决方案。请根据实际情况调整。
如果问题仍然存在,请提供相关代码片段,以便更准确地定位问题。</think>### 解决 TensorFlow 中的 TypeError:参数数量不匹配错误
#### 错误分析
错误信息 `TypeError: tf__combine_features() takes 1 positional argument but 2 were given` 表明:
1. 您的 `combine_features()` 函数被定义为 **只接受 1 个位置参数**
2. 但在调用时传递了 **2 个位置参数**
3. 常见于以下场景:
- 函数被 `@tf.function` 装饰后签名发生变化
- 类方法中未正确处理 `self` 参数
- 参数传递方式不匹配(位置参数 vs 关键字参数)
---
#### 解决方案(逐步操作)
##### 1. 检查函数定义与调用
```python
# 错误示例:定义只接受 1 个参数
@tf.function
def combine_features(inputs): # 仅 1 个参数
return tf.concat(inputs, axis=-1)
# 错误调用:传递了 2 个位置参数
output = combine_features(feature1, feature2) # 触发 TypeError
```
**修正方案**:
```python
# 方案 A:修改函数定义接受多个参数
@tf.function
def combine_features(*inputs): # 动态参数
return tf.concat(inputs, axis=-1)
# 方案 B:调用时打包参数
output = combine_features([feature1, feature2]) # 传递单参数列表
```
##### 2. 处理类方法中的 `self` 参数
```python
class FeatureModel(tf.keras.Model):
@tf.function # 装饰器会改变参数处理
def combine_features(self, inputs): # self 是隐式第一个参数
...
# 调用时实际传递了 (self, inputs) 两个参数
model = FeatureModel()
model.combine_features([feat1, feat2]) # 正确:单位置参数
```
##### 3. 显式定义输入签名(推荐)
```python
@tf.function(
input_signature=[tf.TensorSpec(shape=(None, 128), dtype=tf.float32)] * 2
) # 明确接受 2 个张量
def combine_features(feat1, feat2): # 正确定义两个参数
return tf.concat([feat1, feat2], axis=-1)
```
##### 4. 检查装饰器堆叠顺序
```python
# 错误:tf.function 应在最外层
@some_decorator
@tf.function # 应放在装饰器栈顶部
def combine_features(...):
...
# 正确顺序
@tf.function # 顶层装饰器
@some_decorator
def combine_features(...):
...
```
---
#### 关键预防措施
1. **统一参数传递方式**:
```python
# 推荐关键字参数调用
combine_features(inputs=[tensor1, tensor2])
```
2. **使用 `tf.nest` 处理嵌套结构**:
```python
@tf.function
def combine_features(inputs):
flat_inputs = tf.nest.flatten(inputs) # 解包嵌套参数
return tf.concat(flat_inputs, axis=-1)
```
3. **测试未装饰的原函数**:
```python
# 暂时移除 @tf.function 测试原始函数
def combine_features(...): ... # 确保基础逻辑正确
```
> **调试提示**:使用 `tf.get_concrete_function().pretty_printed_signature()` 查看图函数的实际签名[^1]。
---
阅读全文
相关推荐


















