解决pytorch 模型复制的一些问题
直接使用会出现当更新model2时,model1的权重也会更新,这和自己的初始目的不同。经评论指出可以使用:来实现深拷贝,手上没有pytorch环境,具体还没测试过,谁测试过可以和我说下有没有用。pytorch 中有帮助我们制作数据生成器的模块,其中有 Dataset、TensorDataset、DataLoader 等类可以来创建数据入口。之前在 tensorflow 中可以用 dataset.from_generator() 的形式,pytorch 中也类似,目前我了解到的有两种方法可以实现。第一种就继承 pytorch 定义的 dataset,改写其中的方法即可。第二种就是转换,先把我们准备好的数据转化成 pytorch 的变量,然后传入 TensorDataset,再构造 DataLoader。损失函数定义模型定义完之后,意味着给出输入,就可以得到输出的结果。那么就来比较 outputs 和 targets 之间的区别,那么就需要用到损失函数来描述。
用户评论