Skip to content

Commit 6b3b33c

Browse files
author
Abhishek Nagaraja
committed

File tree

7 files changed

+1282
-0
lines changed

7 files changed

+1282
-0
lines changed

Chapter 7/DS_input.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import os.path
2+
import glob
3+
import tensorflow as tf
4+
5+
# Global constants describing the dataset
6+
# Note this definition must match the ALPHABET chosen in
7+
# preprocess_Librispeech.py
8+
ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ' " # for LibriSpeech
9+
NUM_CLASSES = len(ALPHABET) + 1 # Additional class for blank character
10+
NUM_PER_EPOCH_FOR_TRAIN = 28535
11+
NUM_PER_EPOCH_FOR_EVAL = 2703
12+
NUM_PER_EPOCH_FOR_TEST = 2620
13+
14+
15+
def _generate_feats_and_label_batch(filename_queue, batch_size):
16+
"""Construct a queued batch of spectral features and transcriptions.
17+
18+
Args:
19+
filename_queue: queue of filenames to read data from.
20+
batch_size: Number of utterances per batch.
21+
22+
Returns:
23+
feats: mfccs. 4D tensor of [batch_size, height, width, 3] size.
24+
labels: transcripts. List of length batch_size.
25+
seq_lens: Sequence Lengths. List of length batch_size.
26+
"""
27+
28+
# Define how to parse the example
29+
reader = tf.TFRecordReader()
30+
_, serialized_example = reader.read(filename_queue)
31+
context_features = {
32+
"seq_len": tf.FixedLenFeature([], dtype=tf.int64),
33+
"labels": tf.VarLenFeature(dtype=tf.int64)
34+
}
35+
sequence_features = {
36+
# mfcc features are 13 dimensional
37+
"feats": tf.FixedLenSequenceFeature([13, ], dtype=tf.float32)
38+
}
39+
40+
# Parse the example (returns a dictionary of tensors)
41+
context_parsed, sequence_parsed = tf.parse_single_sequence_example(
42+
serialized=serialized_example,
43+
context_features=context_features,
44+
sequence_features=sequence_features
45+
)
46+
47+
# Generate a batch worth of examples after bucketing
48+
seq_len, (feats, labels) = tf.contrib.training.bucket_by_sequence_length(
49+
input_length=tf.cast(context_parsed['seq_len'], tf.int32),
50+
tensors=[sequence_parsed['feats'], context_parsed['labels']],
51+
batch_size=batch_size,
52+
bucket_boundaries=list(range(100, 1900, 100)),
53+
allow_smaller_final_batch=True,
54+
num_threads=16,
55+
dynamic_pad=True)
56+
57+
return feats, tf.cast(labels, tf.int32), seq_len
58+
59+
60+
def inputs(eval_data, data_dir, batch_size, shuffle=False):
61+
"""Construct input for fordspeech evaluation using the Reader ops.
62+
63+
Args:
64+
eval_data: bool, indicating if one should use the train or eval data set.
65+
data_dir: Path to the fordspeech data directory.
66+
batch_size: Number of images per batch.
67+
68+
Returns:
69+
images: Images. 4D tensor of
70+
[batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
71+
labels: Labels. 1D tensor of [batch_size] size.
72+
"""
73+
if eval_data == 'train':
74+
num_files = len(glob.glob(os.path.join(data_dir,
75+
'train*/*.tfrecords')))
76+
filenames = [os.path.join(data_dir, 'train-clean-100/train_' +
77+
str(i) + '.tfrecords')
78+
for i in range(1, num_files+1)]
79+
elif eval_data == 'val':
80+
filenames = glob.glob(os.path.join(data_dir, 'dev*/*.tfrecords'))
81+
82+
elif eval_data == 'test':
83+
filenames = glob.glob(os.path.join(data_dir, 'test*/*.tfrecords'))
84+
85+
for file in filenames:
86+
if not tf.gfile.Exists(file):
87+
raise ValueError('Failed to find file: ' + file)
88+
89+
# Create a queue that produces the filenames to read.
90+
filename_queue = tf.train.string_input_producer(filenames, shuffle=shuffle)
91+
92+
# Generate a batch of images and labels by building up a queue of examples.
93+
return _generate_feats_and_label_batch(filename_queue, batch_size)

Chapter 7/DS_test.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import json
2+
import os
3+
import math
4+
import time
5+
import argparse
6+
from datetime import datetime
7+
import deepSpeech
8+
import numpy as np
9+
import tensorflow as tf
10+
from Levenshtein import distance
11+
12+
# Note this definition must match the ALPHABET chosen in
13+
# preprocess_Librispeech.py
14+
ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ' "
15+
IX_TO_CHAR = {i: ch for (i, ch) in enumerate(ALPHABET)}
16+
17+
18+
def parse_args():
19+
""" Parses command line arguments."""
20+
parser = argparse.ArgumentParser()
21+
parser.add_argument('--eval_dir', type=str,
22+
default='../models/librispeech/eval',
23+
help='Directory to write event logs')
24+
parser.add_argument('--checkpoint_dir', type=str,
25+
default='../models/librispeech/train',
26+
help='Directory where to read model checkpoints.')
27+
parser.add_argument('--eval_data', type=str, default='val',
28+
help="Either 'test' or 'val' or 'train' ")
29+
parser.add_argument('--batch_size', type=int, default=32,
30+
help='Number of feats to process in a batch')
31+
parser.add_argument('--eval_interval_secs', type=int, default=60 * 5,
32+
help='How often to run the eval')
33+
parser.add_argument('--data_dir', type=str,
34+
default='../data/librispeech/processed/',
35+
help='Path to the deepSpeech data directory')
36+
parser.add_argument('--run_once', type=bool, default=False,
37+
help='Whether to run eval only once')
38+
args = parser.parse_args()
39+
40+
# Read saved parameters from file
41+
param_file = os.path.join(args.checkpoint_dir,
42+
'deepSpeech_parameters.json')
43+
with open(param_file, 'r') as file:
44+
params = json.load(file)
45+
# Read network architecture parameters from
46+
# previously saved parameter file.
47+
args.num_hidden = params['num_hidden']
48+
args.num_rnn_layers = params['num_rnn_layers']
49+
args.rnn_type = params['rnn_type']
50+
args.num_filters = params['num_filters']
51+
args.use_fp16 = params['use_fp16']
52+
args.temporal_stride = params['temporal_stride']
53+
args.moving_avg_decay = params['moving_avg_decay']
54+
return args
55+
56+
57+
def sparse_to_labels(sparse_matrix):
58+
""" Convert index based transcripts to strings"""
59+
results = ['']*sparse_matrix.dense_shape[0]
60+
for i, val in enumerate(sparse_matrix.values.tolist()):
61+
results[sparse_matrix.indices[i, 0]] += IX_TO_CHAR[val]
62+
return results
63+
64+
65+
def initialize_from_checkpoint(sess, saver):
66+
""" Initialize variables on the graph"""
67+
68+
# Initialise variables from a checkpoint file, if provided.
69+
ckpt = tf.train.get_checkpoint_state(ARGS.checkpoint_dir)
70+
if ckpt and ckpt.model_checkpoint_path:
71+
# Restores from checkpoint
72+
saver.restore(sess, ckpt.model_checkpoint_path)
73+
# Assuming model_checkpoint_path looks something like:
74+
# /my-favorite-path/train/model.ckpt-0,
75+
# extract global_step from it.
76+
checkpoint_path = ckpt.model_checkpoint_path
77+
global_step = checkpoint_path.split('/')[-1].split('-')[-1]
78+
return global_step
79+
else:
80+
print('No checkpoint file found')
81+
return
82+
83+
84+
def inference(predictions_op, true_labels_op, display, sess):
85+
""" Perform inference per batch on pre-trained model.
86+
This function performs inference and computes the CER per utterance.
87+
Args:
88+
predictions_op: Prediction op
89+
true_labels_op: True Labels op
90+
display: print sample predictions if True
91+
sess: default session to evaluate the ops.
92+
Returns:
93+
char_err_rate: list of CER per utterance.
94+
"""
95+
char_err_rate = []
96+
# Perform inference of batch worth of data at a time.
97+
[predictions, true_labels] = sess.run([predictions_op,
98+
true_labels_op])
99+
pred_label = sparse_to_labels(predictions[0][0])
100+
actual_label = sparse_to_labels(true_labels)
101+
for (label, pred) in zip(actual_label, pred_label):
102+
char_err_rate.append(distance(label, pred)/len(label))
103+
104+
if display:
105+
# Print sample responses
106+
for i in range(ARGS.batch_size):
107+
print(actual_label[i] + ' vs ' + pred_label[i])
108+
return char_err_rate
109+
110+
111+
def eval_once(saver, summary_writer, predictions_op, summary_op,
112+
true_labels_op):
113+
"""Run Eval once.
114+
115+
Args:
116+
saver: Saver.
117+
summary_writer: Summary writer.
118+
predictions_ops: Op to compute predictions.
119+
summary_op: Summary op.
120+
"""
121+
with tf.Session() as sess:
122+
123+
# Initialize weights from checkpoint file.
124+
global_step = initialize_from_checkpoint(sess, saver)
125+
126+
# Start the queue runners.
127+
coord = tf.train.Coordinator()
128+
try:
129+
threads = []
130+
for queue_runners in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
131+
threads.extend(queue_runners.create_threads(sess, coord=coord,
132+
daemon=True,
133+
start=True))
134+
# Only using a subset of the training data
135+
if ARGS.eval_data == 'train':
136+
num_examples = 2048
137+
138+
elif ARGS.eval_data == 'val':
139+
num_examples = 2703
140+
141+
elif ARGS.eval_data == 'test':
142+
num_examples = 2620
143+
num_iter = int(math.ceil(num_examples / ARGS.batch_size))
144+
step = 0
145+
char_err_rate = []
146+
while step < num_iter and not coord.should_stop():
147+
char_err_rate.append(inference(predictions_op, true_labels_op,
148+
step == 0, sess))
149+
step += 1
150+
151+
# Compute and print mean CER
152+
avg_cer = np.mean(char_err_rate)*100
153+
print('%s: char_err_rate = %.3f %%' % (datetime.now(), avg_cer))
154+
155+
# Add summary ops
156+
summary = tf.Summary()
157+
summary.ParseFromString(sess.run(summary_op))
158+
summary.value.add(tag='char_err_rate', simple_value=avg_cer)
159+
summary_writer.add_summary(summary, global_step)
160+
except Exception as exc: # pylint: disable=broad-except
161+
coord.request_stop(exc)
162+
163+
# Close threads
164+
coord.request_stop()
165+
coord.join(threads, stop_grace_period_secs=10)
166+
167+
168+
def evaluate():
169+
""" Evaluate deepSpeech modelfor a number of steps."""
170+
171+
with tf.Graph().as_default() as graph:
172+
173+
# Get feats and labels for deepSpeech.
174+
feats, labels, seq_lens = deepSpeech.inputs(ARGS.eval_data,
175+
data_dir=ARGS.data_dir,
176+
batch_size=ARGS.batch_size,
177+
use_fp16=ARGS.use_fp16,
178+
shuffle=True)
179+
180+
# Build ops that computes the logits predictions from the
181+
# inference model.
182+
ARGS.keep_prob = 1.0 # Disable dropout during testing.
183+
logits = deepSpeech.inference(feats, seq_lens, ARGS)
184+
185+
# Calculate predictions.
186+
output_log_prob = tf.nn.log_softmax(logits)
187+
decoder = tf.nn.ctc_greedy_decoder
188+
strided_seq_lens = tf.div(seq_lens, ARGS.temporal_stride)
189+
predictions = decoder(output_log_prob, strided_seq_lens)
190+
191+
# Restore the moving average version of the learned variables for eval.
192+
variable_averages = tf.train.ExponentialMovingAverage(
193+
ARGS.moving_avg_decay)
194+
variables_to_restore = variable_averages.variables_to_restore()
195+
saver = tf.train.Saver(variables_to_restore)
196+
197+
# Build the summary operation based on the TF collection of Summaries.
198+
summary_op = tf.summary.merge_all()
199+
summary_writer = tf.summary.FileWriter(ARGS.eval_dir, graph)
200+
201+
while True:
202+
eval_once(saver, summary_writer, predictions, summary_op, labels)
203+
204+
if ARGS.run_once:
205+
break
206+
time.sleep(ARGS.eval_interval_secs)
207+
208+
209+
def main():
210+
"""
211+
Create eval directory and perform inference on checkpointed model.
212+
"""
213+
if tf.gfile.Exists(ARGS.eval_dir):
214+
tf.gfile.DeleteRecursively(ARGS.eval_dir)
215+
tf.gfile.MakeDirs(ARGS.eval_dir)
216+
evaluate()
217+
218+
219+
if __name__ == '__main__':
220+
ARGS = parse_args()
221+
main()

0 commit comments

Comments
 (0)