import numpy as np
import tensorflow as tf
class MAML:
def __init__(self, d, c, nway, meta_lr=1e-3, train_lr=1e-2):
"""
定义了图片大小、通道、样本类way、学习率
:param d:图片大小
:param c:通道
:param nway:类数
:param meta_lr:
:param train_lr:
"""
self.d = d
self.c = c
self.nway = nway
self.meta_lr = meta_lr
self.train_lr = train_lr
print('img shape:', self.d, self.d, self.c, 'meta-lr:', meta_lr, 'train-lr:', train_lr)
def build(self, support_xb, support_yb, query_xb, query_yb, K, meta_batchsz, mode='train'):
"""
输入训练数据,调用conv_weights()函数获取权重,然后在子函数meta_task()函数中调用forward()函数获取logits计算交叉熵损失,进而计算出梯度,然后进行第一次梯度更新,同时使用更新后的权重计算元测试集上的损失;
然后在多个任务上循环,分别计算每个任务的损失,进行第二次梯度下降,同时求出元测试集上的损失;
最后求出训练集、测试集上的平均损失和平均准确率。
:param support_xb: [b, setsz, 84*84*3]
:param support_yb: [b, setsz, n-way]
:param query_xb: [b, querysz, 84*84*3]
:param query_yb: [b, querysz, n-way]
:param K: train update steps
:param meta_batchsz:tasks number
:param mode: train/eval/test, for training, we build train&eval network meanwhile.
:return:
"""
# 创建或重用网络变量,不包括batch_norm变量,因此我们需要额外的重用机制来重用batch_norm变量。
self.weights = self.conv_weights()#weights有5个conv层和5个偏移量b
print('self.weights:',self.weights.values())
# TODO: meta-test is sort of test stage.
training = True if mode is 'train' else False#True/False
def meta_task(input):
"""
map_fn only support one parameters, so we need to unpack from tuple.
map_fn仅支持一个参数,所以我们需要从元组中解包
:param support_x: [setsz, 84*84*3]
:param support_y: [setsz, n-way]
:param query_x: [querysz, 84*84*3]
:param query_y: [querysz, n-way]
:param training: training or not, for batch_norm
:return:
"""
support_x, support_y, query_x, query_y = input
print('input:',input)
# to record the op in t update step.
query_preds, query_losses, query_accs = [], [], []
support_pred = self.forward(support_x, self.weights, training)
support_loss = tf.nn.softmax_cross_entropy_with_logits(logits=support_pred, labels=support_y)
support_acc = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(support_pred, dim=1), axis=1),tf.argmax(support_y, axis=1))
# 计算梯度
grads = tf.gradients(support_loss, list(self.weights.values()))
print('梯度:',grads)
# grad and variable dict
MAML核心代码
最新推荐文章于 2025-06-14 11:37:39 发布