离散傅里叶变换,及其应用

本文不保证读者能都读懂。写本文的主要目的是测试数学公式功能,文字讲解基本没有,请谨慎阅读。

傅里叶变换

连续傅里叶变换是一个 RCRC 的变换。

F[f]=f^(ξ)=f(x) e2πixξdx(ξC)......(1)

它的作用是对于一个时域的复数函数,求出其频谱。

当自变量x表示时间(以秒为单位),变换变量ξ表示频率(以赫兹为单位)。在适当条件下,f^可由逆变换(inverse Fourier transform)由下式确定f:

F1[f^]=f(x)=f^(ξ) e2πiξxdξ(xR)......(2)

它的作用是对于一个频域的复数函数,求出其时域表示。

离散傅里叶变换(DFT)

考虑将上述 ff^ 函数,将它们离散取值。考虑数列 gg^ ,下标x[0,n],则将上(1),(2)式适当变形得

F[g]=g^(n)=1Nk=0N1g(k)e2πiknN......(1,Decrete,Normalized)

F1[g^]=g(k)=1Nn=0N1g^(n)e2πiknN......(2,Decrete,Normalized)

以上两个变换有一种美妙的性质:都符合卷积定理,即可以将某一个域内的卷积对偶于其对应域内的乘法。

为了简化起见,可以将(1,D,N) (2,D,N)两式改变一下常数项,得到非正则化的以下两式,不改变主要性质。

F[g]=g^(n)=k=0N1g(k)e2πiknN......(1,Decrete)

F1[g^]=g(k)=1Nn=0N1g^(n)e2πiknN......(2,Decrete)

为了方便起见,设符合方程 xN=1 的复数 x 称为 N 次单位复根。共有N个,符合下式:

x=WNn=e2πnN

单位复根有以下性质:

WaNan=WNn WNn=(WNn) WNn=WNn+rN,rN


那么,从矩阵角度考察DFT

F=[11111WN1WN2WN(N1)1WN2WN4WN2(N1)1WN(N1)WN2(N1)WN(N1)2]......(1,MatrixForm)

F1=1N[11111WN1WN2WNN11WN2WN4WN2(N1)1WNN1WN2(N1)WN(N1)2]......(2,MatrixForm)

显然,以上两矩阵互逆。

考察矩阵1,Fkn=WNkn,那么

F=FT 即,沿主对角线对称

F(2r)n=F(2r)(n+N2) 即,偶数行的左半部分与右半部分相等

F(2r+1)n=WN2NF(2r+1)(n+N2)=F(2r+1)(n+N2) 即,奇数行左右半部分互为相反数


那么考虑, F[g]=g^(n)=k=0N1g(k)WNkn=k=0N21g(k)WNkn+k=N2N1g(k)WNkn

=k=0N21g(k)WNkn+k=0N21g(k+N2)WNk(n+N2)

g^ 奇偶分开讨论,

g^(2r)=k=0N21(g(k)+g(k+N2))WN2rk

即,原矩阵的这些位置 pic 等于一个N/2规模的DFT矩阵。

g^(2r+1)=k=0N21(g(k)g(k+N2))WNkWN2rk

这样就得到了两个N2点DFT。

不断分治,就得到了一个O(Nlog2N)的做法。

C++代码

typedef complex<double> T;
T dwg(int n, int flag)
{
  return T(cos(pi * 2 / n), flag * sin(pi * 2 / n));
}
void fft(T *a, int m, int flag) // 求2^m点FFT
{
  int nn = 1 << m;
  while (m) {
    int n2 = 1 << (m - 1);
    int n = n2 << 1;
    T wn(dwg(n, -flag));
    for (int j = 0; j < nn; j += n) {
      T w = 1;
      T *l = a + j, *r = a + n2 + j;
      T *rb = a + n + j;
      T t;
      for (; r < rb; ++l, ++r, w *= wn) {
        t = *l - *r;
        *l += *r;
        *r = t * w;
      }
    }
    --m;
  }
  for (int i = 0; i < n; ++i) { // 位反转
    if (R[i] > i) {
      swap(a[i], a[R[i]]);
    }
  }
}

对于正向FFT,取flag=1;对于逆FFT,取flag=1

例题

uoj上的多项式乘法

#include "bits/stdc++.h"
using namespace std;
const int BUF_SIZE = (int)1e6 + 10;

struct fastIO
{
  char buf[BUF_SIZE];
  int cur;
  FILE *in, *out;
  fastIO()
  {
    cur = BUF_SIZE;
    in = stdin;
    out = stdout;
  }
  inline char nextChar()
  {
    if (cur == BUF_SIZE) {
      fread(buf, BUF_SIZE, 1, in);
      cur = 0;
    }
    return buf[cur++];
  }
  inline int nextInt()
  {
    int x = 0;
    char c = nextChar();
    while (!('0' <= c && c <= '9')) c = nextChar();
    while ('0' <= c && c <= '9') {
      x = x * 10 + c - '0';
      c = nextChar();
    }
    return x;
  }
  inline void printChar(char ch)
  {
    if (cur == BUF_SIZE) {
      fwrite(buf, BUF_SIZE, 1, out);
      cur = 0;
    }
    buf[cur++] = ch;
  }
  inline void printInt(int x)
  {
    if (x >= 10) printInt(x / 10);
    printChar(x % 10 + '0');
  }
  inline void close()
  {
    if (cur > 0) {
      fwrite(buf, cur, 1, out);
    }
    cur = 0;
  }
} fin, fout;

const int lg2maxn = 18;
const int maxn = (1 << lg2maxn);
const double pi = acos(-1);
namespace nscpx
{
  typedef double T;
  struct cpx
  {
    T x, y;
    cpx(T a = 0, T b = 0):x(a), y(b) {}
    cpx operator-() const
    {
      return cpx(-x, -y);
    }
    cpx &operator+=(const cpx &a)
    {
      x += a.x;
      y += a.y;
      return *this;
    }
    cpx &operator-=(const cpx &a)
    {
      x -= a.x;
      y -= a.y;
      return *this;
    }
    cpx &operator*=(const cpx &a)
    {
      T nx;
      nx = x * a.x - y * a.y;
      y = y * a.x + x * a.y;
      x = nx;
      return *this;
    }
    cpx operator+(const cpx &a) const
    {
      return cpx(x + a.x, y + a.y);
    }
    cpx operator-(const cpx &a) const
    {
      return cpx(x - a.x, y - a.y);
    }
    cpx operator*(const cpx &a) const
    {
      return cpx(x * a.x - y * a.y, x * a.y + y * a.x);
    }
  };
}
using nscpx::cpx;
int R[maxn];
typedef cpx T;
T dwg(int n, int flag)
{
  return T(cos(pi * 2 / n), flag * sin(pi * 2 / n));
}
void ffft(T *a, int m, int flag)
{
  int nn = 1 << m;
  while (m) {
    int n2 = 1 << (m - 1);
    int n = n2 << 1;
    T wn(dwg(n, -flag));
    for (int j = 0; j < nn; j += n) {
      T w = 1;
      T *l = a + j, *r = a + n2 + j;
      T *rb = a + n + j;
      T t;
      for (; r < rb; ++l, ++r, w *= wn) {
        t = *l - *r;
        *l += *r;
        *r = t * w;
      }
    }
    --m;
  }
}

void rev(T *a, int m)
{
  int n = 1 << m;
  for (int i = 0; i < n; ++i) {
    if (R[i] > i) {
      swap(a[i], a[R[i]]);
    }
  }
}

int n, m;
T a[maxn], b[maxn];
#define lowbit(x) ((x)&-(x))

int main()
{
  n = fin.nextInt();
  m = fin.nextInt();
  int maxnn = m + n + 1;
  while (maxnn != lowbit(maxnn)) maxnn += lowbit(maxnn);
  int shit = 0;
  while (maxnn > (1 << shit)) ++shit;
  assert(maxnn == (1 << shit));
  for (int i = 0; i < maxnn; ++i) {
    R[i] = (R[i >> 1] >> 1) | ((i & 1) << (shit - 1));
  }
  for (int i = 0; i <= n; ++i) {
    a[i].x = fin.nextInt();
  }
  for (int i = 0; i <= m; ++i) {
    b[i].x = fin.nextInt();
  }
  ffft(a, shit, 1);
  ffft(b, shit, 1);
  for (int i = 0; i < maxnn; ++i) {
    a[i] *= b[i];
  }
  rev(a, shit);
  ffft(a, shit, -1);
  rev(a, shit);
  fout.cur = 0;
  for (int i = 0; i <= (n + m); ++i) {
    fout.printInt(a[i].x / maxnn + 0.5);
    fout.printChar(' ');
  }
  fout.printChar('\n');
  fout.close();
  return 0;
}

钦此。