Source code for lightwood.ensemble.identity

from typing import List
import pandas as pd

from lightwood.mixer.base import BaseMixer
from lightwood.ensemble.base import BaseEnsemble
from lightwood.api.types import PredictionArguments
from lightwood.data.encoded_ds import EncodedDs


[docs]class IdentityEnsemble(BaseEnsemble): """ This ensemble performs no aggregation. User can define an "active mixer" and calling the ensemble will call said mixer. Ideal for use cases with single mixers where (potentially expensive) evaluation runs are done internally, as in `BestOf`. """ # noqa def __init__(self, target, mixers: List[BaseMixer], data: EncodedDs, args: PredictionArguments) -> None: super().__init__(target, mixers, data=data) self._active_mixer = 0 single_row_ds = EncodedDs(data.encoders, data.data_frame.iloc[[0]], data.target) _ = self.mixers[self._active_mixer](single_row_ds, args)['prediction'] # prime mixer for storage, needed because NHitsMixer.model (neuralforecast.NHITS) is not serializable without this, oddly enough. Eventually, check this again and remove if possible! # noqa self.prepared = True def __call__(self, ds: EncodedDs, args: PredictionArguments = None) -> pd.DataFrame: assert self.prepared mixer = self.mixers[self.active_mixer] return mixer(ds, args=args) @property def active_mixer(self): return self._active_mixer @active_mixer.setter def active_mixer(self, idx): assert 0 <= idx < len(self.mixers), f'The ensemble has {len(self.mixers)} mixers, please provide a valid index.' self._active_mixer = idx