# -*- coding: utf-8 -*-
# @Time : 7/21/25
# @Author : Yaojie Shen
# @Project : Deep-Learning-Utils
# @File : sample.py
from collections.abc import Sequence
from typing import Union, List, Any
import numpy as np
try:
import torch
except ImportError:
torch = None
[docs]
def sample_evenly(
input_data: Union[List[Any], np.ndarray, 'torch.Tensor', Sequence[Any]],
n: int
) -> Union[np.ndarray, List[Any], 'torch.Tensor']:
"""
Evenly sample N elements from input_data. Supports list, numpy array, or torch tensor.
The input_data can be empty, and n can be less than or equal to 0, in which case it will return empty data.
Args:
input_data: List, numpy array, or torch tensor to sample from.
n: Number of elements to sample.
Returns:
Sampled data in the same type as input_data.
"""
if isinstance(input_data, np.ndarray):
length = input_data.shape[0]
if length == 0 or n <= 0:
return input_data[:0]
indices = np.linspace(0, length - 1, num=n).round().astype(int)
return input_data[indices]
elif torch and isinstance(input_data, torch.Tensor):
length = input_data.size(0)
if length == 0 or n <= 0:
return input_data[:0]
indices = torch.linspace(0, length - 1, steps=n).round().long()
return input_data[indices]
else:
# Treat as a list or other iterable data structure with length.
length = len(input_data)
if length == 0 or n <= 0:
return []
indices = np.linspace(0, length - 1, num=n).round().astype(int)
return [input_data[i] for i in indices]
[docs]
def sample_randomly(
input_data: Union[List[Any], np.ndarray, 'torch.Tensor', Sequence[Any]],
n: int, ordered: bool = False, seed: int = None, put_back: bool = False
) -> Union[np.ndarray, List[Any], 'torch.Tensor']:
"""
Randomly sample N elements from input_data. Supports list, numpy array, or torch tensor.
Args:
input_data: List, numpy array, or torch tensor to sample from.
n: Number of elements to sample.
ordered: Whether to return sampled elements in the original order.
seed: Random seed for reproducibility.
put_back: If True, sample with replacement.
Returns:
Sampled data in the same type as input_data.
"""
assert n > 0, \
f"n must be positive, got {n}"
assert len(input_data) > 0, \
f"input_data must have at least one element, got {len(input_data)}"
if not put_back:
assert n <= len(input_data), \
f"n must be less than or equal to the length of input_data (without replacement), got {n} vs. {len(input_data)}"
rng = np.random.default_rng(seed)
if isinstance(input_data, np.ndarray):
if put_back:
indices = rng.integers(0, len(input_data), size=n)
else:
indices = rng.choice(len(input_data), size=n, replace=False)
if ordered:
indices = np.sort(indices)
return input_data[indices]
elif torch and isinstance(input_data, torch.Tensor):
generator = torch.Generator()
if seed is not None:
generator.manual_seed(seed)
if put_back:
indices = torch.randint(0, input_data.size(0), (n,), generator=generator)
else:
indices = torch.randperm(input_data.size(0), generator=generator)[:n]
if ordered:
indices, _ = torch.sort(indices)
return input_data[indices]
else:
length = len(input_data)
if put_back:
indices = rng.integers(0, length, size=n)
else:
indices = rng.choice(length, size=n, replace=False)
if ordered:
indices = np.sort(indices)
return [input_data[i] for i in indices]
__all__ = [
"sample_evenly",
"sample_randomly"
]