Skip to content

Added RNN (LSTM) example #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions examples/rnn-lstm-example/create_input_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
maxSeqLength = 250
batchSize = 24

import numpy as np
import tensorflow as tf
import re

wordsList = np.load('other_data/wordsList.npy').tolist()
wordsList = [word.decode('UTF-8') for word in wordsList]
wordVectors = np.load('other_data/wordVectors.npy')
strip_special_chars = re.compile("[^A-Za-z0-9 ]+")

def cleanSentences(string):
string = string.lower().replace("<br />", " ")
return re.sub(strip_special_chars, "", string.lower())

def getSentenceMatrix(sentence):
arr = np.zeros([batchSize, maxSeqLength])
sentenceMatrix = np.zeros([batchSize,maxSeqLength], dtype='int32')
cleanedSentence = cleanSentences(sentence)
split = cleanedSentence.split()
for indexCounter,word in enumerate(split):
try:
sentenceMatrix[0,indexCounter] = wordsList.index(word)
except ValueError:
sentenceMatrix[0,indexCounter] = 399999
return sentenceMatrix

inputText = "That movie was terrible."
inputMatrix = getSentenceMatrix(inputText)
print inputMatrix
print inputMatrix.shape
np.savetxt("inputMatrixNegative.csv", inputMatrix, delimiter=',', fmt="%i")

secondInputText = "That movie was the best one I have ever seen."
secondInputMatrix = getSentenceMatrix(secondInputText)
print secondInputMatrix
print secondInputMatrix.shape
np.savetxt("inputMatrixPositive.csv", secondInputMatrix, delimiter=',', fmt="%i")
24 changes: 24 additions & 0 deletions examples/rnn-lstm-example/freeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import tensorflow as tf
import os

model_dir = './model/'
output_node_names = 'add'

checkpoint = tf.train.get_checkpoint_state(model_dir)
input_checkpoint = checkpoint.model_checkpoint_path

absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
output_graph = absolute_model_dir + "/frozen_model_lstm.pb"

clear_devices = True

with tf.Session(graph=tf.Graph()) as sess:
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
saver.restore(sess, input_checkpoint)
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), output_node_names.split(","))

with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())

print("%d ops in the final graph." % len(output_graph_def.node))

Empty file.
Empty file.