|
| 1 | +import json |
| 2 | +import os |
| 3 | +import math |
| 4 | +import time |
| 5 | +import argparse |
| 6 | +from datetime import datetime |
| 7 | +import deepSpeech |
| 8 | +import numpy as np |
| 9 | +import tensorflow as tf |
| 10 | +from Levenshtein import distance |
| 11 | + |
| 12 | +# Note this definition must match the ALPHABET chosen in |
| 13 | +# preprocess_Librispeech.py |
| 14 | +ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ' " |
| 15 | +IX_TO_CHAR = {i: ch for (i, ch) in enumerate(ALPHABET)} |
| 16 | + |
| 17 | + |
| 18 | +def parse_args(): |
| 19 | + """ Parses command line arguments.""" |
| 20 | + parser = argparse.ArgumentParser() |
| 21 | + parser.add_argument('--eval_dir', type=str, |
| 22 | + default='../models/librispeech/eval', |
| 23 | + help='Directory to write event logs') |
| 24 | + parser.add_argument('--checkpoint_dir', type=str, |
| 25 | + default='../models/librispeech/train', |
| 26 | + help='Directory where to read model checkpoints.') |
| 27 | + parser.add_argument('--eval_data', type=str, default='val', |
| 28 | + help="Either 'test' or 'val' or 'train' ") |
| 29 | + parser.add_argument('--batch_size', type=int, default=32, |
| 30 | + help='Number of feats to process in a batch') |
| 31 | + parser.add_argument('--eval_interval_secs', type=int, default=60 * 5, |
| 32 | + help='How often to run the eval') |
| 33 | + parser.add_argument('--data_dir', type=str, |
| 34 | + default='../data/librispeech/processed/', |
| 35 | + help='Path to the deepSpeech data directory') |
| 36 | + parser.add_argument('--run_once', type=bool, default=False, |
| 37 | + help='Whether to run eval only once') |
| 38 | + args = parser.parse_args() |
| 39 | + |
| 40 | + # Read saved parameters from file |
| 41 | + param_file = os.path.join(args.checkpoint_dir, |
| 42 | + 'deepSpeech_parameters.json') |
| 43 | + with open(param_file, 'r') as file: |
| 44 | + params = json.load(file) |
| 45 | + # Read network architecture parameters from |
| 46 | + # previously saved parameter file. |
| 47 | + args.num_hidden = params['num_hidden'] |
| 48 | + args.num_rnn_layers = params['num_rnn_layers'] |
| 49 | + args.rnn_type = params['rnn_type'] |
| 50 | + args.num_filters = params['num_filters'] |
| 51 | + args.use_fp16 = params['use_fp16'] |
| 52 | + args.temporal_stride = params['temporal_stride'] |
| 53 | + args.moving_avg_decay = params['moving_avg_decay'] |
| 54 | + return args |
| 55 | + |
| 56 | + |
| 57 | +def sparse_to_labels(sparse_matrix): |
| 58 | + """ Convert index based transcripts to strings""" |
| 59 | + results = ['']*sparse_matrix.dense_shape[0] |
| 60 | + for i, val in enumerate(sparse_matrix.values.tolist()): |
| 61 | + results[sparse_matrix.indices[i, 0]] += IX_TO_CHAR[val] |
| 62 | + return results |
| 63 | + |
| 64 | + |
| 65 | +def initialize_from_checkpoint(sess, saver): |
| 66 | + """ Initialize variables on the graph""" |
| 67 | + |
| 68 | + # Initialise variables from a checkpoint file, if provided. |
| 69 | + ckpt = tf.train.get_checkpoint_state(ARGS.checkpoint_dir) |
| 70 | + if ckpt and ckpt.model_checkpoint_path: |
| 71 | + # Restores from checkpoint |
| 72 | + saver.restore(sess, ckpt.model_checkpoint_path) |
| 73 | + # Assuming model_checkpoint_path looks something like: |
| 74 | + # /my-favorite-path/train/model.ckpt-0, |
| 75 | + # extract global_step from it. |
| 76 | + checkpoint_path = ckpt.model_checkpoint_path |
| 77 | + global_step = checkpoint_path.split('/')[-1].split('-')[-1] |
| 78 | + return global_step |
| 79 | + else: |
| 80 | + print('No checkpoint file found') |
| 81 | + return |
| 82 | + |
| 83 | + |
| 84 | +def inference(predictions_op, true_labels_op, display, sess): |
| 85 | + """ Perform inference per batch on pre-trained model. |
| 86 | + This function performs inference and computes the CER per utterance. |
| 87 | + Args: |
| 88 | + predictions_op: Prediction op |
| 89 | + true_labels_op: True Labels op |
| 90 | + display: print sample predictions if True |
| 91 | + sess: default session to evaluate the ops. |
| 92 | + Returns: |
| 93 | + char_err_rate: list of CER per utterance. |
| 94 | + """ |
| 95 | + char_err_rate = [] |
| 96 | + # Perform inference of batch worth of data at a time. |
| 97 | + [predictions, true_labels] = sess.run([predictions_op, |
| 98 | + true_labels_op]) |
| 99 | + pred_label = sparse_to_labels(predictions[0][0]) |
| 100 | + actual_label = sparse_to_labels(true_labels) |
| 101 | + for (label, pred) in zip(actual_label, pred_label): |
| 102 | + char_err_rate.append(distance(label, pred)/len(label)) |
| 103 | + |
| 104 | + if display: |
| 105 | + # Print sample responses |
| 106 | + for i in range(ARGS.batch_size): |
| 107 | + print(actual_label[i] + ' vs ' + pred_label[i]) |
| 108 | + return char_err_rate |
| 109 | + |
| 110 | + |
| 111 | +def eval_once(saver, summary_writer, predictions_op, summary_op, |
| 112 | + true_labels_op): |
| 113 | + """Run Eval once. |
| 114 | +
|
| 115 | + Args: |
| 116 | + saver: Saver. |
| 117 | + summary_writer: Summary writer. |
| 118 | + predictions_ops: Op to compute predictions. |
| 119 | + summary_op: Summary op. |
| 120 | + """ |
| 121 | + with tf.Session() as sess: |
| 122 | + |
| 123 | + # Initialize weights from checkpoint file. |
| 124 | + global_step = initialize_from_checkpoint(sess, saver) |
| 125 | + |
| 126 | + # Start the queue runners. |
| 127 | + coord = tf.train.Coordinator() |
| 128 | + try: |
| 129 | + threads = [] |
| 130 | + for queue_runners in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): |
| 131 | + threads.extend(queue_runners.create_threads(sess, coord=coord, |
| 132 | + daemon=True, |
| 133 | + start=True)) |
| 134 | + # Only using a subset of the training data |
| 135 | + if ARGS.eval_data == 'train': |
| 136 | + num_examples = 2048 |
| 137 | + |
| 138 | + elif ARGS.eval_data == 'val': |
| 139 | + num_examples = 2703 |
| 140 | + |
| 141 | + elif ARGS.eval_data == 'test': |
| 142 | + num_examples = 2620 |
| 143 | + num_iter = int(math.ceil(num_examples / ARGS.batch_size)) |
| 144 | + step = 0 |
| 145 | + char_err_rate = [] |
| 146 | + while step < num_iter and not coord.should_stop(): |
| 147 | + char_err_rate.append(inference(predictions_op, true_labels_op, |
| 148 | + step == 0, sess)) |
| 149 | + step += 1 |
| 150 | + |
| 151 | + # Compute and print mean CER |
| 152 | + avg_cer = np.mean(char_err_rate)*100 |
| 153 | + print('%s: char_err_rate = %.3f %%' % (datetime.now(), avg_cer)) |
| 154 | + |
| 155 | + # Add summary ops |
| 156 | + summary = tf.Summary() |
| 157 | + summary.ParseFromString(sess.run(summary_op)) |
| 158 | + summary.value.add(tag='char_err_rate', simple_value=avg_cer) |
| 159 | + summary_writer.add_summary(summary, global_step) |
| 160 | + except Exception as exc: # pylint: disable=broad-except |
| 161 | + coord.request_stop(exc) |
| 162 | + |
| 163 | + # Close threads |
| 164 | + coord.request_stop() |
| 165 | + coord.join(threads, stop_grace_period_secs=10) |
| 166 | + |
| 167 | + |
| 168 | +def evaluate(): |
| 169 | + """ Evaluate deepSpeech modelfor a number of steps.""" |
| 170 | + |
| 171 | + with tf.Graph().as_default() as graph: |
| 172 | + |
| 173 | + # Get feats and labels for deepSpeech. |
| 174 | + feats, labels, seq_lens = deepSpeech.inputs(ARGS.eval_data, |
| 175 | + data_dir=ARGS.data_dir, |
| 176 | + batch_size=ARGS.batch_size, |
| 177 | + use_fp16=ARGS.use_fp16, |
| 178 | + shuffle=True) |
| 179 | + |
| 180 | + # Build ops that computes the logits predictions from the |
| 181 | + # inference model. |
| 182 | + ARGS.keep_prob = 1.0 # Disable dropout during testing. |
| 183 | + logits = deepSpeech.inference(feats, seq_lens, ARGS) |
| 184 | + |
| 185 | + # Calculate predictions. |
| 186 | + output_log_prob = tf.nn.log_softmax(logits) |
| 187 | + decoder = tf.nn.ctc_greedy_decoder |
| 188 | + strided_seq_lens = tf.div(seq_lens, ARGS.temporal_stride) |
| 189 | + predictions = decoder(output_log_prob, strided_seq_lens) |
| 190 | + |
| 191 | + # Restore the moving average version of the learned variables for eval. |
| 192 | + variable_averages = tf.train.ExponentialMovingAverage( |
| 193 | + ARGS.moving_avg_decay) |
| 194 | + variables_to_restore = variable_averages.variables_to_restore() |
| 195 | + saver = tf.train.Saver(variables_to_restore) |
| 196 | + |
| 197 | + # Build the summary operation based on the TF collection of Summaries. |
| 198 | + summary_op = tf.summary.merge_all() |
| 199 | + summary_writer = tf.summary.FileWriter(ARGS.eval_dir, graph) |
| 200 | + |
| 201 | + while True: |
| 202 | + eval_once(saver, summary_writer, predictions, summary_op, labels) |
| 203 | + |
| 204 | + if ARGS.run_once: |
| 205 | + break |
| 206 | + time.sleep(ARGS.eval_interval_secs) |
| 207 | + |
| 208 | + |
| 209 | +def main(): |
| 210 | + """ |
| 211 | + Create eval directory and perform inference on checkpointed model. |
| 212 | + """ |
| 213 | + if tf.gfile.Exists(ARGS.eval_dir): |
| 214 | + tf.gfile.DeleteRecursively(ARGS.eval_dir) |
| 215 | + tf.gfile.MakeDirs(ARGS.eval_dir) |
| 216 | + evaluate() |
| 217 | + |
| 218 | + |
| 219 | +if __name__ == '__main__': |
| 220 | + ARGS = parse_args() |
| 221 | + main() |
0 commit comments