Source code for en_us_normalization.production.classify.multi_token.attached

"""
Copyright Balacoon 2022

tokenize and classify merged tokens
"""

import pynini
from en_us_normalization.production.english_utils import get_data_file_path, UNK_SYMBOLS
from en_us_normalization.production.classify.abbreviation import AbbreviationFst
from en_us_normalization.production.classify.cardinal import CardinalFst
from en_us_normalization.production.classify.word import WordFst
from pynini.lib import pynutil

from learn_to_normalize.grammar_utils.base_fst import BaseFst
from learn_to_normalize.grammar_utils.shortcuts import delete_space, insert_space


[docs]class AttachedTokensFst(BaseFst): """ Attached tokens tries to deal with multi-token string which have `dash` as a separator or doesn't have any separator at all. For example "look33" or "AT&T-wireless". This FST takes advantage of the fact that boundary between some semiotic classes is fairly obvious. Examples of input / output: - look33 -> tokens { name: "look" } tokens { cardinal { count: "33" } } - AT&T-wireless -> tokens { name: "AT and T" } tokens { name: "wireless" } """
[docs] def __init__( self, cardinal: CardinalFst = None, abbreviation: AbbreviationFst = None, word: WordFst = None, ): """ constructor of transducer handling attached (merged) tokens Parameters ---------- cardinal: CardinalFst a cardinal to reuse abbreviation: AbbreviationFst abbreviation to reuse word: WordFst word to reuse """ super().__init__(name="score") # initialize transducers if those are not provided # may be needed in testing. if cardinal is None: cardinal = CardinalFst() if abbreviation is None: abbreviation = AbbreviationFst() if word is None: word = WordFst() symbols = pynini.string_file(get_data_file_path("symbols.tsv")).optimize() # penalize adding more symbols, so if there is another option (for example punctuation) - go with that multiple_symbols = ( pynini.closure(UNK_SYMBOLS) + symbols + pynini.closure(pynutil.add_weight(insert_space, 10) + symbols | UNK_SYMBOLS) ) multiple_symbols = pynutil.insert("name: \"") + multiple_symbols + pynutil.insert("\"") cross_hyphen = pynini.cross("-", " } tokens { ") optional_cross_hyphen = pynutil.insert(" } tokens { ") + pynini.closure(pynutil.delete("-"), 0, 1) # boundary between abbreviation and word is not obvious, so expecting dash as a separator abbr_plus_word = abbreviation.fst + cross_hyphen + word.fst # boundary between abbreviation and number is obvious, so dash is optional abbr_plus_number = abbreviation.fst + optional_cross_hyphen + cardinal.fst # try to avoid situations when string with all consonants is classified as word word_or_abbr = pynutil.add_weight(word.fst, 1.1) | abbreviation.fst # boundary between word and number is also obvious word_plus_number = word_or_abbr + optional_cross_hyphen + cardinal.fst number_plus_word = cardinal.fst + optional_cross_hyphen + word_or_abbr # boundary between word and symbols is obvious word_plus_symbols = word_or_abbr + optional_cross_hyphen + multiple_symbols symbols_plus_word = multiple_symbols + optional_cross_hyphen + word_or_abbr word_plus_unk_symbols = word_or_abbr + pynini.closure(UNK_SYMBOLS, 1) unk_symbols_plus_word = pynini.closure(UNK_SYMBOLS, 1) + word_or_abbr # special case for insta ;) hashtag = pynini.cross("#", "name: \"hashtag\"") + insert_space + word_or_abbr graph = ( abbr_plus_word | abbr_plus_number | pynutil.add_weight(word_plus_number, 1.1) | pynutil.add_weight(number_plus_word, 1.1) | pynutil.add_weight(word_plus_symbols, 20) # regular word weight is 10, avoid shadowing word + punct | pynutil.add_weight(symbols_plus_word, 20) # regular word weight is 10, avoid shadowing punct + word | pynutil.add_weight(word_plus_unk_symbols, 25) # regular word plus unk symbol where latter is deleted | pynutil.add_weight(unk_symbols_plus_word, 25) # unk symbol plus regular word where former is deleted | hashtag # hashtag is overshadowed by symbols_plus_word but has higher weight ) self._multi_fst = graph.optimize()