33 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			33 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| import torch.nn as nn
 | |
| import torch.nn.functional as F
 | |
| import numpy as np
 | |
| 
 | |
| 
 | |
| def gaussian_kernel(kernel_size, sigma):
 | |
|     kernel = np.fromfunction(
 | |
|         lambda x, y: (1 / (2 * np.pi * sigma ** 2)) *
 | |
|                      np.exp(-((x - (kernel_size - 1) / 2) ** 2 + (y - (kernel_size - 1) / 2) ** 2) / (2 * sigma ** 2)),
 | |
|         (kernel_size, kernel_size)
 | |
|     )
 | |
|     return kernel / np.sum(kernel)
 | |
| 
 | |
| 
 | |
| class GaussianBlur(nn.Module):
 | |
|     def __init__(self, channels, kernel_size, sigma):
 | |
|         super(GaussianBlur, self).__init__()
 | |
|         self.channels = channels
 | |
|         self.kernel_size = kernel_size
 | |
|         self.sigma = sigma
 | |
|         self.padding = kernel_size // 2  # Ensure output size matches input size
 | |
|         self.register_buffer('kernel', torch.tensor(gaussian_kernel(kernel_size, sigma), dtype=torch.float32))
 | |
|         self.kernel = self.kernel.view(1, 1, kernel_size, kernel_size)
 | |
|         self.kernel = self.kernel.expand(self.channels, -1, -1, -1)  # Repeat the kernel for each input channel
 | |
| 
 | |
|     def forward(self, x):
 | |
|         x = F.conv2d(x, self.kernel.to(x), padding=self.padding, groups=self.channels)
 | |
|         return x
 | |
| 
 | |
| 
 | |
| gaussian_filter_2d = GaussianBlur(4, 7, 0.8)
 |