Source code for learn_to_pronounce.fst.fst_trainer

"""
Copyright 2022 Balacoon

trains FST - model to generation pronunciation or spelling
"""

import argparse
import logging
import os
from importlib.machinery import SourceFileLoader

from pronunciation_generation import PronunciationDictionary

from learn_to_pronounce.fst.fst_evaluator import FSTEvaluator
from learn_to_pronounce.resources.provider import AbstractProvider


def add_fst_arguments(parser: argparse.ArgumentParser):
    """
    Adds special arguments specific to FST training into argument parsing

    Parameters
    ----------
    parser: argparse.ArgumentParser
        argument parser from recipe to add special arguments to
    """
    arg_group = parser.add_argument_group("fst")
    arg_group.add_argument(
        "--fst-order",
        default=8,
        type=int,
        help="Maximum N-gram order to be used in FST",
    )
    arg_group.add_argument(
        "--fst-spelling-order",
        default=3,
        type=int,
        help="Maximum N-gram order to be used in spelling FST",
    )


[docs]class FSTTrainer: """ Trains FST based on provided lexicon. Training is done with phonetisaurus. Can be used to train pronunciation or spelling generation. """
[docs] def __init__( self, provider: AbstractProvider, work_dir: str, args: argparse.Namespace ): """ constructor Parameters ---------- provider: AbstractProvider resources provider that is used to get specific lexicon for training work_dir: str directory where all intermediate artifacts are stored args: argparse.Namespace parsed arguments, containing arguments added in :func:`add_fst_arguments` """ self._provider = provider self._work_dir = work_dir self._args = args
@staticmethod def _dump_fst_train_data(pd: PronunciationDictionary, path: str): """ Helper function that stores pronunciation dictionary suitalbe for FST training """ with open(path, "w", encoding="utf-8") as fp: words = pd.get_words() # order of words influences result! for word in sorted(words, key=lambda x: x.name()): for pronunciation in word.get_pronunciations(): fp.write("{}\t{}\n".format(word.name(), pronunciation.to_string())) def _train_fst( self, lexicon: PronunciationDictionary, train_data_name: str, model_name: str, ngram_order: int, **phonetisaurus_args ) -> str: """ Helper function that trains FST on the given lexicon Parameters ---------- lexicon: PronunciationDictionary lexicon to train on train_data_name: str name to give to intermediate file with training data model_name: str name to give to file with trained model ngram_order: int maximum n-gram order to be used in the FST training. Primary parameter that defines tradeoff between model size and accuracy. **phonetisaurus_args: other named parameters passed directly to phonetisaurus_train.G2PModelTrainer Returns ------- fst_path: str path to trained FST model """ train_data_path = os.path.join(self._work_dir, train_data_name) self._dump_fst_train_data(lexicon, train_data_path) phonetisaurus_train = SourceFileLoader( "", "/usr/local/bin/phonetisaurus-train" ).load_module() phonetisaurus_trainer = phonetisaurus_train.G2PModelTrainer( train_data_path, dir_prefix=self._work_dir, model_prefix=model_name, ngram_order=ngram_order, **phonetisaurus_args ) phonetisaurus_trainer.TrainG2PModel() fst_path = os.path.join(self._work_dir, model_name + ".fst") return fst_path
[docs] def train_pronunciation(self) -> str: """ Training pronunciation FST Returns ------- fst_path: str path to trained pronunciation model """ train_lexicon = self._provider.get_lexicon( words=self._provider.get_train_words() ) logging.info( "Training pronunciation FST on {} words".format(train_lexicon.size()) ) fst_path = self._train_fst( train_lexicon, train_data_name="pronunciation_training_data", model_name="pronunciation", ngram_order=self._args.fst_order, seq2_del=True, ) return fst_path
def evaluate_pronunciation(self): fst_path = os.path.join(self._work_dir, "pronunciation.fst") if not os.path.isfile(fst_path): raise FileNotFoundError("Can't run evalution, missing [{}]. Run training first.".format(fst_path)) test_words = self._provider.get_test_words() if not test_words: logging.warning( "FST evaluation is enabled, but there is no test words in resource directory" ) return test_lexicon = self._provider.get_lexicon(words=test_words) logging.info( "Evaluating pronunciation FST on {} words".format( test_lexicon.size() ) ) evaluator = FSTEvaluator(fst_path) evaluator.evaluate(test_lexicon)
[docs] def train_spelling(self) -> str: """ Training spelling FST Returns ------- fst_path: str path to trained spelling model """ spelling_lexicon = self._provider.get_spelling_lexicon() logging.info( "Training spelling FST on {} words".format(spelling_lexicon.size()) ) fst_path = self._train_fst( spelling_lexicon, train_data_name="spelling_training_data", model_name="spelling", ngram_order=self._args.fst_spelling_order, seq2_del=True, ) return fst_path