from typing import Dict, List, Tuple, Optional
import numpy as np
from dataprep_ml import StatisticalAnalysis
from lightwood.helpers.log import log
from type_infer.dtype import dtype
from lightwood.ensemble import BaseEnsemble
from lightwood.analysis.base import BaseAnalysisBlock
from lightwood.data.encoded_ds import EncodedDs
from lightwood.encoder.text.pretrained import PretrainedLangEncoder
from lightwood.api.types import ModelAnalysis, ProblemDefinition, PredictionArguments
[docs]def model_analyzer(
predictor: BaseEnsemble,
data: EncodedDs,
train_data: EncodedDs,
stats_info: StatisticalAnalysis,
target: str,
pdef: ProblemDefinition,
dtype_dict: Dict[str, str],
accuracy_functions,
ts_analysis: Dict,
analysis_blocks: Optional[List[BaseAnalysisBlock]] = []
) -> Tuple[ModelAnalysis, Dict[str, object]]:
"""
Analyses model on a validation subset to evaluate accuracy, estimate feature importance and generate a
calibration model to estimating confidence in future predictions.
Additionally, any user-specified analysis blocks (see class `BaseAnalysisBlock`) are also called here.
:return:
runtime_analyzer: This dictionary object gets populated in a sequential fashion with data generated from
any `.analyze()` block call. This dictionary object is stored in the predictor itself, and used when
calling the `.explain()` method of all analysis blocks when generating predictions.
model_analysis: `ModelAnalysis` object that contains core analysis metrics, not necessarily needed when predicting.
"""
runtime_analyzer = {}
data_type = dtype_dict[target]
tss = pdef.timeseries_settings
# retrieve encoded data representations
encoded_train_data = train_data
encoded_val_data = data
data = encoded_val_data.data_frame
input_cols = list([col for col in data.columns if col != target])
if not pdef.embedding_only:
# predictive task
is_numerical = data_type in (dtype.integer, dtype.float, dtype.num_tsarray, dtype.quantity)
is_classification = data_type in (dtype.categorical, dtype.binary, dtype.cat_tsarray)
is_multi_ts = tss.is_timeseries and tss.horizon > 1
has_pretrained_text_enc = any([isinstance(enc, PretrainedLangEncoder)
for enc in encoded_train_data.encoders.values()])
# raw predictions for validation dataset
args = {} if not is_classification else {"predict_proba": True}
normal_predictions = None
if len(analysis_blocks) > 0:
if tss.is_timeseries:
# we retrieve the first entry per group (closest to supervision cutoff)
if tss.group_by:
encoded_val_data.data_frame['__mdb_val_idx'] = np.arange(len(encoded_val_data))
idxs = encoded_val_data.data_frame.groupby(by=tss.group_by).first()['__mdb_val_idx'].values
encoded_val_data.data_frame = encoded_val_data.data_frame.iloc[idxs, :]
if encoded_val_data.cache_built:
encoded_val_data.X_cache = encoded_val_data.X_cache[idxs, :]
encoded_val_data.Y_cache = encoded_val_data.Y_cache[idxs, :]
normal_predictions = predictor(encoded_val_data, args=PredictionArguments.from_dict(args))
normal_predictions = normal_predictions.set_index(encoded_val_data.data_frame.index)
# ------------------------- #
# Run analysis blocks, both core and user-defined
# ------------------------- #
kwargs = {
'predictor': predictor,
'target': target,
'input_cols': input_cols,
'dtype_dict': dtype_dict,
'normal_predictions': normal_predictions,
'data': encoded_val_data.data_frame,
'train_data': train_data,
'encoded_val_data': encoded_val_data,
'is_classification': is_classification,
'is_numerical': is_numerical,
'is_multi_ts': is_multi_ts,
'stats_info': stats_info,
'tss': tss,
'ts_analysis': ts_analysis,
'accuracy_functions': accuracy_functions,
'has_pretrained_text_enc': has_pretrained_text_enc
}
for block in analysis_blocks:
log.info("The block %s is now running its analyze() method", block.__class__.__name__)
runtime_analyzer = block.analyze(runtime_analyzer, **kwargs)
# ------------------------- #
# Populate ModelAnalysis object
# ------------------------- #
model_analysis = ModelAnalysis(
accuracies=runtime_analyzer.get('score_dict', {}),
accuracy_histogram=runtime_analyzer.get('acc_histogram', {}),
accuracy_samples=runtime_analyzer.get('acc_samples', {}),
train_sample_size=len(encoded_train_data),
test_sample_size=len(encoded_val_data),
confusion_matrix=runtime_analyzer.get('cm', []),
column_importances=runtime_analyzer.get('column_importances', {}),
histograms=stats_info.histograms,
dtypes=dtype_dict,
submodel_data=predictor.submodel_data
)
return model_analysis, runtime_analyzer