本文采用的数据集来自于李沐老师的《动手学深度学习》书籍
import torch
from torch import nn
import torch.nn.functional as F
import d2l.torch as d2l
batch_size = 32
num_steps = 35
hidden_size = 256
num_layers = 1
device = 'cuda'
lr = 1
num_epochs = 1000
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
input_size = len(vocab)
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers)
self.linear = nn.Linear(self.hidden_size, self.input_size)
def init_state(self, batch_size, device):
hidden_state = torch.randn(size=(self.num_layers, batch_size, self.hidden_size), device=device) * 0.1
content_state = torch.randn(size=(self.num_layers, batch_size, self.hidden_size), device=device) * 0.1
return (hidden_state, content_state)
def forward(self, inputs, state):
inputs = F.one_hot(inputs.T, self.input_size)
inputs = inputs.to(torch.float32)
# 注意:state是一个元组(hidden_state, content_state)
# Y.shape = (num_steps,batch_size,hidden_size)
Y, state = self.lstm(inputs, state)
output = self.linear(Y.reshape(-1, Y.shape[-1]))
return output, state
def train(net, train_iter, loss_fn, optimizer, num_epochs, batch_size, device):
state = None
for epoch in range(num_epochs):
for x, y in train_iter:
if state is None:
state = net.init_state(batch_size, device=device)
else:
hidden_state, content_state = state
hidden_state.detach_()
content_state.detach_()
x = x.to(device)
y = y.to(device)
y = y.T.reshape(-1)
output, state = net(x, state)
l = loss_fn(output, y)
optimizer.zero_grad()
l.backward()
optimizer.step()
if (epoch + 1) % 100 == 0:
print(l)
def predict(prefix, predict_len, net, vocab, device=device):
prefix_len = len(prefix)
output = []
state = net.init_state(batch_size=1, device=device)
# 模型激活
for i in range(prefix_len - 1):
x = torch.tensor(vocab[prefix[i]], device=device).reshape(1, -1)
_, state = net(x, state)
# 处理最后一个字符
x = torch.tensor(vocab[prefix[-1]], device=device).reshape(1, -1)
y, state = net(x, state)
output.append(vocab.to_tokens(y.argmax()))
# 开始预测
for i in range(predict_len):
x = torch.tensor(vocab[output[-1]], device=device).reshape(1, -1)
y, state = net(x, state)
output.append(vocab.to_tokens(y.argmax()))
# 显示最后结果
return prefix + "".join(output)
net = LSTM(input_size, hidden_size, num_layers)
net.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr)
train(net, train_iter, loss_fn, optimizer, num_epochs, batch_size, device)
prefix = 'time traveller'
predict_len = 100
predict(prefix, predict_len, net, vocab, device=device)
结果如下
‘time travelleryou can show black is white by argument said filby but you willnever convince mepossibly not said the’
因篇幅问题不能全部显示,请点此查看更多更全内容