Source code for lightwood.mixer.tabtransformer

from typing import Dict, Optional

import torch
from tab_transformer_pytorch import TabTransformer

from lightwood.helpers.device import get_device_from_name
from lightwood.data.encoded_ds import EncodedDs
from lightwood.encoder.base import BaseEncoder
from lightwood.mixer.neural import Neural


[docs]class TabTransformerMixer(Neural): def __init__( self, stop_after: float, target: str, dtype_dict: Dict[str, str], target_encoder: BaseEncoder, fit_on_dev: bool, search_hyperparameters: bool, train_args: Optional[dict] = None ): """ This mixer trains a TabTransformer network (FT variant), using concatenated encoder outputs for each dataset feature as input, to predict the encoded target column representation as output. Training logic is based on the Neural mixer, please refer to it for more details on each input parameter. """ # noqa self.train_args = train_args if train_args else {} super().__init__( stop_after, target, dtype_dict, target_encoder, 'FTTransformer', False, # fit_on_dev search_hyperparameters, n_epochs=self.train_args.get('n_epochs', None) ) self.lr = self.train_args.get('lr') self.stable = False # still experimental def _init_net(self, ds: EncodedDs): self.net_class = TabTransformer self.model = TabTransformer( categories=(), # unused, everything is numerical by now num_continuous=len(ds[0][0]), dim=self.train_args.get('dim', 32), dim_out=self.train_args.get('dim_out', len(ds[0][1])), depth=self.train_args.get('depth', 6), heads=self.train_args.get('heads', 8), attn_dropout=self.train_args.get('attn_dropout', 0.1), # post-attention dropout ff_dropout=self.train_args.get('ff_dropout', 0.1), # feed forward dropout mlp_hidden_mults=self.train_args.get('mlp_hidden_mults', (4, 2)), # relative multiples of each hidden dimension of the last mlp to logits # noqa # mlp_act=self.train_args.get('mlp_act', nn.ReLU()), # TODO: import str from nn activations ) self.model.device = get_device_from_name('') self.model.to(self.model.device) def _net_call(self, x: torch.Tensor) -> torch.Tensor: x = torch.unsqueeze(x, 0) if len(x.shape) < 2 else x return self.model(torch.Tensor(), x)
[docs] def fit(self, train_data: EncodedDs, dev_data: EncodedDs) -> None: """ Skip the usual partial_fit call at the end. """ # noqa self._fit(train_data, dev_data)