Source code for learn_to_pronounce.fst.fst_evaluator

"""
Copyright 2022 Balacoon

Evaluates trained FST model on withhold lexicon
"""

import tqdm
import logging
from typing import List, Tuple

import edlib
from pronunciation_generation import FSTPronunciationGenerator, Pronunciation, PronunciationDictionary, Word


class PronunciationComparator:
    """
    Compares pronunciations, tracks total WER and PER.
    Follows evalution strategy from Phonetisaurus: compares top-1 generated pronunciation with
    all the pronunciations for the given word, selects a pair which is most similar.
    """

    def __init__(self, with_stress=True):
        """
        constructor of pronunciation comporator

        Parameters
        ----------
        with_stress: str
            flag whether to take into account stress when pronunciations are compared
        """
        self._with_stress = with_stress
        self._reset()

    def _reset(self):
        """
        Resets all accumulators for metrics
        """
        self._total_words = 0
        self._correct_words = 0
        self._total_phonemes = 0
        self._incorrect_phonemes = 0

    def compare(
        self,
        reference_pronunciations: List[Pronunciation],
        hypothesis_pronunciation: Pronunciation,
    ):
        """
        Compares pronunciations, updates metrics

        Parameters
        ----------
        reference_pronunciations: List[Pronunciation]
            list of correct pronunciations for given word
        hypothesis_pronunciation: Pronunciation
            hypothesis of pronunciation by PronunciationGenerator
        """
        self._total_words += 1
        self._total_phonemes += reference_pronunciations[0].size()
        min_distance = float("inf")
        min_distance_index = -1
        for i, reference_pronunciation in enumerate(reference_pronunciations):
            reference_phonemes = reference_pronunciation.to_string(
                with_stress=self._with_stress
            ).split()
            hypothesis_phonemes = hypothesis_pronunciation.to_string(
                with_stress=self._with_stress
            ).split()
            distance = edlib.align(hypothesis_phonemes, reference_phonemes)[
                "editDistance"
            ]
            if distance < min_distance:
                min_distance = distance
                min_distance_index = i
            if not distance:
                # correct pronunciation. shortcut to exit function
                self._correct_words += 1
                self._total_phonemes += len(reference_phonemes)
                return

        self._incorrect_phonemes += min_distance
        self._total_phonemes += reference_pronunciations[min_distance_index].size()

    def get_metrics(self) -> Tuple[float, float]:
        """
        returns WER and PER in percents given all compared pronunciations

        Returns
        -------
        wer: float
            word error rate in percent
        per: float
            phoneme error rate in percent
        """
        wer = (
            100.0 * (self._total_words - self._correct_words) / float(self._total_words)
        )
        per = 100.0 * self._incorrect_phonemes / float(self._total_phonemes)
        return wer, per


[docs]class FSTEvaluator: """ Evaluates FST given lexicon. Prints WER and PER, computed from comparing ground truth pronunciations and generated one """
[docs] def __init__(self, fst_path: str): """ constructor of fst evaluator Parameters ---------- fst_path: str path to FST model to evaluate """ self._fst = FSTPronunciationGenerator(fst_path)
[docs] def evaluate(self, lexicon: PronunciationDictionary): """ Runs evaluation Parameters ---------- lexicon: PronunciationDictionary words and ground truth pronunciations to evaluate on """ comparator = PronunciationComparator() comparator_wo_stress = PronunciationComparator(with_stress=False) for ref_word in tqdm.tqdm(lexicon.get_words()): hyp_word = Word(ref_word.name()) self._fst.phoneticize(hyp_word) ref_pron = ref_word.get_pronunciations() hyp_pron = hyp_word.get_pronunciation() comparator.compare(ref_pron, hyp_pron) comparator_wo_stress.compare(ref_pron, hyp_pron) logging.info("Performance taking into account stress marks:") wer, per = comparator.get_metrics() logging.info("WER,%: {:.2f}; PER,%: {:.2f}".format(wer, per)) logging.info( "Performance WITHOUT taking into account stress marks (stressless):" ) wer, per = comparator_wo_stress.get_metrics() logging.info("WER,%: {:.2f}; PER,%: {:.2f}".format(wer, per))