1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
| class HAN(nn.Module): def __init__(self, vocab_size: int, embedding_size: int, hidden_size_words: int, hidden_size_sent: int, batch_size: int, num_classes: int, device = "cpu", dropout_prop = 0): """ Implementation of a Hierarchical Attention Network (HAN). :param vocab_size: Size of the input vocabulary :param embedding_size: Size of the word embedding :param hidden_size_words: number of hidden units for the word encoder. :param hidden_size_sent: number of hidden units for the sentence encoder. :batch_size: size of the minibatches passed to the HAN. :num_classes: number of output classes in the classification task. """ super(HAN, self).__init__() self._hidden_size_words = hidden_size_words self._hidden_size_sent = hidden_size_sent self._embedding_dim = (vocab_size, embedding_size) self._num_classes = num_classes self._batch_size = batch_size self._dropout_prop = dropout_prop self.embedding = nn.Embedding(vocab_size, embedding_size) self._word_encoder = word_encoder(self._embedding_dim[1], self._hidden_size_words) self._sentence_encoder = sentence_encoder(self._hidden_size_words * 2, self._hidden_size_sent) self._linear1 = nn.Linear(self._hidden_size_sent * 2, self._num_classes) def forward(self, seqs, seq_lens, hid_state_word, hid_state_sent, return_attention_weights = False): """ :param batch_in: list of input documents of size batch_size input document with dim (sentence x seq_length) :param return_attention_weights: if True, return attention weights
:return: tensor of shape (batch_size, num_classes) and, optionally, the attention vectors for the word and sentence encoders. """ batched_sentences = None batch_length = len(seqs) if return_attention_weights: word_weights = [] for i, seqdata in enumerate(zip(seqs,seq_lens)): seq, seq_len = seqdata embedded = self.embedding(seq) x_packed = pack_padded_sequence(embedded, seq_len, batch_first=True, enforce_sorted=False) we_out, hid_state = self._word_encoder(x_packed, hid_state_word) if batched_sentences is None: batched_sentences = we_out else: batched_sentences = torch.cat((batched_sentences, we_out), 0) out_sent, hid_sent = self._sentence_encoder(batched_sentences.permute(1,0,2), hid_state_sent) if return_attention_weights: word_weights.append(hid_state[1].data) if i == batch_length: sentence_weights = hid_sent[1].data out_sent_dropout = F.dropout(out_sent.squeeze(0), p=self._dropout_prop) prediction_out = F.softmax(self._linear1(out_sent_dropout), dim = 1) if return_attention_weights: return(prediction_out, [word_weights, sentence_weights]) else: return(prediction_out)
def init_hidden_sent(self): return Variable(torch.zeros(2, self._batch_size, self._hidden_size_sent))
def init_hidden_word(self): return Variable(torch.zeros(2, self._batch_size, self._hidden_size_words))
|