tensorflow/models

View on GitHub
research/deep_speech/data/featurizer.py

Summary

Maintainability
A
45 mins
Test Coverage
#  Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
# ==============================================================================
"""Utility class for extracting features from the text and audio input."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import codecs
import numpy as np


def compute_spectrogram_feature(samples, sample_rate, stride_ms=10.0,
                                window_ms=20.0, max_freq=None, eps=1e-14):
  """Compute the spectrograms for the input samples(waveforms).

  More about spectrogram computation, please refer to:
  https://en.wikipedia.org/wiki/Short-time_Fourier_transform.
  """
  if max_freq is None:
    max_freq = sample_rate / 2
  if max_freq > sample_rate / 2:
    raise ValueError("max_freq must not be greater than half of sample rate.")

  if stride_ms > window_ms:
    raise ValueError("Stride size must not be greater than window size.")

  stride_size = int(0.001 * sample_rate * stride_ms)
  window_size = int(0.001 * sample_rate * window_ms)

  # Extract strided windows
  truncate_size = (len(samples) - window_size) % stride_size
  samples = samples[:len(samples) - truncate_size]
  nshape = (window_size, (len(samples) - window_size) // stride_size + 1)
  nstrides = (samples.strides[0], samples.strides[0] * stride_size)
  windows = np.lib.stride_tricks.as_strided(
      samples, shape=nshape, strides=nstrides)
  assert np.all(
      windows[:, 1] == samples[stride_size:(stride_size + window_size)])

  # Window weighting, squared Fast Fourier Transform (fft), scaling
  weighting = np.hanning(window_size)[:, None]
  fft = np.fft.rfft(windows * weighting, axis=0)
  fft = np.absolute(fft)
  fft = fft**2
  scale = np.sum(weighting**2) * sample_rate
  fft[1:-1, :] *= (2.0 / scale)
  fft[(0, -1), :] /= scale
  # Prepare fft frequency list
  freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])

  # Compute spectrogram feature
  ind = np.where(freqs <= max_freq)[0][-1] + 1
  specgram = np.log(fft[:ind, :] + eps)
  return np.transpose(specgram, (1, 0))


class AudioFeaturizer(object):
  """Class to extract spectrogram features from the audio input."""

  def __init__(self,
               sample_rate=16000,
               window_ms=20.0,
               stride_ms=10.0):
    """Initialize the audio featurizer class according to the configs.

    Args:
      sample_rate: an integer specifying the sample rate of the input waveform.
      window_ms: an integer for the length of a spectrogram frame, in ms.
      stride_ms: an integer for the frame stride, in ms.
    """
    self.sample_rate = sample_rate
    self.window_ms = window_ms
    self.stride_ms = stride_ms


def compute_label_feature(text, token_to_idx):
  """Convert string to a list of integers."""
  tokens = list(text.strip().lower())
  feats = [token_to_idx[token] for token in tokens]
  return feats


class TextFeaturizer(object):
  """Extract text feature based on char-level granularity.

  By looking up the vocabulary table, each input string (one line of transcript)
  will be converted to a sequence of integer indexes.
  """

  def __init__(self, vocab_file):
    lines = []
    with codecs.open(vocab_file, "r", "utf-8") as fin:
      lines.extend(fin.readlines())
    self.token_to_index = {}
    self.index_to_token = {}
    self.speech_labels = ""
    index = 0
    for line in lines:
      line = line[:-1]  # Strip the '\n' char.
      if line.startswith("#"):
        # Skip from reading comment line.
        continue
      self.token_to_index[line] = index
      self.index_to_token[index] = line
      self.speech_labels += line
      index += 1