Fits a KAF model for regression using mean squared error loss.
Usage
kaf_fit(
x,
y,
task = c("auto", "regression", "binary", "multiclass"),
hidden = c(64, 64),
num_grids = 16,
dropout = 0,
use_layernorm = TRUE,
fourier_init_scale = 0.01,
epochs = 1000,
lr = 0.001,
batch_size = NULL,
shuffle = TRUE,
validation_split = 0,
x_val = NULL,
y_val = NULL,
weight_decay = 0,
standardize_x = TRUE,
standardize_y = NULL,
patience = NULL,
verbose = TRUE,
print_every = 100,
seed = NULL,
restore_best = TRUE,
min_delta = 0
)Arguments
- x
Matrix, data frame, vector, or 2D torch tensor of predictors.
- y
Vector, matrix, data frame, or torch tensor of targets.
- task
Character. One of
"auto","regression","binary", or"multiclass". With"auto", factor, character, and logical targets are treated as classification; numeric targets are treated as regression.Integer vector. Hidden layer sizes.
- num_grids
Integer. Number of Fourier frequencies per KAF layer.
- dropout
Numeric. Dropout probability.
- use_layernorm
Logical. Whether to apply layer normalization.
- fourier_init_scale
Numeric. Initial scale of the Fourier branch.
- epochs
Integer. Maximum number of training epochs.
- lr
Numeric. Learning rate.
- batch_size
Optional integer. Mini-batch size. If
NULL, full-batch training is used.- shuffle
Logical. Whether to shuffle training rows each epoch.
- validation_split
Numeric in
[0, 1). Fraction of rows to reserve for validation. Ignored ifx_valandy_valare supplied.- x_val
Optional validation predictors.
- y_val
Optional validation targets.
- weight_decay
Numeric. Adam weight decay.
- standardize_x
Logical. Whether to standardize predictors using the training-set mean and standard deviation.
- standardize_y
Logical or
NULL. Whether to standardize regression targets using the training-set mean and standard deviation. IfNULL, targets are standardized for regression and not standardized for classification. Predictions are automatically transformed back to the original target scale.- patience
Optional integer. Number of epochs without improvement before early stopping.
- verbose
Logical. Whether to print progress.
- print_every
Integer. Print frequency.
- seed
Optional integer random seed.
- restore_best
Logical. Whether to restore the best observed model state after training.
- min_delta
Numeric. Minimum loss improvement required to update the best model state.