import torch
class AE_LSTM(nn.Module):
def init(self, in_channel=1):
super(AE_LSTM, self).init()
self.encoder = nn.Sequential( nn.LSTM(in_channel, 16, batch_first=True, bidirectional=True), nn.InstanceNorm1d(16),
nn.ReLU(),
nn.LSTM(16, 32, batch_first=True, bidirectional=True),
nn.InstanceNorm1d(32),
nn.ReLU(),
nn.LSTM(32, 32, batch_first=True, bidirectional=True),
nn.InstanceNorm1d(32),
nn.ReLU(),
# nn.MaxPool1d(kernel_size=8),
nn.AvgPool1d(kernel_size=8),
nn.Conv1d(32, 32, kernel_size=1, stride=1, padding=0) )
然后运行就会报错:
init() got an unexpected keyword argument ‘in_channel’
修改BUG: 这里是犯了个低级错误,没有注意到__init__()
import torch
class AE_LSTM(nn.Module):
def __init__(self, in_channel=1):
super(AE_LSTM, self).__init__()
self.encoder = nn.Sequential( nn.LSTM(in_channel, 16, batch_first=True, bidirectional=True), nn.InstanceNorm1d(16),
nn.ReLU(),
nn.LSTM(16, 32, batch_first=True, bidirectional=True),
nn.InstanceNorm1d(32),
nn.ReLU(),
nn.LSTM(32, 32, batch_first=True, bidirectional=True),
nn.InstanceNorm1d(32),
nn.ReLU(),
# nn.MaxPool1d(kernel_size=8),
nn.AvgPool1d(kernel_size=8),
nn.Conv1d(32, 32, kernel_size=1, stride=1, padding=0) )
就不会报这个错了
参考文章
发表评论