Module julius.lowpass

FIR windowed sinc lowpass filters.

Expand source code Browse git
# File under the MIT license, see https://github.com/adefossez/julius/LICENSE for details.
# Author: adefossez, 2020
"""
FIR windowed sinc lowpass filters.
"""

import math
from typing import Sequence, Optional

import torch
from torch.nn import functional as F

from .core import sinc
from .fftconv import fft_conv1d
from .utils import simple_repr


class LowPassFilters(torch.nn.Module):
    """
    Bank of low pass filters. Note that a high pass or band pass filter can easily
    be implemented by substracting a same signal processed with low pass filters with different
    frequencies (see `julius.bands.SplitBands` for instance).
    This uses a windowed sinc filter, very similar to the one used in
    `julius.resample`. However, because we do not change the sample rate here,
    this filter can be much more efficiently implemented using the FFT convolution from
    `julius.fftconv`.

    Args:
        cutoffs (list[float]): list of cutoff frequencies, in [0, 0.5] expressed as `f/f_s` where
            f_s is the samplerate and `f` is the cutoff frequency.
            The upper limit is 0.5, because a signal sampled at `f_s` contains only
            frequencies under `f_s / 2`.
        stride (int): how much to decimate the output. Keep in mind that decimation
            of the output is only acceptable if the cutoff frequency is under `1/ (2 * stride)`
            of the original sampling rate.
        pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`,
            the output will have the same length as the input.
        zeros (float): Number of zero crossings to keep.
            Controls the receptive field of the Finite Impulse Response filter.
            For lowpass filters with low cutoff frequency, e.g. 40Hz at 44.1kHz,
            it is a bad idea to set this to a high value.
            This is likely appropriate for most use. Lower values
            will result in a faster filter, but with a slower attenuation around the
            cutoff frequency.
        fft (bool or None): if True, uses `julius.fftconv` rather than PyTorch convolutions.
            If False, uses PyTorch convolutions. If None, either one will be chosen automatically
            depending on the effective filter size.


    ..warning::
        All the filters will use the same filter size, aligned on the lowest
        frequency provided. If you combine a lot of filters with very diverse frequencies, it might
        be more efficient to split them over multiple modules with similar frequencies.

    ..note::
        A lowpass with a cutoff frequency of 0 is defined as the null function
        by convention here. This allows for a highpass with a cutoff of 0 to
        be equal to identity, as defined in `julius.filters.HighPassFilters`.

    Shape:

        - Input: `[*, T]`
        - Output: `[F, *, T']`, with `T'=T` if `pad` is True and `stride` is 1, and
            `F` is the numer of cutoff frequencies.

    >>> lowpass = LowPassFilters([1/4])
    >>> x = torch.randn(4, 12, 21, 1024)
    >>> list(lowpass(x).shape)
    [1, 4, 12, 21, 1024]
    """

    def __init__(self, cutoffs: Sequence[float], stride: int = 1, pad: bool = True,
                 zeros: float = 8, fft: Optional[bool] = None):
        super().__init__()
        self.cutoffs = list(cutoffs)
        if min(self.cutoffs) < 0:
            raise ValueError("Minimum cutoff must be larger than zero.")
        if max(self.cutoffs) > 0.5:
            raise ValueError("A cutoff above 0.5 does not make sense.")
        self.stride = stride
        self.pad = pad
        self.zeros = zeros
        self.half_size = int(zeros / min([c for c in self.cutoffs if c > 0]) / 2)
        if fft is None:
            fft = self.half_size > 32
        self.fft = fft
        window = torch.hann_window(2 * self.half_size + 1, periodic=False)
        time = torch.arange(-self.half_size, self.half_size + 1)
        filters = []
        for cutoff in cutoffs:
            if cutoff == 0:
                filter_ = torch.zeros_like(time)
            else:
                filter_ = 2 * cutoff * window * sinc(2 * cutoff * math.pi * time)
                # Normalize filter to have sum = 1, otherwise we will have a small leakage
                # of the constant component in the input signal.
                filter_ /= filter_.sum()
            filters.append(filter_)
        self.register_buffer("filters", torch.stack(filters)[:, None])

    def forward(self, input):
        shape = list(input.shape)
        input = input.view(-1, 1, shape[-1])
        if self.pad:
            input = F.pad(input, (self.half_size, self.half_size), mode='replicate')
        if self.fft:
            out = fft_conv1d(input, self.filters, stride=self.stride)
        else:
            out = F.conv1d(input, self.filters, stride=self.stride)
        shape.insert(0, len(self.cutoffs))
        shape[-1] = out.shape[-1]
        return out.permute(1, 0, 2).reshape(shape)

    def __repr__(self):
        return simple_repr(self)


class LowPassFilter(torch.nn.Module):
    """
    Same as `LowPassFilters` but applies a single low pass filter.

    Shape:

        - Input: `[*, T]`
        - Output: `[*, T']`, with `T'=T` if `pad` is True and `stride` is 1.

    >>> lowpass = LowPassFilter(1/4, stride=2)
    >>> x = torch.randn(4, 124)
    >>> list(lowpass(x).shape)
    [4, 62]
    """

    def __init__(self, cutoff: float, stride: int = 1, pad: bool = True,
                 zeros: float = 8, fft: Optional[bool] = None):
        super().__init__()
        self._lowpasses = LowPassFilters([cutoff], stride, pad, zeros, fft)

    @property
    def cutoff(self):
        return self._lowpasses.cutoffs[0]

    @property
    def stride(self):
        return self._lowpasses.stride

    @property
    def pad(self):
        return self._lowpasses.pad

    @property
    def zeros(self):
        return self._lowpasses.zeros

    @property
    def fft(self):
        return self._lowpasses.fft

    def forward(self, input):
        return self._lowpasses(input)[0]

    def __repr__(self):
        return simple_repr(self)


def lowpass_filters(input: torch.Tensor,  cutoffs: Sequence[float],
                    stride: int = 1, pad: bool = True,
                    zeros: float = 8, fft: Optional[bool] = None):
    """
    Functional version of `LowPassFilters`, refer to this class for more information.
    """
    return LowPassFilters(cutoffs, stride, pad, zeros, fft).to(input)(input)


def lowpass_filter(input: torch.Tensor,  cutoff: float,
                   stride: int = 1, pad: bool = True,
                   zeros: float = 8, fft: Optional[bool] = None):
    """
    Same as `lowpass_filters` but with a single cutoff frequency.
    Output will not have a dimension inserted in the front.
    """
    return lowpass_filters(input, [cutoff], stride, pad, zeros, fft)[0]

Functions

def lowpass_filter(input: torch.Tensor, cutoff: float, stride: int = 1, pad: bool = True, zeros: float = 8, fft: Optional[bool] = None)

Same as lowpass_filters() but with a single cutoff frequency. Output will not have a dimension inserted in the front.

Expand source code Browse git
def lowpass_filter(input: torch.Tensor,  cutoff: float,
                   stride: int = 1, pad: bool = True,
                   zeros: float = 8, fft: Optional[bool] = None):
    """
    Same as `lowpass_filters` but with a single cutoff frequency.
    Output will not have a dimension inserted in the front.
    """
    return lowpass_filters(input, [cutoff], stride, pad, zeros, fft)[0]
def lowpass_filters(input: torch.Tensor, cutoffs: Sequence[float], stride: int = 1, pad: bool = True, zeros: float = 8, fft: Optional[bool] = None)

Functional version of LowPassFilters, refer to this class for more information.

Expand source code Browse git
def lowpass_filters(input: torch.Tensor,  cutoffs: Sequence[float],
                    stride: int = 1, pad: bool = True,
                    zeros: float = 8, fft: Optional[bool] = None):
    """
    Functional version of `LowPassFilters`, refer to this class for more information.
    """
    return LowPassFilters(cutoffs, stride, pad, zeros, fft).to(input)(input)

Classes

class LowPassFilter (cutoff: float, stride: int = 1, pad: bool = True, zeros: float = 8, fft: Optional[bool] = None)

Same as LowPassFilters but applies a single low pass filter.

Shape

  • Input: [*, T]
  • Output: [*, T'], with T'=T if pad is True and stride is 1.
>>> lowpass = LowPassFilter(1/4, stride=2)
>>> x = torch.randn(4, 124)
>>> list(lowpass(x).shape)
[4, 62]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code Browse git
class LowPassFilter(torch.nn.Module):
    """
    Same as `LowPassFilters` but applies a single low pass filter.

    Shape:

        - Input: `[*, T]`
        - Output: `[*, T']`, with `T'=T` if `pad` is True and `stride` is 1.

    >>> lowpass = LowPassFilter(1/4, stride=2)
    >>> x = torch.randn(4, 124)
    >>> list(lowpass(x).shape)
    [4, 62]
    """

    def __init__(self, cutoff: float, stride: int = 1, pad: bool = True,
                 zeros: float = 8, fft: Optional[bool] = None):
        super().__init__()
        self._lowpasses = LowPassFilters([cutoff], stride, pad, zeros, fft)

    @property
    def cutoff(self):
        return self._lowpasses.cutoffs[0]

    @property
    def stride(self):
        return self._lowpasses.stride

    @property
    def pad(self):
        return self._lowpasses.pad

    @property
    def zeros(self):
        return self._lowpasses.zeros

    @property
    def fft(self):
        return self._lowpasses.fft

    def forward(self, input):
        return self._lowpasses(input)[0]

    def __repr__(self):
        return simple_repr(self)

Ancestors

  • torch.nn.modules.module.Module

Class variables

var dump_patches : bool
var training : bool

Instance variables

var cutoff
Expand source code Browse git
@property
def cutoff(self):
    return self._lowpasses.cutoffs[0]
var fft
Expand source code Browse git
@property
def fft(self):
    return self._lowpasses.fft
var pad
Expand source code Browse git
@property
def pad(self):
    return self._lowpasses.pad
var stride
Expand source code Browse git
@property
def stride(self):
    return self._lowpasses.stride
var zeros
Expand source code Browse git
@property
def zeros(self):
    return self._lowpasses.zeros

Methods

def forward(self, input) ‑> Callable[..., Any]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code Browse git
def forward(self, input):
    return self._lowpasses(input)[0]
class LowPassFilters (cutoffs: Sequence[float], stride: int = 1, pad: bool = True, zeros: float = 8, fft: Optional[bool] = None)

Bank of low pass filters. Note that a high pass or band pass filter can easily be implemented by substracting a same signal processed with low pass filters with different frequencies (see SplitBands for instance). This uses a windowed sinc filter, very similar to the one used in julius.resample. However, because we do not change the sample rate here, this filter can be much more efficiently implemented using the FFT convolution from julius.fftconv.

Args

cutoffs : list[float]
list of cutoff frequencies, in [0, 0.5] expressed as f/f_s where f_s is the samplerate and f is the cutoff frequency. The upper limit is 0.5, because a signal sampled at f_s contains only frequencies under f_s / 2.
stride : int
how much to decimate the output. Keep in mind that decimation of the output is only acceptable if the cutoff frequency is under 1/ (2 * stride) of the original sampling rate.
pad : bool
if True, appropriately pad the input with zero over the edge. If stride=1, the output will have the same length as the input.
zeros : float
Number of zero crossings to keep. Controls the receptive field of the Finite Impulse Response filter. For lowpass filters with low cutoff frequency, e.g. 40Hz at 44.1kHz, it is a bad idea to set this to a high value. This is likely appropriate for most use. Lower values will result in a faster filter, but with a slower attenuation around the cutoff frequency.
fft : bool or None
if True, uses julius.fftconv rather than PyTorch convolutions. If False, uses PyTorch convolutions. If None, either one will be chosen automatically depending on the effective filter size.

Warning

All the filters will use the same filter size, aligned on the lowest frequency provided. If you combine a lot of filters with very diverse frequencies, it might be more efficient to split them over multiple modules with similar frequencies.

Note

A lowpass with a cutoff frequency of 0 is defined as the null function by convention here. This allows for a highpass with a cutoff of 0 to be equal to identity, as defined in HighPassFilters.

Shape

  • Input: [*, T]
  • Output: [F, *, T'], with T'=T if pad is True and stride is 1, and F is the numer of cutoff frequencies.
>>> lowpass = LowPassFilters([1/4])
>>> x = torch.randn(4, 12, 21, 1024)
>>> list(lowpass(x).shape)
[1, 4, 12, 21, 1024]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Expand source code Browse git
class LowPassFilters(torch.nn.Module):
    """
    Bank of low pass filters. Note that a high pass or band pass filter can easily
    be implemented by substracting a same signal processed with low pass filters with different
    frequencies (see `julius.bands.SplitBands` for instance).
    This uses a windowed sinc filter, very similar to the one used in
    `julius.resample`. However, because we do not change the sample rate here,
    this filter can be much more efficiently implemented using the FFT convolution from
    `julius.fftconv`.

    Args:
        cutoffs (list[float]): list of cutoff frequencies, in [0, 0.5] expressed as `f/f_s` where
            f_s is the samplerate and `f` is the cutoff frequency.
            The upper limit is 0.5, because a signal sampled at `f_s` contains only
            frequencies under `f_s / 2`.
        stride (int): how much to decimate the output. Keep in mind that decimation
            of the output is only acceptable if the cutoff frequency is under `1/ (2 * stride)`
            of the original sampling rate.
        pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`,
            the output will have the same length as the input.
        zeros (float): Number of zero crossings to keep.
            Controls the receptive field of the Finite Impulse Response filter.
            For lowpass filters with low cutoff frequency, e.g. 40Hz at 44.1kHz,
            it is a bad idea to set this to a high value.
            This is likely appropriate for most use. Lower values
            will result in a faster filter, but with a slower attenuation around the
            cutoff frequency.
        fft (bool or None): if True, uses `julius.fftconv` rather than PyTorch convolutions.
            If False, uses PyTorch convolutions. If None, either one will be chosen automatically
            depending on the effective filter size.


    ..warning::
        All the filters will use the same filter size, aligned on the lowest
        frequency provided. If you combine a lot of filters with very diverse frequencies, it might
        be more efficient to split them over multiple modules with similar frequencies.

    ..note::
        A lowpass with a cutoff frequency of 0 is defined as the null function
        by convention here. This allows for a highpass with a cutoff of 0 to
        be equal to identity, as defined in `julius.filters.HighPassFilters`.

    Shape:

        - Input: `[*, T]`
        - Output: `[F, *, T']`, with `T'=T` if `pad` is True and `stride` is 1, and
            `F` is the numer of cutoff frequencies.

    >>> lowpass = LowPassFilters([1/4])
    >>> x = torch.randn(4, 12, 21, 1024)
    >>> list(lowpass(x).shape)
    [1, 4, 12, 21, 1024]
    """

    def __init__(self, cutoffs: Sequence[float], stride: int = 1, pad: bool = True,
                 zeros: float = 8, fft: Optional[bool] = None):
        super().__init__()
        self.cutoffs = list(cutoffs)
        if min(self.cutoffs) < 0:
            raise ValueError("Minimum cutoff must be larger than zero.")
        if max(self.cutoffs) > 0.5:
            raise ValueError("A cutoff above 0.5 does not make sense.")
        self.stride = stride
        self.pad = pad
        self.zeros = zeros
        self.half_size = int(zeros / min([c for c in self.cutoffs if c > 0]) / 2)
        if fft is None:
            fft = self.half_size > 32
        self.fft = fft
        window = torch.hann_window(2 * self.half_size + 1, periodic=False)
        time = torch.arange(-self.half_size, self.half_size + 1)
        filters = []
        for cutoff in cutoffs:
            if cutoff == 0:
                filter_ = torch.zeros_like(time)
            else:
                filter_ = 2 * cutoff * window * sinc(2 * cutoff * math.pi * time)
                # Normalize filter to have sum = 1, otherwise we will have a small leakage
                # of the constant component in the input signal.
                filter_ /= filter_.sum()
            filters.append(filter_)
        self.register_buffer("filters", torch.stack(filters)[:, None])

    def forward(self, input):
        shape = list(input.shape)
        input = input.view(-1, 1, shape[-1])
        if self.pad:
            input = F.pad(input, (self.half_size, self.half_size), mode='replicate')
        if self.fft:
            out = fft_conv1d(input, self.filters, stride=self.stride)
        else:
            out = F.conv1d(input, self.filters, stride=self.stride)
        shape.insert(0, len(self.cutoffs))
        shape[-1] = out.shape[-1]
        return out.permute(1, 0, 2).reshape(shape)

    def __repr__(self):
        return simple_repr(self)

Ancestors

  • torch.nn.modules.module.Module

Class variables

var dump_patches : bool
var training : bool

Methods

def forward(self, input) ‑> Callable[..., Any]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code Browse git
def forward(self, input):
    shape = list(input.shape)
    input = input.view(-1, 1, shape[-1])
    if self.pad:
        input = F.pad(input, (self.half_size, self.half_size), mode='replicate')
    if self.fft:
        out = fft_conv1d(input, self.filters, stride=self.stride)
    else:
        out = F.conv1d(input, self.filters, stride=self.stride)
    shape.insert(0, len(self.cutoffs))
    shape[-1] = out.shape[-1]
    return out.permute(1, 0, 2).reshape(shape)