Tabnet first visit

https://arxiv.org/abs/1908.07442

With the growth of the neural networks, DNN is showing a good performance in text, image, audio data.
However, for the everyday tabular data, we've yet to see the success compare to XGB LGBM.

Tree based models have advantages in:

  • decision manifolds, like a super hyperplane (infinite expandable, cuts the table well)
  • interpretable
  • fast to train

DNN

  • able to encode data (representation learning)
  • reduce feature engineering
  • online learning

discuss: reasons why DNN not performant on tabular data

  • non-linear, not restriction on convergence, easy overfit compare to tree ensemble
  • adding more layers can cause overparameterization, this may be why it isn't performant in tabular data.

How good would it be if we can have a framework that's both end-to-end, representation and can do online update for tabular data

build decision boundary using DNN

tree: 

DNN:

we can treat the mask as a decision boundary

the mask+FC+ReLU is like a vanilla decision tree
this is a additive model, each output represents the weight of each condition that affects the final decision

Tabnet model structure

this is a more complex additive model (e.g. given input batch x feature -> single vector)

Layers

  • BN: batch normalization
  • Feature transformer: =FC, calculate feature embedding
    • FC + gated
    • shared + step dependent
    • h(X) = (W*X+b) \bigotimes \sigma(V*X+c)
  • Split:
    • [d[i],a[i]]=f_i(M[i] \cdot f)
    • d[i] for output, a[i] for calculating next step's mask
  • Attention transformer
    • calculate mask
    • Sparsemax (like a sparse softmax)
    • M[i] = Sparsemax(P[i-1] \cdot h_i(a[i-1]))
      • a[i-1] is from previous step SPlit
      • h_i() is FC+BN
      • P[i] is Prior scales, P[i]=\sum_{j=1}^{i}(\gamma-M[j]), the degree that a feature is being used in previous step, to reduce weight on the features
      • Loss: L_{sparse} = \sum_{i=1}^{N_{steps}} \sum_{b=1}^{B} \sum_{j=1}^{D} \frac{-M_{b,j}[i]}{N_{steps} \cdot B} log(M_{b,j}[i] + \epsilon), penalise large M, making M sparse

Overall: Tabnet uses sequential multi-step to construct a additive NN framework

  • Attentive transformer
    • use last step result to calculate current step Mask, and tries its best to make sure the Mask is sparse and non-repetitive
    • different data point uses different mask allows, different datato use different feature (instance-wise) (tree: batch)
  • Feature transformer
    • performs the feature selection of current step

Tabnet evaluation

Paper data/evaluation:

Performs very similar or better than gbm based models

private data: offer activation prediction

tabular data, contains customer profile and sale records

LGBM roc, 0.83
Tabnet, 0.8

Tabnet has 100x the machine cost to achieve the same training speed compare to LGBM

Kaggle

data: gene expression

https://www.kaggle.com/c/lish-moa/data?select=train_features.csv

target: the sample had a positive response for each MoA target.

All top models use tabnet in their ensemble, but tabnet doesn't contribute to large weighting.

Conclusion

Tabnet does show a good promise of using NN for tabular data. Although the performance is still not comparable to GBM models, it can still serve as the online learning component in mixed ensembles for tabular data.

Comments