DirectEnsemble

class DirectEnsemble(pipelines: List[etna.pipeline.base.BasePipeline], n_jobs: int = 1, joblib_params: Optional[Dict[str, Any]] = None)[source]

Bases: etna.ensembles.mixins.EnsembleMixin, etna.ensembles.mixins.SaveEnsembleMixin, etna.pipeline.base.BasePipeline

DirectEnsemble is a pipeline that forecasts future values merging the forecasts of base pipelines.

Ensemble expects several pipelines during init. These pipelines are expected to have different forecasting horizons. For each point in the future, forecast of the ensemble is forecast of base pipeline with the shortest horizon, which covers this point.

Examples

>>> from etna.datasets import generate_ar_df
>>> from etna.datasets import TSDataset
>>> from etna.ensembles import DirectEnsemble
>>> from etna.models import NaiveModel
>>> from etna.models import ProphetModel
>>> from etna.pipeline import Pipeline
>>> df = generate_ar_df(periods=30, start_time="2021-06-01", ar_coef=[1.2], n_segments=3)
>>> df_ts_format = TSDataset.to_dataset(df)
>>> ts = TSDataset(df_ts_format, "D")
>>> prophet_pipeline = Pipeline(model=ProphetModel(), transforms=[], horizon=3)
>>> naive_pipeline = Pipeline(model=NaiveModel(lag=10), transforms=[], horizon=5)
>>> ensemble = DirectEnsemble(pipelines=[prophet_pipeline, naive_pipeline])
>>> _ = ensemble.fit(ts=ts)
>>> forecast = ensemble.forecast()
>>> forecast
segment    segment_0 segment_1 segment_2
feature       target    target    target
timestamp
2021-07-01    -10.37   -232.60    163.16
2021-07-02    -10.59   -242.05    169.62
2021-07-03    -11.41   -253.82    177.62
2021-07-04     -5.85   -139.57     96.99
2021-07-05     -6.11   -167.69    116.59

Init DirectEnsemble.

Parameters
  • pipelines (List[etna.pipeline.base.BasePipeline]) – List of pipelines that should be used in ensemble

  • n_jobs (int) – Number of jobs to run in parallel

  • joblib_params (Optional[Dict[str, Any]]) – Additional parameters for joblib.Parallel

Raises

ValueError: – If two or more pipelines have the same horizons.

Inherited-members

Methods

backtest(ts, metrics[, n_folds, mode, ...])

Run backtest with the pipeline.

fit(ts)

Fit pipelines in ensemble.

forecast([ts, prediction_interval, ...])

Make a forecast of the next points of a dataset.

load(path[, ts])

Load an object.

params_to_tune()

Get hyperparameter grid to tune.

predict(ts[, start_timestamp, ...])

Make in-sample predictions on dataset in a given range.

save(path)

Save the object.

set_params(**params)

Return new object instance with modified parameters.

to_dict()

Collect all information about etna object in dict.

Attributes

fit(ts: etna.datasets.tsdataset.TSDataset) etna.ensembles.direct_ensemble.DirectEnsemble[source]

Fit pipelines in ensemble.

Parameters

ts (etna.datasets.tsdataset.TSDataset) – TSDataset to fit ensemble

Returns

Fitted ensemble

Return type

self

params_to_tune() Dict[str, etna.distributions.distributions.BaseDistribution][source]

Get hyperparameter grid to tune.

Not implemented for this class.

Returns

Grid with hyperparameters.

Return type

Dict[str, etna.distributions.distributions.BaseDistribution]