Source code for lightwood.ensemble.embed

from typing import List
import pandas as pd

from lightwood.mixer.base import BaseMixer
from lightwood.ensemble.base import BaseEnsemble
from lightwood.api.types import PredictionArguments
from lightwood.data.encoded_ds import EncodedDs


[docs]class Embedder(BaseEnsemble): """ This ensemble acts as a simple embedder that bypasses all mixers. When called, it will return the encoded representation of the data stored in (or generated by) an EncodedDs object. """ # noqa def __init__(self, target, mixers: List[BaseMixer], data: EncodedDs) -> None: super().__init__(target, list(), data) self.embedding_size = data.get_encoded_data(include_target=False).shape[-1] self.prepared = True def __call__(self, ds: EncodedDs, args: PredictionArguments = None) -> pd.DataFrame: # shape: (B, self.embedding_size) encoded_representations = ds.get_encoded_data(include_target=False).numpy() return pd.DataFrame(encoded_representations)