from typing import List
import pandas as pd
from lightwood.mixer.base import BaseMixer
from lightwood.ensemble.base import BaseEnsemble
from lightwood.api.types import PredictionArguments
from lightwood.data.encoded_ds import EncodedDs
[docs]class Embedder(BaseEnsemble):
"""
This ensemble acts as a simple embedder that bypasses all mixers.
When called, it will return the encoded representation of the data stored in (or generated by) an EncodedDs object.
""" # noqa
def __init__(self, target, mixers: List[BaseMixer], data: EncodedDs) -> None:
super().__init__(target, list(), data)
self.embedding_size = data.get_encoded_data(include_target=False).shape[-1]
self.prepared = True
def __call__(self, ds: EncodedDs, args: PredictionArguments = None) -> pd.DataFrame:
# shape: (B, self.embedding_size)
encoded_representations = ds.get_encoded_data(include_target=False).numpy()
return pd.DataFrame(encoded_representations)