import time
from copy import deepcopy
from typing import Dict, Optional, List
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader
from type_infer.dtype import dtype
from lightwood.api.types import PredictionArguments
from lightwood.encoder.base import BaseEncoder
from lightwood.data.encoded_ds import EncodedDs, ConcatedEncodedDs
from lightwood.mixer.neural import Neural
from lightwood.mixer.helpers.ar_net import ArNet
from lightwood.mixer.helpers.default_net import DefaultNet
from lightwood.api.types import TimeseriesSettings
[docs]class NeuralTs(Neural):
def __init__(
self,
stop_after: float,
target: str,
dtype_dict: Dict[str, str],
timeseries_settings: TimeseriesSettings,
target_encoder: BaseEncoder,
net: str,
fit_on_dev: bool,
search_hyperparameters: bool,
ts_analysis: Dict[str, Dict],
n_epochs: Optional[int] = None,
use_stl: Optional[bool] = False
):
"""
Subclassed Neural mixer used for time series forecasting scenarios.
:param stop_after: How long the total fitting process should take
:param target: Name of the target column
:param dtype_dict: Data type dictionary
:param timeseries_settings: TimeseriesSettings object for time-series tasks, refer to its documentation for available settings.
:param target_encoder: Reference to the encoder used for the target
:param net: The network type to use (`DeafultNet` or `ArNet`)
:param fit_on_dev: If we should fit on the dev dataset
:param search_hyperparameters: If the network should run a more through hyperparameter search (currently disabled)
:param n_epochs: amount of epochs that the network will be trained for. Supersedes all other early stopping criteria if specified.
""" # noqa
super().__init__(
stop_after,
target,
dtype_dict,
target_encoder,
net,
fit_on_dev,
search_hyperparameters,
n_epochs,
)
self.timeseries_settings = timeseries_settings
assert self.timeseries_settings.is_timeseries
self.ts_analysis = ts_analysis
self.net_class = DefaultNet if net == 'DefaultNet' else ArNet
self.stable = True
self.use_stl = use_stl
def _select_criterion(self) -> torch.nn.Module:
if self.dtype_dict[self.target] in (dtype.integer, dtype.float, dtype.num_tsarray, dtype.quantity):
criterion = nn.L1Loss()
else:
criterion = super()._select_criterion()
return criterion
def _fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:
"""
:param train_data: The network is fit/trained on this
:param dev_data: Data used for early stopping and hyperparameter determination
""" # noqa
self.started = time.time()
original_train = deepcopy(train_data.data_frame)
original_dev = deepcopy(dev_data.data_frame)
# ConcatedEncodedDs
self.batch_size = min(200, int(len(train_data) / 10))
self.batch_size = max(40, self.batch_size)
dev_dl = DataLoader(dev_data, batch_size=self.batch_size, shuffle=False)
train_dl = DataLoader(train_data, batch_size=self.batch_size, shuffle=False)
self.lr = 1e-4
self.num_hidden = 1
# Find learning rate
# keep the weights
self._init_net(train_data)
self.lr, self.model = self._find_lr(train_data)
# Keep on training
optimizer = self._select_optimizer(self.model, lr=self.lr)
criterion = self._select_criterion()
scaler = GradScaler()
# Only 0.8 of the remaining time budget is used to allow some time for the final tuning and partial fit
self.model, epoch_to_best_model, _ = self._max_fit(
train_dl, dev_dl, criterion, optimizer, scaler, (self.stop_after - (time.time() - self.started)) * 0.8,
return_model_after=20000)
self.epochs_to_best += epoch_to_best_model
# restore dfs
train_data.data_frame = original_train
dev_data.data_frame = original_dev
if self.fit_on_dev:
self.partial_fit(dev_data, train_data)
[docs] def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None:
self._fit(train_data, dev_data)
def __call__(self, ds: EncodedDs,
args: PredictionArguments = PredictionArguments()
) -> pd.DataFrame:
original_df = deepcopy(ds.data_frame)
self.model = self.model.eval()
decoded_predictions = []
all_probs: List[List[float]] = []
rev_map = {}
length = sum(ds.encoded_ds_lengths) if isinstance(ds, ConcatedEncodedDs) else len(ds)
pred_cols = [f'prediction_{i}' for i in range(self.timeseries_settings.horizon)]
ydf = pd.DataFrame(0, # zero-filled
index=np.arange(length),
dtype=object,
columns=pred_cols)
with torch.no_grad():
for idx, (X, Y) in enumerate(ds):
X = X.to(self.model.device)
Yh = self.model(X)
Yh = torch.unsqueeze(Yh, 0) if len(Yh.shape) < 2 else Yh
kwargs = {}
for dep in self.target_encoder.dependencies:
kwargs['dependency_data'] = {dep: ds.data_frame.iloc[idx][[dep]].values}
if args.predict_proba and self.supports_proba:
decoded_prediction, probs, rev_map = self.target_encoder.decode_probabilities(Yh, **kwargs)
all_probs.append(probs)
else:
decoded_prediction = self.target_encoder.decode(Yh, **kwargs)
decoded_predictions.extend(decoded_prediction)
decoded_predictions = np.array(decoded_predictions)
if len(decoded_predictions.shape) == 1:
decoded_predictions = np.expand_dims(decoded_predictions, axis=1)
ydf[pred_cols] = decoded_predictions
ydf['prediction'] = ydf.values.tolist()
if self.timeseries_settings.horizon == 1:
ydf['prediction'] = [p[0] for p in ydf['prediction']]
if args.predict_proba and self.supports_proba:
raw_predictions = np.array(all_probs).squeeze(axis=1)
for idx, label in enumerate(rev_map.values()):
ydf[f'__mdb_proba_{label}'] = raw_predictions[:, idx]
# TODO: make this part of the base mixer class? to avoid repetitive code
# and ensure other contribs don't accidentally modify the df
ds.data_frame = original_df
return ydf[['prediction']]