Source code for lightwood.encoder.categorical.simple_label

from typing import List, Union
from collections import defaultdict
import pandas as pd
import numpy as np
import torch

from lightwood.encoder.base import BaseEncoder
from lightwood.helpers.constants import _UNCOMMON_WORD


[docs]class SimpleLabelEncoder(BaseEncoder): """ Simple encoder that assigns a unique integer to every observed label. Allocates an `unknown` label by default to index 0. Labels must be exact matches between inference and training (e.g. no .lower() on strings is performed here). """ # noqa def __init__(self, is_target=False, normalize=True) -> None: super().__init__(is_target) self.label_map = defaultdict(int) # UNK category maps to 0 self.inv_label_map = {} # invalid encoded values are mapped to None in `decode` self.output_size = 1 self.n_labels = None self.normalize = normalize
[docs] def prepare(self, priming_data: Union[list, pd.Series]) -> None: if not isinstance(priming_data, pd.Series): priming_data = pd.Series(priming_data) for i, v in enumerate(priming_data.unique()): if v is not None: self.label_map[str(v)] = int(i + 1) # leave 0 for UNK self.n_labels = len(self.label_map) for k, v in self.label_map.items(): self.inv_label_map[v] = k self.is_prepared = True
[docs] def encode(self, data: Union[tuple, np.ndarray, pd.Series], normalize=True) -> torch.Tensor: """ :param normalize: can be used to temporarily return unnormalized values """ if not isinstance(data, pd.Series): data = pd.Series(data) # specific to the Gym class - remove once deprecated! if isinstance(data, np.ndarray): data = pd.Series(data) data = data.astype(str) encoded = torch.Tensor(data.map(self.label_map)) if normalize and self.normalize: encoded /= self.n_labels if len(encoded.shape) < 2: encoded = encoded.unsqueeze(-1) return encoded
[docs] def decode(self, encoded_values: torch.Tensor, normalize=True) -> List[object]: """ :param normalize: can be used to temporarily return unnormalized values """ if normalize and self.normalize: encoded_values *= self.n_labels values = encoded_values.long().squeeze().tolist() # long() as inv_label_map expects an int key values = [self.inv_label_map.get(v, _UNCOMMON_WORD) for v in values] return values