Module julius.fftconv
Implementation of a FFT based 1D convolution in PyTorch. While FFT is used in CUDNN for small kernel sizes, it is not the case for long ones, e.g. 512. This module implements efficient FFT based convolutions for such convolutions. A typical application is for evaluationg FIR filters with a long receptive field, typically evaluated with a stride of 1.
Functions
def fft_conv1d(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: int = 1, padding: int = 0, block_ratio: float = 5)
-
Same as
torch.nn.functional.conv1d
but using FFT for the convolution. Please check PyTorch documentation for more information.Args
input
:Tensor
- input signal of shape
[B, C, T]
. weight
:Tensor
- weight of the convolution
[D, C, K]
withD
the number of output channels. bias
:Tensor
orNone
- if not None, bias term for the convolution.
stride
:int
- stride of convolution.
padding
:int
- padding to apply to the input.
block_ratio
:float
- can be tuned for speed. The input is splitted in chunks
with a size of
int(block_ratio * kernel_size)
.
Shape
- Inputs:
input
is[B, C, T]
,weight
is[D, C, K]
and bias is[D]
. - Output:
(*, T)
Note
This function is faster than
torch.nn.functional.conv1d
only in specific cases. Typically, the kernel size should be of the order of 256 to see any real gain, for a stride of 1.Warning
Dilation and groups are not supported at the moment. This function might use more memory than the default Conv1d implementation.
Classes
class FFTConv1d (in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = True)
-
Same as
torch.nn.Conv1d
but based onfft_conv1d()
. Please check PyTorch documentation for more information.Args
in_channels
:int
- number of input channels.
out_channels
:int
- number of output channels.
kernel_size
:int
- kernel size of convolution.
stride
:int
- stride of convolution.
padding
:int
- padding to apply to the input.
bias
:bool
- if True, use a bias term.
Note
This module is faster than
torch.nn.Conv1d
only in specific cases. Typically,kernel_size
should be of the order of 256 to see any real gain, for a stride of 1.Warning
Dilation and groups are not supported at the moment. This module might use more memory than the default Conv1d implementation.
>>> fftconv = FFTConv1d(12, 24, 128, 4) >>> x = torch.randn(4, 12, 1024) >>> print(list(fftconv(x).shape)) [4, 24, 225]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code Browse git
class FFTConv1d(torch.nn.Module): """ Same as `torch.nn.Conv1d` but based on `fft_conv1d`. Please check PyTorch documentation for more information. Args: in_channels (int): number of input channels. out_channels (int): number of output channels. kernel_size (int): kernel size of convolution. stride (int): stride of convolution. padding (int): padding to apply to the input. bias (bool): if True, use a bias term. ..note:: This module is faster than `torch.nn.Conv1d` only in specific cases. Typically, `kernel_size` should be of the order of 256 to see any real gain, for a stride of 1. ..warning:: Dilation and groups are not supported at the moment. This module might use more memory than the default Conv1d implementation. >>> fftconv = FFTConv1d(12, 24, 128, 4) >>> x = torch.randn(4, 12, 1024) >>> print(list(fftconv(x).shape)) [4, 24, 225] """ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = True): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, bias=bias) self.weight = conv.weight self.bias = conv.bias def forward(self, input: torch.Tensor): return fft_conv1d( input, self.weight, self.bias, self.stride, self.padding) def __repr__(self): return simple_repr(self, overrides={"bias": self.bias is not None})
Ancestors
- torch.nn.modules.module.Module
Class variables
var call_super_init : bool
var dump_patches : bool
var training : bool
Methods
def forward(self, input: torch.Tensor) ‑> Callable[..., Any]
-
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.