ModelMixin

class ModelMixin[source]

Bases: object

Mixin for perturbation models.

abstract fit(traindata, valdata=None)[source]

Model fitting.

Parameters:
  • traindata (PlibData[DataFrame]) – Training data.

  • valdata (Optional[PlibData[DataFrame]]) – Validation data.

load_state(model_state)[source]

Args: model_state: Recovering the state of the model.

abstract predict(data_x, batch_size=None)[source]

Predict values for the given data.

Parameters:
  • data_x (PlibData[DataFrame]) – Data without labels i.e. without the “values” column.

  • batch_size (Optional[int]) – Batch size for prediction. Some models might not support this functionality

Return type:

ndarray[Any, dtype[TypeVar(_ScalarType_co, bound= generic, covariant=True)]]

Returns:

Value predictions.

save(path_to_model, model_pars)[source]

Args: path_to_model: Path where the model should be saved. model_pars: Model parameters.