SpringBoot集成JWT实现token验证

【Vue|组件开发】Element-ui的引入开发Input-Number计数器组件

  返回  

TensorFlow2.0+LSTM实现文本分类

2021/7/21 1:13:47 浏览:

功能描述

给定一段文本,判断文本的分类。

常用于新闻分类、文本情感分类等。

原理

RNN(Recurrent Neural Network,循环神经网络)实现序列文本分析。

An unrolled recurrent neural network.

RNN的问题:

  • 梯度消失(梯度趋于0,没有办法更新网络参数)和梯度爆炸(梯度趋于无穷大小,网络参数饱和,没有办法更新网络参数)
  • 梯度消失爆炸的本质是:RNN很难学到长文本序列

为了解决RNN的问题,提出LSTM网络(Long-Short Term Memory,长短期记忆网络)

  • 《Understanding LSTM Networks》
  • 《6.8. 长短期记忆(LSTM)》

LSTM的问题:参数过多。

GRU(Gate Recurrent Unit,门控循环单元)解决这个问题

  • 《6.7. 门控循环单元(GRU)》

代码

数据集:THUCNews

main.py

import tensorflow as tf
import numpy as np
from data_processing import DataConfig
from data_processing import load_data
import datetime
import os

#  解决 'OMP: Error #15: Initializing libiomp5.dylib, but found libiomp5.dylib already initialized.'
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

data_path = './data/news_data'


"""
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, None, 16)          80000     
_________________________________________________________________
bidirectional (Bidirectional (None, 32)                4224      
_________________________________________________________________
dropout (Dropout)            (None, 32)                0         
_________________________________________________________________
dense (Dense)                (None, 16)                528       
_________________________________________________________________
dense_1 (Dense)              (None, 10)                170       
=================================================================
Total params: 84,922
Trainable params: 84,922
Non-trainable params: 0
"""


def get_model():
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Embedding(DataConfig.vocab_size, 16))
    # 制定激活函数sigmoid,避免RNN、LSTM、GRU报错 'their signatures do not match.'
    model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(16, activation='sigmoid')))
    model.add(tf.keras.layers.Dropout(0.3))
    model.add(tf.keras.layers.Dense(16, activation='relu'))
    model.add(tf.keras.layers.Dense(10, activation='softmax'))

    model.summary()
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    return model


if __name__ == '__main__':
    all_x, all_y = load_data(data_path)
    np.random.seed(1)
    np.random.shuffle(all_x)
    np.random.seed(1)
    np.random.shuffle(all_y)

    val_x = all_x[:10000]
    train_x = all_x[10000:]
    val_y = all_y[:10000]
    train_y = all_y[10000:]

    log_dir = 'logs/fit/'+datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

    model = get_model()
    model.fit(train_x,
              train_y,
              epochs=2,
              batch_size=512,
              validation_data=(val_x, val_y),
              verbose=1,
              callbacks=[tensorboard_callback])

data_processing.py 

from collections import Counter
import tensorflow as tf
import os


DOCUMENTS = list()
data_path = './data/news_data'


class DataConfig:
    vocab_path = './vocab.txt'
    vocab_size = 5000
    max_length = 200


def build_vocab():
    all_data = []
    for content in DOCUMENTS:
        all_data.extend(content)

    counter = Counter(all_data)
    count_pairs = counter.most_common(DataConfig.vocab_size - 1)
    words, _ = list(zip(*count_pairs))
    words = ['<PAD>'] + list(words)
    open(DataConfig.vocab_path, 'w').write('\n'.join(words)+'\n')


def read_file(dir_path):
    global DOCUMENTS
    dir_list = os.listdir(dir_path)
    for sub_dir in dir_list:
        child_dir = os.path.join('%s/%s' % (dir_path, sub_dir))
        if os.path.isfile(child_dir):
            with open(child_dir, 'r') as file:
                document = ''
                lines = file.readlines()
                for line in lines:
                    document += line.strip()
            DOCUMENTS.append(dir_path[dir_path.rfind('/') + 1:] + '\t' + document)
        else:
            read_file(child_dir)


def load_data(dir_path):
    global DOCUMENTS
    data_x = []
    data_y = []

    read_file(dir_path)
    if not os.path.exists(DataConfig.vocab_path):
        build_vocab()

    with open(DataConfig.vocab_path, 'r') as fp:
        words = [_.strip() for _ in fp]
    word_to_id = dict(zip(words, range(len(words))))

    categories = ['科技', '股票', '体育', '娱乐', '时政', '社会', '教育', '财经', '家居', '游戏']
    categories_to_id = dict(zip(categories, range(len(categories))))

    for document in DOCUMENTS:
        y_, x_ = document.split('\t', 1)
        data_x.append([word_to_id[x] for x in x_ if x in word_to_id])
        data_y.append(categories_to_id[y_])
    data_x = tf.keras.preprocessing.sequence.pad_sequences(data_x, DataConfig.max_length)
    data_y = tf.keras.utils.to_categorical(data_y, num_classes=len(categories))

    return data_x, data_y


if __name__ == '__main__':
    load_data(data_path)

参考

  • 《Understanding LSTM Networks》
  • 《TensorFlow从零开始学》

联系我们

如果您对我们的服务有兴趣,请及时和我们联系!

服务热线:18288888888
座机:18288888888
传真:
邮箱:888888@qq.com
地址:郑州市文化路红专路93号