深度神经网络语言识别

「AI秘籍」系列课程:

  • 人工智能应用数学基础
  • 人工智能Python基础
  • 人工智能基础核心知识
  • 人工智能BI核心知识
  • 人工智能CV核心知识

使用 DNN 和字符 n-gram 对一段文本的语言进行分类(附 Python 代码)

在这里插入图片描述

资料来源,flaticon:https://www.flaticon.com/premium-icon/cyborg_901032

语言识别是自然语言处理 (NLP) 问题中的一个重要步骤。它涉及尝试预测一段文本的自然语言。在采取其他操作(即翻译/情感分析)之前,了解文本的语言非常重要。例如,如果你使用谷歌翻译,你输入的框会显示“检测语言”。这是因为谷歌首先尝试识别你的句子的语言,然后才能翻译它。

在这里插入图片描述

语言识别有几种不同的方法,在本文中,我们将详细探讨其中一种方法。即使用神经网络和字符 n-gram 作为特征。最后,我们表明这种方法可以实现超过 98% 的准确率。在此过程中,我们将讨论关键代码,你可以在GitHub1找到完整的项目。首先,我们将讨论用于训练神经网络的数据集。

数据集

数据集2由 Tatoeba 提供。 完整数据集包含 328 种独特语言的 6,872,356 个句子。为了简化我们的问题,我们将考虑:

  • 6 种拉丁语言:英语、德语、西班牙语、法语、葡萄牙语和意大利语。
  • 长度在 20 到 200 个字符之间的句子。

我们可以在表 1 中看到每种语言的一个句子示例。我们的目标是创建一个可以使用提供的文本预测目标变量的模型。

在这里插入图片描述

我们在下面的代码中加载数据集并进行一些初始处理。我们首先过滤数据集以获取所需长度和语言的句子。我们从每种语言中随机选择 50,000 个句子,这样我们总共有 300,000 行。然后将这些句子分成训练集(70%)、验证集(20%)和测试集(10%)。

# read in full dataset
data = pd.read_csv(data_path + '/public_articles/sentences.csv', 
                            sep='\t', 
                            encoding='utf8', 
                            index_col=0,
                            names=['lang','text'])

# Filter by text length
data = data[data['text'].str.len().between(20, 200)]

# Filter by text language
lang = ['deu', 'eng', 'fra', 'ita', 'por', 'spa']
data = data[data['lang'].isin(lang)]

# Select 50000 rows for each language
data_trim_list = [data[data['lang'] == l].sample(50000, random_state=100) for l in lang]

# Concatenate all the samples
data_trim = pd.concat(data_trim_list)

# Create a random train, valid, test split
data_shuffle = data_trim.sample(frac=1, random_state=100)

train = data_shuffle[:210000]
valid = data_shuffle[210000:270000]
test = data_shuffle[270000:300000]

# Check the shapes to ensure everything is correct
print(f"Train set shape: {train.shape}")
print(f"Validation set shape: {valid.shape}")
print(f"Test set shape: {test.shape}")

特征工程

在拟合模型之前,我们必须将数据集转换为神经网络可以理解的形式。换句话说,我们需要从句子列表中提取特征来创建特征矩阵。我们使用字符 n-gram(n 个连续字符的集合)来实现这一点。这是一种类似于词袋模型的方法,只不过我们使用的是字符而不是单词。

对于我们的语言识别问题,我们将使用字符 3-grams/ trigrams (即 3 个连续字符的集合)。在图 2 中,我们看到了如何使用 trigrams 对句子进行矢量化的示例。首先,我们从句子中获取所有 trigrams 。为了减少特征空间,我们取这些 trigrams 的子集。我们使用这个子集对句子进行矢量化。第一个句子的向量是 [2,0,1,0,0],因为 trigrams “is_”在句子中出现两次,“his”出现一次。

在这里插入图片描述

创建三元特征矩阵的过程类似,但稍微复杂一些。在下一节中,我们将深入研究用于创建矩阵的代码。在此之前,有必要对如何创建特征矩阵进行总体概述。所采取的步骤如下:

  1. 使用训练集,我们从每种语言中选择了 200 个最常见的三字母组
  2. 根据这些 trigrams 创建一个唯一 trigrams 列表。这些语言共享一些共同的 trigrams ,因此我们最终得到了 661 个唯一 trigrams
  3. 通过计算每个句子中每个 trigrams 出现的次数来创建特征矩阵

我们可以在表 2 中看到此类特征矩阵的示例。顶行给出了 661 个 trigrams 中的每一个。然后,每个编号行给出了我们数据集中的一个句子。矩阵中的数字给出了该 trigrams 在句子中出现的次数。例如,“eux”在句子 2 中出现了一次。

表 2:训练特征矩阵

创建特征

在本节中,我们将介绍用于创建表 2 中的训练特征矩阵和验证/测试特征矩阵的代码。我们大量使用了SciKit Learn 提供的CountVectorizer包。此包允许我们根据一些词汇表(即单词/字符列表)对文本进行矢量化。在我们的例子中,词汇表是一组 661 个 trigrams 。

首先,我们必须创建这个词汇表。我们首先从每种语言中获取 200 个最常见的 trigrams 。这是使用下面代码中的*get_trigrams*函数完成的。此函数获取一个句子列表,并将从这些句子中返回 200 个最常见的 trigrams 的列表。

from sklearn.feature_extraction.text import CountVectorizer

def get_trigrams(corpus, n_feat=200):
    """
    Returns a list of the N most common character trigrams from a list of sentences
    params
    ------------
        corpus: list of strings
        n_feat: integer
    """
    # fit the n-gram model
    vectorizer = CountVectorizer(analyzer='char', ngram_range=(3, 3), max_features=n_feat)

    X = vectorizer.fit_transform(corpus)

    # Get model feature names
    feature_names = vectorizer.get_feature_names_out()
    return feature_names

在下面的代码中,我们循环遍历这 6 种语言。对于每种语言,我们从训练集中获取相关句子。然后我们使用get_trigrams函数获取 200 个最常见的 trigrams 并将它们添加到集合中。最后,由于这些语言共享一些共同的 trigrams ,我们得到了一组 661 个独特的 trigrams 。我们用它们来创建一个词汇表。

# obtain trigrams from each language
features = {}
features_set = set()

for l in lang:
    
    # get corpus filtered by language
    corpus = train[train.lang==l]['text']
    
    # get 200 most frequent trigrams
    trigrams = get_trigrams(corpus)
    
    # add to dict and set
    features[l] = trigrams 
    features_set.update(trigrams)

    
# create vocabulary list using feature set
vocab = dict()
for i,f in enumerate(features_set):
    vocab[f]=i

然后,CountVectorisor 包使用词汇表对训练集中的每个句子进行矢量化。结果就是我们之前看到的表 2 中的特征矩阵。

# train count vectoriser using vocabulary
vectorizer = CountVectorizer(analyzer='char',
                             ngram_range=(3, 3),
                            vocabulary=vocab)

# create feature matrix for training set
corpus = train['text']   
X = vectorizer.fit_transform(corpus)
feature_names = vectorizer.get_feature_names_out()

train_feat = pd.DataFrame(data=X.toarray(),columns=feature_names)

在训练模型之前,最后一步是缩放特征矩阵。这将有助于我们的神经网络收敛到最佳参数权重。在下面的代码中,我们使用最小-最大缩放来缩放训练矩阵。

# Scale feature matrix 
train_min = train_feat.min()
train_max = train_feat.max()
train_feat = (train_feat - train_min)/(train_max-train_min)

# Add target variable 
train_feat['lang'] = list(train['lang'])

我们还需要获取验证和测试数据集的特征矩阵。在下面的代码中,我们像对训练集所做的那样对 2 个集合进行矢量化和缩放。值得注意的是,我们使用了词汇表以及从训练集中获得的最小/最大值。这是为了避免任何数据泄露。

# create feature matrix for validation set
corpus = valid['text']   
X = vectorizer.fit_transform(corpus)

valid_feat = pd.DataFrame(data=X.toarray(),columns=feature_names)
valid_feat = (valid_feat - train_min)/(train_max-train_min)
valid_feat['lang'] = list(valid['lang'])

# create feature matrix for test set
corpus = test['text']   
X = vectorizer.fit_transform(corpus)

test_feat = pd.DataFrame(data=X.toarray(),columns=feature_names)
test_feat = (test_feat - train_min)/(train_max-train_min)
test_feat['lang'] = list(test['lang'])

探索 trigrams

现在,我们已经准备好了可用于训练神经网络的数据集。在此之前,探索数据集并建立一些直觉来了解这些特征在预测语言方面的表现会很有用。图 2 给出了每种语言与其他语言共有的 trigrams 数量。例如,英语和德语有 56 个最常见的 trigrams 是共同的。

我们发现西班牙语和葡萄牙语的共同 trigrams 最多,有 128 个共同的 trigrams。这是有道理的,因为在所有语言中,这两种语言在词汇上最相似。这意味着,使用这些特征,我们的模型可能很难区分西班牙语和葡萄牙语,反之亦然。同样,葡萄牙语和德语的共同 trigrams 最少,我们可以预期我们的模型在区分这些语言方面会更好。

图 2: trigrams 特征相似度图

建模

我们使用keras包来训练 DNN。模型的输出层使用 softmax 激活函数。这意味着我们必须将目标变量列表转换为 one-hot 编码列表。这可以通过下面的编码函数来实现。 该函数接收目标变量列表,并返回单次编码向量列表。 例如,[eng,por,por, fra,…] 将变为[[0,1,0,0,0,0],[0,0,0,0,1,0],[0,0,0,0,1,0],[0,0,1,0,0,0],…]。

from sklearn.preprocessing import LabelEncoder
from keras.utils import np_utils

# Fit encoder
encoder = LabelEncoder()
encoder.fit(['deu', 'eng', 'fra', 'ita', 'por', 'spa'])

def encode(y):
    """
    Returns a list of one hot encodings 
    
    Params
    ---------
        y: list of language labels
    """
    
    y_encoded = encoder.transform(y)
    y_dummy = np_utils.to_categorical(y_encoded)
    
    return y_dummy

在选择最终模型结构之前,我进行了一些超参数调整。我改变了隐藏层中的节点数、epoch 数和批处理大小。最终模型选择了在验证集上实现最高准确率的超参数组合。

最终模型有 3 个隐藏层,分别有 500、500 和 250 个节点。输出层有 6 个节点,每个语言一个。隐藏层都具有 ReLU 激活函数,并且如上所述,输出层具有 softmax 激活函数。我们使用 4 个 epoch 和 100 的批处理大小来训练此模型。使用我们的训练集和独热编码目标变量列表,我们在以下代码中训练此 DDN。最终,我们实现了 99.57% 的训练准确率。

from keras.models import Sequential
from keras.layers import Dense

#Get training data
x = train_feat.drop('lang',axis=1)
y = encode(train_feat['lang'])

#Define model
model = Sequential()
model.add(Dense(500, input_dim=661, activation='relu'))
model.add(Dense(500, activation='relu'))
model.add(Dense(250, activation='relu'))
model.add(Dense(6, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

#Train model
model.fit(x, y, epochs=4, batch_size=100)

模型评估

在模型训练过程中,模型可能会偏向训练集和验证集。因此,最好在未见过的测试集上确定模型准确率。测试集的最终准确率为 98.60%。这低于训练准确率 99.57%,表明发生了一些对训练集的过度拟合。

通过查看图 3 中的混淆矩阵,我们可以更好地了解模型对每种语言的表现。红色对角线表示每种语言的正确预测数。非对角线数字表示一种语言被错误预测为另一种语言的次数。例如,德语被错误预测为英语 5 次。我们发现,该模型最常将葡萄牙语混淆为西班牙语(78 次)或将西班牙语混淆为葡萄牙语(88 次)。这是我们在探索特征时看到的结果。

图 3:困惑热图

创建此混淆矩阵的代码如下所示。首先,我们使用上面训练的模型对测试集进行预测。使用这些预测语言和实际语言,我们创建一个混淆矩阵并使用 seaborn 热图对其进行可视化。

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, confusion_matrix
import numpy as np

# x_test 和 y_test 已经定义,并且 model 是一个已训练好的 Keras 模型
x_test = test_feat.drop('lang', axis=1)
y_test = test_feat['lang']

# Use model.predict to get probabilities
predictions_prob = model.predict(x_test)
# Find the index of the highest probability for each sample
labels = np.argmax(predictions_prob, axis=1)
predictions = encoder.inverse_transform(labels)

# Ensure y_test is a 1D array
if y_test.ndim > 1:
    y_test = np.argmax(y_test, axis=1)

# Accuracy on test set
accuracy = accuracy_score(y_test, predictions)
print(f"Accuracy: {accuracy}")

# Create confusion matrix
lang = ['deu', 'eng', 'fra', 'ita', 'por', 'spa']
conf_matrix = confusion_matrix(y_test, predictions)
conf_matrix_df = pd.DataFrame(conf_matrix, columns=lang, index=lang)

# Plot confusion matrix heatmap
plt.figure(figsize=(10, 10), facecolor='w', edgecolor='k')
sns.set(font_scale=1.5)
sns.heatmap(conf_matrix_df, cmap='coolwarm', annot=True, fmt='.5g', cbar=False)
plt.xlabel('Predicted', fontsize=22)
plt.ylabel('Actual', fontsize=22)

plt.savefig('../figures/model_eval.png', format='png', dpi=150)
plt.show()

最后,98.60% 的测试准确率仍有提升空间。在特征选择方面,我们保持简单,只为每种语言选择了 200 个最常见的 trigrams 。更复杂的方法可以帮助我们区分更相似的语言。例如,我们可以选择在西班牙语中很常见但在葡萄牙语中不太常见的 trigrams ,反之亦然。我们还可以尝试不同的模型。希望这对你的语言识别实验来说是一个良好的起点。

参考


  1. 茶桁的公开文章项目文件 https://github.com/hivandu/public_articles ↩︎

  2. Tatoeba 数据集 https://downloads.tatoeba.org/exports/ ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/774726.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

京东金融大数据分析平台总体架构:剖析和解读

京东金融大数据分析平台总体架构:剖析和解读 在现代金融行业中,大数据分析已成为决策支持和业务创新的重要工具。京东金融凭借其强大的大数据分析平台,成功地将海量数据转化为洞察力,为企业和用户提供优质服务。本文将深入探讨京…

浅谈反射机制

1. 何为反射? 反射(Reflection)机制指的是程序在运行的时候能够获取自身的信息。具体来说,反射允许程序在运行时获取关于自己代码的各种信息。如果知道一个类的名称或者它的一个实例对象, 就能把这个类的所有方法和变…

VMware替换关键技术:核心业务系统中,访存密集型应用的性能优化

越来越多用户采用虚拟化、超融合以及云平台环境来承载其核心业务,核心业务的高并发对性能的要求尤为严格,在VMware替换的热潮下,原VMware用户也更为关注新平台在核心业务上的性能表现是否对标,或实现超越。深信服将通过系列解析&a…

当心!不要在SpringBoot中再犯这样严重的错误

1. 简介 在Spring Boot中,Configuration注解用于声明配置类,以定义和注册Bean对象。这些Bean对象可以是普通的业务组件,也可以是特殊的处理器,如BeanPostProcessor或BeanFactoryPostProcessor,用于在Spring容器中对其…

OCC显示渲染性能分析及优化方案

1.背景介绍 君方智能设计平台(ShipMaker),使用OCC中的图形构造功能和图形渲染功能。OCC的图形渲染采用Opengl API 并且将所有图形渲染相关的逻辑放置在TKOpenGL模块中。 性能场景1: 大场景中包含2万个构件,超过300万三角面片时,…

景区智慧公厕解决方案,公厕革命新方式

在智慧旅游的浪潮下,景区智慧公厕解决方案正悄然引领着一场公厕革命,不仅革新了传统公厕的管理模式,更以智能化、人性化的服务理念,为游客提供了前所未有的舒适体验。作为智慧城市建设的重要一环,智慧公厕解决方案正逐…

跟《经济学人》学英文:2024年07月06日这期 Central banks are winning the battle against inflation

Central banks are winning the battle against inflation. But the war is just getting started Politics and protectionism will make life difficult 原文: The trajectory of inflation has not given central bankers much cause for celebration in rece…

时间同步协议详解:从原理到应用的全方位解析

作者介绍 随着信息技术的飞速发展,时间同步技术在通信、导航、电力等多个领域发挥着越来越重要的作用。从日常生活到高精尖的科学实验,精确的时间同步都是确保系统正常运行和任务成功完成的关键因素。本文将对几种主流的时间同步技术进行介绍和对比分析&…

剪画小程序:自媒体工具推荐:视频文案提取!

各位小伙伴,你们好啊! 上周五观看《歌手 2024》第八期时,我再次被何炅老师幽默风趣的主持风格所折服。他的每一句话都仿佛带着魔力,让现场气氛热烈非凡,实在令人羡慕不已! 何炅老师的口才之所以如此出色&a…

代码随想录算法训练营第四十四天|188.买卖股票的最佳时机IV、309.最佳买卖股票时机含冷冻期、714.买卖股票的最佳时机含手续费

188.买卖股票的最佳时机IV 题目链接:188.买卖股票的最佳时机IV 文档讲解:代码随想录 状态:不会 思路: 在股票买卖1使用一维dp的基础上,升级成二维的即可。 定义dp[k1][2],其中 dp[j][0] 表示第j次交易后持…

【Cadence18】如何放置定位孔

在菜单的place->manually会出现Placement对话框, 在Advanced settings中勾选database和library 然后点击Placement list,下拉框中选择Mechanical symbols,勾选你要的定位孔 (如下图的HOLE_1_6R00D2R70-PTH,注意:…

相关技术 检测离型纸

网盘 https://pan.baidu.com/s/1W-k4hl9uhjAG98hqJG11ug?pwdcrpn 离型无纺布.pdf 离型纸剥离机构.pdf 离型纸处理装置及贴胶设备.pdf 离型纸收集机构.pdf 离型纸涂布装置.pdf 防伪印刷离型纸的制造工艺.pdf

gitee代码初次上传步骤

ps. 前提是已经下载安装gitee 一、在本地项目目录下空白处右击,选择“Git Bash Here” 二、初始化 git init 三、添加、提交代码(注意add与点之间的空格) git add . git commit -m 添加注释 四、连接、推送到gitee仓库 git remote add …

E2.【C语言】练习:static部分

#include <stdio.h> int sum(int a) {int c 0;static int b 3;c 1;b 2;return (a b c); } int main() {int i;int a 2;for (i 0; i < 5;i){printf("%d ", sum(a));} } 求执行结果 c是auto类变量(普通的局部变量)&#xff0c;自动产生&#xff0c…

一个项目学习Vue3---Class和Style绑定

看下面一段代码学习此部分内容 <template><button click"stateChang">状态切换</button><div :class"{font-color:classObject.openColor,font-weight:classObject.openWeight}">颜色和粗细变化</div><div :class"…

Java中使用arima预测未来数据

看着已经存在的曲线图数据&#xff0c;想预估下后面曲线图的数据。 import java.util.Vector;public class AR {double[] stdoriginalData{};int p;ARMAMath armamathnew ARMAMath();/*** AR模型* param stdoriginalData* param p //p为MA模型阶数*/public AR(double [] stdori…

通证经济重塑经济格局

在数字化转型的全球浪潮中&#xff0c;通证经济模式犹如一股新兴力量&#xff0c;以其独特的价值传递与共享机制&#xff0c;重塑着经济格局&#xff0c;引领我们步入数字经济的新纪元。 通证&#xff0c;作为这一模式的核心&#xff0c;不仅是权利与权益的数字化凭证&#xf…

Netty学习(NIO基础)

NIO基础 三大组件 Channel and Buffer 常用的只有ByteBuffer Selector&#xff08;选择器&#xff09; 结合服务器的设计演化来理解Selector 多线程版设计 最早在nio设计出现前服务端程序的设计是多线程版设计,即一个客户端对应一个socket连接,一个连接用一个线程处理,每…

静力水准仪:测量与安装的全面指南

静力水准仪作为一种高精度的测量仪器&#xff0c;广泛应用于管廊、大坝、核电站、高层建筑、基坑、隧道、桥梁、地铁等工程领域&#xff0c;用于监测垂直位移和倾斜变化。本文将详细介绍静力水准仪的工作原理、特点及其安装过程中的注意事项&#xff0c;旨在为相关工程人员提供…

字符串函数5-9题(30 天 Pandas 挑战)

字符串函数 1. 相关知识点1.5 字符串的长度条件判断1.6 apply映射操作1.7 python大小写转换1.8 正则表达式匹配2.9 包含字符串查询 2. 题目2.5 无效的推文2.6 计算特殊奖金2.7 修复表中的名字2.8 查找拥有有效邮箱的用户2.9 患某种疾病的患者 1. 相关知识点 1.5 字符串的长度条…