如何将实时音频实时馈送到keras模型?

如何解决如何将实时音频实时馈送到keras模型?

我正在尝试创建声控目标喇叭。我已经创建并训练了CNN模型,但是我不知道如何使用它对实时数据进行预测。

我想根据内置麦克风的音频的最后一秒钟做出实时班级预测,以将最新的音频分类为“ YES_GOAL”,“ YES_WIN”或“ NO_GOAL”。

我的项目的最终目标是每次我尖叫“ GOAL!”时都可以在iTunes上播放目标号角。

当我尝试运行代码时,我得到

ValueError:密集层的输入0与该层不兼容:预期输入形状的轴-1的值为2200,但接收到形状为[32,1]的输入

到目前为止,这是我的代码:

import pyaudio
import librosa
import numpy as np
import time
import subprocess
import os
import sys
#import kbHitMod
import tensorflow.keras as keras

MODEL_PATH = "/Users/schoolwork/Documents/Goal_Horn_Project_Stuff/Goal Horn Program/Goal_Model.model"

GOAL_TRACK = "1 New York Islanders Overtime Goal and Win Horn || NYCB Live: Home of the Nassau Veterans Memorial Coliseum"
WIN_TRACK = "2 New York Islanders Win Horn || NYCB Live: Home of the Nassau Veterans Memorial Coliseum"
OT_GOAL_TRACK = "3 New York Islanders Goal Horn || NYCB Live Home of the Nassau Veterans Memorial Coliseum"
QUIET_TRACK = "4 pure silence"

PAUSE_COMMAND = "osascript -e 'tell application \"iTunes\" to pause'"

class RingBuffer:
    """ class that implements a not-yet-full buffer """
    def __init__(self,size_max):
        self.max = size_max
        self.data = []

    class __Full:
        """ class that implements a full buffer """
        def append(self,x):
            """ Append an element overwriting the oldest one. """
            self.data[self.cur] = x
            self.cur = (self.cur+1) % self.max
        def get(self):
            """ return list of elements in correct order """
            return self.data[self.cur:]+self.data[:self.cur]

    def append(self,x):
        """append an element at the end of the buffer"""
        self.data.append(x)
        if len(self.data) == self.max:
            self.cur = 0
            # Permanently change self's class from non-full to full
            self.__class__ = self.__Full

    def get(self):
        """ Return a list of elements from the oldest to the newest. """
        return self.data

# ring buffer will keep the last 1 second worth of audio
ringBuffer = RingBuffer(1 * 22050)

overtime = False
print("\nOvertime mode: off\n")

def play(track_name):
    subprocess.getoutput("osascript -e 'tell application \"iTunes\" to play (first track of playlist \"Library\" whose name is \"4 pure silence\")'")

    subprocess.getoutput("osascript -e 'tell application \"iTunes\" to play (first track of playlist \"Library\" whose name is \"" + track_name + "\")'")

def callback(in_data,frame_count,time_info,flag):
   
    state = subprocess.getoutput("osascript -e 'tell application \"iTunes\" to player state as string'")

    model = keras.models.load_model(MODEL_PATH,compile=True)

    audio_data = np.fromstring(in_data,dtype=np.float32)
    
    # we trained on audio with a sample rate of 22050 so we need to convert it
    audio_data = librosa.resample(audio_data,44100,22050)
    ringBuffer.append(audio_data)

    # machine learning model takes live audio as input and
    # decides if the last 1 second of audio contains a goal
    if model.predict_classes(ringBuffer.get()) == "YES_GOAL" and state == "paused":
        # GOAL!! 
        if overtime:
            play(GOAL_TRACK)
        else:
            play(OT_GOAL_TRACK)
              
        # decides if the last 1 second of audio contains a win
    elif model.predict_classes(ringBuffer.get()) == "YES_WIN" and state == "paused":
        play(WIN_TRACK)

    return (in_data,pyaudio.paContinue)

pa = pyaudio.PyAudio()

stream = pa.open(format = pyaudio.paFloat32,channels = 1,rate = 44100,output = False,input = True,stream_callback=callback)

# start the stream
stream.start_stream()

i = 0 # This is just an alternative to breaking the loop with kbHitMod

while stream.is_active():
    time.sleep(0.25)
    
    """
    kb = kbHitMod.KBHit() # detects if a key has been pressed

    ot = kb.getch()
    if ot == "o":
        if overtime == False:
            overtime = True
            print("Overtime mode: ON\n")
        else:
            overtime = False
            print("Overtime mode: off\n")
    elif ot == "q":
        print("Quitting... Goodbye!\n")
        break
    """
    i += 1

    if i >= 100:
        break

            
stream.close()
pa.terminate()

play(QUIET_TRACK)
subprocess.getoutput(PAUSE_COMMAND)
print("Program terminated. \n")

我的模特:

import json
import numpy as np
from sklearn.model_selection import train_test_split
import tensorflow.keras as keras

DATA_PATH = "/Users/schoolwork/Documents/Goal_Horn_Project_Stuff/Goal Horn Program/data.json"

MODEL_PATH = "/Users/schoolwork/Documents/Goal_Horn_Project_Stuff/Goal Horn Program/Goal_Model.model"

def load_data(data_path):
    """Loads training dataset from json file.

        :param data_path (str): Path to json file containing data
        :return X (ndarray): Inputs
        :return y (ndarray): Targets
    """

    with open(data_path,"r") as fp:
        data = json.load(fp)

        X = np.array(data["mfcc"])
        y = np.array(data["labels"])
        return X,y

def prepare_datasets(test_size,validation_size):

    # load data
    X,y = load_data(DATA_PATH)

    # create train/test split
    X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=test_size)

    # create train/validation split
    X_train,X_validation,y_validation = train_test_split(X_train,test_size=validation_size)

    # 3d array -> (130,50,1)
    X_train = X_train[...,np.newaxis] # 4d array -> (num_samples,130,1) (I don't know where these numbers are coming from. They might not be right)
    X_validation = X_validation[...,np.newaxis]
    X_test = X_test[...,np.newaxis]

    return X_train,y_validation,y_test

def build_model(input_shape):

    # create model
    model = keras.Sequential()

    # 1st conv layer
    model.add(keras.layers.Conv2D(32,(3,3),activation="relu",input_shape=input_shape))
    model.add(keras.layers.MaxPool2D((3,strides=(2,2),padding="same"))
    model.add(keras.layers.BatchNormalization())

    # 2nd conv layer
    model.add(keras.layers.Conv2D(32,padding="same"))
    model.add(keras.layers.BatchNormalization())

    # 3rd conv layer
    model.add(keras.layers.Conv2D(32,(2,input_shape=input_shape))
    model.add(keras.layers.MaxPool2D((2,padding="same"))
    model.add(keras.layers.BatchNormalization())

    # flatten the output and feed it into dense layer
    model.add(keras.layers.Flatten())
    model.add(keras.layers.Dense(64,activation="relu"))
    model.add(keras.layers.Dropout(0.3))

    # output layer
    model.add(keras.layers.Dense(3,activation="softmax"))

    return model

def predict(model,X,y):

    X = X[np.newaxis,...]

    # prediction = [ [0.1,0.2,...] ]
    prediction = model.predict(X) # X -> (1,1)

    # extract index with max_value
    predicted_index = np.argmax(prediction,axis=-1) # [4]
    print("Expected index: {},Predicted index: {}".format(y,predicted_index))


if __name__ == "__main__":
    # create train,validation and test sets
    X_train,y_test = prepare_datasets(0.25,0.2)

    # build the CNN net
    input_shape = (X_train.shape[1],X_train.shape[2],X_train.shape[3])
    model = build_model(input_shape)

    # compile the network
    optimizer = keras.optimizers.Adam(learning_rate=0.0001)
    model.compile(optimizer=optimizer,loss="sparse_categorical_crossentropy",metrics=["accuracy"])

    # train the CNN
    model.fit(X_train,validation_data=(X_validation,y_validation),batch_size=32,epochs=30)

    # evaluate the CNN on the test set
    test_error,test_accuracy = model.evaluate(X_test,y_test,verbose=1)
    print("Accuracy on test set is: {}".format(test_accuracy))

    # make prediction on a sample
    X = X_test[2]
    y = y_test[2]

    print(X_test.shape)

    predict(model,y)

    model.save(MODEL_PATH)

解决方法

结果证明这非常简单。

使用模型路径,我所要做的就是将我的数据像函数的参数一样输入模型:

lambda x: [10**i for i,a in enumerate(str(x)[::-1]) for _ in range(int(a))]

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


依赖报错 idea导入项目后依赖报错,解决方案:https://blog.csdn.net/weixin_42420249/article/details/81191861 依赖版本报错:更换其他版本 无法下载依赖可参考:https://blog.csdn.net/weixin_42628809/a
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下 2021-12-03 13:33:33.927 ERROR 7228 [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPL
错误1:gradle项目控制台输出为乱码 # 解决方案:https://blog.csdn.net/weixin_43501566/article/details/112482302 # 在gradle-wrapper.properties 添加以下内容 org.gradle.jvmargs=-Df
错误还原:在查询的过程中,传入的workType为0时,该条件不起作用 <select id="xxx"> SELECT di.id, di.name, di.work_type, di.updated... <where> <if test=&qu
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct redisServer’没有名为‘server_cpulist’的成员 redisSetCpuAffinity(server.server_cpulist); ^ server.c: 在函数‘hasActiveC
解决方案1 1、改项目中.idea/workspace.xml配置文件,增加dynamic.classpath参数 2、搜索PropertiesComponent,添加如下 <property name="dynamic.classpath" value="tru
删除根组件app.vue中的默认代码后报错:Module Error (from ./node_modules/eslint-loader/index.js): 解决方案:关闭ESlint代码检测,在项目根目录创建vue.config.js,在文件中添加 module.exports = { lin
查看spark默认的python版本 [root@master day27]# pyspark /home/software/spark-2.3.4-bin-hadoop2.7/conf/spark-env.sh: line 2: /usr/local/hadoop/bin/hadoop: No s
使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -> systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping("/hires") public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate<String
使用vite构建项目报错 C:\Users\ychen\work>npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-