Module julius.bands
Decomposition of a signal over frequency bands in the waveform domain.
Functions
def split_bands(signal: torch.Tensor, sample_rate: float, n_bands: Optional[int] = None, cutoffs: Optional[Sequence[float]] = None, pad: bool = True, zeros: float = 8, fft: Optional[bool] = None)
-
Functional version of
SplitBands
, refer to this class for more information.>>> x = torch.randn(6, 4, 1024) >>> list(split_bands(x, sample_rate=64, cutoffs=[12, 24]).shape) [3, 6, 4, 1024]
Classes
class SplitBands (sample_rate: float, n_bands: Optional[int] = None, cutoffs: Optional[Sequence[float]] = None, pad: bool = True, zeros: float = 8, fft: Optional[bool] = None)
-
Decomposes a signal over the given frequency bands in the waveform domain using a cascade of low pass filters as implemented by
LowPassFilters
. You can either specify explicitely the frequency cutoffs, or just the number of bands, in which case the frequency cutoffs will be spread out evenly in mel scale.Args
sample_rate
:float
- Sample rate of the input signal in Hz.
n_bands
:int
orNone
- number of bands, when not giving them explictely with
cutoffs
. In that case, the cutoff frequencies will be evenly spaced in mel-space. cutoffs
:list[float]
orNone
- list of frequency cutoffs in Hz.
pad
:bool
- if True, appropriately pad the input with zero over the edge. If
stride=1
, the output will have the same length as the input. zeros
:float
- Number of zero crossings to keep. See
LowPassFilters
for more informations. fft
:bool
orNone
- See
LowPassFilters
for more info.
Note
The sum of all the bands will always be the input signal.
Warning
Unlike
LowPassFilters
, the cutoffs frequencies must be provided in Hz along with the sample rate.Shape
- Input:
[*, T]
- Output:
[B, *, T']
, withT'=T
ifpad
is True. Ifn_bands
was provided,B = n_bands
otherwiseB = len(cutoffs) + 1
>>> bands = SplitBands(sample_rate=128, n_bands=10) >>> x = torch.randn(6, 4, 1024) >>> list(bands(x).shape) [10, 6, 4, 1024]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code Browse git
class SplitBands(torch.nn.Module): """ Decomposes a signal over the given frequency bands in the waveform domain using a cascade of low pass filters as implemented by `julius.lowpass.LowPassFilters`. You can either specify explicitely the frequency cutoffs, or just the number of bands, in which case the frequency cutoffs will be spread out evenly in mel scale. Args: sample_rate (float): Sample rate of the input signal in Hz. n_bands (int or None): number of bands, when not giving them explictely with `cutoffs`. In that case, the cutoff frequencies will be evenly spaced in mel-space. cutoffs (list[float] or None): list of frequency cutoffs in Hz. pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`, the output will have the same length as the input. zeros (float): Number of zero crossings to keep. See `LowPassFilters` for more informations. fft (bool or None): See `LowPassFilters` for more info. ..note:: The sum of all the bands will always be the input signal. ..warning:: Unlike `julius.lowpass.LowPassFilters`, the cutoffs frequencies must be provided in Hz along with the sample rate. Shape: - Input: `[*, T]` - Output: `[B, *, T']`, with `T'=T` if `pad` is True. If `n_bands` was provided, `B = n_bands` otherwise `B = len(cutoffs) + 1` >>> bands = SplitBands(sample_rate=128, n_bands=10) >>> x = torch.randn(6, 4, 1024) >>> list(bands(x).shape) [10, 6, 4, 1024] """ def __init__(self, sample_rate: float, n_bands: Optional[int] = None, cutoffs: Optional[Sequence[float]] = None, pad: bool = True, zeros: float = 8, fft: Optional[bool] = None): super().__init__() if (cutoffs is None) + (n_bands is None) != 1: raise ValueError("You must provide either n_bands, or cutoffs, but not boths.") self.sample_rate = sample_rate self.n_bands = n_bands self._cutoffs = list(cutoffs) if cutoffs is not None else None self.pad = pad self.zeros = zeros self.fft = fft if cutoffs is None: if n_bands is None: raise ValueError("You must provide one of n_bands or cutoffs.") if not n_bands >= 1: raise ValueError(f"n_bands must be greater than one (got {n_bands})") cutoffs = mel_frequencies(n_bands + 1, 0, sample_rate / 2)[1:-1] else: if max(cutoffs) > 0.5 * sample_rate: raise ValueError("A cutoff above sample_rate/2 does not make sense.") if len(cutoffs) > 0: self.lowpass = LowPassFilters( [c / sample_rate for c in cutoffs], pad=pad, zeros=zeros, fft=fft) else: # Here I cannot make both TorchScript and MyPy happy. # I miss the good old times, before all this madness was created. self.lowpass = None # type: ignore def forward(self, input): if self.lowpass is None: return input[None] lows = self.lowpass(input) low = lows[0] bands = [low] for low_and_band in lows[1:]: # Get a bandpass filter by substracting lowpasses band = low_and_band - low bands.append(band) low = low_and_band # Last band is whatever is left in the signal bands.append(input - low) return torch.stack(bands) @property def cutoffs(self): if self._cutoffs is not None: return self._cutoffs elif self.lowpass is not None: return [c * self.sample_rate for c in self.lowpass.cutoffs] else: return [] def __repr__(self): return simple_repr(self, overrides={"cutoffs": self._cutoffs})
Ancestors
- torch.nn.modules.module.Module
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Instance variables
prop cutoffs
-
Expand source code
@property def cutoffs(self): if self._cutoffs is not None: return self._cutoffs elif self.lowpass is not None: return [c * self.sample_rate for c in self.lowpass.cutoffs] else: return []
Methods
def forward(self, input) ‑> Callable[..., Any]
-
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.