博客
关于我
一种通用的载入本地数据集的方法
阅读量:252 次
发布时间:2019-03-01

本文共 4045 字,大约阅读时间需要 13 分钟。

程序目录

1.说明

1.1 数据集放置格式说明

数据集文件夹下的不同类别图片需要先进行整理,放在不同的子文件夹,放置格式如图所示:

数据集存放格式
这里只有2类,当然多个分类也行,这个对分类类别的数量没有要求。

1.2 函数引用说明

在其他程序中引用这里的函数,引用方法如下:

import sys      #绝对路径引用,不然引用load_data会报错#load_data所在程序路径sys.path.append(r'E:\Pycharm\project\yeah&ok\load_data')	from load_data import load_data_func,test_image,augment

一般只需要引用load_data_func和test_image即可。

1.3 加载数据集程序中函数的使用方法说明

经过载入数据集程序的处理后,加载数据集就很简单了,加载方法如下:

ata_dir = 'E:\Pycharm\project\yeah&ok\dataset'Batch_size = 32     #批处理尺寸train_dataset,test_dataset = load_data_func(data_dir,batchsize=Batch_size)test_image(train_dataset)	#显示9张图像

然后就能继续进行网络结构的搭建,进行训练等步骤了。

2.配置库文件(开始)

import tensorflow as tfimport matplotlib.pyplot as pltimport numpy as npimport pathlibimport randomimport tensorflow_datasets as tfds

3.主函数

主要函数:功能是输入数据集路径与批处理大小,返回训练集与测试集。

def load_data_func(data_dir,batch_size):    data_root = pathlib.Path(data_dir)  #读取路径,创建path对象    print(data_dir)    print(data_root)    all_image_path = list(data_root.glob('*/*'))    #*/*是获取文件夹下的所有文件及其子文件    print(all_image_path)    all_image_path = [str(path) for path in all_image_path] #获取所有图片的完整路径    print(all_image_path)    random.shuffle(all_image_path)  #打乱    label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())    #获取图像文件夹名字    label_to_index = dict((name, index) for index, name in enumerate(label_names))  #创建字典对象,设置图像名称的映射为整数    print(label_to_index)   #OK:0,Yeah:1	# 获取所有图像对应的标签    all_image_label = [label_to_index[pathlib.Path(p).parent.name]for p in all_image_path]  #获取每个图象的父类名称,并变成数值,0101...    print(len(all_image_label))	#显示获取的数据量    index_to_label = dict((v,k) for k,v in label_to_index.items())  #获取数值对应的标签名字,以备后用    image_patn = all_image_path[5]    image_show = (1 + load_preprocess_image(image_patn)) / 2.  # 要变成image/255.才能正常显示    plt.imshow(image_show)  # 这里是测试图片能不能正常显示    plt.show()    path_ds = tf.data.Dataset.from_tensor_slices(all_image_path)    image_dataset = path_ds.map(load_preprocess_image)  # 这里才是把所有图片提取出来,前面的都是路径    label_dataset = tf.data.Dataset.from_tensor_slices(all_image_label)    dataset = tf.data.Dataset.zip((image_dataset, label_dataset))  # 做成数据集,zip将label和image对应起来    image_count = len(all_image_path)  # 数据集的数量    test_count = int(image_count * 0.2)    train_count = image_count - test_count    print(test_count, train_count)    train_dataset = dataset.skip(test_count)  # 跳过test_count构成数据集    test_dataset = dataset.take(test_count)  # 取test_count构成数据集    BATCH_SIZE = batch_size # buffer_size = train_count    train_dataset = train_dataset.shuffle(buffer_size=150).repeat(3).batch(BATCH_SIZE)     # 数据集数量不够则加个.repeat()    test_dataset = test_dataset.batch(BATCH_SIZE)    # 数据增强,OK,之前打乱过了,只需要对训练集数据增强    train_dataset = train_dataset.map(augment)    return train_dataset,test_dataset

4.从路径提取图片,并进行归一化处理

def load_preprocess_image(img_path):    img_raw = tf.io.read_file(img_path)           #读取路径    img_tensor = tf.image.decode_jpeg(img_raw,channels=3)   #解码图片 decode_image通用,但不会返回shape,改成对应的格式    img_tensor = tf.image.resize(img_tensor,[160,160])      #改变图片大小    img_tensor = tf.cast(img_tensor, tf.float32)  #转换数据类型    img = img_tensor/127.5-1                   #标准化,归一化    return img

5.对图片进行数据增强的函数

根据需要选择。

def augment(image,label):    #随机进行水平翻转    image = tf.image.random_flip_left_right(image)    #随机设置对比度    image = tf.image.random_contrast(image,lower=0.0,upper=1.0)    #垂直翻转    image = tf.image.random_flip_up_down(image)    #设置亮度    image = tf.image.random_brightness(image,max_delta=0.5)    #设置色度    image = tf.image.random_hue(image,max_delta=0.3)    #设置饱和度    image = tf.image.random_saturation(image,lower=0.3,upper=0.5)    return image,label

6.显示9张图片,可以用来看数据增强后图片效果

这个函数会比较耗费时间,不需要每次都调用它。

def test_image(train_dataset):    #用一次就行了    plt.figure(figsize=(12,12))    for batch in tfds.as_numpy(train_dataset):  #这里耗时间很久。。尽量不用        for i in range(9):            image, label = (1+batch[0][i])/2., batch[1][i]   #image前面进行了归一化,因此这里要先恢复过来,才能正常显示图像            plt.subplot(3,3,i+1)            plt.imshow(image)            plt.grid(False)        break    plt.show()

转载地址:http://whcv.baihongyu.com/

你可能感兴趣的文章
Multimodal Unsupervised Image-to-Image Translation多通道无监督图像翻译
查看>>
MySQL Cluster与MGR集群实战
查看>>
multipart/form-data与application/octet-stream的区别、application/x-www-form-urlencoded
查看>>
mysql cmake 报错,MySQL云服务器应用及cmake报错解决办法
查看>>
Multiple websites on single instance of IIS
查看>>
mysql CONCAT()函数拼接有NULL
查看>>
multiprocessing.Manager 嵌套共享对象不适用于队列
查看>>
multiprocessing.pool.map 和带有两个参数的函数
查看>>
MYSQL CONCAT函数
查看>>
multiprocessing.Pool:map_async 和 imap 有什么区别?
查看>>
MySQL Connector/Net 句柄泄露
查看>>
multiprocessor(中)
查看>>
mysql CPU使用率过高的一次处理经历
查看>>
Multisim中555定时器使用技巧
查看>>
MySQL CRUD 数据表基础操作实战
查看>>
multisim变压器反馈式_穿过隔离栅供电:认识隔离式直流/ 直流偏置电源
查看>>
mysql csv import meets charset
查看>>
multivariate_normal TypeError: ufunc ‘add‘ output (typecode ‘O‘) could not be coerced to provided……
查看>>
MySQL DBA 数据库优化策略
查看>>
multi_index_container
查看>>