torch.index_select与torch.gather

本文详细介绍了PyTorch中两种重要的索引操作:index_select和gather。index_select适用于两维矩阵,能够按行或列选取数据;而gather则能处理多维张量,并按具体索引选取特定位置的数据。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

torch.index_select

index_select 只能处理两维矩阵,指定行或者列的索引,按行或者按列取出
# indices 只能是一维
x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [2, 2, 3, 4]])
print(x)
indices = torch.tensor([0, 2])
print(torch.index_select(x, 0, indices))
# 按行取
print(torch.index_select(x, 1, indices))
# 按列取
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8],
        [2, 2, 3, 4]])
# 按行取
tensor([[1, 2, 3, 4],
        [2, 2, 3, 4]])
# 按列取
tensor([[1, 3],
        [5, 7],
        [2, 3]])

torch.gather

gather可处理多维张量, 按具体的索引取出对应位置的数(不是整行或者整列取)

x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [2, 2, 3, 4]])
print(x)
# 3*4



indices = torch.tensor([[1], [2], [3]])
# 3*1
# 意思是取[0,1],[2,2],[3,3]索引位置的数
# 每行取一个,indices指定列标
print(torch.gather(x, 1, indices))
# 按照第二维度取数,indices第一个维度要和x一样


indices = torch.tensor([[1, 2, 1, 0]])
# 1*4
# 意思是取[1,0],[2,1],[1,2],[0,3]索引位置的数
# 每列取一个,indices指定行标
print(torch.gather(x, 0, indices))
# 按照第一维度取数,indices第二个维度要和x一样

tensor([[1, 2, 3, 4],
        [5, 6, 7, 8],
        [2, 2, 3, 4]])
# 每行取一个,指定列标
tensor([[2],
        [7],
        [4]])
# 每列取一个,指定行标
tensor([[5, 2, 7, 4]])

def forward(self, x, cond_BD, attn_bias, gt_hidden_states: Optional[torch.Tensor] = None, last_patch = None): # C: embed_dim, D: cond_dim if self.shared_aln: gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C else: gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2) gate_1 = self.token_router(x[:,255:424,:]) # [B, 169, 1] gate_2 = self.token_router(x[:,424:,:]) # [B, 256, 1] gate_1, indices_1 = self.set_top_k(data=gate_1.squeeze(dim=-1), k=last_patch[0]**2, indices_ret=True) # [B, 169] [B, 36] gate_2, indices_2 = self.set_top_k(data=gate_2.squeeze(dim=-1), k=last_patch[1]**2, indices_ret=True) # [B, 256] [B, 36] scale_9_org_res = torch.einsum("beh, be -> bh", x[:,255:424,:], gate_1) # [B, 1024] scale_10_org_res = torch.einsum("beh, be -> bh", x[:,424:,:], gate_2) # [B, 1024] x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1)) if gt_hidden_states is None: gt_hidden_states = x else: gt_hidden_states = gt_hidden_states + self.drop_path(self.attn( self.ln_wo_grad(gt_hidden_states).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1)) x_1, route_loss, gt_hidden_states_1, balance_loss, kd_loss = self.ffn(self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2), self.ln_wo_grad(gt_hidden_states).mul(scale2.add(1)).add_(shift2)) x = x + self.drop_path(x_1.mul(gamma2)) gt_hidden_states = gt_hidden_states + self.drop_path(gt_hidden_states_1.mul(gamma2)) scale_9_x_res = torch.einsum("beh, be -> bh", x[:, 255:424,:], gate_1) # [B, 1024] scale_10_x_res = torch.einsum("beh, be -> bh", x[:, 424:,:], gate_2) # [B, 1024] # 计算损失 select_token_loss = -(F.mse_loss(scale_9_x_res, scale_9_org_res) + F.mse_loss(scale_10_x_res, scale_10_org_res)) return x, route_loss, gt_hidden_states, balance_loss, kd_loss, select_token_loss 以上是我的代码,是什么操作导致线性层self.token_router在反向传播时没有梯度?
06-22
class Agent: def __init__(self, state_dim, action_dim): self.state_dim = state_dim self.action_dim = action_dim self.memory = deque(maxlen=2000) self.gamma = 0.95 # 折扣因子 self.epsilon = 1.0 # 探索率 self.epsilon_min = 0.01 self.epsilon_decay = 0.995 self.learning_rate = 0.001 self.batch_size = 32 self.update_target = 100 # 目标网络更新频率 # 创建策略网络和目标网络 self.policy_net = self._build_network() self.target_net = self._build_network() self.target_net.set_weights(self.policy_net.get_weights()) self.optimizer = optimizers.Adam(learning_rate=self.learning_rate) self.steps = 0 def _build_network(self): model = models.Sequential([ layers.Dense(64, activation='relu', input_shape=(self.state_dim,)), layers.Dense(64, activation='relu'), layers.Dense(self.action_dim) ]) return model def act(self, state): if np.random.rand() <= self.epsilon: return random.randrange(self.action_dim) state = np.expand_dims(state, axis=0) q_values = self.policy_net.predict(state, verbose=0) return np.argmax(q_values[0]) def remember(self, state, action, reward, next_state, done): self.memory.append((state, action, reward, next_state, done)) def replay(self): if len(self.memory) < self.batch_size: return # 从经验池中采样 minibatch = random.sample(self.memory, self.batch_size) states, actions, rewards, next_states, dones = zip(*minibatch) states = np.array(states) actions = np.array(actions) rewards = np.array(rewards) next_states = np.array(next_states) dones = np.array(dones) # 计算目标Q值 next_q_values = self.target_net.predict(next_states, verbose=0) next_actions = np.argmax(self.policy_net.predict(next_states, verbose=0), axis=1) next_q = next_q_values[np.arange(self.batch_size), next_actions] target_q = rewards + self.gamma * next_q * (1 - dones) # 计算当前Q值 target = self.policy_net.predict(states, verbose=0) target[np.arange(self.batch_size), actions] = target_q # 训练网络 self.policy_net.train_on_batch(states, target) # 更新探索率 if self.epsilon > self.epsilon_min: self.epsilon *= self.epsilon_decay # 更新目标网络 self.steps += 1 if self.steps % self.update_target == 0: self.target_net.set_weights(self.policy_net.get_weights()) def save_weights(self, path): self.policy_net.save_weights(path) def load_weights(self, path): self.policy_net.load_weights(path) self.target_net.load_weights(path)体现的是什么算法
最新发布
07-18
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值