pypots/forecasting/template/model.py
"""
The implementation of YourNewModel for the partially-observed time-series forecasting task.
TODO: modify the above description with your model's information.
"""
# Created by Your Name <Your contact email> TODO: modify the author information.
# License: BSD-3-Clause
from typing import Union, Optional
import torch
from .core import _YourNewModel
# TODO: import the base class from the forecasting package in PyPOTS.
# Here I suppose this is a neural-network forecasting model.
# You should make your model inherent BaseForecaster if it is not a NN.
# from ..base import BaseForecaster
from ..base import BaseNNForecaster
from ...optim.adam import Adam
from ...optim.base import Optimizer
# TODO: define your new model's wrapper here.
# It should be a subclass of a base class defined in PyPOTS task packages (e.g.
# BaseNNForecaster of PyPOTS forecasting task package), and it has to implement all abstract methods of the base class.
# Note that this class is a wrapper of your new model and will be directly exposed to users.
class YourNewModel(BaseNNForecaster):
def __init__(
self,
# TODO: add your model's hyper-parameters here
batch_size: int = 32,
epochs: int = 100,
patience: Optional[int] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: Optional[str] = None,
model_saving_strategy: Optional[str] = "best",
verbose: bool = True,
):
super().__init__(
batch_size,
epochs,
patience,
num_workers,
device,
saving_path,
model_saving_strategy,
verbose,
)
# set up the hyper-parameters
# TODO: set up your model's hyper-parameters here
# set up the model
self.model = _YourNewModel(
# pass the arguments to your model
)
self._print_model_size()
self._send_model_to_given_device()
# set up the optimizer
self.optimizer = optimizer
self.optimizer.init_optimizer(self.model.parameters())
def _assemble_input_for_training(self, data: list) -> dict:
raise NotImplementedError
def _assemble_input_for_validating(self, data: list) -> dict:
raise NotImplementedError
def _assemble_input_for_testing(self, data: list) -> dict:
raise NotImplementedError
def fit(
self,
train_set: Union[dict, str],
val_set: Optional[Union[dict, str]] = None,
file_type: str = "hdf5",
) -> None:
raise NotImplementedError
def predict(
self,
test_set: Union[dict, str],
file_type: str = "hdf5",
) -> dict:
raise NotImplementedError