汇总PyTorch踩过的10个坑
发布网友
发布时间:2024-10-04 20:29
我来回答
共1个回答
热心网友
时间:2024-10-05 18:28
PyTorch中的交叉熵:PyTorch的交叉熵nn.CrossEntropyLoss在训练阶段内置了softmax操作,因此只需输入原始数据结果,无需额外添加softmax层。这与TensorFlow的tf.softmax_cross_entropy_with_logits类似。
MSELoss和KLDivLoss:在深度学习中,MSELoss和KLDivLoss是常用的损失函数,PyTorch提供nn.MSELoss和nn.KLDivLoss。在使用这些函数时,目标标签(target)需要为不可训练的值,即requires_grad=False。否则,会引发错误。
在验证和测试阶段取消梯度:在模型验证和测试阶段,我们仅需进行前向传播,无需保存梯度。保存梯度会增加内存使用,有时会导致Out Of Memory错误。因此,在验证和测试阶段,建议使用torch.no_grad()取消梯度。
显式指定训练和测试阶段:在PyTorch中,通过model.train()和model.eval()显式指定模型处于训练或测试阶段。这有助于调整模型中的某些参数,如dropout率和Batch Normalization参数。
关于retain_graph的使用:在反向传播过程中,通过backward()函数即可计算梯度。retain_graph参数控制反向传播后的图是否保留。保留图可以用于在后续迭代中复用计算图,特别是在GAN等场景中。
梯度累积:在GPU内存紧张时,可以利用retain_graph参数进行梯度累积,等同于使用更大的batch_size进行训练。通过保留计算图,可以在不增加GPU内存使用的情况下实现大batch_size训练。
dropout的使用:torch.nn.functional.dropout允许用户指定训练阶段是否进行随机神经元丢弃,与torch.nn.Dropout不同,它不保留状态信息。
torch.index_select:torch.index_select用于根据索引选择张量中的元素。在使用时,需注意索引合法性以及索引张量的类型。
BN层的更新:在训练模式下,BN层的running_mean和running_var会自动更新。这可能与预期有所不同,需要注意。
F.interpolate的问题:torch.nn.functional.interpolate函数用于图像插值,但要求输入图像为batch形式。在进行尺寸变换时,需要正确调整图像的形状以适应函数要求。