Исходный код recs_searcher.augmentation._char_aug
from typing import List, Union, Optional, Literal
import numpy as np
from ._base import BaseAugmentation
from ._actions import CHAR_ACTIONS
[документация]
class CharAugmentation(BaseAugmentation):
"""Augmentation at the character level."""
def __init__(
self,
unit_prob: float = 0.3,
min_aug: int = 1,
max_aug: int = 5,
mult_num: int = 5,
action: Optional[Literal["delete", "multiply", "swap", "insert"]] = None,
seed: Union[int, None] = None,
) -> None:
super().__init__(
min_aug=min_aug,
max_aug=max_aug,
seed=seed,
)
self.mult_num = mult_num
self.unit_prob = unit_prob
if action is None:
action = np.random.choice(CHAR_ACTIONS)
else:
self.action = action
@property
def actions_list(self) -> List[str]:
"""
Returns:
-------
List[str]: A list of possible methods.
"""
return CHAR_ACTIONS
def __delete(self) -> str:
"""Deletes a random character.
Returns:
str: Empty string.
"""
return ""
def __insert(self, char: str, vocab: List[str]) -> str:
"""Inserts a random character.
Args:
char (str): A symbol from the word.
vocab (List[str]): ...
Returns:
str: A symbol + new symbol.
"""
return char + np.random.choice(vocab)
def __multiply(self, char: str) -> str:
"""Repeats a randomly selected character.
Args:
char (str): A symbol from the word.
Returns:
str: A symbol from the word matmul n times.
"""
if char in [" ", ",", ".", "?", "!", "-"]:
return char
else:
n = np.random.randint(1, self.mult_num)
return char * n
[документация]
def _transform(self, array: List[str]) -> List[str]:
transformed_array = []
for text in array:
typo_text_arr = list(text)
aug_idxs = self._aug_indexing(typo_text_arr, self.unit_prob, clip=True)
for idx in aug_idxs:
if self.action == "delete":
typo_text_arr[idx] = self.__delete()
elif self.action == "insert":
vocab = list(set(text))
typo_text_arr[idx] = self.__insert(typo_text_arr[idx], vocab)
elif self.action == "multiply":
typo_text_arr[idx] = self.__multiply(typo_text_arr[idx])
elif self.action == "swap":
sw = max(0, idx - 1)
typo_text_arr[sw], typo_text_arr[idx] = (
typo_text_arr[idx],
typo_text_arr[sw],
)
else:
raise NameError(
"""These type of augmentation is not available, please try TypoAug.actions_list() to see
available augmentations"""
)
text = ''.join(typo_text_arr)
transformed_array.append(text)
return transformed_array