Skip to content

Commit 9dd757b

Browse files
authored
Merge pull request #25 from anshuman23/dev
Added RNN (LSTM) example
2 parents 6cbdec0 + 23c077a commit 9dd757b

File tree

4 files changed

+63
-0
lines changed

4 files changed

+63
-0
lines changed

Diff for: examples/rnn-lstm-example/create_input_data.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
maxSeqLength = 250
2+
batchSize = 24
3+
4+
import numpy as np
5+
import tensorflow as tf
6+
import re
7+
8+
wordsList = np.load('other_data/wordsList.npy').tolist()
9+
wordsList = [word.decode('UTF-8') for word in wordsList]
10+
wordVectors = np.load('other_data/wordVectors.npy')
11+
strip_special_chars = re.compile("[^A-Za-z0-9 ]+")
12+
13+
def cleanSentences(string):
14+
string = string.lower().replace("<br />", " ")
15+
return re.sub(strip_special_chars, "", string.lower())
16+
17+
def getSentenceMatrix(sentence):
18+
arr = np.zeros([batchSize, maxSeqLength])
19+
sentenceMatrix = np.zeros([batchSize,maxSeqLength], dtype='int32')
20+
cleanedSentence = cleanSentences(sentence)
21+
split = cleanedSentence.split()
22+
for indexCounter,word in enumerate(split):
23+
try:
24+
sentenceMatrix[0,indexCounter] = wordsList.index(word)
25+
except ValueError:
26+
sentenceMatrix[0,indexCounter] = 399999
27+
return sentenceMatrix
28+
29+
inputText = "That movie was terrible."
30+
inputMatrix = getSentenceMatrix(inputText)
31+
print inputMatrix
32+
print inputMatrix.shape
33+
np.savetxt("inputMatrixNegative.csv", inputMatrix, delimiter=',', fmt="%i")
34+
35+
secondInputText = "That movie was the best one I have ever seen."
36+
secondInputMatrix = getSentenceMatrix(secondInputText)
37+
print secondInputMatrix
38+
print secondInputMatrix.shape
39+
np.savetxt("inputMatrixPositive.csv", secondInputMatrix, delimiter=',', fmt="%i")

Diff for: examples/rnn-lstm-example/freeze.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import tensorflow as tf
2+
import os
3+
4+
model_dir = './model/'
5+
output_node_names = 'add'
6+
7+
checkpoint = tf.train.get_checkpoint_state(model_dir)
8+
input_checkpoint = checkpoint.model_checkpoint_path
9+
10+
absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
11+
output_graph = absolute_model_dir + "/frozen_model_lstm.pb"
12+
13+
clear_devices = True
14+
15+
with tf.Session(graph=tf.Graph()) as sess:
16+
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
17+
saver.restore(sess, input_checkpoint)
18+
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), output_node_names.split(","))
19+
20+
with tf.gfile.GFile(output_graph, "wb") as f:
21+
f.write(output_graph_def.SerializeToString())
22+
23+
print("%d ops in the final graph." % len(output_graph_def.node))
24+

Diff for: examples/rnn-lstm-example/model/README.md

Whitespace-only changes.

Diff for: examples/rnn-lstm-example/other_data/README.md

Whitespace-only changes.

0 commit comments

Comments
 (0)