from typing import List, Optional
import pandas as pd
from lightwood.mixer.base import BaseMixer
from lightwood.data.encoded_ds import EncodedDs
from lightwood.api.types import PredictionArguments, SubmodelData
[docs]class BaseEnsemble:
"""
Base class for all ensembles.
Ensembles wrap sets of Lightwood mixers, with the objective of generating better predictions based on the output of each mixer.
There are two important methods for any ensemble to work:
1. `__init__()` should prepare all mixers and internal ensemble logic.
2. `__call__()` applies any aggregation rules to generate final predictions based on the output of each mixer.
Class Attributes:
- mixers: List of mixers the ensemble will use.
- supports_proba: For classification tasks, whether the ensemble supports yielding per-class scores rather than only returning the predicted label.
NOTE: this ensemble is not functional. Do not use it when generating custom JsonAI objects, as the learning process will fail.
""" # noqa
data: EncodedDs
mixers: List[BaseMixer]
best_index: int # @TODO: maybe only applicable to BestOf
supports_proba: bool
prepared: bool
submodel_data: List[SubmodelData]
def __init__(self, target, mixers: List[BaseMixer], data: EncodedDs, fit: Optional[bool] = True) -> None:
self.data = data
self.mixers = mixers
self.best_index = 0
self.supports_proba = False
self.submodel_data = []
def __call__(self, ds: EncodedDs, args: PredictionArguments) -> pd.DataFrame:
raise NotImplementedError()