当前位置: 首页 > news >正文

Keras深度学习实战(38)——图像字幕生成

Keras深度学习实战(38)——图像字幕生成

    • 0. 前言
    • 1. 模型与数据集分析
      • 1.1 数据集分析
      • 1.2 模型分析
    • 2. 实现图像字幕生成模型
      • 2.1 数据集加载与预处理
      • 2.2 模型构建与训练
    • 3. 使用束搜索生成字幕
      • 3.1 束搜索原理
      • 3.2 利用束搜索改进预测结果
    • 小结
    • 系列链接

0. 前言

图像和文本是当今两种主要的信息载体,其中图像具有生动形象的特点,而文本概括性强,能够以简练的形式传递信息。图像字幕生成旨在让计算机自动地使用文本对给定的图像加以描述,在图像检索、人机对话等应用中被广泛使用。
当前,网络中包含了数以十亿计的图片,为我们提供更加丰富娱乐和信息。但是,有视觉障碍的人或互联网速度较慢而无法加载图像时,则无法访问大部分视觉信息,手动添加的图像说明提供了一种更易于访问的方式。然而,现有的人工管理的图像说明字段仅涵盖少数图像,虽然自动生成图像字幕可以帮助解决这个问题,但获取准确的图像字幕是一项具有挑战性的任务。

1. 模型与数据集分析

我们已经学习了如何结合使用卷积神经网络 (Convolutional Neural Networks, CNN),循环神经网络 (Recurrent Neural Network, RNN) 和 CTC 损失转录手写文本图片。在本节中,我们将学习如何融合 CNNRNN 体系结构为给定图片生成字幕信息。

1.1 数据集分析

本节所用数据集称为 Conceptual Captions,由大约 330 万对图像-字幕对组成,它们是通过从数十亿网页中自动提取和筛选的图像字幕注释创建的。由于 Conceptual Captions 中的图像是从网络上提取的,因此它具有更广泛的图像-字幕对,从而可以更好地训练图像-字幕模型。
数据集可以通过 GitHub 进行下载,需要分别下载图像数据文件 Flickr8k_Dataset.zip 与字幕标签数据文件 Flickr8k_text.zip。下载完成后,解压文件后,可以查看图片示例:

图片示例

以上图片数据示例包括以下图像字幕:

  • 两个男孩在踢足球
  • 两个男孩在草地上踢足球
  • 两个男孩在草地上玩耍

1.2 模型分析

在实现图像字幕生成模型之前,我们首先了解为实现图像生成字幕所采用的策略流程:

  • 下载并加载数据集,其中包括各类图像以及与图像相关联的图像说明字幕
  • 使用预训练的 VGG16 模型提取每张图片的特征
  • 预处理每一图像相对应的字幕文本:
    • 将所有单词转换为小写
    • 删除标点符号
    • 为每一个图片字幕添加开始 (<start>) 和结束 (<end>) 标记
  • 为了快速训练模型,我们可以仅使用部分数据集,例如仅保留与女孩相关的照片(图片字幕文本中包含单词girl)
  • 为图像字幕词汇表中的每个不同单词分配一个索引
  • 填充所有图像字幕文本,以使所有句子具有相同的长度,句子中的每个单词均由索引值表示
  • 为了预测第一个单词,模型需要同时利用 VGG16 提取到特征和开始标记 <start> 的嵌入
  • 同理,为了预测第二个单词,模型需要同时利用 VGG16 提取到的特征以及开始标记 <start> 和第一个单词的嵌入
  • 以此类推,直到获取所有的预测单词,遇到结束标记 <end> 为止

2. 实现图像字幕生成模型

在本小节中,我们实现在上一小节定义的图像字幕生成模型。

2.1 数据集加载与预处理

首先,下载并解压缩包含图像及其相应图像字幕标签数据集。

(1) 导入相关的库,并加载图片与图片字幕文本数据集:

import numpy as np
import matplotlib.pyplot as plt
import random
from keras.preprocessing.sequence import pad_sequences
from keras.models import Model
from keras.layers import LSTM, Embedding, TimeDistributed, Dense, RepeatVector, Activation, Flatten
from keras.optimizers import Adam
from keras.layers import Bidirectional, Conv2D, MaxPooling2D, Input, concatenate
from keras.applications.inception_v3 import InceptionV3
import re
import cv2
from keras.applications.vgg16 import VGG16

caption_file = 'Flickr8k_text/Flickr8k.lemma.token.txt'
captions = open(caption_file, 'r').read().strip().split('\n')
d = {}
for i, row in enumerate(captions):
    row = row.split('\t')
    row[0] = row[0][:len(row[0])-2]
    if row[0] in d:
        d[row[0]].append(row[1])
    else:
        d[row[0]] = [row[1]]
total_images = list(d.keys())

(2) 加载图片并使用预训练的 VGG16 模型提取图像的特征:

image_path = '/home/brainiac/learnStar/finished/keras/Flickr8k_Dataset/Flicker8k_Dataset/'
vgg16=VGG16(include_top=False, weights='imagenet', input_shape=(224,224,3))

x = []
y = []
x2 = []
tot_images = ''
for i in range(len(d.keys())):
    if i%100==0:
        print(i)
    for j in range(len(d[total_images[i]])):
        img_path = image_path+total_images[i]
        img = cv2.imread(img_path)
        try:
            img2 = cv2.resize(img, (224, 224))/255
            img3 = vgg16.predict(img2.reshape(1,224,224,3))
            x.append(img3)
            y.append(d[total_images[i]][j])
            tot_images = tot_images + ' '+total_images[i]
        except:
            continue

(3) 将提取到的 VGG16 图像特征转换为 NumPy 数组:

x = np.array(x)
x = x.reshape(x.shape[0],7,7,512)

(4) 创建预处理函数,用于删除字幕文本中的标点符号、将所有单词都转换为小写:

def preprocess(text):
    text=text.lower()
    text=re.sub('[^0-9a-zA-Z]+',' ',text)
    words = text.split()
    words2 = words
    words4=' '.join(words2)
    return(words4)

接下来,使用预处理函数 preprocess 处理所有字幕文本,并附加开始标记 <start> 和结束标记 <end>

caps = []
for key, val in d.items():
    if(key in img_path2):
        for i in val:
            i = preprocess(i)
            print(i)
            caps.append('<start> ' + i + ' <end>')

(5) 为了快速得到训练后的网络,我们可以仅使用数据集的一部分,例如仅使用图片中包含女孩的样本:

caps2 = []
x3 = []
img_path3 = []
for i in range(len(caps)):
    if (('girl') in caps[i]):
        caps2.append(caps[i])
        x3.append(x2[i])
        img_path3.append(img_path2[i])
    elif 'dog' in caps[i]:
        caps2.append(caps[i])
        x3.append(x2[i])
        img_path3.append(img_path2[i])
    else:
        continue
# 如果使用全部数据集,为了保持代码的一致性,构造新变量
# for i in range(len(caps)):
#     caps2.append(caps[i])
#     x3.append(x2[i])
#     img_path3.append(img_path2[i])

(6) 提取图像字幕文本中的所有不重复的单词构成词汇表,并为词汇表中的单词分配不同索引:

words = [i.split() for i in caps2]
unique = []
for i in words:
    unique.extend(i)
	
unique = list(set(unique))
len(unique)
vocab_size = len(unique)

word2idx = {val:(index+1) for index, val in enumerate(unique)}
idx2word = {(index+1):val for index, val in enumerate(unique)}

(7) 计算图片字幕文本的最大长度,用于将所有字幕文本填充为相同长度:

max_len = 0
for c in caps2:
    c = c.split()
    if len(c) > max_len:
        max_len = len(c)
print(max_len)

根据计算出的图片字幕文本的最大长度 max_len,将所有字幕文本长度均填充至长度 max_len

n = np.zeros(vocab_size+1)
y = []
y2 = []
for k in range(len(caps2)):
    t= [word2idx[i] for i in caps2[k].split()]
    y.append(len(t))
    while(len(t)<max_len):
        t.append(word2idx['<end>'])
y2.append(t)

2.2 模型构建与训练

(1) 构建模型,以图片作为输入并从中提取图像特征:

embedding_size = 300
inp = Input(shape=(7,7,512))
inp1 = Conv2D(512, (3,3), activation='relu')(inp)
inp11 = MaxPooling2D(pool_size=(2, 2))(inp1)
inp2 = Flatten()(inp11)
img_emb = Dense(embedding_size, activation='relu') (inp2)
img_emb2 = RepeatVector(max_len)(img_emb)

(2) 构建另一模型,将图片字幕文字作为输入并从中提取文本特征:

inp2 = Input(shape=(max_len,))
cap_emb = Embedding((vocab_size+1), embedding_size, input_length=max_len) (inp2)
cap_emb2 = LSTM(256, return_sequences=True)(cap_emb)
cap_emb3 = TimeDistributed(Dense(300)) (cap_emb2)

(3) 将以上两个模型的输出串联起来,最后通过 softmax 函数计算得出所有可能单词的输出概率:

final1 = concatenate([img_emb2, cap_emb3])
final2 = Bidirectional(LSTM(256, return_sequences=False))(final1)
final3 = Dense(vocab_size+1)(final2)
final4 = Activation('softmax')(final3)

final_model = Model([inp, inp2], final4)

(4) 实例化Adam优化器,编译并模型:

adam = Adam(lr = 0.0001)
final_model.compile(loss='categorical_crossentropy', optimizer = adam, metrics=['acc'])

for i in range(8000):
    x4 = []
    x4_sent = []
    y3 = []
    shortlist_y = random.sample(range(len(y)-100),32)
    for j in range(len(shortlist_y)):
        for k in range(y[shortlist_y[j]]-1):
            # print(y[shortlist_y[j]]-1)
            n = np.zeros(vocab_size+1)      
            x4.append(x3[shortlist_y[j]])
            pad_sent = pad_sequences([y2[shortlist_y[j]][:(k+1)]], maxlen=max_len, padding='post')
            x4_sent.append(pad_sent)
            n[y2[shortlist_y[j]][(k+1)]] = 1
            y3.append(n)
    x4 = np.array(x4)
    x4_sent =np.array(x4_sent)
    x4_sent = x4_sent.reshape(x4_sent.shape[0], x4_sent.shape[2])
    y3 = np.array(y3) 
    history = final_model.fit([x4/np.max(x4), x4_sent], y3, batch_size=32, epochs=2, verbose=1)
    if i%100==0:
        l_train.append(history.history['loss'][0])

每次循环中,采样 32 张图片作为一个 batch 进行处理,训练模型;此外,我们以图像字幕文本中的前 n 个输出单词与利用预训练的 VGG16 提取的图片特征作为输入,模型的输出是图像字母文本中的第 n+1 个单词;同时,我们通过将利用 VGG16 提取的图像特征 (x4) 除以特征最大值 (np.max(x4)),将输入特征数据缩放至 [0, 1] 之间。

(5) 使用随机图片样本,利用训练完成的模型生成图像字幕文本:

l=-25
im_path = image_path+ img_path3[l]
img1 = cv2.imread(im_path)
plt.imshow(img1)

(6) 解码模型输出,得到图像字幕文本:

title = []
p = np.zeros(max_len)
p[0] = word2idx['<start>']
for i in range(max_len-1):
    pred= final_model.predict([x3[l].reshape(1,7,7,512)/np.max(x4), p.reshape(1,max_len)])
    pred2 = np.argmax(pred)
    print(idx2word[pred2])
    p[i+1] = pred2
    if(idx2word[pred2]=='<end>'):
        break
    else:
        title.append(idx2word[pred2])
plt.title(' '.join(title))
plt.show()
print(caps2[l])

以上代码的输出结果如下:

图像字幕生成

可以看到,生成的字幕文本正确地检测到两个男孩并且其中一个男孩戴着帽子。

...
<start> a young boy pull another boy wear a silly hat in a radio flyer wagon <end>

3. 使用束搜索生成字幕

在上一节中,我们基于给定时间戳中的 softmax 概率输出,解码获取具有最高概率的单词。在本节中,我们将通过使用束搜索来改善预测的结果字幕。

3.1 束搜索原理

我们首先介绍束搜索的工作原理,同时讲解如何将束搜索用于生成图像字幕,提高模型效果:

  • 首先,根据输入图片的 VGG16 特征和开始标记 <start> 文本特征,在第 1 个时间戳中提取预测单词概率
  • 与之前方法的不同在于,我们并不使用概率最高的单词作为输出,而是同时考虑概率最高的前 3 个单词
  • 在下一时间戳中,同样获取概率最高的前 3 个单词
  • 循环在第一个时间戳中输出的概率最高的前 3 个预测结果,作为第 2 个时间戳中预测的输入,并为每个输入提取预测结果中概率最高的前 3 个结果:
    • 假设第 1 个时间戳的概率最高的前 3 个预测为 abc
    • 使用 a 作为输入以及基于预训练 VGG16 提取的图像特征来预测时间戳 2 中概率最高的前 3 个预测结果分别为 def,然后对输入 bc 进行类似处理
    • 因此,在经过第 1 个时间戳和第 2 个时间戳后,共可以得到 9 种输出组合
    • 除了存储预测结果组合外,我们还将存储这 9 种结果组合的预测置信度:
      • 例如:如果时间戳 1a 的输出概率为 0.4,时间戳 2d 的输出概率为 0.5,则预测结果组合的预测置信度为 0.4 x 0.5 = 0.2
    • 保留置信度最高的前 3 个组合,并丢弃其余组合
  • 重复以上的步骤,在每个时间戳均保留置信度最高的前 3 个组合,直到到达句子结尾
  • 在以上过程中,值 3 是我们要搜索组合的束长度,可以使用不同的束长度,在每个时间戳中保留不同数量的预测结果,观察使用不同束长度时模型性能。

接下来,我们将实现束搜索,将其用于图像字幕生成。

3.2 利用束搜索改进预测结果

(1) 定义函数 get_top3,该函数以图片的 VGG16 特征作为输入,以及单词序列及其之前时间戳的相应置信度作为输入,并返回当前时间戳的前 3 个预测:

def get_top3(img, string_with_conf):
    tokens, confidence = string_with_conf
    p = np.zeros((1, max_len))
    p[0, :len(tokens)] = np.array(tokens)
    # 预测下一时间戳结果
    pred = final_model.predict([img.reshape(1,7,7,512)/12, p])
    best_pred = list(np.argsort(pred)[0][-beamsize:])
    best_confs = list(pred[0,best_pred])
    top_best = [(tokens + list([best_pred[i]]), confidence*best_confs[i]) for i in range(beamsize)]
    return top_best

在函数 get_top3 中,我们分离了 string_with_conf 参数中提供的单词 ID 及其对应的置信度。此外,我们将分词序列存储在数组中,并使用该序列进行预测。
然后,在下一个时间戳中提取前 3 个预测并将其存储在 best_pred 中。此外,除了单词 ID戳 的最佳预测以外,我们还存储与当前时间戳中前 3 项预测相关的置信度。最后,返回第 2 个时间戳的三个预测。

(2) 在句子的最大可能长度范围内循环,并在所有时间戳中提取前 3 个可能的单词组合:

start_token = word2idx['<start>']
best_strings = [([start_token], 1)]
for i in range(max_len):
     new_best_strings = []
     for string in best_strings:
         strings = get_top3(x4[l], string)
         new_best_strings.extend(strings) 
         best_strings = sorted(new_best_strings, key=lambda x: x[1], reverse=True)[:beamsize]

(3) 循环遍历前面获得的 best_strings 以打印输出:

for i in range(3):
     string = best_strings[i][0]
     print('============')
     for j in string:
         print(idx2word[j])
         if(idx2word[j]=='<end>'):
             break

我们在上一节中测试的同一张图片的输出语句如下:

束搜索改进结果
可以看到,图像字幕生成模型对于示例图像,关于图像的描述,第一句和第二句有所不同,而第二句恰好与第三句相同,这是因为组合的可能性更高。

...
<start> one child pull another sit in a red wagon along the sand <end>
<start> a toddler in dirty jean attampts to pull a young child in a wagon <end>
<start> a toddler in dirty jean attampts to pull a young child in a wagon <end>

小结

随着大规模数据集的出现,深度学习因其出色的计算能力在很多传统的计算机视觉任务上取得了巨大的成功,尤其是图像识别领域的图像字幕生成任务。本文利用深度学习技术设计出能够连接图像与自然语言的模型,从而实现图像字幕生成。本文设计的模型主要包含两个部分,一个是图像特征提取部分,另一个是语言建模与生成部分。同时,为了提高图像字幕生成模型的性能,我们使用束搜索对模型进行改进。

系列链接

Keras深度学习实战(1)——神经网络基础与模型训练过程详解
Keras深度学习实战(2)——使用Keras构建神经网络
Keras深度学习实战(3)——神经网络性能优化技术
Keras深度学习实战(4)——深度学习中常用激活函数和损失函数详解
Keras深度学习实战(5)——批归一化详解
Keras深度学习实战(6)——深度学习过拟合问题及解决方法
Keras深度学习实战(7)——卷积神经网络详解与实现
Keras深度学习实战(8)——使用数据增强提高神经网络性能
Keras深度学习实战(9)——卷积神经网络的局限性
Keras深度学习实战(10)——迁移学习详解
Keras深度学习实战(11)——可视化神经网络中间层输出
Keras深度学习实战(12)——面部特征点检测
Keras深度学习实战(13)——目标检测基础详解
Keras深度学习实战(14)——从零开始实现R-CNN目标检测
Keras深度学习实战(15)——从零开始实现YOLO目标检测
Keras深度学习实战(16)——自编码器详解
Keras深度学习实战(17)——使用U-Net架构进行图像分割
Keras深度学习实战(18)——语义分割详解
Keras深度学习实战(19)——使用对抗攻击生成可欺骗神经网络的图像
Keras深度学习实战(20)——DeepDream模型详解
Keras深度学习实战(21)——神经风格迁移详解
Keras深度学习实战(22)——生成对抗网络详解与实现
Keras深度学习实战(23)——DCGAN详解与实现
Keras深度学习实战(24)——从零开始构建单词向量
Keras深度学习实战(25)——使用skip-gram和CBOW模型构建单词向量
Keras深度学习实战(26)——文档向量详解
Keras深度学习实战(27)——循环神经详解与实现
Keras深度学习实战(28)——利用单词向量构建情感分析模型
Keras深度学习实战(29)——长短时记忆网络详解与实现
Keras深度学习实战(30)——使用文本生成模型进行文学创作
Keras深度学习实战(31)——构建电影推荐系统
Keras深度学习实战(32)——基于LSTM预测股价
Keras深度学习实战(33)——基于LSTM的序列预测模型
Keras深度学习实战(34)——构建聊天机器人
Keras深度学习实战(35)——构建机器翻译模型
Keras深度学习实战(36)——基于编码器-解码器的机器翻译模型
Keras深度学习实战(37)——手写文字识别

相关文章:

  • 含电热联合系统的微电网运行优化附Matlab代码
  • SecXOps 核心技术能力划分
  • PyTorch学习笔记-TensorBoard
  • 牛顿法与拟牛顿法摘记
  • Collectors.collectingAndThen()
  • verilog练习——组合逻辑
  • 【Python+百度API】实现人脸识别和颜值检测系统(包括人脸数量、年龄、颜值评分、性别、种族、表情检测)(超详细 附源码)
  • Python:每日一题之顺子日期
  • 简单的Hystrix熔断
  • 036-JList列表控件使用案例讲解
  • 【网页设计】期末大作业html+css+js(在线鲜花盆栽网站)
  • Python用PyMC3实现贝叶斯线性回归模型
  • VuePress构建一个文档管理网站
  • PyQt5学习笔记--多线程处理、数据交互
  • Node.js 入门教程 3 如何安装 Node.js
  • Android Studio实现志愿者系统
  • 智能驾驶 车牌检测和识别(一)《CCPD车牌数据集》
  • 【c语言进阶】动态内存管理知识大全(下)
  • 前端实现水印的两种方式(DOM和Canvas)
  • Linux——进程
  • 【看表情包学Linux】冯诺依曼架构 | 理解操作系统 | 基于 Pintos 实现新的用户级程序的系统调用
  • 单链表——简单的增删查改
  • 电子技术——MOS放大器基础
  • 嵌入式 学习
  • 【Linux】vim编辑器的使用
  • JAVA基础(File类的重命名和删除功能)
  • Ajax重构
  • 1005 继续(3n+1)猜想 PTA
  • JAVA基础(File类的判断功能)
  • 1002 写出这个数(水题)
  • JSP基本语法
  • JAVA基础(File类的获取功能)
  • python初步入门
  • JavaBean的应用
  • JAVA基础(输出指定目录下指定后缀的文件名)