Python Pandas dataframe.first_valid_index()

Pandas的first_valid_index()方法用于查找DataFrame或Series中第一个非NA/null的索引。在示例中,当数据帧或系列包含空值时,该函数返回第一个有效值的索引。如果所有值都是NA/null,则返回None。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Pandas dataframe.first_valid_index()函数返回数据框架中第一个非NA/null值的索引。在Pandas Series的情况下,会返回第一个非NA/null的索引。在pandas Dataframe的情况下,将返回具有单个非NA/null值的索引。

注意:如果所有元素都是非NA/null,返回None。对于空的DataFrame也返回None。

语法:

DataFrame.first_valid_index()

返回 : 标量 : 索引的类型

示例#1:使用first_valid_index()函数来查找数据帧中第一个非NA/null的索引。

# importing pandas as pd
import pandas as pd
  
# Creating the dataframe 
df = pd.DataFrame({"A":[None, None, 2, 4, 5], 
                   "B":[5, None, None, 44, 2],
                   "C":[None, None, None, 1, 5]})
  
# Print the dataframe
df

现在应用first_valid_index()函数。

# applying first_valid_index() function 
df.first_valid_index()

输出 :

注意,在第一行的第二列有非纳值。所以输出为0,表明第0个索引包含一个非纳值。

示例#2:使用first_valid_index()函数查找数据帧中第一个非NA/null索引。

# importing pandas as pd
import pandas as pd
  
# Creating the dataframe 
df = pd.DataFrame({"A":[None, None, 2, 4, 5],
                   "B":[None, None, None, 44, 2],
                   "C":[None, None, None, 1, 5]})
  
# applying first_valid_index() function 
df.first_valid_index()

输出 :

正如我们在数据框架中看到的,前两行只有NA值。因此,输出为2

例子#3:使用first_valid_index()函数来查找一个系列中的第一个非NA/null索引。

# importing pandas as pd
import pandas as pd
  
# Creating the series
ser = pd.Series([None, None, "sam", "alex", "sophia", None])
  
# Print the series
ser

现在应用first_valid_index()函数。

# applying first_valid_index() function 
ser.first_valid_index()

输出 :


输出是2,因为第0和第1个索引是空值。

 

import tkinter as tk from tkinter import ttk, filedialog, messagebox import pandas as pd import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg import tensorflow as tf from tensorflow.keras.models import Model from tensorflow.keras.layers import Input, Dense, Lambda from tensorflow.keras.optimizers import Adam from sklearn.preprocessing import MinMaxScaler import os import time import warnings import matplotlib.dates as mdates warnings.filterwarnings('ignore', category=UserWarning, module='tensorflow') mpl.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS'] mpl.rcParams['axes.unicode_minus'] = False # 关键修复:使用 ASCII 减号 # 设置中文字体支持 plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False class PINNModel(tf.keras.Model): def __init__(self, num_layers=4, hidden_units=32, dropout_rate=0.1, **kwargs): super(PINNModel, self).__init__(**kwargs) self.dense_layers = [] self.dropout_layers = [] # 创建隐藏层和对应的Dropout层 for _ in range(num_layers): self.dense_layers.append(Dense(hidden_units, activation='tanh')) self.dropout_layers.append(tf.keras.layers.Dropout(dropout_rate)) self.final_layer = Dense(1, activation='linear') # 添加更多带约束的物理参数 # 基本衰减系数 self.k1_raw = tf.Variable(0.1, trainable=True, dtype=tf.float32, name='k1_raw') self.k1 = tf.math.sigmoid(self.k1_raw) * 0.5 # 约束在0-0.5之间 # 水位依赖的衰减系数 self.k2_raw = tf.Variable(0.01, trainable=True, dtype=tf.float32, name='k2_raw') self.k2 = tf.math.sigmoid(self.k2_raw) * 0.1 # 约束在0-0.1之间 # 非线性项系数 self.alpha_raw = tf.Variable(0.1, trainable=True, dtype=tf.float32, name='alpha_raw') self.alpha = tf.math.sigmoid(self.alpha_raw) * 1.0 # 约束在0-1.0之间 # 外部影响系数(如降雨、温度等) self.beta_raw = tf.Variable(0.05, trainable=True, dtype=tf.float32, name='beta_raw') self.beta = tf.math.sigmoid(self.beta_raw) * 0.2 # 约束在0-0.2之间 def call(self, inputs, training=False): t, h, dt = inputs # 添加更多特征交互项 interaction = tf.concat([ t * h, h * dt, t * dt, (t * h * dt) ], axis=1) # 将时间、水位和时间步长作为输入特征 x = tf.concat([t, h, dt, interaction], axis=1) # 依次通过所有隐藏层和Dropout层 for dense_layer, dropout_layer in zip(self.dense_layers, self.dropout_layers): x = dense_layer(x) x = dropout_layer(x, training=training) # 仅在训练时应用Dropout return self.final_layer(x) def physics_loss(self, t, h_current, dt, training=False): """计算物理损失(改进的离散渗流方程)""" # 预测下一时刻的水位 h_next_pred = self([t, h_current, dt], training=training) # 改进的物理方程:非线性衰减模型 + 外部影响项 # 添加数值保护 exponent = - (self.k1 + self.k2 * h_current) * dt exponent = tf.clip_by_value(exponent, -50.0, 50.0) # 防止指数爆炸 decay_term = h_current * tf.exp(exponent) # 同样保护第二个指数项 beta_exp = -self.beta * dt beta_exp = tf.clip_by_value(beta_exp, -50.0, 50.0) external_term = self.alpha * (1 - tf.exp(beta_exp)) residual = h_next_pred - (decay_term + external_term) return tf.reduce_mean(tf.square(residual)) class DamSeepageModel: def __init__(self, root): self.root = root self.root.title("大坝渗流预测模型(PINNs)") self.root.geometry("1200x800") # 初始化数据 self.train_df = None # 训练集 self.test_df = None # 测试集 self.model = None self.scaler_t = MinMaxScaler(feature_range=(0, 1)) self.scaler_h = MinMaxScaler(feature_range=(0, 1)) self.scaler_dt = MinMaxScaler(feature_range=(0, 1)) self.evaluation_metrics = {} # 创建主界面 self.create_widgets() def create_widgets(self): # 创建主框架 main_frame = ttk.Frame(self.root, padding=10) main_frame.pack(fill=tk.BOTH, expand=True) # 左侧控制面板 control_frame = ttk.LabelFrame(main_frame, text="模型控制", padding=10) control_frame.pack(side=tk.LEFT, fill=tk.Y, padx=5, pady=5) # 文件选择部分 file_frame = ttk.LabelFrame(control_frame, text="数据文件", padding=10) file_frame.pack(fill=tk.X, pady=5) # 训练集选择 ttk.Label(file_frame, text="训练集:").grid(row=0, column=0, sticky=tk.W, pady=5) self.train_file_var = tk.StringVar() ttk.Entry(file_frame, textvariable=self.train_file_var, width=30, state='readonly').grid( row=0, column=1, padx=5) ttk.Button(file_frame, text="选择文件", command=lambda: self.select_file("train")).grid(row=0, column=2) # 测试集选择 ttk.Label(file_frame, text="测试集:").grid(row=1, column=0, sticky=tk.W, pady=5) self.test_file_var = tk.StringVar() ttk.Entry(file_frame, textvariable=self.test_file_var, width=30, state='readonly').grid(row=1, column=1, padx=5) ttk.Button(file_frame, text="选择文件", command=lambda: self.select_file("test")).grid(row=1, column=2) # PINNs参数设置 param_frame = ttk.LabelFrame(control_frame, text="PINNs参数", padding=10) param_frame.pack(fill=tk.X, pady=10) # 验证集切分比例 ttk.Label(param_frame, text="验证集比例:").grid(row=0, column=0, sticky=tk.W, pady=5) self.split_ratio_var = tk.DoubleVar(value=0.2) ttk.Spinbox(param_frame, from_=0, to=1, increment=0.05, textvariable=self.split_ratio_var, width=10).grid(row=0, column=1, padx=5) # 隐藏层数量 ttk.Label(param_frame, text="网络层数:").grid(row=1, column=0, sticky=tk.W, pady=5) self.num_layers_var = tk.IntVar(value=4) ttk.Spinbox(param_frame, from_=2, to=8, increment=1, textvariable=self.num_layers_var, width=10).grid(row=1, column=1, padx=5) # 每层神经元数量 ttk.Label(param_frame, text="神经元数/层:").grid(row=2, column=0, sticky=tk.W, pady=5) self.hidden_units_var = tk.IntVar(value=32) ttk.Spinbox(param_frame, from_=16, to=128, increment=4, textvariable=self.hidden_units_var, width=10).grid(row=2, column=1, padx=5) # 训练轮次 ttk.Label(param_frame, text="训练轮次:").grid(row=3, column=0, sticky=tk.W, pady=5) self.epochs_var = tk.IntVar(value=500) ttk.Spinbox(param_frame, from_=100, to=2000, increment=100, textvariable=self.epochs_var, width=10).grid(row=3, column=1, padx=5) # 物理损失权重 ttk.Label(param_frame, text="物理损失权重:").grid(row=4, column=0, sticky=tk.W, pady=5) self.physics_weight_var = tk.DoubleVar(value=0.5) ttk.Spinbox(param_frame, from_=0.1, to=1.0, increment=0.1, textvariable=self.physics_weight_var, width=10).grid(row=4, column=1, padx=5) # 控制按钮 btn_frame = ttk.Frame(control_frame) btn_frame.pack(fill=tk.X, pady=10) ttk.Button(btn_frame, text="训练模型", command=self.train_model).pack(side=tk.LEFT, padx=5) ttk.Button(btn_frame, text="预测结果", command=self.predict).pack(side=tk.LEFT, padx=5) ttk.Button(btn_frame, text="保存结果", command=self.save_results).pack(side=tk.LEFT, padx=5) ttk.Button(btn_frame, text="重置", command=self.reset).pack(side=tk.RIGHT, padx=5) # 状态栏 self.status_var = tk.StringVar(value="就绪") status_bar = ttk.Label(control_frame, textvariable=self.status_var, relief=tk.SUNKEN, anchor=tk.W) status_bar.pack(fill=tk.X, side=tk.BOTTOM) # 右侧结果显示区域 result_frame = ttk.Frame(main_frame) result_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True, padx=5, pady=5) # 创建标签页 self.notebook = ttk.Notebook(result_frame) self.notebook.pack(fill=tk.BOTH, expand=True) # 损失曲线标签页 self.loss_frame = ttk.Frame(self.notebook) self.notebook.add(self.loss_frame, text="训练损失") # 预测结果标签页 self.prediction_frame = ttk.Frame(self.notebook) self.notebook.add(self.prediction_frame, text="预测结果") # 指标显示 self.metrics_var = tk.StringVar() metrics_label = ttk.Label( self.prediction_frame, textvariable=self.metrics_var, font=('TkDefaultFont', 10, 'bold'), relief='ridge', padding=5 ) metrics_label.pack(fill=tk.X, padx=5, pady=5) # 初始化绘图区域 self.fig, self.ax = plt.subplots(figsize=(10, 6)) self.canvas = FigureCanvasTkAgg(self.fig, master=self.prediction_frame) self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True) # 损失曲线画布 self.loss_fig, self.loss_ax = plt.subplots(figsize=(10, 4)) self.loss_canvas = FigureCanvasTkAgg(self.loss_fig, master=self.loss_frame) self.loss_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True) def select_file(self, file_type): """选择Excel文件并计算时间步长""" try: file_path = filedialog.askopenfilename( title=f"选择{file_type}集Excel文件", filetypes=[("Excel文件", "*.xlsx *.xls"), ("所有文件", "*.*")] ) if not file_path: return df = pd.read_excel(file_path) # 验证必需列是否存在 required_cols = ['year', 'month', 'day', '水位'] missing_cols = [col for col in required_cols if col not in df.columns] if missing_cols: messagebox.showerror("列名错误", f"缺少必需列: {', '.join(missing_cols)}") return # 时间特征处理 time_features = ['year', 'month', 'day'] missing_time_features = [feat for feat in time_features if feat not in df.columns] if missing_time_features: messagebox.showerror("列名错误", f"Excel文件缺少预处理后的时间特征列: {', '.join(missing_time_features)}") return # 创建时间戳列 (增强兼容性) time_cols = ['year', 'month', 'day'] if 'hour' in df.columns: time_cols.append('hour') if 'minute' in df.columns: time_cols.append('minute') if 'second' in df.columns: time_cols.append('second') # 填充缺失的时间单位 for col in ['hour', 'minute', 'second']: if col not in df.columns: df[col] = 0 df['datetime'] = pd.to_datetime(df[time_cols]) # 设置时间索引 df = df.set_index('datetime') # 计算相对时间(天) df['days'] = (df.index - df.index[0]).days # 新增:计算时间步长dt(单位:天) df['dt'] = df.index.to_series().diff().dt.total_seconds() / 86400 # 精确到秒级 # 处理时间步长异常值 if len(df) > 1: # 计算有效时间步长(排除<=0的值) valid_dt = df['dt'][df['dt'] > 0] if len(valid_dt) > 0: avg_dt = valid_dt.mean() else: avg_dt = 1.0 else: avg_dt = 1.0 # 替换非正值 df.loc[df['dt'] <= 0, 'dt'] = avg_dt # 填充缺失值 df['dt'] = df['dt'].fillna(avg_dt) # 保存数据 if file_type == "train": self.train_df = df self.train_file_var.set(os.path.basename(file_path)) self.status_var.set(f"已加载训练集: {len(self.train_df)}条数据") else: self.test_df = df self.test_file_var.set(os.path.basename(file_path)) self.status_var.set(f"已加载测试集: {len(self.test_df)}条数据") except Exception as e: error_msg = f"文件读取失败: {str(e)}\n\n请确保:\n1. 文件不是打开状态\n2. 文件格式正确\n3. 包含必需的时间和水位列" messagebox.showerror("文件错误", error_msg) def calculate_metrics(self, y_true, y_pred): """计算评估指标""" from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score mse = mean_squared_error(y_true, y_pred) rmse = np.sqrt(mse) mae = mean_absolute_error(y_true, y_pred) non_zero_idx = np.where(y_true != 0)[0] if len(non_zero_idx) > 0: mape = np.mean(np.abs((y_true[non_zero_idx] - y_pred[non_zero_idx]) / y_true[non_zero_idx])) * 100 else: mape = float('nan') r2 = r2_score(y_true, y_pred) return { 'MSE': mse, 'RMSE': rmse, 'MAE': mae, 'MAPE': mape, # 修正键名 'R2': r2 } def train_model(self): """训练PINNs模型(带早停机制+训练指标监控)""" if self.train_df is None: messagebox.showwarning("警告", "请先选择训练集文件") return try: self.status_var.set("正在预处理数据...") self.root.update() # 从训练集中切分训练子集和验证子集(时间顺序切分) split_ratio = 1 - self.split_ratio_var.get() split_idx = int(len(self.train_df) * split_ratio) train_subset = self.train_df.iloc[:split_idx] valid_subset = self.train_df.iloc[split_idx:] # 检查数据量是否足够 if len(train_subset) < 2 or len(valid_subset) < 2: messagebox.showerror("数据错误", "训练集数据量不足(至少需要2个时间步)") return # 数据预处理 - 分别归一化不同特征 # 归一化时间特征 t_train = train_subset['days'].values[1:].reshape(-1, 1) self.scaler_t.fit(t_train) t_train_scaled = self.scaler_t.transform(t_train).astype(np.float32) # 归一化水位特征 h_train = train_subset['水位'].values[:-1].reshape(-1, 1) self.scaler_h.fit(h_train) h_train_scaled = self.scaler_h.transform(h_train).astype(np.float32) # 归一化时间步长特征 dt_train = train_subset['dt'].values[1:].reshape(-1, 1) self.scaler_dt.fit(dt_train) dt_train_scaled = self.scaler_dt.transform(dt_train).astype(np.float32) # 归一化标签(下一时刻水位) h_next_train = train_subset['水位'].values[1:].reshape(-1, 1) h_next_train_scaled = self.scaler_h.transform(h_next_train).astype(np.float32) # 准备验证数据(同样进行归一化) t_valid = valid_subset['days'].values[1:].reshape(-1, 1) t_valid_scaled = self.scaler_t.transform(t_valid).astype(np.float32) h_valid = valid_subset['水位'].values[:-1].reshape(-1, 1) h_valid_scaled = self.scaler_h.transform(h_valid).astype(np.float32) dt_valid = valid_subset['dt'].values[1:].reshape(-1, 1) dt_valid_scaled = self.scaler_dt.transform(dt_valid).astype(np.float32) h_next_valid_scaled = self.scaler_h.transform( valid_subset['水位'].values[1:].reshape(-1, 1) ).astype(np.float32) # 原始值用于指标计算 h_next_train_true = h_next_train h_next_valid_true = valid_subset['水位'].values[1:].reshape(-1, 1) # 创建模型和优化器 self.model = PINNModel( num_layers=self.num_layers_var.get(), hidden_units=self.hidden_units_var.get() ) optimizer = Adam(learning_rate=0.001) # 在训练循环中,使用归一化后的数据 train_dataset = tf.data.Dataset.from_tensor_slices( ((t_train_scaled, h_train_scaled, dt_train_scaled), h_next_train_scaled) ) train_dataset = train_dataset.shuffle(buffer_size=1024).batch(32) valid_dataset = tf.data.Dataset.from_tensor_slices( ((t_valid_scaled, h_valid_scaled, dt_valid_scaled), h_next_valid_scaled) ) valid_dataset = valid_dataset.batch(32) # 初始化训练历史记录列表 train_data_loss_history = [] physics_loss_history = [] valid_data_loss_history = [] train_metrics_history = [] valid_metrics_history = [] # 早停机制参数 patience = int(self.epochs_var.get() / 3) min_delta = 1e-4 best_valid_loss = float('inf') wait = 0 best_epoch = 0 best_weights = None start_time = time.time() # 自定义训练循环 for epoch in range(self.epochs_var.get()): # 训练阶段 epoch_train_data_loss = [] epoch_physics_loss = [] # 收集训练预测值(归一化后) train_pred_scaled = [] for step, ((t_batch, h_batch, dt_batch), h_next_batch) in enumerate(train_dataset): with tf.GradientTape() as tape: # 预测下一时刻水位 h_pred = self.model([t_batch, h_batch, dt_batch], training=True) # 添加training=True data_loss = tf.reduce_mean(tf.square(h_next_batch - h_pred)) # 动态调整物理损失权重 current_physics_weight = tf.minimum( self.physics_weight_var.get() * (1.0 + epoch / self.epochs_var.get()), 0.8 ) # 计算物理损失(传入时间步长dt) physics_loss = self.model.physics_loss(t_batch, h_batch, dt_batch,training=True) # 添加training=True loss = data_loss + current_physics_weight * physics_loss grads = tape.gradient(loss, self.model.trainable_variables) optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) epoch_train_data_loss.append(data_loss.numpy()) epoch_physics_loss.append(physics_loss.numpy()) train_pred_scaled.append(h_pred.numpy()) # 保存训练预测值(归一化) # 合并训练预测值(归一化后) train_pred_scaled = np.concatenate(train_pred_scaled, axis=0) # 反归一化得到原始预测值 train_pred_true = self.scaler_h.inverse_transform(train_pred_scaled) # 计算训练集指标(使用原始真实值和预测值) train_metrics = self.calculate_metrics( y_true=h_next_train_true.flatten(), y_pred=train_pred_true.flatten() ) train_metrics_history.append(train_metrics) # 验证阶段 epoch_valid_data_loss = [] valid_pred_scaled = [] for ((t_v_batch, h_v_batch, dt_v_batch), h_v_next_batch) in valid_dataset: h_v_pred = self.model([t_v_batch, h_v_batch, dt_v_batch], training=False) # 验证时不启用Dropout valid_data_loss = tf.reduce_mean(tf.square(h_v_next_batch - h_v_pred)) epoch_valid_data_loss.append(valid_data_loss.numpy()) valid_pred_scaled.append(h_v_pred.numpy()) # 保存验证预测值(归一化) # 合并验证预测值(归一化后) valid_pred_scaled = np.concatenate(valid_pred_scaled, axis=0) # 反归一化得到原始预测值 valid_pred_true = self.scaler_h.inverse_transform(valid_pred_scaled) # 计算验证集指标(使用原始真实值和预测值) valid_metrics = self.calculate_metrics( y_true=h_next_valid_true.flatten(), y_pred=valid_pred_true.flatten() ) valid_metrics_history.append(valid_metrics) # 计算平均损失 avg_train_data_loss = np.mean(epoch_train_data_loss) avg_physics_loss = np.mean(epoch_physics_loss) avg_valid_data_loss = np.mean(epoch_valid_data_loss) # 记录损失 train_data_loss_history.append(avg_train_data_loss) physics_loss_history.append(avg_physics_loss) valid_data_loss_history.append(avg_valid_data_loss) # 早停机制逻辑 current_valid_loss = avg_valid_data_loss if current_valid_loss < best_valid_loss - min_delta: best_valid_loss = current_valid_loss best_epoch = epoch + 1 wait = 0 best_weights = self.model.get_weights() else: wait += 1 if wait >= patience: self.status_var.set(f"触发早停!最佳轮次: {best_epoch},最佳验证损失: {best_valid_loss:.4f}") if best_weights is not None: self.model.set_weights(best_weights) break # 更新状态 if epoch % 1 == 0: # 提取当前训练/验证的关键指标 train_rmse = train_metrics['RMSE'] valid_rmse = valid_metrics['RMSE'] train_r2 = train_metrics['R2'] valid_r2 = valid_metrics['R2'] elapsed = time.time() - start_time self.status_var.set( f"训练中 | 轮次: {epoch + 1}/{self.epochs_var.get()} | " f"训练RMSE: {train_rmse:.4f} | 验证RMSE: {valid_rmse:.4f} | " f"训练R²: {train_r2:.4f} | 验证R²: {valid_r2:.4f} | " f"k1: {self.model.k1.numpy():.6f}, k2: {self.model.k2.numpy():.6f} | 时间: {elapsed:.1f}秒 | 早停等待: {wait}/{patience}" ) self.root.update() # 绘制损失曲线 self.loss_ax.clear() epochs_range = range(1, len(train_data_loss_history) + 1) self.loss_ax.plot(epochs_range, train_data_loss_history, 'b-', label='训练数据损失') self.loss_ax.plot(epochs_range, physics_loss_history, 'r--', label='物理损失') self.loss_ax.plot(epochs_range, valid_data_loss_history, 'g-.', label='验证数据损失') self.loss_ax.set_title('PINNs训练与验证损失') self.loss_ax.set_xlabel('轮次') self.loss_ax.set_ylabel('损失', rotation=0) self.loss_ax.legend() self.loss_ax.grid(True, alpha=0.3) self.loss_ax.set_yscale('log') self.loss_canvas.draw() # 训练完成提示 elapsed = time.time() - start_time if wait >= patience: completion_msg = ( f"早停触发 | 最佳轮次: {best_epoch} | 最佳验证损失: {best_valid_loss:.4f} | " f"最佳验证RMSE: {valid_metrics_history[best_epoch - 1]['RMSE']:.4f} | " f"总时间: {elapsed:.1f}秒" ) else: completion_msg = ( f"训练完成 | 总轮次: {self.epochs_var.get()} | " f"最终训练RMSE: {train_metrics_history[-1]['RMSE']:.4f} | " f"最终验证RMSE: {valid_metrics_history[-1]['RMSE']:.4f} | " f"最终训练R²: {train_metrics_history[-1]['R2']:.4f} | " f"最终验证R²: {valid_metrics_history[-1]['R2']:.4f} | " f"总时间: {elapsed:.1f}秒" ) # 保存训练历史 self.train_history = { 'train_data_loss': train_data_loss_history, 'physics_loss': physics_loss_history, 'valid_data_loss': valid_data_loss_history, 'train_metrics': train_metrics_history, 'valid_metrics': valid_metrics_history } # 保存学习到的物理参数 self.learned_params = { "k1": self.model.k1.numpy(), "k2": self.model.k2.numpy(), "alpha": self.model.alpha.numpy(), "beta": self.model.beta.numpy() } self.status_var.set(completion_msg) messagebox.showinfo("训练完成", f"PINNs模型训练成功完成!\n{completion_msg}") except Exception as e: messagebox.showerror("训练错误", f"模型训练失败:\n{str(e)}") self.status_var.set("训练失败") def predict(self): """使用PINNs模型进行递归预测(带Teacher Forcing和蒙特卡洛Dropout)""" if self.model is None: messagebox.showwarning("警告", "请先训练模型") return if self.test_df is None: messagebox.showwarning("警告", "请先选择测试集文件") return try: self.status_var.set("正在生成预测(使用Teacher Forcing和MC Dropout)...") self.root.update() # 预处理测试数据 - 归一化 t_test = self.test_df['days'].values.reshape(-1, 1) t_test_scaled = self.scaler_t.transform(t_test).astype(np.float32) dt_test = self.test_df['dt'].values.reshape(-1, 1) dt_test_scaled = self.scaler_dt.transform(dt_test).astype(np.float32) h_test = self.test_df['水位'].values.reshape(-1, 1) h_test_scaled = self.scaler_h.transform(h_test).astype(np.float32) # 递归预测参数 n = len(t_test) mc_iterations = 100 # 蒙特卡洛采样次数 teacher_forcing_prob = 0.7 # 使用真实值的概率 # 存储蒙特卡洛采样结果 mc_predictions_scaled = np.zeros((mc_iterations, n, 1), dtype=np.float32) # 进行多次蒙特卡洛采样 for mc_iter in range(mc_iterations): predicted_scaled = np.zeros((n, 1), dtype=np.float32) predicted_scaled[0] = h_test_scaled[0] # 第一个点使用真实值 # 递归预测(带Teacher Forcing) for i in range(1, n): # 决定使用真实值还是预测值(Teacher Forcing) use_actual = np.random.rand() < teacher_forcing_prob if use_actual and i < n - 1: # 不能使用未来值 h_prev = h_test_scaled[i - 1:i] else: h_prev = predicted_scaled[i - 1:i] t_prev = t_test_scaled[i - 1:i] dt_i = dt_test_scaled[i:i + 1] # 启用Dropout进行不确定性估计 h_pred = self.model([t_prev, h_prev, dt_i], training=True) # 启用Dropout进行不确定性估计 predicted_scaled[i] = h_pred.numpy()[0][0] mc_predictions_scaled[mc_iter] = predicted_scaled # 计算预测统计量 mean_pred_scaled = np.mean(mc_predictions_scaled, axis=0) std_pred_scaled = np.std(mc_predictions_scaled, axis=0) # 反归一化结果 predictions = self.scaler_h.inverse_transform(mean_pred_scaled) uncertainty = self.scaler_h.inverse_transform(std_pred_scaled) * 1.96 # 95%置信区间 actual_values = h_test test_time = self.test_df.index # 清除现有图表 self.ax.clear() # 计算合理的y轴范围 - 基于数据集中区域 # 获取实际值和预测值的中位数 median_val = np.median(actual_values) # 计算数据的波动范围(标准差) data_range = np.std(actual_values) * 4 # 4倍标准差覆盖大部分数据 # 设置y轴范围为中心值±数据波动范围 y_center = median_val y_half_range = max(data_range, 10) # 确保最小范围为20个单位 y_min_adjusted = y_center - y_half_range y_max_adjusted = y_center + y_half_range # 确保范围不为零 if y_max_adjusted - y_min_adjusted < 1: y_min_adjusted -= 5 y_max_adjusted += 5 # 绘制结果(带置信区间) self.ax.plot(test_time, actual_values, 'b-', label='真实值', linewidth=2) self.ax.plot(test_time, predictions, 'r--', label='预测均值', linewidth=2) self.ax.fill_between( test_time, (predictions - uncertainty).flatten(), (predictions + uncertainty).flatten(), color='orange', alpha=0.3, label='95%置信区间' ) # 设置自动调整的y轴范围 self.ax.set_ylim(y_min_adjusted, y_max_adjusted) self.ax.set_title('大坝渗流水位预测(PINNs with MC Dropout)') self.ax.set_xlabel('时间') self.ax.set_ylabel('测压管水位', rotation=0) self.ax.legend(loc='best') # 自动选择最佳位置 # 优化时间轴刻度 self.ax.xaxis.set_major_locator(mdates.YearLocator()) self.ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y')) self.ax.xaxis.set_minor_locator(mdates.MonthLocator(interval=2)) self.ax.grid(which='minor', axis='x', linestyle=':', color='gray', alpha=0.3) self.ax.grid(which='major', axis='y', linestyle='-', color='lightgray', alpha=0.5) self.ax.tick_params(axis='x', which='major', rotation=0, labelsize=9) self.ax.tick_params(axis='x', which='minor', length=2) # 计算评估指标(排除第一个点) eval_actual = actual_values[1:].flatten() eval_pred = predictions[1:].flatten() self.evaluation_metrics = self.calculate_metrics(eval_actual, eval_pred) # 添加不确定性指标 avg_uncertainty = np.mean(uncertainty) max_uncertainty = np.max(uncertainty) self.evaluation_metrics['Avg Uncertainty'] = avg_uncertainty self.evaluation_metrics['Max Uncertainty'] = max_uncertainty metrics_text = ( f"MSE: {self.evaluation_metrics['MSE']:.4f} | " f"RMSE: {self.evaluation_metrics['RMSE']:.4f} | " f"MAE: {self.evaluation_metrics['MAE']:.4f} | " f"MAPE: {self.evaluation_metrics['MAPE']:.2f}% | " f"R²: {self.evaluation_metrics['R2']:.4f}\n" f"平均不确定性: {avg_uncertainty:.4f} | 最大不确定性: {max_uncertainty:.4f}" ) self.metrics_var.set(metrics_text) # 在图表上添加指标 self.ax.text( 0.5, 1.05, metrics_text, transform=self.ax.transAxes, ha='center', fontsize=8, bbox=dict(facecolor='white', alpha=0.8) ) params_text = ( f"物理参数: k1={self.learned_params['k1']:.4f}, " f"k2={self.learned_params['k2']:.4f}, " f"alpha={self.learned_params['alpha']:.4f}, " f"beta={self.learned_params['beta']:.4f} | " f"Teacher Forcing概率: {teacher_forcing_prob}" ) self.ax.text( 0.5, 1.12, params_text, transform=self.ax.transAxes, ha='center', fontsize=8, bbox=dict(facecolor='white', alpha=0.8) ) # 调整布局 plt.tight_layout(pad=2.0) self.canvas.draw() # 保存预测结果 self.predictions = predictions self.uncertainty = uncertainty self.actual_values = actual_values self.test_time = test_time self.mc_predictions = mc_predictions_scaled self.status_var.set(f"预测完成(MC Dropout采样{mc_iterations}次)") except Exception as e: messagebox.showerror("预测错误", f"预测失败:\n{str(e)}") self.status_var.set("预测失败") import traceback traceback.print_exc() def save_results(self): """保存预测结果和训练历史数据""" if not hasattr(self, 'predictions') or not hasattr(self, 'train_history'): messagebox.showwarning("警告", "请先生成预测结果并完成训练") return # 选择保存路径 save_path = filedialog.asksaveasfilename( defaultextension=".xlsx", filetypes=[("Excel文件", "*.xlsx"), ("所有文件", "*.*")], title="保存结果" ) if not save_path: return try: # 1. 创建预测结果DataFrame result_df = pd.DataFrame({ '时间': self.test_time, '实际水位': self.actual_values.flatten(), '预测水位': self.predictions.flatten() }) # 2. 创建评估指标DataFrame metrics_df = pd.DataFrame([self.evaluation_metrics]) # 3. 创建训练历史DataFrame history_data = { '轮次': list(range(1, len(self.train_history['train_data_loss']) + 1)), '训练数据损失': self.train_history['train_data_loss'], '物理损失': self.train_history['physics_loss'], '验证数据损失': self.train_history['valid_data_loss'] } # 添加训练集指标 for metric in ['MSE', 'RMSE', 'MAE', 'MAPE', 'R2']: history_data[f'训练集_{metric}'] = [item[metric] for item in self.train_history['train_metrics']] # 添加验证集指标 for metric in ['MSE', 'RMSE', 'MAE', 'MAPE', 'R2']: history_data[f'验证集_{metric}'] = [item[metric] for item in self.train_history['valid_metrics']] history_df = pd.DataFrame(history_data) # 保存到Excel with pd.ExcelWriter(save_path) as writer: result_df.to_excel(writer, sheet_name='预测结果', index=False) metrics_df.to_excel(writer, sheet_name='评估指标', index=False) history_df.to_excel(writer, sheet_name='训练历史', index=False) # 保存图表 chart_path = os.path.splitext(save_path)[0] + "_chart.png" self.fig.savefig(chart_path, dpi=300) # 保存损失曲线图 loss_path = os.path.splitext(save_path)[0] + "_loss.png" self.loss_fig.savefig(loss_path, dpi=300) self.status_var.set(f"结果已保存至: {os.path.basename(save_path)}") messagebox.showinfo("保存成功", f"预测结果和图表已保存至:\n" f"主文件: {save_path}\n" f"预测图表: {chart_path}\n" f"损失曲线: {loss_path}") except Exception as e: messagebox.showerror("保存错误", f"保存结果失败:\n{str(e)}") def reset(self): # 重置归一化器 self.scaler_t = MinMaxScaler(feature_range=(0, 1)) self.scaler_h = MinMaxScaler(feature_range=(0, 1)) self.scaler_dt = MinMaxScaler(feature_range=(0, 1)) """重置程序状态""" self.train_df = None self.test_df = None self.model = None self.train_file_var.set("") self.test_file_var.set("") # 清除训练历史 if hasattr(self, 'train_history'): del self.train_history # 清除图表 if hasattr(self, 'ax'): self.ax.clear() if hasattr(self, 'loss_ax'): self.loss_ax.clear() # 重绘画布 if hasattr(self, 'canvas'): self.canvas.draw() if hasattr(self, 'loss_canvas'): self.loss_canvas.draw() # 清除状态 self.status_var.set("已重置,请选择新数据") # 清除预测结果 if hasattr(self, 'predictions'): del self.predictions # 清除指标文本 if hasattr(self, 'metrics_var'): self.metrics_var.set("") messagebox.showinfo("重置", "程序已重置,可以开始新的分析") if __name__ == "__main__": root = tk.Tk() app = DamSeepageModel(root) root.mainloop() 帮我改进一下模型训练时候的学习率使之可以动态变化
最新发布
07-27
TypeError Traceback (most recent call last) Cell In[135], line 6 3 adata.obs['CellLineage'] = adata.obs['CellLineage'].cat.add_categories('RUNX1+_Adipocytes') 4 adata.obs['CellLineage'] = adata.obs['CellLineage'].cat.add_categories('Beige_Adipocytes') ----> 6 adata.obs.loc[Adipocytes.obs_names,'CellLineage'] = Adipocytes.obs['CellLineage'] 8 adata.obs['CellLineage'] = adata.obs['CellLineage'].cat.remove_unused_categories() File /disk201/lihx/anaconda3/envs/lihxbase/lib/python3.13/site-packages/pandas/core/indexing.py:911, in _LocationIndexer.__setitem__(self, key, value) 908 self._has_valid_setitem_indexer(key) 910 iloc = self if self.name == "iloc" else self.obj.iloc --> 911 iloc._setitem_with_indexer(indexer, value, self.name) File /disk201/lihx/anaconda3/envs/lihxbase/lib/python3.13/site-packages/pandas/core/indexing.py:1942, in _iLocIndexer._setitem_with_indexer(self, indexer, value, name) 1939 # align and set the values 1940 if take_split_path: 1941 # We have to operate column-wise -> 1942 self._setitem_with_indexer_split_path(indexer, value, name) 1943 else: 1944 self._setitem_single_block(indexer, value, name) File /disk201/lihx/anaconda3/envs/lihxbase/lib/python3.13/site-packages/pandas/core/indexing.py:1986, in _iLocIndexer._setitem_with_indexer_split_path(self, indexer, value, name) 1982 self._setitem_with_indexer_2d_value(indexer, value) 1984 elif len(ilocs) == 1 and lplane_indexer == len(value) and not is_scalar(pi): 1985 # We are setting multiple rows in a single column. ... 2314 ) 2316 codes = self.categories.get_indexer(value) 2317 return codes.astype(self._ndarray.dtype, copy=False) TypeError: Cannot setitem on a Categorical with a new category, set the categories first
06-04
import os import pandas as pd import tkinter as tk from tkinter import ttk, filedialog, scrolledtext, messagebox from tkinter.colorchooser import askcolor from difflib import SequenceMatcher import re import openpyxl import threading import numpy as np from openpyxl.utils import get_column_letter import xlrd import gc import hashlib import json import tempfile from concurrent.futures import ThreadPoolExecutor, as_completed import unicodedata class EnhancedSignalComparator: def __init__(self, root): self.root = root self.root.title("增强版信号功能对比工具") self.root.geometry("1200x800") self.root.configure(bg="#f0f0f0") # 初始化变量 self.folder_path = tk.StringVar() self.search_text = tk.StringVar() self.files = [] self.results = {} # 存储信号对比结果 self.highlight_color = "#FFD700" # 默认高亮色 self.search_running = False self.stop_requested = False self.cache_dir = os.path.join(tempfile.gettempdir(), "excel_cache") self.file_cache = {} # 文件缓存 self.column_cache = {} # 列名缓存 self.max_workers = 4 # 最大并发线程数 # 创建缓存目录 os.makedirs(self.cache_dir, exist_ok=True) # 创建界面 self.create_widgets() def create_widgets(self): # 顶部控制面板 control_frame = ttk.Frame(self.root, padding=10) control_frame.pack(fill=tk.X) # 文件夹选择 ttk.Label(control_frame, text="选择文件夹:").grid(row=0, column=0, sticky=tk.W) folder_entry = ttk.Entry(control_frame, textvariable=self.folder_path, width=50) folder_entry.grid(row=0, column=1, padx=5, sticky=tk.EW) ttk.Button(control_frame, text="浏览...", command=self.browse_folder).grid(row=0, column=2) # 搜索输入 ttk.Label(control_frame, text="搜索信号:").grid(row=1, column=0, sticky=tk.W, pady=(10,0)) search_entry = ttk.Entry(control_frame, textvariable=self.search_text, width=50) search_entry.grid(row=1, column=1, padx=5, pady=(10,0), sticky=tk.EW) search_entry.bind("<Return>", lambda event: self.start_search_thread()) ttk.Button(control_frame, text="搜索", command=self.start_search_thread).grid(row=1, column=2, pady=(10,0)) ttk.Button(control_frame, text="停止", command=self.stop_search).grid(row=1, column=3, pady=(10,0), padx=5) # 高级选项 ttk.Label(control_frame, text="并发线程:").grid(row=2, column=0, sticky=tk.W, pady=(10,0)) self.thread_var = tk.StringVar(value="4") ttk.Combobox(control_frame, textvariable=self.thread_var, values=["1", "2", "4", "8"], width=5).grid(row=2, column=1, sticky=tk.W, padx=5, pady=(10,0)) # 文件过滤 ttk.Label(control_frame, text="文件过滤:").grid(row=2, column=2, sticky=tk.W, pady=(10,0)) self.filter_var = tk.StringVar(value="*.xlsx;*.xlsm;*.xls") ttk.Entry(control_frame, textvariable=self.filter_var, width=20).grid(row=2, column=3, sticky=tk.W, padx=5, pady=(10,0)) # 高亮颜色选择 ttk.Label(control_frame, text="高亮颜色:").grid(row=3, column=0, sticky=tk.W, pady=(10,0)) self.color_btn = tk.Button(control_frame, bg=self.highlight_color, width=3, command=self.choose_color) self.color_btn.grid(row=3, column=1, sticky=tk.W, padx=5, pady=(10,0)) # 进度条 self.progress = ttk.Progressbar(control_frame, orient="horizontal", length=200, mode="determinate") self.progress.grid(row=3, column=2, columnspan=2, sticky=tk.EW, padx=5, pady=(10,0)) # 结果标签 self.result_label = ttk.Label(control_frame, text="") self.result_label.grid(row=3, column=4, sticky=tk.W, padx=5, pady=(10,0)) # 对比面板 notebook = ttk.Notebook(self.root) notebook.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) # 表格视图 self.table_frame = ttk.Frame(notebook) notebook.add(self.table_frame, text="表格视图") # 文本对比视图 self.text_frame = ttk.Frame(notebook) notebook.add(self.text_frame, text="行内容对比") # 状态栏 self.status_var = tk.StringVar() status_bar = ttk.Label(self.root, textvariable=self.status_var, relief=tk.SUNKEN, anchor=tk.W) status_bar.pack(side=tk.BOTTOM, fill=tk.X) # 初始化表格和文本区域 self.init_table_view() self.init_text_view() def init_table_view(self): """初始化表格视图""" # 创建树状表格 columns = ("信号", "文件", "行内容摘要") self.tree = ttk.Treeview(self.table_frame, columns=columns, show="headings") # 设置列标题 for col in columns: self.tree.heading(col, text=col) self.tree.column(col, width=200, anchor=tk.W) # 添加滚动条 scrollbar = ttk.Scrollbar(self.table_frame, orient=tk.VERTICAL, command=self.tree.yview) self.tree.configure(yscrollcommand=scrollbar.set) self.tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) scrollbar.pack(side=tk.RIGHT, fill=tk.Y) # 绑定选择事件 self.tree.bind("<<TreeviewSelect>>", self.on_table_select) def init_text_view(self): """初始化文本对比视图""" self.text_panes = {} self.text_frame.columnconfigure(0, weight=1) self.text_frame.rowconfigure(0, weight=1) # 创建对比容器 self.compare_container = ttk.Frame(self.text_frame) self.compare_container.grid(row=0, column=0, sticky="nsew", padx=5, pady=5) # 添加差异高亮按钮 btn_frame = ttk.Frame(self.text_frame) btn_frame.grid(row=1, column=0, sticky="ew", padx=5, pady=5) ttk.Button(btn_frame, text="高亮显示差异", command=self.highlight_differences).pack(side=tk.LEFT) ttk.Button(btn_frame, text="导出差异报告", command=self.export_report).pack(side=tk.LEFT, padx=5) ttk.Button(btn_frame, text="清除缓存", command=self.clear_cache).pack(side=tk.LEFT, padx=5) ttk.Button(btn_frame, text="手动指定列名", command=self.manual_column_select).pack(side=tk.LEFT, padx=5) def browse_folder(self): """选择文件夹""" folder = filedialog.askdirectory(title="选择包含Excel文件的文件夹") if folder: self.folder_path.set(folder) self.load_files() def load_files(self): """加载文件夹中的Excel文件(优化特殊字符处理)""" folder = self.folder_path.get() if not folder or not os.path.isdir(folder): return # 获取文件过滤模式 filter_patterns = self.filter_var.get().split(';') self.files = [] for file in os.listdir(folder): file_path = os.path.join(folder, file) # 跳过临时文件 if file.startswith('~$'): continue # 检查文件扩展名 file_lower = file.lower() matched = False for pattern in filter_patterns: # 移除通配符并转换为小写 ext = pattern.replace('*', '').lower() if file_lower.endswith(ext): matched = True break if matched: # 规范化文件名处理特殊字符 normalized_path = self.normalize_file_path(file_path) if normalized_path and os.path.isfile(normalized_path): self.files.append(normalized_path) self.status_var.set(f"找到 {len(self.files)} 个Excel文件") def normalize_file_path(self, path): """规范化文件路径,处理特殊字符""" try: # 尝试直接访问文件 if os.path.exists(path): return path # 尝试Unicode规范化 normalized = unicodedata.normalize('NFC', path) if os.path.exists(normalized): return normalized # 尝试不同编码方案 encodings = ['utf-8', 'shift_jis', 'euc-jp', 'cp932'] for encoding in encodings: try: decoded = path.encode('latin1').decode(encoding) if os.path.exists(decoded): return decoded except: continue # 最终尝试原始路径 return path except Exception as e: self.status_var.set(f"文件路径处理错误: {str(e)}") return path def get_file_hash(self, file_path): """计算文件哈希值用于缓存""" try: hash_md5 = hashlib.md5() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest() except Exception as e: self.status_var.set(f"计算文件哈希失败: {str(e)}") return str(os.path.getmtime(file_path)) def get_cache_filename(self, file_path): """获取缓存文件名""" file_hash = self.get_file_hash(file_path) return os.path.join(self.cache_dir, f"{os.path.basename(file_path)}_{file_hash}.cache") def load_header_cache(self, file_path): """加载列名缓存""" cache_file = self.get_cache_filename(file_path) if os.path.exists(cache_file): try: with open(cache_file, "r", encoding='utf-8') as f: return json.load(f) except: return None return None def save_header_cache(self, file_path, header_info): """保存列名缓存""" cache_file = self.get_cache_filename(file_path) try: with open(cache_file, "w", encoding='utf-8') as f: json.dump(header_info, f) return True except: return False def find_header_row(self, file_path): """查找列名行(增强版)""" # 检查缓存 cache = self.load_header_cache(file_path) if cache: return cache.get("header_row"), cache.get("signal_col") # 没有缓存则重新查找 if file_path.lower().endswith((".xlsx", ".xlsm")): return self.find_header_row_openpyxl(file_path) elif file_path.lower().endswith(".xls"): return self.find_header_row_xlrd(file_path) return None, None def find_header_row_openpyxl(self, file_path): """使用openpyxl查找列名行(增强版)""" try: wb = openpyxl.load_workbook(file_path, read_only=True, data_only=True) ws = wb.active # 尝试多种列名匹配模式 patterns = [ r'データ名', # 半角片假名 r'データ名', # 全角片假名 r'信号名', # 中文 r'Signal Name', # 英文 r'Data Name', r'信号名称', r'データ名称' ] # 扩大搜索范围:前100行和前100列 for row_idx in range(1, 101): # 1-100行 # 扩大列搜索范围到100列 for col_idx in range(1, 101): # 1-100列 try: cell = ws.cell(row=row_idx, column=col_idx) cell_value = cell.value if not cell_value: continue # 尝试所有匹配模式 cell_str = str(cell_value) for pattern in patterns: if re.search(pattern, cell_str, re.IGNORECASE): # 找到列名行后,尝试确定信号列 signal_col = None # 在同行中查找信号列 for col_idx2 in range(1, 101): # 1-100列 try: cell2 = ws.cell(row=row_idx, column=col_idx2) cell2_value = cell2.value if not cell2_value: continue cell2_str = str(cell2_value) if re.search(pattern, cell2_str, re.IGNORECASE): signal_col = col_idx2 break except: continue # 保存缓存 if signal_col is not None: header_info = {"header_row": row_idx, "signal_col": signal_col} self.save_header_cache(file_path, header_info) wb.close() return row_idx, signal_col except: continue wb.close() except Exception as e: self.status_var.set(f"查找列名行出错: {str(e)}") return None, None def find_header_row_xlrd(self, file_path): """使用xlrd查找列名行(增强版)""" try: wb = xlrd.open_workbook(file_path) ws = wb.sheet_by_index(0) # 尝试多种列名匹配模式 patterns = [ r'データ名', # 半角片假名 r'データ名', # 全角片假名 r'信号名', # 中文 r'Signal Name', # 英文 r'Data Name', r'信号名称', r'データ名称' ] # 扩大搜索范围:前100行和前100列 for row_idx in range(0, 100): # 0-99行 # 扩大列搜索范围到100列 for col_idx in range(0, 100): # 0-99列 try: cell_value = ws.cell_value(row_idx, col_idx) if not cell_value: continue # 尝试所有匹配模式 cell_str = str(cell_value) for pattern in patterns: if re.search(pattern, cell_str, re.IGNORECASE): # 找到列名行后,尝试确定信号列 signal_col = None # 在同行中查找信号列 for col_idx2 in range(0, 100): # 0-99列 try: cell2_value = ws.cell_value(row_idx, col_idx2) if not cell2_value: continue cell2_str = str(cell2_value) if re.search(pattern, cell2_str, re.IGNORECASE): signal_col = col_idx2 break except: continue # 保存缓存 if signal_col is not None: header_info = {"header_row": row_idx, "signal_col": signal_col} self.save_header_cache(file_path, header_info) return row_idx, signal_col except: continue except Exception as e: self.status_var.set(f"查找列名行出错: {str(e)}") return None, None def extract_row_content(self, ws, row_idx, header_row, max_cols=100): """高效提取行内容(最多到100列)""" content = [] # 扩展到100列 for col_idx in range(1, max_cols + 1): try: cell = ws.cell(row=row_idx, column=col_idx) if cell.value is not None and str(cell.value).strip() != '': # 使用列名缓存 col_key = f"{header_row}-{col_idx}" if col_key in self.column_cache: col_name = self.column_cache[col_key] else: col_name_cell = ws.cell(row=header_row, column=col_idx) col_name = col_name_cell.value if col_name_cell.value else f"列{get_column_letter(col_idx)}" self.column_cache[col_key] = col_name content.append(f"{col_name}: {str(cell.value).strip()}") except: continue return "\n".join(content) def start_search_thread(self): """启动搜索线程""" if self.search_running: return self.search_running = True self.stop_requested = False self.max_workers = int(self.thread_var.get()) threading.Thread(target=self.search_files, daemon=True).start() def stop_search(self): """停止搜索""" self.stop_requested = True self.status_var.set("正在停止搜索...") def search_files(self): """在文件中搜索内容(优化特殊文件处理)""" search_term = self.search_text.get().strip() if not search_term: self.status_var.set("请输入搜索内容") self.search_running = False return if not self.files: self.status_var.set("请先选择文件夹") self.search_running = False return # 重置结果和UI self.results = {} for item in self.tree.get_children(): self.tree.delete(item) total_files = len(self.files) processed_files = 0 found_signals = 0 # 使用线程池处理文件 # 在search_files方法中添加详细进度 with ThreadPoolExecutor(max_workers=self.max_workers) as executor: futures = {} for i, file_path in enumerate(self.files): if self.stop_requested: break future = executor.submit(self.process_file, file_path, search_term) futures[future] = (file_path, i) # 保存文件索引 for future in as_completed(futures): if self.stop_requested: break file_path, idx = futures[future] try: found = future.result() found_signals += found processed_files += 1 # 更详细的进度反馈 progress = int(processed_files / total_files * 100) self.progress["value"] = progress self.status_var.set( f"已处理 {processed_files}/{total_files} 个文件 | " f"当前: {os.path.basename(file_path)} | " f"找到: {found_signals} 个匹配" ) self.root.update_idletasks() except Exception as e: self.status_var.set(f"处理文件 {os.path.basename(file_path)} 出错: {str(e)}") # 更新结果 if self.stop_requested: self.status_var.set(f"搜索已停止,已处理 {processed_files}/{total_files} 个文件") elif found_signals == 0: self.status_var.set(f"未找到包含 '{search_term}' 的信号") else: self.status_var.set(f"找到 {len(self.results)} 个匹配信号,共 {found_signals} 处匹配") self.update_text_view() self.progress["value"] = 0 self.search_running = False gc.collect() # 强制垃圾回收释放内存 def process_file(self, file_path, search_term): """处理单个文件(增强异常处理)""" found = 0 try: # 获取列名行和信号列 header_row, signal_col = self.find_header_row(file_path) # 如果自动查找失败,尝试手动模式 if header_row is None or signal_col is None: self.status_var.set(f"文件 {os.path.basename(file_path)} 未找到列名行,尝试手动查找...") header_row, signal_col = self.manual_find_header_row(file_path) if header_row is None or signal_col is None: self.status_var.set(f"文件 {os.path.basename(file_path)} 无法确定列名行,已跳过") return found # 使用pandas处理所有Excel文件类型 found = self.process_file_with_pandas(file_path, search_term, header_row, signal_col) except Exception as e: self.status_var.set(f"处理文件 {os.path.basename(file_path)} 出错: {str(e)}") return found def manual_find_header_row(self, file_path): """手动查找列名行(当自动查找失败时使用)""" try: # 尝试打开文件 if file_path.lower().endswith((".xlsx", ".xlsm")): wb = openpyxl.load_workbook(file_path, read_only=True, data_only=True) ws = wb.active # 扫描整个工作表(最多1000行) for row_idx in range(1, 1001): for col_idx in range(1, 101): try: cell = ws.cell(row=row_idx, column=col_idx) if cell.value and "データ" in str(cell.value): # 找到可能的列名行 return row_idx, col_idx except: continue wb.close() elif file_path.lower().endswith(".xls"): wb = xlrd.open_workbook(file_path) ws = wb.sheet_by_index(0) # 扫描整个工作表(最多1000行) for row_idx in range(0, 1000): for col_idx in range(0, 100): try: cell_value = ws.cell_value(row_idx, col_idx) if cell_value and "データ" in str(cell_value): # 找到可能的列名行 return row_idx, col_idx except: continue except: pass return None, None def get_file_cache_key(self, file_path, header_row, signal_col): """生成唯一的文件缓存键""" file_hash = self.get_file_hash(file_path) return f"{file_hash}_{header_row}_{signal_col}" def process_file_with_pandas(self, file_path, search_term, header_row, signal_col): """使用pandas高效处理Excel文件(优化版)""" found = 0 try: # 使用pandas读取Excel文件 file_ext = os.path.splitext(file_path)[1].lower() engine = 'openpyxl' if file_ext in ['.xlsx', '.xlsm'] else 'xlrd' # 动态确定要读取的列范围(前后10列) start_col = max(1, signal_col - 10) end_col = signal_col + 10 # 读取数据 df = pd.read_excel( file_path, engine=engine, header=header_row-1, usecols=range(start_col-1, end_col), dtype=str ) # 获取信号列名称 signal_col_idx = signal_col - start_col if signal_col_idx < len(df.columns): signal_col_name = df.columns[signal_col_idx] else: self.status_var.set(f"文件 {os.path.basename(file_path)}: 信号列超出范围") return 0 # 搜索匹配的信号 if signal_col_name in df.columns: matches = df[df[signal_col_name].str.contains(search_term, case=False, na=False)] else: self.status_var.set(f"文件 {os.path.basename(file_path)}: 未找到信号列 '{signal_col_name}'") return 0 # 处理匹配行 short_name = os.path.basename(file_path) for idx, row in matches.iterrows(): # 只显示有值的列 row_content = [] for col_name, value in row.items(): # 跳过空值 if pd.notna(value) and str(value).strip() != '': row_content.append(f"{col_name}: {str(value).strip()}") row_content = "\n".join(row_content) signal_value = row[signal_col_name] # 使用复合键确保唯一性 signal_key = f"{signal_value}||{short_name}" # 添加到结果集 self.results[signal_key] = { "signal": signal_value, "file": short_name, "content": row_content } # 添加到表格 summary = row_content[:50] + "..." if len(row_content) > 50 else row_content self.tree.insert("", tk.END, values=(signal_value, short_name, summary)) found += 1 # 每处理10行更新一次UI if found % 10 == 0: self.status_var.set(f"处理 {short_name}: 找到 {found} 个匹配") self.root.update_idletasks() except Exception as e: self.status_var.set(f"处理文件 {os.path.basename(file_path)} 出错: {str(e)}") finally: # 显式释放内存 if 'df' in locals(): del df if 'matches' in locals(): del matches gc.collect() return found def extract_xlrd_row_content(self, ws, row_idx, header_row): """为xls文件高效提取行内容""" content = [] try: row_values = ws.row_values(row_idx) except: return "" # 扩展到100列 for col_idx in range(min(len(row_values), 100)): try: cell_value = row_values[col_idx] if cell_value is not None and str(cell_value).strip() != '': # 使用列名缓存 col_key = f"{header_row}-{col_idx}" if col_key in self.column_cache: col_name = self.column_cache[col_key] else: try: col_name = ws.cell_value(header_row, col_idx) if not col_name: col_name = f"列{col_idx+1}" except: col_name = f"列{col_idx+1}" self.column_cache[col_key] = col_name content.append(f"{col_name}: {str(cell_value).strip()}") except: continue return "\n".join(content) def update_text_view(self): """更新文本对比视图""" # 清除现有文本区域 for widget in self.compare_container.winfo_children(): widget.destroy() if not self.results: return # 获取第一个信号作为默认显示 first_signal_key = next(iter(self.results.keys())) self.display_signal_comparison(first_signal_key) def on_table_select(self, event): """表格选择事件处理""" selected = self.tree.selection() if not selected: return item = self.tree.item(selected[0]) signal_value = item["values"][0] file_name = item["values"][1] # 使用复合键 signal_key = f"{signal_value}||{file_name}" self.display_signal_comparison(signal_key) def display_signal_comparison(self, signal_key): """显示指定信号的对比""" # 清除现有文本区域 for widget in self.compare_container.winfo_children(): widget.destroy() if signal_key not in self.results: return signal_data = self.results[signal_key] # 创建列框架 col_frame = ttk.Frame(self.compare_container) col_frame.grid(row=0, column=0, sticky="nsew", padx=5, pady=5) self.compare_container.columnconfigure(0, weight=1) # 文件名标签 file_label = ttk.Label(col_frame, text=signal_data["file"], font=("Arial", 10, "bold")) file_label.pack(fill=tk.X, pady=(0, 5)) # 信号名标签 signal_label = ttk.Label(col_frame, text=signal_data["signal"], font=("Arial", 9, "italic")) signal_label.pack(fill=tk.X, pady=(0, 5)) # 文本区域 text_area = scrolledtext.ScrolledText(col_frame, wrap=tk.WORD, width=60, height=20) text_area.insert(tk.INSERT, signal_data["content"]) text_area.configure(state="disabled") text_area.pack(fill=tk.BOTH, expand=True) # 保存引用 self.text_panes[signal_key] = text_area def highlight_differences(self): """高亮显示文本差异""" if not self.text_panes: return # 获取所有行内容 all_contents = [] for text_area in self.text_panes.values(): text_area.configure(state="normal") text = text_area.get("1.0", tk.END).strip() text_area.configure(state="disabled") all_contents.append(text) # 如果所有内容相同,则不需要高亮 if len(set(all_contents)) == 1: self.status_var.set("所有文件行内容完全一致") return # 使用第一个文件作为基准 base_text = all_contents[0] # 对比并高亮差异 for i, (file, text_area) in enumerate(self.text_panes.items()): if i == 0: # 基准文件不需要处理 continue text_area.configure(state="normal") text_area.tag_configure("diff", background=self.highlight_color) # 清除之前的高亮 text_area.tag_remove("diff", "1.0", tk.END) # 获取当前文本 compare_text = text_area.get("1.0", tk.END).strip() # 使用序列匹配器查找差异 s = SequenceMatcher(None, base_text, compare_text) # 高亮差异部分 for tag in s.get_opcodes(): opcode = tag[0] start = tag[3] end = tag[4] if opcode != "equal": # 添加高亮标签 text_area.tag_add("diff", f"1.0+{start}c", f"1.0+{end}c") text_area.configure(state="disabled") self.status_var.set("差异已高亮显示") def choose_color(self): """选择高亮颜色""" color = askcolor(title="选择高亮颜色", initialcolor=self.highlight_color) if color[1]: self.highlight_color = color[1] self.color_btn.configure(bg=self.highlight_color) def export_report(self): """导出差异报告""" if not self.results: messagebox.showwarning("警告", "没有可导出的结果") return try: # 创建报告数据结构 report_data = [] for signal, files_data in self.results.items(): for file, content in files_data.items(): report_data.append({ "信号": signal, "文件": file, "行内容": content }) # 转换为DataFrame df = pd.DataFrame(report_data) # 保存到Excel save_path = filedialog.asksaveasfilename( defaultextension=".xlsx", filetypes=[("Excel文件", "*.xlsx")], title="保存差异报告" ) if save_path: df.to_excel(save_path, index=False) self.status_var.set(f"报告已保存到: {save_path}") except Exception as e: messagebox.showerror("错误", f"导出报告失败: {str(e)}") def clear_cache(self): """清除缓存""" try: for file in os.listdir(self.cache_dir): if file.endswith(".cache"): os.remove(os.path.join(self.cache_dir, file)) self.file_cache = {} self.column_cache = {} self.status_var.set("缓存已清除") except Exception as e: self.status_var.set(f"清除缓存失败: {str(e)}") def manual_column_select(self): """手动指定列名位置""" if not self.files: messagebox.showinfo("提示", "请先选择文件夹") return # 创建手动选择窗口 manual_window = tk.Toplevel(self.root) manual_window.title("手动指定列名位置") manual_window.geometry("400x300") # 文件选择 ttk.Label(manual_window, text="选择文件:").pack(pady=(10, 5)) file_var = tk.StringVar() file_combo = ttk.Combobox(manual_window, textvariable=file_var, values=[os.path.basename(f) for f in self.files]) file_combo.pack(fill=tk.X, padx=20, pady=5) file_combo.current(0) # 行号输入 ttk.Label(manual_window, text="列名行号:").pack(pady=(10, 5)) row_var = tk.StringVar(value="1") row_entry = ttk.Entry(manual_window, textvariable=row_var) row_entry.pack(fill=tk.X, padx=20, pady=5) # 列号输入 ttk.Label(manual_window, text="信号列号:").pack(pady=(10, 5)) col_var = tk.StringVar(value="1") col_entry = ttk.Entry(manual_window, textvariable=col_var) col_entry.pack(fill=tk.X, padx=20, pady=5) # 确认按钮 def confirm_selection(): try: file_idx = file_combo.current() file_path = self.files[file_idx] header_row = int(row_var.get()) signal_col = int(col_var.get()) # 保存到缓存 header_info = {"header_row": header_row, "signal_col": signal_col} self.save_header_cache(file_path, header_info) messagebox.showinfo("成功", f"已为 {os.path.basename(file_path)} 设置列名位置:行{header_row} 列{signal_col}") manual_window.destroy() except Exception as e: messagebox.showerror("错误", f"无效输入: {str(e)}") ttk.Button(manual_window, text="确认", command=confirm_selection).pack(pady=20) if __name__ == "__main__": root = tk.Tk() app = EnhancedSignalComparator(root) root.mainloop() 1、扩大搜索范围之后,有一个文件中的信号值反而找不到了,但是信号值搜索范围在5时,能够找到,搜索范围在10时,找不到 2、行内容对比区域,现在只能显示一个文件的内容? 3、表格视图区域,除显示信号值之外,好像还搜索到了其他内容,行内容摘要显示有データ还有一行是DA: R? 4、目前来看显示的内容还是错的,信号值所在行,有内容的单元格对应的列名是错的 5、搜索到的两个文件【ドラフト版】D01D-00-02(HEV車).xlsm与【ドラフト版】D01D-00-03(コンベ車).xlsx显示的内容一致
07-24
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值