您好,欢迎来到好走旅游网。
搜索
您的当前位置:首页使用pytorch实现LSTM模型

使用pytorch实现LSTM模型

来源:好走旅游网

本文采用的数据集来自于李沐老师的《动手学深度学习》书籍

1.导入相关依赖包

import torch
from torch import nn
import torch.nn.functional as F
import d2l.torch as d2l

2.加载数据集并设置模型超参数

batch_size = 32
num_steps = 35
hidden_size = 256
num_layers = 1
device = 'cuda'
lr = 1
num_epochs = 1000
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
input_size = len(vocab)

3.创建模型

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers)
        self.linear = nn.Linear(self.hidden_size, self.input_size)

    def init_state(self, batch_size, device):
        hidden_state = torch.randn(size=(self.num_layers, batch_size, self.hidden_size), device=device) * 0.1
        content_state = torch.randn(size=(self.num_layers, batch_size, self.hidden_size), device=device) * 0.1
        return (hidden_state, content_state)

    def forward(self, inputs, state):
        inputs = F.one_hot(inputs.T, self.input_size)
        inputs = inputs.to(torch.float32)
        # 注意:state是一个元组(hidden_state, content_state)
        # Y.shape = (num_steps,batch_size,hidden_size)
        Y, state = self.lstm(inputs, state)
        output = self.linear(Y.reshape(-1, Y.shape[-1]))
        return output, state

4.定义训练函数

def train(net, train_iter, loss_fn, optimizer, num_epochs, batch_size, device):
    state = None
    for epoch in range(num_epochs):
        for x, y in train_iter:
            if state is None:
                state = net.init_state(batch_size, device=device)
            else:
                hidden_state, content_state = state
                hidden_state.detach_()
                content_state.detach_()
            x = x.to(device)
            y = y.to(device)
            y = y.T.reshape(-1)
            output, state = net(x, state)
            l = loss_fn(output, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            if (epoch + 1) % 100 == 0:
                print(l)

5.定义预测函数

def predict(prefix, predict_len, net, vocab, device=device):
    prefix_len = len(prefix)
    output = []
    state = net.init_state(batch_size=1, device=device)
    # 模型激活
    for i in range(prefix_len - 1):
        x = torch.tensor(vocab[prefix[i]], device=device).reshape(1, -1)
        _, state = net(x, state)
    # 处理最后一个字符
    x = torch.tensor(vocab[prefix[-1]], device=device).reshape(1, -1)
    y, state = net(x, state)
    output.append(vocab.to_tokens(y.argmax()))
    # 开始预测
    for i in range(predict_len):
        x = torch.tensor(vocab[output[-1]], device=device).reshape(1, -1)
        y, state = net(x, state)
        output.append(vocab.to_tokens(y.argmax()))
    # 显示最后结果
    return prefix + "".join(output)

6.训练

net = LSTM(input_size, hidden_size, num_layers)
net.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr)
train(net, train_iter, loss_fn, optimizer, num_epochs, batch_size, device)

7.预测

prefix = 'time traveller'
predict_len = 100
predict(prefix, predict_len, net, vocab, device=device)

结果如下

‘time travelleryou can show black is white by argument said filby but you willnever convince mepossibly not said the’

因篇幅问题不能全部显示,请点此查看更多更全内容

Copyright © 2019- haog.cn 版权所有

违法及侵权请联系:TEL:199 1889 7713 E-MAIL:2724546146@qq.com

本站由北京市万商天勤律师事务所王兴未律师提供法律服务