深度学习模型搭建与训练_model.add(flatten())-程序员宅基地

技术标签: python  tensorflow  网络  机器学习  神经网络  

01 前情摘要
环境出现问题故借鉴github
前面的task2与task3讲解了音频数据的分析以及特征提取等内容,本次任务主要是讲解CNN模型的搭建与训练,由于模型训练需要用到之前的特侦提取等得让,于是在此再贴一下相关代码。

1.1 导包
In [1]:
#基本库
import pandas as pd
import numpy as np
pd.plotting.register_matplotlib_converters()
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import MinMaxScaler

#深度学习框架
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Flatten, Dense, MaxPool2D, Dropout
from tensorflow.keras.utils import to_categorical
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
import tensorflow as tf
import tensorflow.keras

#音频处理库
import os
import librosa
import librosa.display
import glob
/opt/conda/lib/python3.6/site-packages/sklearn/ensemble/weight_boosting.py:29: DeprecationWarning: numpy.core.umath_tests is an internal NumPy module and should not be imported. It will be removed in a future NumPy release.
from numpy.core.umath_tests import inner1d
1.2 特征提取以及数据集的建立
In [2]:
feature = []
label = []

建立类别标签,不同类别对应不同的数字。

label_dict = {‘aloe’: 0, ‘burger’: 1, ‘cabbage’: 2,‘candied_fruits’:3, ‘carrots’: 4, ‘chips’:5,
‘chocolate’: 6, ‘drinks’: 7, ‘fries’: 8, ‘grapes’: 9, ‘gummies’: 10, ‘ice-cream’:11,
‘jelly’: 12, ‘noodles’: 13, ‘pickles’: 14, ‘pizza’: 15, ‘ribs’: 16, ‘salmon’:17,
‘soup’: 18, ‘wings’: 19}
label_dict_inv = {v:k for k,v in label_dict.items()}
建立提取音频特征的函数

In [3]:
from tqdm import tqdm
def extract_features(parent_dir, sub_dirs, max_file=10, file_ext="*.wav"):
c = 0
label, feature = [], []
for sub_dir in sub_dirs:
for fn in tqdm(glob.glob(os.path.join(parent_dir, sub_dir, file_ext))[:max_file]): # 遍历数据集的所有文件

       # segment_log_specgrams, segment_labels = [], []
        #sound_clip,sr = librosa.load(fn)
        #print(fn)
        label_name = fn.split('/')[-2]
        label.extend([label_dict[label_name]])
        X, sample_rate = librosa.load(fn,res_type='kaiser_fast')
        mels = np.mean(librosa.feature.melspectrogram(y=X,sr=sample_rate).T,axis=0) # 计算梅尔频谱(mel spectrogram),并把它作为特征
        feature.extend([mels])
        
return [feature, label]

In [4]:

自己更改目录

parent_dir = ‘./train_sample/’
save_dir = “./”
folds = sub_dirs = np.array([‘aloe’,‘burger’,‘cabbage’,‘candied_fruits’,
‘carrots’,‘chips’,‘chocolate’,‘drinks’,‘fries’,
‘grapes’,‘gummies’,‘ice-cream’,‘jelly’,‘noodles’,‘pickles’,
‘pizza’,‘ribs’,‘salmon’,‘soup’,‘wings’])

获取特征feature以及类别的label

temp = extract_features(parent_dir,sub_dirs,max_file=100)
100%|██████████| 45/45 [00:11<00:00, 5.04it/s]
100%|██████████| 64/64 [00:14<00:00, 4.72it/s]
100%|██████████| 48/48 [00:15<00:00, 2.87it/s]
100%|██████████| 74/74 [00:26<00:00, 1.51it/s]
100%|██████████| 49/49 [00:14<00:00, 3.51it/s]
100%|██████████| 57/57 [00:16<00:00, 3.13it/s]
100%|██████████| 27/27 [00:07<00:00, 3.38it/s]
100%|██████████| 27/27 [00:07<00:00, 3.20it/s]
100%|██████████| 57/57 [00:15<00:00, 3.44it/s]
100%|██████████| 61/61 [00:17<00:00, 3.75it/s]
100%|██████████| 65/65 [00:20<00:00, 3.64it/s]
100%|██████████| 69/69 [00:21<00:00, 3.24it/s]
100%|██████████| 43/43 [00:12<00:00, 3.59it/s]
100%|██████████| 33/33 [00:08<00:00, 3.85it/s]
100%|██████████| 75/75 [00:23<00:00, 3.06it/s]
100%|██████████| 55/55 [00:17<00:00, 2.97it/s]
100%|██████████| 47/47 [00:14<00:00, 3.33it/s]
100%|██████████| 37/37 [00:11<00:00, 2.99it/s]
100%|██████████| 32/32 [00:07<00:00, 3.17it/s]
100%|██████████| 35/35 [00:10<00:00, 2.80it/s]
In [5]:
temp = np.array(temp)
data = temp.transpose()
/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify ‘dtype=object’ when creating the ndarray
“”"Entry point for launching an IPython kernel.
In [6]:

获取特征

X = np.vstack(data[:, 0])

获取标签

Y = np.array(data[:, 1])
print(‘X的特征尺寸是:’,X.shape)
print(‘Y的特征尺寸是:’,Y.shape)
X的特征尺寸是: (1000, 128)
Y的特征尺寸是: (1000,)
In [7]:

在Keras库中:to_categorical就是将类别向量转换为二进制(只有0和1)的矩阵类型表示

Y = to_categorical(Y)
In [8]:
‘’‘最终数据’’’
print(X.shape)
print(Y.shape)
(1000, 128)
(1000, 20)
In [9]:
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, random_state = 1, stratify=Y)
print(‘训练集的大小’,len(X_train))
print(‘测试集的大小’,len(X_test))
训练集的大小 750
测试集的大小 250
In [10]:
X_train = X_train.reshape(-1, 16, 8, 1)
X_test = X_test.reshape(-1, 16, 8, 1)
02 建立模型
2.1 深度学习框架
Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlow, CNTK, 或者 Theano 作为后端运行。现在Keras已经和TensorFlow合并,可以通过TensorFlow来调用。

2.1.1 网络结构搭建
Keras 的核心数据结构是 model,一种组织网络层的方式。最简单的模型是 Sequential 顺序模型,它由多个网络层线性堆叠。对于更复杂的结构,你应该使用 Keras 函数式 API,它允许构建任意的神经网络图。

Sequential模型可以直接通过如下方式搭建:

from keras.models import Sequential

model = Sequential()

In [11]:
model = Sequential()
2.1.2 搭建CNN网络
In [12]:

输入的大小

input_dim = (16, 8, 1)
2.1.3 CNN基础知识
推荐的资料中,我们推荐大家去看看李宏毅老师的讲的CNN网络这里也附上老师的PPT。

CNN网络的基本架构

图片.png

卷积神经网络CNN的结构一般包含这几个层:

1)输入层:用于数据的输入

2)卷积层:使用卷积核进行特征提取和特征映射------>可以多次重复使用

3)激励层:由于卷积也是一种线性运算,因此需要增加非线性映射(也就是激活函数)

4)池化层:进行下采样,对特征图稀疏处理,减少数据运算量----->可以多次重复使用

5)Flatten操作:将二维的向量,拉直为一维的向量,从而可以放入下一层的神经网络中

6)全连接层:通常在CNN的尾部进行重新拟合,减少特征信息的损失----->DNN网络

对于Keras操作中,可以简单地使用 .add() ,将需要搭建的神经网络的layer堆砌起来,像搭积木一样:

In [13]:
model.add(Conv2D(64, (3, 3), padding = “same”, activation = “tanh”, input_shape = input_dim))# 卷积层
model.add(MaxPool2D(pool_size=(2, 2)))# 最大池化
model.add(Conv2D(128, (3, 3), padding = “same”, activation = “tanh”)) #卷积层
model.add(MaxPool2D(pool_size=(2, 2))) # 最大池化层
model.add(Dropout(0.1))
model.add(Flatten()) # 展开
model.add(Dense(1024, activation = “tanh”))
model.add(Dense(20, activation = “softmax”)) # 输出层:20个units输出20个类的概率
如果需要,你还可以进一步地配置你的优化器.complies())。Keras 的核心原则是使事情变得相当简单,同时又允许用户在需要的时候能够进行完全的控制(终极的控制是源代码的易扩展性)。

In [14]:

编译模型,设置损失函数,优化方法以及评价标准

model.compile(optimizer = ‘adam’, loss = ‘categorical_crossentropy’, metrics = [‘accuracy’])
03 CNN模型训练与测试
3.1 模型训练
批量的在之前搭建的模型上训练:

In [15]:

训练模型

model.fit(X_train, Y_train, epochs = 90, batch_size = 50, validation_data = (X_test, Y_test))
Epoch 1/90
15/15 [] - 4s 162ms/step - loss: 2.9158 - accuracy: 0.1150 - val_loss: 2.7704 - val_accuracy: 0.1360
Epoch 2/90
15/15 [
] - 1s 79ms/step - loss: 2.5593 - accuracy: 0.2049 - val_loss: 2.5741 - val_accuracy: 0.2080
Epoch 3/90
15/15 [] - 1s 81ms/step - loss: 2.2751 - accuracy: 0.3186 - val_loss: 2.4504 - val_accuracy: 0.2520
Epoch 4/90
15/15 [
] - 1s 73ms/step - loss: 2.1422 - accuracy: 0.3405 - val_loss: 2.3872 - val_accuracy: 0.2680
Epoch 5/90
15/15 [] - 1s 76ms/step - loss: 1.9961 - accuracy: 0.3965 - val_loss: 2.3609 - val_accuracy: 0.2680
Epoch 6/90
15/15 [
] - 1s 81ms/step - loss: 1.8838 - accuracy: 0.4284 - val_loss: 2.4414 - val_accuracy: 0.2800
Epoch 7/90
15/15 [] - 1s 78ms/step - loss: 1.8825 - accuracy: 0.4246 - val_loss: 2.3338 - val_accuracy: 0.3000
Epoch 8/90
15/15 [
] - 1s 79ms/step - loss: 1.6590 - accuracy: 0.5138 - val_loss: 2.3595 - val_accuracy: 0.3000
Epoch 9/90
15/15 [] - 1s 81ms/step - loss: 1.5388 - accuracy: 0.5446 - val_loss: 2.4145 - val_accuracy: 0.3560
Epoch 10/90
15/15 [
] - 1s 78ms/step - loss: 1.4083 - accuracy: 0.5778 - val_loss: 2.3290 - val_accuracy: 0.3440
Epoch 11/90
15/15 [] - 1s 82ms/step - loss: 1.3643 - accuracy: 0.5991 - val_loss: 2.4037 - val_accuracy: 0.3320
Epoch 12/90
15/15 [
] - 1s 80ms/step - loss: 1.2137 - accuracy: 0.6554 - val_loss: 2.5388 - val_accuracy: 0.3280
Epoch 13/90
15/15 [] - 1s 81ms/step - loss: 1.1269 - accuracy: 0.6624 - val_loss: 2.5813 - val_accuracy: 0.3400
Epoch 14/90
15/15 [
] - 1s 82ms/step - loss: 1.1075 - accuracy: 0.6706 - val_loss: 2.6684 - val_accuracy: 0.3600
Epoch 15/90
15/15 [] - 1s 81ms/step - loss: 1.0023 - accuracy: 0.7228 - val_loss: 2.6690 - val_accuracy: 0.3560
Epoch 16/90
15/15 [
] - 1s 79ms/step - loss: 0.8535 - accuracy: 0.7726 - val_loss: 2.8743 - val_accuracy: 0.3560
Epoch 17/90
15/15 [] - 1s 79ms/step - loss: 0.8443 - accuracy: 0.7638 - val_loss: 2.8667 - val_accuracy: 0.3520
Epoch 18/90
15/15 [
] - 1s 74ms/step - loss: 0.7446 - accuracy: 0.7868 - val_loss: 2.9064 - val_accuracy: 0.3800
Epoch 19/90
15/15 [] - 1s 81ms/step - loss: 0.7462 - accuracy: 0.8109 - val_loss: 2.9133 - val_accuracy: 0.3960
Epoch 20/90
15/15 [
] - 1s 80ms/step - loss: 0.6511 - accuracy: 0.8115 - val_loss: 3.0051 - val_accuracy: 0.3920
Epoch 21/90
15/15 [] - 1s 81ms/step - loss: 0.6088 - accuracy: 0.8420 - val_loss: 3.0465 - val_accuracy: 0.3960
Epoch 22/90
15/15 [
] - 1s 82ms/step - loss: 0.5683 - accuracy: 0.8465 - val_loss: 3.0872 - val_accuracy: 0.3880
Epoch 23/90
15/15 [] - 1s 79ms/step - loss: 0.5308 - accuracy: 0.8631 - val_loss: 3.2055 - val_accuracy: 0.3840
Epoch 24/90
15/15 [
] - 1s 83ms/step - loss: 0.4531 - accuracy: 0.8883 - val_loss: 3.4579 - val_accuracy: 0.3800
Epoch 25/90
15/15 [] - 1s 79ms/step - loss: 0.4123 - accuracy: 0.8862 - val_loss: 3.3042 - val_accuracy: 0.4240
Epoch 26/90
15/15 [
] - 1s 81ms/step - loss: 0.3770 - accuracy: 0.9290 - val_loss: 3.5789 - val_accuracy: 0.4040
Epoch 27/90
15/15 [] - 1s 80ms/step - loss: 0.3534 - accuracy: 0.9032 - val_loss: 3.7284 - val_accuracy: 0.3880
Epoch 28/90
15/15 [
] - 2s 102ms/step - loss: 0.3603 - accuracy: 0.9148 - val_loss: 3.7052 - val_accuracy: 0.3920
Epoch 29/90
15/15 [] - 1s 74ms/step - loss: 0.2902 - accuracy: 0.9249 - val_loss: 3.7417 - val_accuracy: 0.4200
Epoch 30/90
15/15 [
] - 1s 84ms/step - loss: 0.2579 - accuracy: 0.9549 - val_loss: 3.7335 - val_accuracy: 0.4360
Epoch 31/90
15/15 [] - 1s 80ms/step - loss: 0.2289 - accuracy: 0.9466 - val_loss: 3.9161 - val_accuracy: 0.4240
Epoch 32/90
15/15 [
] - 1s 80ms/step - loss: 0.2180 - accuracy: 0.9463 - val_loss: 3.9646 - val_accuracy: 0.3960
Epoch 33/90
15/15 [] - 1s 81ms/step - loss: 0.2210 - accuracy: 0.9557 - val_loss: 4.0021 - val_accuracy: 0.4360
Epoch 34/90
15/15 [
] - 1s 81ms/step - loss: 0.2220 - accuracy: 0.9520 - val_loss: 3.9088 - val_accuracy: 0.4160
Epoch 35/90
15/15 [] - 1s 84ms/step - loss: 0.2364 - accuracy: 0.9426 - val_loss: 4.1504 - val_accuracy: 0.4120
Epoch 36/90
15/15 [
] - 1s 85ms/step - loss: 0.2370 - accuracy: 0.9434 - val_loss: 4.4365 - val_accuracy: 0.4200
Epoch 37/90
15/15 [] - 1s 84ms/step - loss: 0.2233 - accuracy: 0.9508 - val_loss: 4.2807 - val_accuracy: 0.4000
Epoch 38/90
15/15 [
] - 1s 79ms/step - loss: 0.1689 - accuracy: 0.9609 - val_loss: 4.5733 - val_accuracy: 0.4000
Epoch 39/90
15/15 [] - 1s 80ms/step - loss: 0.1540 - accuracy: 0.9628 - val_loss: 4.4454 - val_accuracy: 0.3920
Epoch 40/90
15/15 [
] - 1s 85ms/step - loss: 0.1966 - accuracy: 0.9445 - val_loss: 4.5280 - val_accuracy: 0.4120
Epoch 41/90
15/15 [] - 1s 85ms/step - loss: 0.1539 - accuracy: 0.9598 - val_loss: 4.7511 - val_accuracy: 0.4080
Epoch 42/90
15/15 [
] - 1s 79ms/step - loss: 0.1697 - accuracy: 0.9600 - val_loss: 4.5433 - val_accuracy: 0.4160
Epoch 43/90
15/15 [] - 1s 80ms/step - loss: 0.1852 - accuracy: 0.9558 - val_loss: 4.7979 - val_accuracy: 0.3920
Epoch 44/90
15/15 [
] - 1s 81ms/step - loss: 0.1319 - accuracy: 0.9735 - val_loss: 4.8103 - val_accuracy: 0.4120
Epoch 45/90
15/15 [] - 1s 81ms/step - loss: 0.1807 - accuracy: 0.9545 - val_loss: 4.5106 - val_accuracy: 0.4000
Epoch 46/90
15/15 [
] - 1s 78ms/step - loss: 0.1525 - accuracy: 0.9557 - val_loss: 4.6622 - val_accuracy: 0.4120
Epoch 47/90
15/15 [] - 1s 76ms/step - loss: 0.1094 - accuracy: 0.9735 - val_loss: 4.7476 - val_accuracy: 0.4240
Epoch 48/90
15/15 [
] - 1s 86ms/step - loss: 0.1285 - accuracy: 0.9639 - val_loss: 4.9710 - val_accuracy: 0.4120
Epoch 49/90
15/15 [] - 1s 74ms/step - loss: 0.1017 - accuracy: 0.9834 - val_loss: 4.7824 - val_accuracy: 0.4120
Epoch 50/90
15/15 [
] - 1s 81ms/step - loss: 0.1118 - accuracy: 0.9808 - val_loss: 5.0023 - val_accuracy: 0.4000
Epoch 51/90
15/15 [] - 1s 84ms/step - loss: 0.0734 - accuracy: 0.9847 - val_loss: 4.9060 - val_accuracy: 0.4440
Epoch 52/90
15/15 [
] - 1s 80ms/step - loss: 0.0770 - accuracy: 0.9823 - val_loss: 4.9116 - val_accuracy: 0.4320
Epoch 53/90
15/15 [] - 1s 82ms/step - loss: 0.0883 - accuracy: 0.9778 - val_loss: 5.0644 - val_accuracy: 0.4240
Epoch 54/90
15/15 [
] - 1s 73ms/step - loss: 0.0669 - accuracy: 0.9899 - val_loss: 4.9008 - val_accuracy: 0.4400
Epoch 55/90
15/15 [] - 1s 82ms/step - loss: 0.0530 - accuracy: 0.9905 - val_loss: 4.9777 - val_accuracy: 0.4320
Epoch 56/90
15/15 [
] - 1s 82ms/step - loss: 0.0622 - accuracy: 0.9921 - val_loss: 4.9766 - val_accuracy: 0.4440
Epoch 57/90
15/15 [] - 1s 86ms/step - loss: 0.0494 - accuracy: 0.9867 - val_loss: 5.1327 - val_accuracy: 0.4400
Epoch 58/90
15/15 [
] - 1s 87ms/step - loss: 0.0750 - accuracy: 0.9840 - val_loss: 5.2465 - val_accuracy: 0.4360
Epoch 59/90
15/15 [] - 1s 80ms/step - loss: 0.0760 - accuracy: 0.9803 - val_loss: 5.1679 - val_accuracy: 0.4120
Epoch 60/90
15/15 [
] - 2s 114ms/step - loss: 0.0773 - accuracy: 0.9776 - val_loss: 5.3310 - val_accuracy: 0.3960
Epoch 61/90
15/15 [] - 1s 80ms/step - loss: 0.0564 - accuracy: 0.9856 - val_loss: 5.1986 - val_accuracy: 0.4200
Epoch 62/90
15/15 [
] - 1s 81ms/step - loss: 0.0642 - accuracy: 0.9877 - val_loss: 5.2850 - val_accuracy: 0.3880
Epoch 63/90
15/15 [] - 1s 86ms/step - loss: 0.1085 - accuracy: 0.9804 - val_loss: 5.6972 - val_accuracy: 0.3920
Epoch 64/90
15/15 [
] - 1s 81ms/step - loss: 0.1160 - accuracy: 0.9661 - val_loss: 5.7879 - val_accuracy: 0.3840
Epoch 65/90
15/15 [] - 1s 80ms/step - loss: 0.1378 - accuracy: 0.9759 - val_loss: 5.5282 - val_accuracy: 0.4200
Epoch 66/90
15/15 [
] - 1s 80ms/step - loss: 0.1800 - accuracy: 0.9459 - val_loss: 5.7916 - val_accuracy: 0.3960
Epoch 67/90
15/15 [] - 1s 80ms/step - loss: 0.1467 - accuracy: 0.9514 - val_loss: 5.8140 - val_accuracy: 0.4120
Epoch 68/90
15/15 [
] - 1s 80ms/step - loss: 0.1248 - accuracy: 0.9747 - val_loss: 5.6973 - val_accuracy: 0.4200
Epoch 69/90
15/15 [] - 1s 77ms/step - loss: 0.0927 - accuracy: 0.9795 - val_loss: 5.4326 - val_accuracy: 0.4640
Epoch 70/90
15/15 [
] - 1s 83ms/step - loss: 0.0611 - accuracy: 0.9796 - val_loss: 5.6963 - val_accuracy: 0.4160
Epoch 71/90
15/15 [] - 1s 80ms/step - loss: 0.0786 - accuracy: 0.9800 - val_loss: 5.8339 - val_accuracy: 0.4280
Epoch 72/90
15/15 [
] - 1s 73ms/step - loss: 0.0620 - accuracy: 0.9864 - val_loss: 5.6282 - val_accuracy: 0.4400
Epoch 73/90
15/15 [] - 1s 79ms/step - loss: 0.0545 - accuracy: 0.9852 - val_loss: 5.4416 - val_accuracy: 0.4440
Epoch 74/90
15/15 [
] - 1s 74ms/step - loss: 0.0414 - accuracy: 0.9938 - val_loss: 5.6265 - val_accuracy: 0.4120
Epoch 75/90
15/15 [] - 1s 87ms/step - loss: 0.0502 - accuracy: 0.9837 - val_loss: 5.3705 - val_accuracy: 0.4560
Epoch 76/90
15/15 [
] - 1s 79ms/step - loss: 0.0462 - accuracy: 0.9899 - val_loss: 5.6978 - val_accuracy: 0.4320
Epoch 77/90
15/15 [] - 1s 77ms/step - loss: 0.0519 - accuracy: 0.9870 - val_loss: 5.7476 - val_accuracy: 0.4160
Epoch 78/90
15/15 [
] - 1s 79ms/step - loss: 0.0314 - accuracy: 0.9936 - val_loss: 5.9432 - val_accuracy: 0.4240
Epoch 79/90
15/15 [] - 1s 81ms/step - loss: 0.0422 - accuracy: 0.9861 - val_loss: 5.7963 - val_accuracy: 0.4000
Epoch 80/90
15/15 [
] - 1s 80ms/step - loss: 0.0473 - accuracy: 0.9871 - val_loss: 5.9414 - val_accuracy: 0.4280
Epoch 81/90
15/15 [] - 1s 80ms/step - loss: 0.0385 - accuracy: 0.9920 - val_loss: 5.9808 - val_accuracy: 0.4360
Epoch 82/90
15/15 [
] - 1s 73ms/step - loss: 0.0263 - accuracy: 0.9975 - val_loss: 5.8779 - val_accuracy: 0.4280
Epoch 83/90
15/15 [] - 1s 80ms/step - loss: 0.0227 - accuracy: 0.9983 - val_loss: 5.7883 - val_accuracy: 0.4360
Epoch 84/90
15/15 [
] - 1s 75ms/step - loss: 0.0258 - accuracy: 0.9945 - val_loss: 5.8290 - val_accuracy: 0.4280
Epoch 85/90
15/15 [] - 1s 85ms/step - loss: 0.0194 - accuracy: 0.9967 - val_loss: 5.7754 - val_accuracy: 0.4280
Epoch 86/90
15/15 [
] - 1s 86ms/step - loss: 0.0261 - accuracy: 0.9884 - val_loss: 5.7649 - val_accuracy: 0.4280
Epoch 87/90
15/15 [] - 1s 84ms/step - loss: 0.0240 - accuracy: 0.9927 - val_loss: 5.8440 - val_accuracy: 0.4320
Epoch 88/90
15/15 [
] - 1s 79ms/step - loss: 0.0196 - accuracy: 0.9913 - val_loss: 5.9228 - val_accuracy: 0.4280
Epoch 89/90
15/15 [] - 1s 74ms/step - loss: 0.0182 - accuracy: 0.9971 - val_loss: 5.9385 - val_accuracy: 0.4200
Epoch 90/90
15/15 [
] - 1s 80ms/step - loss: 0.0180 - accuracy: 0.9986 - val_loss: 5.9088 - val_accuracy: 0.4320
Out[15]:
<tensorflow.python.keras.callbacks.History at 0x7fca7ad7a6d8>
查看网络的统计信息

In [16]:
model.summary()
Model: “sequential”


Layer (type) Output Shape Param #

conv2d (Conv2D) (None, 16, 8, 64) 640


max_pooling2d (MaxPooling2D) (None, 8, 4, 64) 0


conv2d_1 (Conv2D) (None, 8, 4, 128) 73856


max_pooling2d_1 (MaxPooling2 (None, 4, 2, 128) 0


dropout (Dropout) (None, 4, 2, 128) 0


flatten (Flatten) (None, 1024) 0


dense (Dense) (None, 1024) 1049600


dense_1 (Dense) (None, 20) 20500

Total params: 1,144,596
Trainable params: 1,144,596
Non-trainable params: 0


3.2 预测测试集
新的数据生成预测

In [19]:
def extract_features(test_dir, file_ext="*.wav"):
feature = []
for fn in tqdm(glob.glob(os.path.join(test_dir, file_ext))[:]): # 遍历数据集的所有文件
X, sample_rate = librosa.load(fn,res_type=‘kaiser_fast’)
mels = np.mean(librosa.feature.melspectrogram(y=X,sr=sample_rate).T,axis=0) # 计算梅尔频谱(mel spectrogram),并把它作为特征
feature.extend([mels])
return feature
保存预测的结果

In [20]:
X_test = extract_features(’./test_a/’)
100%|██████████| 2000/2000 [10:13<00:00, 3.56it/s]
In [21]:
X_test = np.vstack(X_test)
predictions = model.predict(X_test.reshape(-1, 16, 8, 1))
In [22]:
preds = np.argmax(predictions, axis = 1)
preds = [label_dict_inv[x] for x in preds]

path = glob.glob(’./test_a/*.wav’)
result = pd.DataFrame({‘name’:path, ‘label’: preds})

result[‘name’] = result[‘name’].apply(lambda x: x.split(’/’)[-1])
result.to_csv(‘submit.csv’,index=None)
In [23]:
!ls ./test_a/*.wav | wc -l
2000
In [24]:
!wc -l submit.csv
2001 submit.csv

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_41920295/article/details/115878169

智能推荐

稀疏编码的数学基础与理论分析-程序员宅基地

文章浏览阅读290次,点赞8次,收藏10次。1.背景介绍稀疏编码是一种用于处理稀疏数据的编码技术,其主要应用于信息传输、存储和处理等领域。稀疏数据是指数据中大部分元素为零或近似于零的数据,例如文本、图像、音频、视频等。稀疏编码的核心思想是将稀疏数据表示为非零元素和它们对应的位置信息,从而减少存储空间和计算复杂度。稀疏编码的研究起源于1990年代,随着大数据时代的到来,稀疏编码技术的应用范围和影响力不断扩大。目前,稀疏编码已经成为计算...

EasyGBS国标流媒体服务器GB28181国标方案安装使用文档-程序员宅基地

文章浏览阅读217次。EasyGBS - GB28181 国标方案安装使用文档下载安装包下载,正式使用需商业授权, 功能一致在线演示在线API架构图EasySIPCMSSIP 中心信令服务, 单节点, 自带一个 Redis Server, 随 EasySIPCMS 自启动, 不需要手动运行EasySIPSMSSIP 流媒体服务, 根..._easygbs-windows-2.6.0-23042316使用文档

【Web】记录巅峰极客2023 BabyURL题目复现——Jackson原生链_原生jackson 反序列化链子-程序员宅基地

文章浏览阅读1.2k次,点赞27次,收藏7次。2023巅峰极客 BabyURL之前AliyunCTF Bypassit I这题考查了这样一条链子:其实就是Jackson的原生反序列化利用今天复现的这题也是大同小异,一起来整一下。_原生jackson 反序列化链子

一文搞懂SpringCloud,详解干货,做好笔记_spring cloud-程序员宅基地

文章浏览阅读734次,点赞9次,收藏7次。微服务架构简单的说就是将单体应用进一步拆分,拆分成更小的服务,每个服务都是一个可以独立运行的项目。这么多小服务,如何管理他们?(服务治理 注册中心[服务注册 发现 剔除])这么多小服务,他们之间如何通讯?这么多小服务,客户端怎么访问他们?(网关)这么多小服务,一旦出现问题了,应该如何自处理?(容错)这么多小服务,一旦出现问题了,应该如何排错?(链路追踪)对于上面的问题,是任何一个微服务设计者都不能绕过去的,因此大部分的微服务产品都针对每一个问题提供了相应的组件来解决它们。_spring cloud

Js实现图片点击切换与轮播-程序员宅基地

文章浏览阅读5.9k次,点赞6次,收藏20次。Js实现图片点击切换与轮播图片点击切换<!DOCTYPE html><html> <head> <meta charset="UTF-8"> <title></title> <script type="text/ja..._点击图片进行轮播图切换

tensorflow-gpu版本安装教程(过程详细)_tensorflow gpu版本安装-程序员宅基地

文章浏览阅读10w+次,点赞245次,收藏1.5k次。在开始安装前,如果你的电脑装过tensorflow,请先把他们卸载干净,包括依赖的包(tensorflow-estimator、tensorboard、tensorflow、keras-applications、keras-preprocessing),不然后续安装了tensorflow-gpu可能会出现找不到cuda的问题。cuda、cudnn。..._tensorflow gpu版本安装

随便推点

物联网时代 权限滥用漏洞的攻击及防御-程序员宅基地

文章浏览阅读243次。0x00 简介权限滥用漏洞一般归类于逻辑问题,是指服务端功能开放过多或权限限制不严格,导致攻击者可以通过直接或间接调用的方式达到攻击效果。随着物联网时代的到来,这种漏洞已经屡见不鲜,各种漏洞组合利用也是千奇百怪、五花八门,这里总结漏洞是为了更好地应对和预防,如有不妥之处还请业内人士多多指教。0x01 背景2014年4月,在比特币飞涨的时代某网站曾经..._使用物联网漏洞的使用者

Visual Odometry and Depth Calculation--Epipolar Geometry--Direct Method--PnP_normalized plane coordinates-程序员宅基地

文章浏览阅读786次。A. Epipolar geometry and triangulationThe epipolar geometry mainly adopts the feature point method, such as SIFT, SURF and ORB, etc. to obtain the feature points corresponding to two frames of images. As shown in Figure 1, let the first image be ​ and th_normalized plane coordinates

开放信息抽取(OIE)系统(三)-- 第二代开放信息抽取系统(人工规则, rule-based, 先抽取关系)_语义角色增强的关系抽取-程序员宅基地

文章浏览阅读708次,点赞2次,收藏3次。开放信息抽取(OIE)系统(三)-- 第二代开放信息抽取系统(人工规则, rule-based, 先关系再实体)一.第二代开放信息抽取系统背景​ 第一代开放信息抽取系统(Open Information Extraction, OIE, learning-based, 自学习, 先抽取实体)通常抽取大量冗余信息,为了消除这些冗余信息,诞生了第二代开放信息抽取系统。二.第二代开放信息抽取系统历史第二代开放信息抽取系统着眼于解决第一代系统的三大问题: 大量非信息性提取(即省略关键信息的提取)、_语义角色增强的关系抽取

10个顶尖响应式HTML5网页_html欢迎页面-程序员宅基地

文章浏览阅读1.1w次,点赞6次,收藏51次。快速完成网页设计,10个顶尖响应式HTML5网页模板助你一臂之力为了寻找一个优质的网页模板,网页设计师和开发者往往可能会花上大半天的时间。不过幸运的是,现在的网页设计师和开发人员已经开始共享HTML5,Bootstrap和CSS3中的免费网页模板资源。鉴于网站模板的灵活性和强大的功能,现在广大设计师和开发者对html5网站的实际需求日益增长。为了造福大众,Mockplus的小伙伴整理了2018年最..._html欢迎页面

计算机二级 考试科目,2018全国计算机等级考试调整,一、二级都增加了考试科目...-程序员宅基地

文章浏览阅读282次。原标题:2018全国计算机等级考试调整,一、二级都增加了考试科目全国计算机等级考试将于9月15-17日举行。在备考的最后冲刺阶段,小编为大家整理了今年新公布的全国计算机等级考试调整方案,希望对备考的小伙伴有所帮助,快随小编往下看吧!从2018年3月开始,全国计算机等级考试实施2018版考试大纲,并按新体系开考各个考试级别。具体调整内容如下:一、考试级别及科目1.一级新增“网络安全素质教育”科目(代..._计算机二级增报科目什么意思

conan简单使用_apt install conan-程序员宅基地

文章浏览阅读240次。conan简单使用。_apt install conan