本文共 4045 字,大约阅读时间需要 13 分钟。
数据集文件夹下的不同类别图片需要先进行整理,放在不同的子文件夹,放置格式如图所示:
在其他程序中引用这里的函数,引用方法如下:
import sys #绝对路径引用,不然引用load_data会报错#load_data所在程序路径sys.path.append(r'E:\Pycharm\project\yeah&ok\load_data') from load_data import load_data_func,test_image,augment
一般只需要引用load_data_func和test_image即可。
经过载入数据集程序的处理后,加载数据集就很简单了,加载方法如下:
ata_dir = 'E:\Pycharm\project\yeah&ok\dataset'Batch_size = 32 #批处理尺寸train_dataset,test_dataset = load_data_func(data_dir,batchsize=Batch_size)test_image(train_dataset) #显示9张图像
然后就能继续进行网络结构的搭建,进行训练等步骤了。
import tensorflow as tfimport matplotlib.pyplot as pltimport numpy as npimport pathlibimport randomimport tensorflow_datasets as tfds
主要函数:功能是输入数据集路径与批处理大小,返回训练集与测试集。
def load_data_func(data_dir,batch_size): data_root = pathlib.Path(data_dir) #读取路径,创建path对象 print(data_dir) print(data_root) all_image_path = list(data_root.glob('*/*')) #*/*是获取文件夹下的所有文件及其子文件 print(all_image_path) all_image_path = [str(path) for path in all_image_path] #获取所有图片的完整路径 print(all_image_path) random.shuffle(all_image_path) #打乱 label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir()) #获取图像文件夹名字 label_to_index = dict((name, index) for index, name in enumerate(label_names)) #创建字典对象,设置图像名称的映射为整数 print(label_to_index) #OK:0,Yeah:1 # 获取所有图像对应的标签 all_image_label = [label_to_index[pathlib.Path(p).parent.name]for p in all_image_path] #获取每个图象的父类名称,并变成数值,0101... print(len(all_image_label)) #显示获取的数据量 index_to_label = dict((v,k) for k,v in label_to_index.items()) #获取数值对应的标签名字,以备后用 image_patn = all_image_path[5] image_show = (1 + load_preprocess_image(image_patn)) / 2. # 要变成image/255.才能正常显示 plt.imshow(image_show) # 这里是测试图片能不能正常显示 plt.show() path_ds = tf.data.Dataset.from_tensor_slices(all_image_path) image_dataset = path_ds.map(load_preprocess_image) # 这里才是把所有图片提取出来,前面的都是路径 label_dataset = tf.data.Dataset.from_tensor_slices(all_image_label) dataset = tf.data.Dataset.zip((image_dataset, label_dataset)) # 做成数据集,zip将label和image对应起来 image_count = len(all_image_path) # 数据集的数量 test_count = int(image_count * 0.2) train_count = image_count - test_count print(test_count, train_count) train_dataset = dataset.skip(test_count) # 跳过test_count构成数据集 test_dataset = dataset.take(test_count) # 取test_count构成数据集 BATCH_SIZE = batch_size # buffer_size = train_count train_dataset = train_dataset.shuffle(buffer_size=150).repeat(3).batch(BATCH_SIZE) # 数据集数量不够则加个.repeat() test_dataset = test_dataset.batch(BATCH_SIZE) # 数据增强,OK,之前打乱过了,只需要对训练集数据增强 train_dataset = train_dataset.map(augment) return train_dataset,test_dataset
def load_preprocess_image(img_path): img_raw = tf.io.read_file(img_path) #读取路径 img_tensor = tf.image.decode_jpeg(img_raw,channels=3) #解码图片 decode_image通用,但不会返回shape,改成对应的格式 img_tensor = tf.image.resize(img_tensor,[160,160]) #改变图片大小 img_tensor = tf.cast(img_tensor, tf.float32) #转换数据类型 img = img_tensor/127.5-1 #标准化,归一化 return img
根据需要选择。
def augment(image,label): #随机进行水平翻转 image = tf.image.random_flip_left_right(image) #随机设置对比度 image = tf.image.random_contrast(image,lower=0.0,upper=1.0) #垂直翻转 image = tf.image.random_flip_up_down(image) #设置亮度 image = tf.image.random_brightness(image,max_delta=0.5) #设置色度 image = tf.image.random_hue(image,max_delta=0.3) #设置饱和度 image = tf.image.random_saturation(image,lower=0.3,upper=0.5) return image,label
这个函数会比较耗费时间,不需要每次都调用它。
def test_image(train_dataset): #用一次就行了 plt.figure(figsize=(12,12)) for batch in tfds.as_numpy(train_dataset): #这里耗时间很久。。尽量不用 for i in range(9): image, label = (1+batch[0][i])/2., batch[1][i] #image前面进行了归一化,因此这里要先恢复过来,才能正常显示图像 plt.subplot(3,3,i+1) plt.imshow(image) plt.grid(False) break plt.show()
转载地址:http://whcv.baihongyu.com/