#!/usr/bin/env python3
# -*- coding: utf-8 -*
"""
频域图像增强实验代码
基于 OpenCV 和 NumPy 实现

功能：
1. 傅里叶变换与频谱显示
2. 理想低通/高通滤波器
3. 巴特沃斯低通/高通滤波器
4. 高斯低通/高通滤波器
5. 不同截止频率对比
6. 振铃效应分析

使用方法：
1. 准备测试图片（如：car.jpg）放在同目录
2. 运行：python 频域图像增强实验代码.py
3. 查看输出结果
"""

import cv2
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams

# 设置中文字体
rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
rcParams['axes.unicode_minus'] = False


def fft2_image(image):
    """
    对图像进行二维傅里叶变换，并进行频谱中心化
    """
    # 转换为灰度图
    if len(image.shape) == 3:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    
    # 转换为 float64
    f = np.float64(image)
    
    # 傅里叶变换
    F = np.fft.fft2(f)
    
    # 频谱中心化
    Fshift = np.fft.fftshift(F)
    
    return Fshift


def ifft2_image(Fshift):
    """
    逆傅里叶变换，返回空域图像
    """
    # 逆中心化
    F = np.fft.ifftshift(Fshift)
    
    # 逆傅里叶变换
    f = np.fft.ifft2(F)
    
    # 取实部
    img = np.real(f)
    
    # 归一化到 0-255
    img = np.clip(img, 0, 255)
    img = np.uint8(img)
    
    return img


def get_distance_matrix(M, N):
    """
    获取频率距离矩阵 D(u,v)
    """
    u = np.arange(M)
    v = np.arange(N)
    
    # 中心化
    u = np.where(u > M/2, u - M, u)
    v = np.where(v > N/2, v - N, v)
    
    # 创建网格
    V, U = np.meshgrid(v, u)
    
    # 计算距离
    D = np.sqrt(U**2 + V**2)
    
    return D


def ideal_lowpass(M, N, D0):
    """
    理想低通滤波器
    """
    D = get_distance_matrix(M, N)
    H = np.double(D <= D0)
    return H


def ideal_highpass(M, N, D0):
    """
    理想高通滤波器
    """
    D = get_distance_matrix(M, N)
    H = np.double(D > D0)
    return H


def butterworth_lowpass(M, N, D0, n=2):
    """
    巴特沃斯低通滤波器
    """
    D = get_distance_matrix(M, N)
    H = 1 / (1 + (D / D0)**(2*n))
    return H


def butterworth_highpass(M, N, D0, n=2):
    """
    巴特沃斯高通滤波器
    """
    D = get_distance_matrix(M, N)
    H = 1 / (1 + (D0 / D)**(2*n))
    H[D == 0] = 0  # 避免除以0
    return H


def gaussian_lowpass(M, N, D0):
    """
    高斯低通滤波器
    """
    D = get_distance_matrix(M, N)
    H = np.exp(-(D**2) / (2 * D0**2))
    return H


def gaussian_highpass(M, N, D0):
    """
    高斯高通滤波器
    """
    D = get_distance_matrix(M, N)
    H = 1 - np.exp(-(D**2) / (2 * D0**2))
    return H


def apply_filter(Fshift, H):
    """
    应用滤波器
    """
    G = Fshift * H
    return G


def display_spectrum(Fshift, title="频谱"):
    """
    显示频谱（对数变换增强可视化）
    """
    magnitude = np.abs(Fshift)
    magnitude_log = np.log(1 + magnitude)
    
    # 归一化到 0-255
    magnitude_norm = cv2.normalize(magnitude_log, None, 0, 255, cv2.NORM_MINMAX)
    
    return np.uint8(magnitude_norm)


def task1_fourier_transform(image_path):
    """
    任务1：傅里叶变换与频谱显示
    """
    print("=" * 50)
    print("任务1：傅里叶变换与频谱显示")
    print("=" * 50)
    
    # 读取图像
    img = cv2.imread(image_path)
    if img is None:
        print(f"错误：无法读取图像 {image_path}")
        return
    
    # 转换为灰度图
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    
    # 傅里叶变换
    Fshift = fft2_image(gray)
    
    # 显示频谱
    spectrum = display_spectrum(Fshift)
    
    # 保存结果
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.imshow(gray, cmap='gray')
    plt.title('原图')
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(spectrum, cmap='gray')
    plt.title('频谱（对数变换后）')
    plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('task1_fourier_transform.png', dpi=150)
    print("结果已保存：task1_fourier_transform.png")
    plt.close()


def task2_ideal_filter(image_path, D0_values=[30, 60, 120]):
    """
    任务2：理想低通和高通滤波器
    """
    print("=" * 50)
    print("任务2：理想低通和高通滤波器")
    print("=" * 50)
    
    # 读取图像
    img = cv2.imread(image_path)
    if img is None:
        print(f"错误：无法读取图像 {image_path}")
        return
    
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    M, N = gray.shape
    
    # 傅里叶变换
    Fshift = fft2_image(gray)
    
    # 理想低通滤波
    plt.figure(figsize=(15, 10))
    
    for i, D0 in enumerate(D0_values):
        # 构建滤波器
        H = ideal_lowpass(M, N, D0)
        
        # 应用滤波器
        G = apply_filter(Fshift, H)
        
        # 逆变换
        result = ifft2_image(G)
        
        # 显示
        plt.subplot(3, 3, i*3 + 1)
        plt.imshow(H, cmap='gray')
        plt.title(f'理想低通滤波器 D0={D0}')
        plt.axis('off')
        
        plt.subplot(3, 3, i*3 + 2)
        plt.imshow(np.log(1 + np.abs(G)), cmap='gray')
        plt.title(f'滤波后频谱 D0={D0}')
        plt.axis('off')
        
        plt.subplot(3, 3, i*3 + 3)
        plt.imshow(result, cmap='gray')
        plt.title(f'滤波结果 D0={D0}')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('task2_ideal_lowpass.png', dpi=150)
    print("结果已保存：task2_ideal_lowpass.png")
    plt.close()
    
    # 理想高通滤波
    plt.figure(figsize=(15, 10))
    
    for i, D0 in enumerate(D0_values):
        # 构建滤波器
        H = ideal_highpass(M, N, D0)
        
        # 应用滤波器
        G = apply_filter(Fshift, H)
        
        # 逆变换
        result = ifft2_image(G)
        
        # 显示
        plt.subplot(3, 3, i*3 + 1)
        plt.imshow(H, cmap='gray')
        plt.title(f'理想高通滤波器 D0={D0}')
        plt.axis('off')
        
        plt.subplot(3, 3, i*3 + 2)
        plt.imshow(np.log(1 + np.abs(G)), cmap='gray')
        plt.title(f'滤波后频谱 D0={D0}')
        plt.axis('off')
        
        plt.subplot(3, 3, i*3 + 3)
        plt.imshow(result, cmap='gray')
        plt.title(f'滤波结果 D0={D0}')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('task2_ideal_highpass.png', dpi=150)
    print("结果已保存：task2_ideal_highpass.png")
    plt.close()


def task3_butterworth_filter(image_path, D0=60, n_values=[1, 2, 4]):
    """
    任务3：巴特沃斯低通和高通滤波器
    """
    print("=" * 50)
    print("任务3：巴特沃斯低通和高通滤波器")
    print("=" * 50)
    
    # 读取图像
    img = cv2.imread(image_path)
    if img is None:
        print(f"错误：无法读取图像 {image_path}")
        return
    
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    M, N = gray.shape
    
    # 傅里叶变换
    Fshift = fft2_image(gray)
    
    # 巴特沃斯低通滤波
    plt.figure(figsize=(15, 5))
    
    for i, n in enumerate(n_values):
        # 构建滤波器
        H = butterworth_lowpass(M, N, D0, n)
        
        # 应用滤波器
        G = apply_filter(Fshift, H)
        
        # 逆变换
        result = ifft2_image(G)
        
        # 显示
        plt.subplot(2, 3, i + 1)
        plt.imshow(H, cmap='gray')
        plt.title(f'BLPF D0={D0}, n={n}')
        plt.axis('off')
        
        plt.subplot(2, 3, i + 4)
        plt.imshow(result, cmap='gray')
        plt.title(f'结果 n={n}')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('task3_butterworth_lowpass.png', dpi=150)
    print("结果已保存：task3_butterworth_lowpass.png")
    plt.close()
    
    # 巴特沃斯高通滤波
    plt.figure(figsize=(15, 5))
    
    for i, n in enumerate(n_values):
        # 构建滤波器
        H = butterworth_highpass(M, N, D0, n)
        
        # 应用滤波器
        G = apply_filter(Fshift, H)
        
        # 逆变换
        result = ifft2_image(G)
        
        # 显示
        plt.subplot(2, 3, i + 1)
        plt.imshow(H, cmap='gray')
        plt.title(f'BHPF D0={D0}, n={n}')
        plt.axis('off')
        
        plt.subplot(2, 3, i + 4)
        plt.imshow(result, cmap='gray')
        plt.title(f'结果 n={n}')
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('task3_butterworth_highpass.png', dpi=150)
    print("结果已保存：task3_butterworth_highpass.png")
    plt.close()


def task4_gaussian_filter(image_path, D0_values=[30, 60, 120]):
