使用深度学习建立一个击掌计数器
本文来自布莱恩·道格拉斯,YouTube内容创建者,用于控制系统和深度学习应用程序
大约十年来,我想实施这个愚蠢的想法,我已经测量了一个人的手的加速,以计算他们全天五个高的次数。我不确定如何使用基于规则的算法开发来实现这一目标,即我熟悉的算法,因此该项目坐在暂停。只是在我制作的时候关于深入学习的MATLAB技术谈话视频系列我意识到深度学习非常适合解决这个问题!
该系列中第4个视频的主题是转移学习,结果是我需要我需要快速获得高五个计数算法的关键概念。在这个博客文章中,我将通过我写的代码的详细信息以及我用来为该视频提供我的高柜台的工具。希望你可以用它作为一个起点来解决你过去10年的困难的分类问题也是如此。
此帖子分为以下部分:
那么,让我们开始吧!
硬件概述
硬件设置非常简单。我有一个加速计,它连接到Arduino Uno.通过I2C总线。然后Arduino通过USB连接到我的电脑。
要感到加速,我正在使用MPU-9250..这是来自TDK Invensense的9度自由惯性测量单元。我使用a而不是将芯片集成到我自己的自定义电路设计中突发板暴露电源、接地和I2C通信引脚。我使用这个芯片的唯一原因是我已经有了一个,但任何加速度计都可以工作,只要它足够小,可以用手快速移动。
您可以看到我的硬件设置非常粗糙地构建面包板和一些跳线,但我认为你不需要设置任何太想要的东西,这有点好。
在Matlab中读取加速度计
为了通过Arduino读取MPU-9250的加速,我正在使用MATLAB支金宝app持包装arduino硬件.此包允许您与Arduino通信,而无需编译代码。另外,有一个内置的mpu9250功能,允许您用一行命令读取传感器。
它只需要三行代码来连接到Arduino,实例化MPU9250对象,并读取加速度计。
数据预处理和标量程
如果您观看了关于深度学习的技术讲座系列的第四个视频,您就会知道我选择将三轴加速度数据转换为图像以利用googlenet.-经过训练以识别图像的网络。特别是,我使用了连续小波变换创建一个缩放图.
标度图是一种时间-频率表示,适用于存在于多个标度的信号。也就是说,信号是低频的,变化缓慢,但偶尔会被高频瞬变中断。事实证明,它们对于可视化加速度数据很有用,因为在缓慢移动的手中,偶尔会出现高频击掌。
我用于使上面绘图的MATLAB代码的清洁版本在以下可折叠块中。
全关%如果您的计算机无法实时运行此操作,请减少示例%率或注释掉标量图部分FS = 50;%在50 Hz处运行a = arduino('com3','uno','图书馆','i2c');对您的Arduino的百分比变更IMU = MPU9250(A);buffer_length_sec = 2;要存储在缓冲区中的数据的秒数Accel = Zeros(Boary(Buffer_Length_sec * FS)+ 1,3);%初始化缓冲t=0:1/fs:(缓冲区长度秒(结束));%的时间向量子图(2,1,1)plot_accel = plot(t,Accel);建立加速图轴([0,buffer_length_sec,-50,50]);子图(2,1,2)plot_scale =图像(零(224,224,3));%设置标度图抽搐%启动计时器last_read_time = 0;我= 0;跑20秒虽然(toc <= 20)current_read_time = toc;If (current_read_time - last_read_time) >= 1/fs I = I + 1;Accel (1:end-1,:) = Accel (2:end,:);FIFO缓冲区中的换档值Accel(结束,:)= ReadAcceleration(IMU);plot_accel(1).ydata = Accel(:,1);plot_accel(2).ydata = Accel(:,2);plot_accel(3).ydata = Accel(:,3);%仅在每三个样本中运行缩放标准以保存计算时间如果mod(i,3)== 0 fb = cwtfilterbank('signallength',length(t),'samplingfricquency',fs,...'voicesperoctave',12);sig = Accel(:,1);[cfs,〜] = wt(fb,sig);cfs_abs = abs(cfs);Accel_i = IMResize(CFS_ABS / 8,[224 224]);fb = cwtfilterbank('signallength',length(t),'samplingfrice',fs,...'voicesperoctave',12);SIG = ACCEL(:,2);[cfs,〜] = wt(fb,sig);cfs_abs = abs(cfs);Accel_i(:,:,2)= Imresize(CFS_ABS / 8,[224 224]); fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs, ... 'VoicesPerOctave', 12); sig = accel(:, 3); [cfs, ~] = wt(fb, sig); cfs_abs = abs(cfs); accel_i(:, :, 3) = imresize(cfs_abs/8, [224 224]); if~(isempty(accel_i(accel_i>1))) accel_i(accel_i>1) = 1; end plot_scale.CData = accel_i; end last_read_time = current_read_time; end end
请注意,此代码使用一个名为cwtfilterbank来创建标量图,这是小波工具箱.如果您无权访问此工具箱,并且您不想自己编写代码,请尝试另一种类型的时间频率可视化拍摄。也许A.光谱图将工作或一些您提出的其他算法。无论你选择什么,这里的想法是我们试图创建一个图像,这将使高五个模式的唯一和可识别的功能明显。我已经表明,缩放结构工作,但其他方法也可能工作。
创建培训数据
为了训练人际网络识别击掌,我们需要多个击掌的例子来说明击掌是什么样子,击掌不是什么样子。因为我们将从一个预先培训过的网络开始,所以我们不需要像从头开始培训网络那样多的培训示例。我不知道到底需要多少培训数据才能完全捕获所有可能的击掌游戏的解决方案空间,但是,我收集了100个击掌游戏和100个非击掌游戏的数据,这似乎效果很好。我怀疑我制作的视频可以少一些,但我想如果我真的在制作一个产品,我会用更多的例子。您可以处理标记的训练数据量,并查看其对结果的影响。
收集200张图片似乎很多工作,但我写了一个脚本,一个脚本一个接一个地通过它们循环,并在适当的文件夹中保存图像。我两次运行以下脚本;一旦使用“HIGH_FIVE”标签,并将图像保存到数据/高_五文件夹一旦使用“no_high_five”标签,并将图像保存到数据/ no_high_five.文件夹。
%此脚本收集训练数据并将其放置在指定的位置%标签子文件夹。从中收集3秒的数据%传感器,但只保留和保存最后2秒。%这给了用户一些缓冲时间来启动high five。%程序在图像之间暂停并提示用户继续。%注意,您需要将图形从MATLAB窗口移开,以便%您可以看到加速后,您响应等待提示。关闭全部清除全部%如果您的计算机不能实时运行,请降低采样率FS = 50;%在50 Hz处运行parentDir=pwd;dataDir='data';%%设置要生成的数据的标签%标签='no_high_five';标签='high_five';a = arduino('com3','uno','图书馆','i2c');对您的Arduino的百分比变更IMU = MPU9250(A);buffer_length_sec = 2;要存储在缓冲区中的数据的秒数Accel = Zeros(Boary(Buffer_Length_sec * FS)+ 1,3);%初始化缓冲t=0:1/fs:(缓冲区长度秒(结束));%的时间向量子图(2,1,1)plot_accel = plot(t,Accel);建立加速图轴([0 buffer_length_sec -50 50]);子图(2,1,2)plot_scale =图像(零(224,224,3));%设置标度图对于j = 1:100收集100张图片%提示用户准备好录制下一个五个H =输入(“准备就绪时的命中:”);抽搐%启动计时器last_read_time = 0;我= 0;%运行3秒钟While (toc <= 3) current_read_time = toc;If (current_read_time - last_read_time) >= 1/fs I = I + 1;Accel (1:end-1,:) = Accel (2:end,:);缓冲区中的%换档值Accel(结束,:)= ReadAcceleration(IMU);plot_accel(1).ydata = Accel(:,1);plot_accel(2).ydata = Accel(:,2);plot_accel(3).ydata = Accel(:,3);%每3个样本运行标度图如果mod(i,3)== 0 fb = cwtfilterbank('signallength',length(t),'samplingfricquency',fs,...'voicesperoctave',12);sig = Accel(:,1);[cfs,〜] = wt(fb,sig);cfs_abs = abs(cfs);Accel_i = IMResize(CFS_ABS / 8,[224 224]);fb = cwtfilterbank('signallength',length(t),'samplingfrice',fs,...'voicesperoctave',12);SIG = ACCEL(:,2);[cfs,〜] = wt(fb,sig);cfs_abs = abs(cfs);Accel_i(:,:,2)= Imresize(CFS_ABS / 8,[224 224]); fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs, ... 'VoicesPerOctave', 12); sig = accel(:, 3); [cfs, ~] = wt(fb, sig); cfs_abs = abs(cfs); accel_i(:, :, 3) = imresize(cfs_abs/8, [224 224]); if~(isempty(accel_i(accel_i>1))) accel_i(accel_i>1) = 1; end plot_scale.CData = accel_i; end last_read_time = current_read_time; end end%将图像保存到数据文件夹imageRoot=fullfile(parentDir、dataDir);imgLoc=fullfile(imageRoot,char(标签));imFileName=strcat(char(labels),“uu”,num2str(j),“.jpg”);imwrite(plot_scale.CData,fullfile(imgLoc,imFileName),“JPEG”);结束
运行脚本后,我手动通过我的培训数据并删除了我认为将腐败培训的图像。这些是高五个不是在帧的中间或图像中间的图像,其中我知道我的差劲五个动作。在下面的GIF中,我删除了高五个图像49,因为它不在框架的中心。
转移学习和googlenet
在其适当的文件夹中包含所有培训数据,下一步是设置网络。对于这部分,我与Matlab示例一起进行基于小波分析和深度学习的时间序列分类,除了除了通过matlab脚本来运行所有内容,我发现使用的是更容易使用它的网络深层网络设计师应用程序。
我从预先训练过的GoogLeNet开始,利用这个网络的所有知识来识别图像中的物体。GoogLeNet接受的训练是在图像中识别鱼和热狗之类的东西——显然不是我要找的东西——但这正是迁移学习有用的地方。通过迁移学习,我可以保持现有网络的大部分,只替换网络末端的两层,这两层将那些通用特征结合到我正在寻找的特定模式中。当我重新训练这个网络时,基本上只需要训练这两层,这就是为什么迁移学习的训练要快得多。
如果你想确切地知道我是如何替换图层的,以及我使用了什么样的训练参数,我建议你跟随我使用的MATLAB示例,或者观看技术讲座。然而,这里还是一个很好的地方,可以让你尝试一些不同的东西。您可以尝试从一个不同的预训练网络开始,如SqueezeNet,也可以在GoogLeNet中替换更多层,或者更改训练参数。这里有很多选择,我认为偏离我所做的可以帮助你对所有这些变量如何影响结果产生一些直觉。
培训网络
随着网络准备去的,深度网络设计师应用程序的培训非常简单。在“数据”选项卡中,我通过选择我保存一组高五,没有高五图像的文件夹导入培训数据。我还留出20%的图像用于在培训过程中用于验证。
然后在训练选项卡,我设置我的训练选项。这里,我使用的选项和我在MATLAB例子中使用的是一样的,但是,我再次鼓励你们使用这些值看看它们是如何影响结果的。
培训在我的单级CPU上花了4分钟,达到了大约97%的验证准确性。几个小时的工作不太糟糕!
测试High Five计数器
现在我有了一个训练有素的网络,我使用这个函数分类从深度学习工具箱传递每个样本时间尺度图,并让网络返回一个标签。如果返回的标签是high_five,则增加一个计数器。为了避免在整个缓冲区的加速数据中多次计算相同的击掌,我添加了一个超时,它不会计算新的击掌,除非距离上一次击掌至少有2秒。
以下是我用于计算高fives的代码的清理版本。
全关%%更新到培训的网络的名称load trainedGN trainedNetwork = trainedGN;%如果您的计算机无法实时运行此操作,请减少示例%率或注释掉标量图部分FS = 50;%在50 Hz处运行a = arduino('com3','uno','图书馆','i2c');对您的Arduino的百分比变更IMU = MPU9250(A);buffer_length_sec = 2;要存储在缓冲区中的数据的秒数Accel = Zeros(Boary(Buffer_Length_sec * FS)+ 1,3);%初始化缓冲t=0:1/fs:(缓冲区长度秒(结束));%的时间向量%设置图h=数字;h、 位置=[100 900 700];p1=子批次(2,1,1);曲线图加速度=曲线图(t,加速度);绘图加速度(1)。线宽=3;绘图加速度(2)。线宽=3;绘图加速度(3)。线宽=3;p1.FontSize=20;p1.Title.String=‘加速度’;轴线([0T(末端)-5060]);xlabel(‘秒’);ylabel(“加速度,mpss”);网格化;label_string=文本(1.3,45,'无高五');label_string.解释器='none';label_string.FontSize=25;count_string=text(0.1,45,'高位五位计数器:');count_string.解释器='none';count_string.FontSize=15;val_string=text(0.65,45,'0');val_string.解释器='none';val_string.FontSize=15;p2=子批次(2,1,2);缩放加速度=图像(零(2242243));p2.Title.String='scalegram';p2.FontSize=20;末端=0;hfcount=0;抽搐%启动计时器last_read_time = 0;我= 0;%高五计数器运行20秒虽然(toc <= 20)current_read_time = toc;If (current_read_time - last_read_time) >= 1/fs I = I + 1;视线= tela间+ 1;%阅读Accel.Accel (1:end-1,:) = Accel (2:end,:);FIFO缓冲区中的换档值Accel(结束,:)= ReadAcceleration(IMU);plot_accel(1).ydata = Accel(:,1);plot_accel(2).ydata = Accel(:,2);plot_accel(3).ydata = Accel(:,3);%仅在每三个样本中运行缩放标准以保存计算时间如果mod(i,3)=0%缩放图fb = cwtfilterbank('signallength',length(t),'samplingfrice',fs,...'voicesperoctave',12);sig = Accel(:,1);[cfs,〜] = wt(fb,sig);cfs_abs = abs(cfs);Accel_i = IMResize(CFS_ABS / 8,[224 224]);fb = cwtfilterbank('signallength',length(t),'samplingfrice',fs,...'voicesperoctave',12);SIG = ACCEL(:,2);[cfs,〜] = wt(fb,sig);cfs_abs = abs(cfs);Accel_i(:,:,2)= Imresize(CFS_ABS / 8,[224 224]); fb = cwtfilterbank('SignalLength', length(t), 'SamplingFrequency', fs, ... 'VoicesPerOctave', 12); sig = accel(:, 3); [cfs, ~] = wt(fb, sig); cfs_abs = abs(cfs); accel_i(:, :, 3) = imresize(cfs_abs/8, [224 224]);%将像素饱和为1if〜(ActEmpty(Accel_i(Accel_I> 1)))Accel_i(Accel_I> 1)= 1;END SCALE_ACCEL.CDATA = IM2UINT8(ACCEL_I);%分类标量程表(YPred,聚合氯化铝)= (trainedNetwork scale_accel.CData)进行分类;如果strcmp(string(YPred), 'high_five')) label_string. if strcmp(string(YPred), 'high_five')背景色= [1 0 0];label_string。String = "击掌!";%仅当100个样本从上一个high five之后已过时才计算If telapse > 100 hfcount = 1;val_string。字符串(字符串(hfcount);telapse = 0;其他label_string结束。背景色= [1 1 1];label_string。String = "No High Five";端端端端
在这里它正在行动!
- 类别:
- 深度学习
注释
要发表评论,请点击这里登录您的MathWorks帐户或创建新的。