Source code for tinycio.util.miscutil

from __future__ import annotations
import typing
from typing import Union

import torch
import numpy as np
from tqdm import tqdm

from ..globals import TINYCIO_VERSION

[docs] def version() -> str: """ Get current tinycio version. :return: version string ("major.minor.patch") """ return TINYCIO_VERSION
[docs] def version_check_minor(ver_str: str) -> bool: """ Verify tinycio version. Check if the major and minor version of `ver_str` matches current. :param ver_str: Version string to compare :return: True if major.minor match, else False """ ver = TINYCIO_VERSION.split('.') chk = ver_str.split('.') return ver[0] == chk[0] and ver[1] == chk[1]
[docs] def remap(x: Union[float, torch.Tensor], from_start: float, from_end: float, to_start: float, to_end: float) -> Union[float, torch.Tensor]: """ Linearly remap scalar or tensor. :param x: Input value or tensor :param from_start: Start of input range :param from_end: End of input range :param to_start: Start of target range :param to_end: End of target range :return: Remapped and clamped value """ res = (x - from_start) / (from_end - from_start) * (to_end - to_start) + to_start return torch.clamp(res, to_start, to_end) if torch.is_tensor(res) else np.clip(res, to_start, to_end)
[docs] def remap_to_01(x: Union[float, torch.Tensor], start: float, end: float) -> Union[float, torch.Tensor]: """ Remap value to [0, 1] range. :param x: Input value or tensor :param start: Start of original range :param end: End of original range :return: Normalized value clamped to [0, 1] """ res = (x - start) / (end - start) return torch.clamp(res, 0., 1.) if torch.is_tensor(res) else np.clip(res, 0., 1.)
[docs] def remap_from_01(x: Union[float, torch.Tensor], start: float, end: float) -> Union[float, torch.Tensor]: """ Remap [0, 1] value back to specified range. :param x: Normalized value or tensor :param start: Target range start :param end: Target range end :return: Rescaled value clamped to [start, end] """ res = x * (end - start) + start return torch.clamp(res, start, end) if torch.is_tensor(res) else np.clip(res, start, end)
[docs] def smoothstep(edge0: float, edge1: float, x: torch.Tensor) -> torch.Tensor: """ Smooth Hermite interpolation between 0 and 1. For x in [edge0, edge1]. :param edge0: Lower bound of transition :param edge1: Upper bound of transition :param x: Input tensor :return: Smoothly interpolated tensor """ t = torch.clamp((x - edge0) / (edge1 - edge0), 0.0, 1.0) return t * t * (3 - 2 * t)
[docs] def softsign(x: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: """ Smooth nonlinearity. x / (1 + \|x\|), useful for range compression. :param x: Input scalar or tensor :return: Softsign result """ return x / (1 + x.abs()) if torch.is_tensor(x) else x / (1 + np.abs(x))
[docs] def fract(x: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: """ Get the fractional part of input (x - floor(x)). :param x: Input scalar or tensor :return: Fractional part """ return x - torch.floor(x) if torch.is_tensor(x) else x - np.floor(x)
[docs] def serialize_tensor(val: torch.Tensor) -> Union[float, List[float]]: """ Convert a tensor into a float or list of floats. :param val: Tensor to serialize :return: Scalar if 1-element tensor, else flattened list """ if val.numel() == 1: return val.item() return val.flatten().tolist()
[docs] def trilinear_interpolation(im_3d:torch.Tensor, indices:Union[ColorImage, torch.Tensor]) -> torch.Tensor: """ Interpolate 3D image tensor. :param im_3d: Input 3D image tensor of shape (C, D, H, W). :param indices: Indices into the tensor. :return: Interpolated color values. """ # NOTE: Internal - leaving this clutter undocumented intentionally indices_floor = indices.floor().to(torch.long) indices_ceil = indices.ceil().clamp(0, im_3d.size(0) - 1).to(torch.long) weights = (indices - indices_floor).float() c000 = im_3d[indices_floor[0], indices_floor[1], indices_floor[2]] c001 = im_3d[indices_floor[0], indices_floor[1], indices_ceil[2]] c010 = im_3d[indices_floor[0], indices_ceil[1], indices_floor[2]] c011 = im_3d[indices_floor[0], indices_ceil[1], indices_ceil[2]] c100 = im_3d[indices_ceil[0], indices_floor[1], indices_floor[2]] c101 = im_3d[indices_ceil[0], indices_floor[1], indices_ceil[2]] c110 = im_3d[indices_ceil[0], indices_ceil[1], indices_floor[2]] c111 = im_3d[indices_ceil[0], indices_ceil[1], indices_ceil[2]] interpolated_values = torch.zeros_like(c000).requires_grad_() interpolated_values = ( (1 - weights[0]) * (1 - weights[1]) * (1 - weights[2]) * c000.permute(2,0,1) + (1 - weights[0]) * (1 - weights[1]) * weights[2] * c001.permute(2,0,1) + (1 - weights[0]) * weights[1] * (1 - weights[2]) * c010.permute(2,0,1) + (1 - weights[0]) * weights[1] * weights[2] * c011.permute(2,0,1) + weights[0] * (1 - weights[1]) * (1 - weights[2]) * c100.permute(2,0,1) + weights[0] * (1 - weights[1]) * weights[2] * c101.permute(2,0,1) + weights[0] * weights[1] * (1 - weights[2]) * c110.permute(2,0,1) + weights[0] * weights[1] * weights[2] * c111.permute(2,0,1) ) return interpolated_values
class _ProgressBar(tqdm): """Provides `update_fit_status(n)` which uses `tqdm.update(delta_n)`.""" def update_fit_status(self, batches_done=1, steps_per_batch=1, steps_total=None, loss=''): if steps_total is not None: self.total = steps_total self.set_description('Loss: ' + '{:0.5f}'.format(loss)) return self.update(batches_done * steps_per_batch - self.n)
[docs] def progress_bar(): """Context to display a progressbar with tqdm.""" return _ProgressBar(unit=' steps', unit_scale=True, unit_divisor=1024, miniters=1, desc="Fitting")