使用pytorch进行语义分割模型训练

这篇文章我主要介绍一下我搭建的语义分割任务框架,这个框架可以训练很多语义分割模型。

我主要是在PASCAL VOC上训练了FCN网络,希望对大家能有所帮助。

项目架构

上图就是项目架构了,我介绍几个主要的东西

checkpoint:用来存放中间的结果文件

dataset:用来存放加载数据集的文件

model:用来存放网络模型

pic:存放混淆矩阵可视化图片

util:用来保存工具脚本

eval.py:计算测试集性能指标的代码

train.py:训练代码

下面上代码

dataset/pascal_data.py

import torch
import torchvision.transforms as tfs
import os
import scipy.io as scio
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import random


# PASCAL VOC语义分割增强数据集
prefix = "G:/data/VOCBSD/"

# 超参数,设置裁剪的尺寸
CROP = 256

class PASCAL_BSD(object):
    def __init__(self, mode="train", change=False):
        super(PASCAL_BSD, self).__init__()
        # 读取数据的模式
        self.mode = mode
        # 类别标签,一共有20+1个类
        self.classes = ['background','aeroplane','bicycle','bird','boat',
           'bottle','bus','car','cat','chair','cow','diningtable',
           'dog','horse','motorbike','person','potted plant',
           'sheep','sofa','train','tv/monitor']
        # 颜色标签,分别对应21个类别
        self.colormap = [[0,0,0],[128,0,0],[0,128,0], [128,128,0], [0,0,128],
            [128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0],
            [64,128,0],[192,128,0],[64,0,128],[192,0,128],
            [64,128,128],[192,128,128],[0,64,0],[128,64,0],
            [0,192,0],[128,192,0],[0,64,128]]

        self.im_tfs = tfs.Compose([
            tfs.ToTensor(),
            tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        # 将mat格式的数据转换成png格式
        if (change == True):
            self.mat2png()

        self.image_name = []
        self.label_name = []
        self.readImage()
        print("%s->成功加载%d张图片"%(self.mode, len(self.image_name)))

    # 读取图像和标签信息
    def readImage(self):
        img_root = prefix + "JPEGImage/"
        label_root = prefix + "SegmentationClass/"
        if(self.mode == "train"):
            with open(prefix+"train.txt", "r") as f:
                list_dir = f.readlines()
        elif(self.mode == "val"):
            with open(prefix + "val.txt", "r") as f:
                list_dir = f.readlines()
        for item in list_dir:
            self.image_name.append(img_root + item.split("\n")[0] + ".jpg")
            self.label_name.append(label_root + item.split("\n")[0] + ".png")

    # 数据处理,输入Image对象,返回tensor对象
    def data_process(self, img, img_gt):
        if(self.mode == "train"):
            # 以50%的概率左右翻转
            a = random.random()
            if(a > 0.5):
                img = img.transpose(Image.FLIP_LEFT_RIGHT)
                img_gt = img_gt.transpose(Image.FLIP_LEFT_RIGHT)
            # 以50%的概率上下翻转
            a = random.random()
            if(a > 0.5):
                img = img.transpose(Image.FLIP_TOP_BOTTOM)
                img_gt = img_gt.transpose(Image.FLIP_TOP_BOTTOM)
            # 以50%的概率像素矩阵转置
            a = random.random()
            if (a > 0.5):
                img = img.transpose(Image.TRANSPOSE)
                img_gt = img_gt.transpose(Image.TRANSPOSE)
            a = random.random()
            # 进行随机裁剪
            width, height = img.size
            st = random.randint(0,20)
            box = (st, st, width-1, height-1)
            img = img.crop(box)
            img_gt = img_gt.crop(box)

        img = img.resize((CROP, CROP))
        img_gt = img_gt.resize((CROP, CROP))

        img = self.im_tfs(img)
        img_gt = np.array(img_gt)
        img_gt = torch.from_numpy(img_gt)

        """
        plt.subplot(1,2,1), plt.imshow(img.permute(1,2,0))
        plt.subplot(1,2,2), plt.imshow(img_gt)
        plt.show()
        """


        return img, img_gt

    def add_noise(self, img, gama=0.2):
        noise = torch.randn(img.shape[0], img.shape[1], img.shape[2])
        noise = noise * gam
评论 21
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值