Source code for lightwood.mixer.unit

from typing import List, Optional

import torch
import pandas as pd

from lightwood.helpers.log import log
from lightwood.mixer.base import BaseMixer
from lightwood.encoder.base import BaseEncoder
from lightwood.data.encoded_ds import EncodedDs
from lightwood.api.types import PredictionArguments


[docs]class Unit(BaseMixer): def __init__(self, stop_after: float, target_encoder: BaseEncoder): """ The "Unit" mixer serves as a simple wrapper around a target encoder, essentially borrowing the encoder's functionality for predictions. In other words, it simply arg-maxes the output of the encoder Used with encoders that already fine-tune on the targets (namely, pre-trained text ML models). Attributes: :param target_encoder: An instance of a Lightwood BaseEncoder. This encoder is used to decode predictions. :param stop_after (float): Time budget (in seconds) to train this mixer. """ # noqa super().__init__(stop_after) self.target_encoder = target_encoder self.supports_proba = False self.stable = True
[docs] def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None: log.info("Unit mixer does not require training, it passes through predictions from its encoders.")
[docs] def partial_fit(self, train_data: EncodedDs, dev_data: EncodedDs, args: Optional[dict] = None) -> None: pass
def __call__(self, ds: EncodedDs, args: PredictionArguments = PredictionArguments()) -> pd.DataFrame: """ Makes predictions using the provided EncodedDs dataset. Mixer decodes predictions using the target encoder and returns them in a pandas DataFrame. :returns ydf (pd.DataFrame): a data frame containing the decoded predictions. """ if args.predict_proba: # @TODO: depending on the target encoder, this might be enabled log.warning('This model does not output probability estimates') decoded_predictions: List[object] = [] for X, _ in ds: decoded_prediction = self.target_encoder.decode(torch.unsqueeze(X, 0)) decoded_predictions.extend(decoded_prediction) ydf = pd.DataFrame({"prediction": decoded_predictions}) return ydf