MABAlgo#
- class gobrec.mabs.mab_algo.MABAlgo(seed: int | None = None)[source]#
Abstract class for Multi-Armed Bandit algorithms.
This class defines the interface for MAB algorithms to be used in the GOBRec recommender system. It also includes commom label encoding functionality to handle item IDs.
- Attributes:
- seedint
Random seed for reproducibility.
- rngnp.random.Generator
Random number generator, initialized with the provided seed.
- label_encoder_LabelEncoder
A label encoder to convert item IDs to integer indices and vice versa.
- num_armsint
The number of unique arms (items) known to the algorithm. It starts as None and is updated when fitting with new item IDs.
- num_featuresint
The number of features in the context vectors. It starts as None and is set when fitting for the first time. It is asserted to be consistent in subsequent fits.
Methods
fit(contexts, decisions, rewards)Fit the MAB algorithm with the provided contexts, item IDs, and rewards.
predict(contexts)Predict the expected rewards for the given contexts.
reset()Reset the MAB algorithm to its initial state.
- abstract fit(contexts: ndarray, decisions: ndarray, rewards: ndarray)[source]#
Fit the MAB algorithm with the provided contexts, item IDs, and rewards.
- Parameters:
- contextsnp.ndarray
A 2D array of shape (n_samples, n_features) representing the context arrays.
- decisionsnp.ndarray
A 1D array of item IDs (arms or decisions) of shape (n_samples,) where each element can be strings or integers.
- rewardsnp.ndarray
A 1D array of rewards (ratings) of shape (n_samples,). It can be integers or floats.
- abstract predict(contexts: ndarray) Tensor[source]#
Predict the expected rewards for the given contexts.
- Parameters:
- contextsnp.ndarray
A 2D array of shape (n_samples, n_features) representing the context arrays
- Returns:
- expected_rewardstorch.Tensor
A 2D tensor of shape (n_samples, n_arms) representing the expected rewards for each arm (item) given the contexts. The encoded items ids are used here. To get the original item IDs, it is possible to use the label_encoder.inverse_transform method.