|
| 1 | +import theano.tensor as T |
| 2 | +from agentnet.utils.layers import DictLayer |
| 3 | +from lasagne.init import GlorotUniform |
| 4 | + |
| 5 | +class AttentionLayer(DictLayer): |
| 6 | + def __init__(self, |
| 7 | + input_sequence, |
| 8 | + controller_state, |
| 9 | + num_units, |
| 10 | + mask_input = None, |
| 11 | + nonlinearity = T.tanh, |
| 12 | + weights_nonlinearity=T.nnet.softmax, |
| 13 | + W_enc = GlorotUniform(), |
| 14 | + W_dec = GlorotUniform(), |
| 15 | + W_out = GlorotUniform(), |
| 16 | + ): |
| 17 | + """ |
| 18 | + Implements basic Bahdanau-style attention. Implementation is inspired by tfnn@yandex. |
| 19 | +
|
| 20 | + Kurzgesagt, attention lets network decide which fraction of sequence/image should it view now |
| 21 | + by using small one-layer block that predicts (input_element,controller) -> do i want to see input_element |
| 22 | + for all input_elements. You can read more about it here - http://distill.pub/2016/augmented-rnns/ . |
| 23 | +
|
| 24 | + This layer outputs a dict with keys "attn" and "probs" |
| 25 | + - attn - inputs processed with attention, shape [batch_size, enc_units] |
| 26 | + - probs - probabilities for each activation [batch_size, seq_length] |
| 27 | +
|
| 28 | + This layer assumes input sequence/image/video/whatever to have 1 spatial dimension (see below). |
| 29 | + - rnn/emb format [batch,seq_len,units] works out of the box |
| 30 | + - 1d convolution format [batch,units,seq_len] needs dimshuffle(conv,[0,2,1]) |
| 31 | + - 2d convolution format [batch,units,dim1,dim2] needs two-step procedure |
| 32 | + - step1 = dimshuffle(conv,[0,2,3,1]) |
| 33 | + - step2 = reshape(step1,[-1,dim1*dim2,units]) |
| 34 | + - higher dimensionality follows the same principle as 2d example above |
| 35 | + - reshape and dimshuffle can both be found in lasagne.layers (aliases to ReshapeLayer and DimshuffleLayer) |
| 36 | +
|
| 37 | + When calling get_output, you can pass flag hard_attention=True to replace attention with argmax over logits. |
| 38 | +
|
| 39 | + :param input_sequence: sequence of inputs to be processed with attention |
| 40 | + :type input_sequence: lasagne.layers.Layer with shape [batch,seq_length,units] |
| 41 | +
|
| 42 | + :param conteroller_state: single time-step state of decoder (usually lstm/gru/rnn hid) |
| 43 | + :type controller_state: lasagne.layers.Layer with shape [batch,units] |
| 44 | +
|
| 45 | + :param num_units: number of hidden units in attention intermediate activation |
| 46 | + :type num_units: int |
| 47 | +
|
| 48 | + :param nonlinearity: nonlinearity in attention intermediate activation |
| 49 | + :type nonlinearity: function(x) -> x that works with theano tensors |
| 50 | +
|
| 51 | + :param weights_nonlinearity: nonlinearity that converts logits of shape [batch,seq_length] into attention weights of same shape |
| 52 | + (you can provide softmax with tunable temperature or gumbel-softmax or anything of the sort) |
| 53 | + :type weights_nonlinearity: function(x) -> x that works with theano tensors |
| 54 | +
|
| 55 | +
|
| 56 | + :param mask_input: mask for input_sequence (like other lasagne masks). Default is no mask |
| 57 | + :type mask_input: lasagne.layers.Layer with shape [batch,seq_length] |
| 58 | +
|
| 59 | + Other params can be theano shared variable, expression, numpy array or callable. |
| 60 | + Initial value, expression or initializer for the weights. |
| 61 | + These should be a matrix with shape ``(num_inputs, num_units)``. |
| 62 | + See :func:`lasagne.utils.create_param` for more information. |
| 63 | +
|
| 64 | + The roles of those params are: |
| 65 | + W_enc - weights from encoder (each state) to hidden layer |
| 66 | + W_dec - weights from decoder (each state) to hidden layer |
| 67 | + W_out - hidden to logit weights |
| 68 | + No logit biases are introduces because softmax is invariant to adding bias to each logit |
| 69 | +
|
| 70 | + """ |
| 71 | + assert len(input_sequence.output_shape)==3,"input_sequence must be a 3-dimensional (batch,time,units)" |
| 72 | + assert len(controller_state.output_shape)==2,"controller_state must be a 2-dimensional for single tick (batch,units)" |
| 73 | + assert mask_input is None or len(mask_input.output_shape)==2,"mask_input must be 2-dimensional (batch,time) or None" |
| 74 | + |
| 75 | + batch_size,seq_len,enc_units = input_sequence.output_shape |
| 76 | + dec_units = controller_state.output_shape[-1] |
| 77 | + |
| 78 | + incomings = [input_sequence,controller_state] |
| 79 | + if mask_input is not None: |
| 80 | + incomings.append(mask_input) |
| 81 | + |
| 82 | + output_shapes = {'attn':(batch_size,enc_units), |
| 83 | + 'probs':(batch_size,seq_len)} |
| 84 | + |
| 85 | + super(AttentionLayer,self).__init__(incomings,output_shapes) |
| 86 | + |
| 87 | + |
| 88 | + |
| 89 | + self.W_enc = self.add_param(W_enc,(enc_units,num_units),name='enc_to_hid') |
| 90 | + self.W_dec = self.add_param(W_dec,(dec_units,num_units),name='dec_to_hid') |
| 91 | + self.W_out = self.add_param(W_out,(num_units,1),name='hid_to_logit') |
| 92 | + self.nonlinearity = nonlinearity |
| 93 | + self.weights_nonlinearity = weights_nonlinearity |
| 94 | + |
| 95 | + def get_output_for(self, inputs, hard_attention=False , **kwargs): |
| 96 | + """ |
| 97 | + :param inputs: should consist of (enc_seq, dec) or (enc_seq, dec, inp_mask) |
| 98 | + Shapes are |
| 99 | + enc_seq: [batch_size, seq_length, enc_units] |
| 100 | + dec: [batch_size, dec_units] |
| 101 | + inp_mask: [batch_size,seq_length] if any |
| 102 | +
|
| 103 | + --------------------------------- |
| 104 | + :returns: dict with keys "attn" and "probs" |
| 105 | + - attn - inputs processed with attention, shape [batch_size, enc_size] |
| 106 | + - probs - probabilities for each activation [batch_size, ninp] |
| 107 | + """ |
| 108 | + assert len(inputs) in (2,3),"inputs should be (enc_seq, dec) or (enc_seq, dec, inp_mask)" |
| 109 | + mask_provided = len(inputs)==3 |
| 110 | + |
| 111 | + #parse inputs |
| 112 | + enc_seq, dec = inputs[:2] |
| 113 | + if mask_provided: |
| 114 | + mask = inputs[-1] |
| 115 | + |
| 116 | + #Hidden layer activations, shape [batch,seq_len,hid_units] |
| 117 | + hid = self.nonlinearity( |
| 118 | + enc_seq.dot(self.W_enc) +\ |
| 119 | + dec.dot(self.W_dec)[:,None,:] |
| 120 | + ) |
| 121 | + |
| 122 | + |
| 123 | + #Logits from hidden. Mask implementation from tfnn |
| 124 | + |
| 125 | + logits = hid.dot(self.W_out)[:,:,0] # [batch_size,seq_len] |
| 126 | + |
| 127 | + if mask_provided: # substract large number from mask=0 time-steps |
| 128 | + logits -= (1 - mask) * 1000 # (written to match tfnn implementation) |
| 129 | + |
| 130 | + if not hard_attention: |
| 131 | + #regular soft attention, use softmax |
| 132 | + probs = self.weights_nonlinearity(logits) # [batch_size,seq_len] |
| 133 | + |
| 134 | + # Compose attention. |
| 135 | + attn = T.sum(probs[:,:,None] * enc_seq, axis=1) |
| 136 | + |
| 137 | + return {'attn':attn, 'probs':probs} |
| 138 | + |
| 139 | + else: #hard_attention |
| 140 | + |
| 141 | + #use argmax over logits |
| 142 | + max_i = logits.argmax(axis=-1) |
| 143 | + batch_size = enc_seq.shape[0] |
| 144 | + attn = enc_seq[T.arange(batch_size),max_i] |
| 145 | + |
| 146 | + # one-hot probabilities |
| 147 | + one_hot = T.extra_ops.to_one_hot(max_i,logits.shape[1]) |
| 148 | + |
| 149 | + return {'attn': attn, 'probs': one_hot } |
| 150 | + |
| 151 | + |
0 commit comments