bikes.utils.splitters

Split dataframes into subsets (e.g., train/valid/test).

  1"""Split dataframes into subsets (e.g., train/valid/test)."""
  2
  3# %% IMPORTS
  4
  5import abc
  6import typing as T
  7
  8import numpy as np
  9import numpy.typing as npt
 10import pydantic as pdt
 11from sklearn import model_selection
 12
 13from bikes.core import schemas
 14
 15# %% TYPES
 16
 17Index = npt.NDArray[np.int64]
 18TrainTestIndex = tuple[Index, Index]
 19TrainTestSplits = T.Iterator[TrainTestIndex]
 20
 21# %% SPLITTERS
 22
 23
 24class Splitter(abc.ABC, pdt.BaseModel, strict=True, frozen=True, extra="forbid"):
 25    """Base class for a splitter.
 26
 27    Use splitters to split data in sets.
 28    e.g., split between a train/test subsets.
 29
 30    # https://scikit-learn.org/stable/glossary.html#term-CV-splitter
 31    """
 32
 33    KIND: str
 34
 35    @abc.abstractmethod
 36    def split(
 37        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
 38    ) -> TrainTestSplits:
 39        """Split a dataframe into subsets.
 40
 41        Args:
 42            inputs (schemas.Inputs): model inputs.
 43            targets (schemas.Targets): model targets.
 44            groups (Index | None, optional): group labels.
 45
 46        Returns:
 47            TrainTestSplits: iterator over the dataframe train/test splits.
 48        """
 49
 50    @abc.abstractmethod
 51    def get_n_splits(
 52        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
 53    ) -> int:
 54        """Get the number of splits generated.
 55
 56        Args:
 57            inputs (schemas.Inputs): models inputs.
 58            targets (schemas.Targets): model targets.
 59            groups (Index | None, optional): group labels.
 60
 61        Returns:
 62            int: number of splits generated.
 63        """
 64
 65
 66class TrainTestSplitter(Splitter):
 67    """Split a dataframe into a train and test set.
 68
 69    Parameters:
 70        shuffle (bool): shuffle the dataset. Default is False.
 71        test_size (int | float): number/ratio for the test set.
 72        random_state (int): random state for the splitter object.
 73    """
 74
 75    KIND: T.Literal["TrainTestSplitter"] = "TrainTestSplitter"
 76
 77    shuffle: bool = False  # required (time sensitive)
 78    test_size: int | float = 24 * 30 * 2  # 2 months
 79    random_state: int = 42
 80
 81    @T.override
 82    def split(
 83        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
 84    ) -> TrainTestSplits:
 85        index = np.arange(len(inputs))  # return integer position
 86        train_index, test_index = model_selection.train_test_split(
 87            index, shuffle=self.shuffle, test_size=self.test_size, random_state=self.random_state
 88        )
 89        yield train_index, test_index
 90
 91    @T.override
 92    def get_n_splits(
 93        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
 94    ) -> int:
 95        return 1
 96
 97
 98class TimeSeriesSplitter(Splitter):
 99    """Split a dataframe into fixed time series subsets.
100
101    Parameters:
102        gap (int): gap between splits.
103        n_splits (int): number of split to generate.
104        test_size (int | float): number or ratio for the test dataset.
105    """
106
107    KIND: T.Literal["TimeSeriesSplitter"] = "TimeSeriesSplitter"
108
109    gap: int = 0
110    n_splits: int = 4
111    test_size: int | float = 24 * 30 * 2  # 2 months
112
113    @T.override
114    def split(
115        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
116    ) -> TrainTestSplits:
117        splitter = model_selection.TimeSeriesSplit(n_splits=self.n_splits, test_size=self.test_size)
118        yield from splitter.split(inputs)
119
120    @T.override
121    def get_n_splits(
122        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
123    ) -> int:
124        return self.n_splits
125
126
127SplitterKind = TrainTestSplitter | TimeSeriesSplitter
Index = numpy.ndarray[typing.Any, numpy.dtype[numpy.int64]]
TrainTestIndex = tuple[numpy.ndarray[typing.Any, numpy.dtype[numpy.int64]], numpy.ndarray[typing.Any, numpy.dtype[numpy.int64]]]
TrainTestSplits = typing.Iterator[tuple[numpy.ndarray[typing.Any, numpy.dtype[numpy.int64]], numpy.ndarray[typing.Any, numpy.dtype[numpy.int64]]]]
class Splitter(abc.ABC, pydantic.main.BaseModel):
25class Splitter(abc.ABC, pdt.BaseModel, strict=True, frozen=True, extra="forbid"):
26    """Base class for a splitter.
27
28    Use splitters to split data in sets.
29    e.g., split between a train/test subsets.
30
31    # https://scikit-learn.org/stable/glossary.html#term-CV-splitter
32    """
33
34    KIND: str
35
36    @abc.abstractmethod
37    def split(
38        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
39    ) -> TrainTestSplits:
40        """Split a dataframe into subsets.
41
42        Args:
43            inputs (schemas.Inputs): model inputs.
44            targets (schemas.Targets): model targets.
45            groups (Index | None, optional): group labels.
46
47        Returns:
48            TrainTestSplits: iterator over the dataframe train/test splits.
49        """
50
51    @abc.abstractmethod
52    def get_n_splits(
53        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
54    ) -> int:
55        """Get the number of splits generated.
56
57        Args:
58            inputs (schemas.Inputs): models inputs.
59            targets (schemas.Targets): model targets.
60            groups (Index | None, optional): group labels.
61
62        Returns:
63            int: number of splits generated.
64        """

Base class for a splitter.

Use splitters to split data in sets. e.g., split between a train/test subsets.

https://scikit-learn.org/stable/glossary.html#term-CV-splitter

KIND: str
@abc.abstractmethod
def split( self, inputs: pandera.typing.pandas.DataFrame[bikes.core.schemas.InputsSchema], targets: pandera.typing.pandas.DataFrame[bikes.core.schemas.TargetsSchema], groups: numpy.ndarray[typing.Any, numpy.dtype[numpy.int64]] | None = None) -> Iterator[tuple[numpy.ndarray[Any, numpy.dtype[numpy.int64]], numpy.ndarray[Any, numpy.dtype[numpy.int64]]]]:
36    @abc.abstractmethod
37    def split(
38        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
39    ) -> TrainTestSplits:
40        """Split a dataframe into subsets.
41
42        Args:
43            inputs (schemas.Inputs): model inputs.
44            targets (schemas.Targets): model targets.
45            groups (Index | None, optional): group labels.
46
47        Returns:
48            TrainTestSplits: iterator over the dataframe train/test splits.
49        """

Split a dataframe into subsets.

Arguments:
  • inputs (schemas.Inputs): model inputs.
  • targets (schemas.Targets): model targets.
  • groups (Index | None, optional): group labels.
Returns:

TrainTestSplits: iterator over the dataframe train/test splits.

@abc.abstractmethod
def get_n_splits( self, inputs: pandera.typing.pandas.DataFrame[bikes.core.schemas.InputsSchema], targets: pandera.typing.pandas.DataFrame[bikes.core.schemas.TargetsSchema], groups: numpy.ndarray[typing.Any, numpy.dtype[numpy.int64]] | None = None) -> int:
51    @abc.abstractmethod
52    def get_n_splits(
53        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
54    ) -> int:
55        """Get the number of splits generated.
56
57        Args:
58            inputs (schemas.Inputs): models inputs.
59            targets (schemas.Targets): model targets.
60            groups (Index | None, optional): group labels.
61
62        Returns:
63            int: number of splits generated.
64        """

Get the number of splits generated.

Arguments:
  • inputs (schemas.Inputs): models inputs.
  • targets (schemas.Targets): model targets.
  • groups (Index | None, optional): group labels.
Returns:

int: number of splits generated.

model_config = {'strict': True, 'frozen': True, 'extra': 'forbid'}
model_fields = {'KIND': FieldInfo(annotation=str, required=True)}
model_computed_fields = {}
Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
dict
json
parse_obj
parse_raw
parse_file
from_orm
construct
copy
schema
schema_json
validate
update_forward_refs
class TrainTestSplitter(Splitter):
67class TrainTestSplitter(Splitter):
68    """Split a dataframe into a train and test set.
69
70    Parameters:
71        shuffle (bool): shuffle the dataset. Default is False.
72        test_size (int | float): number/ratio for the test set.
73        random_state (int): random state for the splitter object.
74    """
75
76    KIND: T.Literal["TrainTestSplitter"] = "TrainTestSplitter"
77
78    shuffle: bool = False  # required (time sensitive)
79    test_size: int | float = 24 * 30 * 2  # 2 months
80    random_state: int = 42
81
82    @T.override
83    def split(
84        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
85    ) -> TrainTestSplits:
86        index = np.arange(len(inputs))  # return integer position
87        train_index, test_index = model_selection.train_test_split(
88            index, shuffle=self.shuffle, test_size=self.test_size, random_state=self.random_state
89        )
90        yield train_index, test_index
91
92    @T.override
93    def get_n_splits(
94        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
95    ) -> int:
96        return 1

Split a dataframe into a train and test set.

Arguments:
  • shuffle (bool): shuffle the dataset. Default is False.
  • test_size (int | float): number/ratio for the test set.
  • random_state (int): random state for the splitter object.
KIND: Literal['TrainTestSplitter']
shuffle: bool
test_size: int | float
random_state: int
@T.override
def split( self, inputs: pandera.typing.pandas.DataFrame[bikes.core.schemas.InputsSchema], targets: pandera.typing.pandas.DataFrame[bikes.core.schemas.TargetsSchema], groups: numpy.ndarray[typing.Any, numpy.dtype[numpy.int64]] | None = None) -> Iterator[tuple[numpy.ndarray[Any, numpy.dtype[numpy.int64]], numpy.ndarray[Any, numpy.dtype[numpy.int64]]]]:
82    @T.override
83    def split(
84        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
85    ) -> TrainTestSplits:
86        index = np.arange(len(inputs))  # return integer position
87        train_index, test_index = model_selection.train_test_split(
88            index, shuffle=self.shuffle, test_size=self.test_size, random_state=self.random_state
89        )
90        yield train_index, test_index

Split a dataframe into subsets.

Arguments:
  • inputs (schemas.Inputs): model inputs.
  • targets (schemas.Targets): model targets.
  • groups (Index | None, optional): group labels.
Returns:

TrainTestSplits: iterator over the dataframe train/test splits.

@T.override
def get_n_splits( self, inputs: pandera.typing.pandas.DataFrame[bikes.core.schemas.InputsSchema], targets: pandera.typing.pandas.DataFrame[bikes.core.schemas.TargetsSchema], groups: numpy.ndarray[typing.Any, numpy.dtype[numpy.int64]] | None = None) -> int:
92    @T.override
93    def get_n_splits(
94        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
95    ) -> int:
96        return 1

Get the number of splits generated.

Arguments:
  • inputs (schemas.Inputs): models inputs.
  • targets (schemas.Targets): model targets.
  • groups (Index | None, optional): group labels.
Returns:

int: number of splits generated.

model_config = {'strict': True, 'frozen': True, 'extra': 'forbid'}
model_fields = {'KIND': FieldInfo(annotation=Literal['TrainTestSplitter'], required=False, default='TrainTestSplitter'), 'shuffle': FieldInfo(annotation=bool, required=False, default=False), 'test_size': FieldInfo(annotation=Union[int, float], required=False, default=1440), 'random_state': FieldInfo(annotation=int, required=False, default=42)}
model_computed_fields = {}
Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
dict
json
parse_obj
parse_raw
parse_file
from_orm
construct
copy
schema
schema_json
validate
update_forward_refs
class TimeSeriesSplitter(Splitter):
 99class TimeSeriesSplitter(Splitter):
100    """Split a dataframe into fixed time series subsets.
101
102    Parameters:
103        gap (int): gap between splits.
104        n_splits (int): number of split to generate.
105        test_size (int | float): number or ratio for the test dataset.
106    """
107
108    KIND: T.Literal["TimeSeriesSplitter"] = "TimeSeriesSplitter"
109
110    gap: int = 0
111    n_splits: int = 4
112    test_size: int | float = 24 * 30 * 2  # 2 months
113
114    @T.override
115    def split(
116        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
117    ) -> TrainTestSplits:
118        splitter = model_selection.TimeSeriesSplit(n_splits=self.n_splits, test_size=self.test_size)
119        yield from splitter.split(inputs)
120
121    @T.override
122    def get_n_splits(
123        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
124    ) -> int:
125        return self.n_splits

Split a dataframe into fixed time series subsets.

Arguments:
  • gap (int): gap between splits.
  • n_splits (int): number of split to generate.
  • test_size (int | float): number or ratio for the test dataset.
KIND: Literal['TimeSeriesSplitter']
gap: int
n_splits: int
test_size: int | float
@T.override
def split( self, inputs: pandera.typing.pandas.DataFrame[bikes.core.schemas.InputsSchema], targets: pandera.typing.pandas.DataFrame[bikes.core.schemas.TargetsSchema], groups: numpy.ndarray[typing.Any, numpy.dtype[numpy.int64]] | None = None) -> Iterator[tuple[numpy.ndarray[Any, numpy.dtype[numpy.int64]], numpy.ndarray[Any, numpy.dtype[numpy.int64]]]]:
114    @T.override
115    def split(
116        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
117    ) -> TrainTestSplits:
118        splitter = model_selection.TimeSeriesSplit(n_splits=self.n_splits, test_size=self.test_size)
119        yield from splitter.split(inputs)

Split a dataframe into subsets.

Arguments:
  • inputs (schemas.Inputs): model inputs.
  • targets (schemas.Targets): model targets.
  • groups (Index | None, optional): group labels.
Returns:

TrainTestSplits: iterator over the dataframe train/test splits.

@T.override
def get_n_splits( self, inputs: pandera.typing.pandas.DataFrame[bikes.core.schemas.InputsSchema], targets: pandera.typing.pandas.DataFrame[bikes.core.schemas.TargetsSchema], groups: numpy.ndarray[typing.Any, numpy.dtype[numpy.int64]] | None = None) -> int:
121    @T.override
122    def get_n_splits(
123        self, inputs: schemas.Inputs, targets: schemas.Targets, groups: Index | None = None
124    ) -> int:
125        return self.n_splits

Get the number of splits generated.

Arguments:
  • inputs (schemas.Inputs): models inputs.
  • targets (schemas.Targets): model targets.
  • groups (Index | None, optional): group labels.
Returns:

int: number of splits generated.

model_config = {'strict': True, 'frozen': True, 'extra': 'forbid'}
model_fields = {'KIND': FieldInfo(annotation=Literal['TimeSeriesSplitter'], required=False, default='TimeSeriesSplitter'), 'gap': FieldInfo(annotation=int, required=False, default=0), 'n_splits': FieldInfo(annotation=int, required=False, default=4), 'test_size': FieldInfo(annotation=Union[int, float], required=False, default=1440)}
model_computed_fields = {}
Inherited Members
pydantic.main.BaseModel
BaseModel
model_extra
model_fields_set
model_construct
model_copy
model_dump
model_dump_json
model_json_schema
model_parametrized_name
model_post_init
model_rebuild
model_validate
model_validate_json
model_validate_strings
dict
json
parse_obj
parse_raw
parse_file
from_orm
construct
copy
schema
schema_json
validate
update_forward_refs