Source code for schemist.splitting

"""Tools for splitting tabular datasets, optionally based on chemical features."""

from typing import Dict, Iterable, List, Optional, Tuple, Union
from collections import defaultdict
from math import ceil
from random import random, seed

try:
    from itertools import batched
except ImportError:
    from carabiner.itertools import batched

from tqdm.auto import tqdm

from .converting import convert_string_representation, _convert_input_to_smiles
from .typing import DataSplits

# def _train_test_splits

def _train_test_val_sizes(total: int, 
                          train: float = 1., 
                          test: float = 0.) -> Tuple[int]:
    
    n_train = int(ceil(train * total))
    n_test = int(ceil(test * total))
    n_val = total - n_train - n_test

    return n_train, n_test, n_val
    

def _random_chunk(strings: str,
                  train: float = 1., 
                  test: float = 0.,
                  carry: Optional[Dict[str, List[int]]] = None,
                  start_from: int = 0) -> Dict[str, List[int]]:
    
    carry = carry or defaultdict(list)

    train_test: float = train + test

    for i, _ in enumerate(strings):

        random_number: float = random()

        if random_number < train:

            key = 'train'

        elif random_number < train_test:

            key = 'test'

        else:

            key = 'validation'

        carry[key].append(start_from + i)

    return carry


[docs] def split_random(strings: Union[str, Iterable[str]], train: float = 1., test: float = 0., chunksize: Optional[int] = None, set_seed: Optional[int] = None, *args, **kwargs) -> DataSplits: """ """ if set_seed is not None: seed(set_seed) if chunksize is None: idx = _random_chunk(strings=strings, train=train, test=test) else: idx = defaultdict(list) for i, chunk in enumerate(batched(strings, chunksize)): idx = _random_chunk(strings=chunk, train=train, test=test, carry=idx, start_from=i * chunksize) seed(None) return DataSplits(**idx)
@_convert_input_to_smiles def _scaffold_chunk(strings: str, carry: Optional[Dict[str, List[int]]] = None, start_from: int = 0) -> Dict[str, List[int]]: carry = carry or defaultdict(list) these_scaffolds = convert_string_representation(strings=strings, output_representation='scaffold') for j, scaff in enumerate(these_scaffolds): carry[scaff].append(start_from + j) return carry def _scaffold_aggregator(scaffold_sets: Dict[str, List[int]], train: float = 1., test: float = 0., progress: bool = False) -> DataSplits: scaffold_sets = {key: sorted(value) for key, value in scaffold_sets.items()} scaffold_sets = sorted(scaffold_sets.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True) nrows = sum(len(idx) for _, idx in scaffold_sets) n_train, n_test, n_val = _train_test_val_sizes(nrows, train, test) idx = defaultdict(list) iterator = tqdm(scaffold_sets) if progress else scaffold_sets for _, scaffold_idx in iterator: if (len(idx['train']) + len(scaffold_idx)) > n_train: if (len(idx['test']) + len(scaffold_idx)) > n_test: key = 'validation' else: key = 'test' else: key = 'train' idx[key] += scaffold_idx return DataSplits(**idx)
[docs] def split_scaffold(strings: Union[str, Iterable[str]], train: float = 1., test: float = 0., chunksize: Optional[int] = None, progress: bool = True, *args, **kwargs) -> DataSplits: """ """ if chunksize is None: scaffold_sets = _scaffold_chunk(strings) else: scaffold_sets = defaultdict(list) for i, chunk in enumerate(batched(strings, chunksize)): scaffold_sets = _scaffold_chunk(chunk, carry=scaffold_sets, start_from=i * chunksize) return _scaffold_aggregator(scaffold_sets, train=train, test=test, progress=progress)
_SPLITTERS = {#'simpd': split_simpd, 'scaffold': split_scaffold, 'random': split_random} # _SPLIT_SUPERTYPES = {'scaffold': 'grouped', # 'random': 'independent'} _GROUPED_SPLITTERS = {'scaffold': (_scaffold_chunk, _scaffold_aggregator)} assert all(_type in _SPLITTERS for _type in _GROUPED_SPLITTERS) ## Should never fail!
[docs] def split(split_type: str, *args, **kwargs) -> DataSplits: """ """ splitter = _SPLITTERS[split_type] return splitter(*args, **kwargs)