pytorch-LSTM()

torch.nn包下实现了LSTM函数,实现LSTM层。多个LSTMcell组合起来是LSTM。

LSTM自动实现了前向传播,不需要自己对序列进行迭代。

LSTM的用到的参数如下:创建LSTM指定如下参数,至少指定前三个参数

input_size:
输入特征维数
hidden_size:
隐层状态的维数
num_layers:
RNN层的个数,在图中竖向的是层数,横向的是seq_len
bias:
隐层状态是否带bias,默认为true
batch_first:
是否输入输出的第一维为batch_size,因为pytorch中batch_size维度默认是第二维度,故此选项可以将 batch_size放在第一维度。如input是(4,1,5),中间的1是batch_size,指定batch_first=True后就是(1,4,5)
dropout:
是否在除最后一个RNN层外的RNN层后面加dropout层
bidirectional:
是否是双向RNN,默认为false,若为true,则num_directions=2,否则为1

为了统一,以后都batch_first=True

LSTM的输入为:LSTM(input,(h0,co))

其中,指定batch_first=True​后,input就是(batch_size,seq_len,input_size)​

(h0,c0)是初始的隐藏层,因为每个LSTM单元其实需要两个隐藏层的。记hidden=(h0,c0)

其中,h0的维度是(num_layers*num_directions, batch_size, hidden_size)

c0维度同h0。注意,即使batch_first=True,这里h0的维度依然是batch_size在第二维度

LSTM的输出为:out,(hn,cn)

其中,out是每一个时间步的最后一个隐藏层h的输出,假如有5个时间步(即seq_len=5),则有5个对应的输出,out的维度是:(batch_size,seq_len,hidden_size)

hidden=(hn,cn),他自己实现了时间步的迭代,每次迭代需要使用上一步的输出和hidden层,最后一步hidden=(hn,cn)记录了最后一各时间步的隐藏层输出,有几层对应几个输出,如果这个是RNN-encoder,则hn,cn就是中间的编码向量。hn的维度是(num_layers*num_directions,batch_size,hidden_size),cn同。

应用LSTM

创建一LSTM:

lstm = torch.nn.LSTM(input_size,hidden_size,num_layers,batch_first=True)

forward使用LSTM层:

out,hidden = lstm(input,hidden)

其中,hidden=(h0,c0)是个tuple

最终得到out,hidden

举例:

import torch
# 实现一个num_layers层的LSTM-RNN
class RNN(torch.nn.Module):
def __init__(self,input_size, hidden_size, num_layers):
super(RNN,self).__init__()
self.input_size = input_size
self.hidden_size=hidden_size
self.num_layers=num_layers
self.lstm = torch.nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True) def forward(self,input):
# input应该为(batch_size,seq_len,input_szie)
self.hidden = self.initHidden(input.size(0))
out,self.hidden = lstm(input,self.hidden)
return out,self.hidden def initHidden(self,batch_size):
if self.lstm.bidirectional:
return (torch.rand(self.num_layers*2,batch_size,self.hidden_size),torch.rand(self.num_layers*2,batch_size,self.hidden_size))
else:
return (torch.rand(self.num_layers,batch_size,self.hidden_size),torch.rand(self.num_layers,batch_size,self.hidden_size)) input_size = 12
hidden_size = 10
num_layers = 3
batch_size = 2
model = RNN(input_size,hidden_size,num_layers)
# input (seq_len, batch, input_size) 包含特征的输入序列,如果设置了batch_first,则batch为第一维
input = torch.rand(2,4,12)
model(input)

【pytorch】pytorch-LSTM的更多相关文章

  1. 【翻译】理解 LSTM 网络

    目录 理解 LSTM 网络 递归神经网络 长期依赖性问题 LSTM 网络 LSTM 的核心想法 逐步解析 LSTM 的流程 长短期记忆的变种 结论 鸣谢 本文翻译自 Christopher Olah ...

  2. 【翻译】理解 LSTM 及其图示

    目录 理解 LSTM 及其图示 本文翻译自 Shi Yan 的博文 Understanding LSTM and its diagrams,原文阐释了作者对 Christopher Olah 博文 U ...

  3. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  4. 【转载】Pytorch tutorial 之Datar Loading and Processing

    前言 上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1.自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Datase ...

  5. 【转载】 pytorch笔记:06)requires_grad和volatile

    原文地址: https://blog.csdn.net/jiangpeng59/article/details/80667335 作者:PJ-Javis 来源:CSDN --------------- ...

  6. 【转载】 Pytorch 细节记录

    原文地址: https://www.cnblogs.com/king-lps/p/8570021.html ---------------------------------------------- ...

  7. 【转载】 pytorch之添加BN

    原文地址: https://blog.csdn.net/weixin_40123108/article/details/83509838 ------------------------------- ...

  8. 【转载】 pytorch自定义网络结构不进行参数初始化会怎样?

    原文地址: https://blog.csdn.net/u011668104/article/details/81670544 ------------------------------------ ...

  9. 【转载】 Pytorch中的学习率调整lr_scheduler,ReduceLROnPlateau

    原文地址: https://blog.csdn.net/happyday_d/article/details/85267561 ------------------------------------ ...

  10. 【转载】 PyTorch学习之六个学习率调整策略

    原文地址: https://blog.csdn.net/shanglianlm/article/details/85143614 ----------------------------------- ...

随机推荐

  1. 图像检索(4):IF-IDF,RootSift,VLAD

    TF-IDF RootSift VLAD TF-IDF TF-IDF是一种用于信息检索的常用加权技术,在文本检索中,用以评估词语对于一个文件数据库中的其中一份文件的重要程度.词语的重要性随着它在文件中 ...

  2. 【我们一起写框架】MVVM的WPF框架(二)—绑定

    MVVM的特点之一是实现数据同步,即,前台页面修改了数据,后台的数据会同步更新. 上一篇我们已经一起编写了框架的基础结构,并且实现了ViewModel反向控制Xaml窗体. 那么现在就要开始实现数据同 ...

  3. DS标签控件文本解析格式

    DS标签控件使用DSL文本渲染引擎,支持DSL引擎代码.目前支持代码如下: <b>粗体</b> 以粗体显示 <i>斜体</i> 以斜体显示 <u& ...

  4. asp.net三层架构增删改查

    数据库 use master if exists (select * from sysdatabases where name='bond') drop database bond create da ...

  5. spring的理解

    看过<fate系列>的博友知道,这是一个七位英灵的圣杯争夺战争.今天主要来谈谈圣杯的容器概念,以便对spring的理解. 圣杯: 圣杯本身是没有实体的,而是将具有魔术回路的存在(人)作为“ ...

  6. 微信小程序 picker 中range-key的坑

    <picker class='fr' bindchange="onChangeBuild" range-key="{{'num'}}" value=&qu ...

  7. glibc溢出提权CVE-2018-1000001总结

    遇到了好几个centos6.5,一直尝试想提权.暂未成功,靶机内核:2.6.32-696.18.7.el6.x86_64. glibc版本:ldd (GNU libc) 2.12 目前编译过程中都发现 ...

  8. arcgis api 3.x for js 热力图优化篇-不依赖地图服务(附源码下载)

    前言 关于本篇功能实现用到的 api 涉及类看不懂的,请参照 esri 官网的 arcgis api 3.x for js:esri 官网 api,里面详细的介绍 arcgis api 3.x 各个类 ...

  9. 【Android】用Cubism 2制作自己的Live2D——软件的安装与破解!

    前言- 上文我们简单的了解了Cubism的情况,但是Cubism 2.X安装好以后如果不进行破解只能使用Free版本,这是我们接受不了的,我们是专业的.是来学习的,怎么能不用Pro版本呢?所以话不多说 ...

  10. C#中++i与i++的区别

    日常编程中经常用到++i与i++,知识点虽然很小,但有时候会犯迷糊,在这里小小的记录一下. ++i 即前递增,顾名思义也就是先自增后传值: 举个栗子 int i=5; int j=++i; 此时i的值 ...