PyTorch实战:5分钟搞定LSTM/GRU处理变长文本分类(附完整代码)

张开发
2026/6/23 16:21:07 15 分钟阅读
PyTorch实战:5分钟搞定LSTM/GRU处理变长文本分类(附完整代码)
PyTorch实战5分钟搞定LSTM/GRU处理变长文本分类附完整代码自然语言处理中文本分类是最基础也最实用的任务之一。但现实世界的数据往往充满挑战——社交媒体评论长短不一商品评价有的洋洋洒洒数百字有的只有好用两个字。这种变长序列输入给传统神经网络带来了难题如何高效处理不同长度的文本如何避免因填充过多空白导致的资源浪费PyTorch框架提供了优雅的解决方案。通过pack_padded_sequence和pad_packed_sequence这对黄金组合配合LSTM/GRU等循环神经网络我们可以构建能自适应变长输入的文本分类器。本文将手把手带你实现一个工业级可用的变长文本分类模型从数据预处理到模型训练全程只需5分钟代码即可跑通。1. 环境准备与数据预处理首先确保安装最新版PyTorch。建议使用conda创建虚拟环境conda create -n textcls python3.8 conda activate textcls pip install torch torchtext pandas scikit-learn我们以电商评论情感分析为例。假设原始数据如下格式import pandas as pd from sklearn.model_selection import train_test_split # 示例数据 data { text: [质量很好物流也快, 不推荐做工粗糙, 性价比超高会回购, 一般般], label: [1, 0, 1, 0] # 1正面, 0负面 } df pd.DataFrame(data) # 划分训练测试集 train_df, test_df train_test_split(df, test_size0.2, random_state42)关键步骤是构建词汇表和处理变长序列。使用TorchText可以简化这一过程from torchtext.vocab import build_vocab_from_iterator from torchtext.data.utils import get_tokenizer tokenizer get_tokenizer(basic_english) # 简单英文分词中文需自定义 def yield_tokens(data_iter): for text in data_iter: yield tokenizer(text) vocab build_vocab_from_iterator(yield_tokens(train_df[text]), specials[unk, pad]) vocab.set_default_index(vocab[unk]) text_pipeline lambda x: vocab(tokenizer(x)) label_pipeline lambda x: int(x)2. 构建数据加载器PyTorch的DataLoader需要处理变长序列的两个关键点按序列长度排序使用pad_sequence统一长度from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader, Dataset import torch class TextDataset(Dataset): def __init__(self, df): self.texts df[text].tolist() self.labels df[label].tolist() def __len__(self): return len(self.texts) def __getitem__(self, idx): text torch.tensor(text_pipeline(self.texts[idx]), dtypetorch.long) label torch.tensor(label_pipeline(self.labels[idx]), dtypetorch.float) return text, label def collate_batch(batch): text_list, label_list, lengths [], [], [] for (_text, _label) in batch: text_list.append(_text) label_list.append(_label) lengths.append(len(_text)) # 按长度降序排序 sorted_indices torch.argsort(torch.tensor(lengths), descendingTrue) text_list [text_list[i] for i in sorted_indices] label_list torch.tensor([label_list[i] for i in sorted_indices]) lengths torch.tensor([lengths[i] for i in sorted_indices]) # 填充并转置为(batch, seq_len) padded_text pad_sequence(text_list, batch_firstTrue, padding_valuevocab[pad]) return padded_text, label_list, lengths train_loader DataLoader(TextDataset(train_df), batch_size4, shuffleTrue, collate_fncollate_batch) test_loader DataLoader(TextDataset(test_df), batch_size4, shuffleFalse, collate_fncollate_batch)3. 实现LSTM/GRU模型核心在于正确使用pack_padded_sequence和pad_packed_sequenceimport torch.nn as nn class TextClassifier(nn.Module): def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes): super().__init__() self.embedding nn.Embedding(vocab_size, embed_dim, padding_idxvocab[pad]) self.lstm nn.LSTM(embed_dim, hidden_dim, batch_firstTrue, bidirectionalTrue) self.fc nn.Linear(hidden_dim*2, num_classes) # 双向LSTM需*2 def forward(self, text, lengths): # text: (batch, seq_len) embedded self.embedding(text) # (batch, seq_len, embed_dim) # 打包变长序列 packed nn.utils.rnn.pack_padded_sequence( embedded, lengths.cpu(), batch_firstTrue, enforce_sortedTrue) packed_output, (hidden, cell) self.lstm(packed) # 解包训练时不需要只需最后隐藏状态 # output, _ nn.utils.rnn.pad_packed_sequence(packed_output, batch_firstTrue) # 双向LSTM的最终隐藏状态拼接 hidden torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim1) return self.fc(hidden)提示将GRU替换LSTM只需修改一行代码两者接口完全兼容self.gru nn.GRU(embed_dim, hidden_dim, batch_firstTrue, bidirectionalTrue)4. 训练与优化技巧训练循环需要特别注意GPU内存优化device torch.device(cuda if torch.cuda.is_available() else cpu) model TextClassifier(len(vocab), 100, 256, 1).to(device) criterion nn.BCEWithLogitsLoss() optimizer torch.optim.Adam(model.parameters(), lr0.001) def train(model, iterator, optimizer, criterion): model.train() epoch_loss 0 for batch in iterator: text, labels, lengths batch text, labels text.to(device), labels.to(device) optimizer.zero_grad() predictions model(text, lengths).squeeze(1) loss criterion(predictions, labels) loss.backward() # 梯度裁剪防止爆炸 torch.nn.utils.clip_grad_norm_(model.parameters(), 1) optimizer.step() epoch_loss loss.item() return epoch_loss / len(iterator) for epoch in range(10): train_loss train(model, train_loader, optimizer, criterion) print(fEpoch: {epoch1:02} | Train Loss: {train_loss:.3f})5. 实际应用与性能优化部署时可以考虑以下优化策略批次划分策略动态批次将相似长度的样本组合成批次最大令牌数控制每批次总token数而非样本数from torch.utils.data import Sampler class BucketSampler(Sampler): def __init__(self, lengths, batch_size): self.lengths lengths self.batch_size batch_size def __iter__(self): # 按长度分组 indices torch.argsort(torch.tensor(self.lengths), descendingTrue) batches [indices[i:iself.batch_size] for i in range(0, len(indices), self.batch_size)] return iter(batches)混合精度训练可大幅减少显存占用from torch.cuda.amp import GradScaler, autocast scaler GradScaler() with autocast(): predictions model(text, lengths) loss criterion(predictions, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()完整代码已封装为可复用组件适用于各类变长序列分类任务。实际项目中可以进一步添加注意力机制提升长文本性能结合CNN实现Hybrid模型使用BERT等预训练模型初始化词向量变长序列处理是NLP工程师的必备技能。掌握PyTorch这套方案后无论是处理用户评论、医疗记录还是金融时间序列都能游刃有余。

更多文章