|
| 1 | +{ |
| 2 | +"cells": [ |
| 3 | + { |
| 4 | +"cell_type":"markdown", |
| 5 | +"metadata": {}, |
| 6 | +"source": [ |
| 7 | +"# 选择参数并将 domain1 的源图像迁移到至少 3 个不同的域。" |
| 8 | + ] |
| 9 | + }, |
| 10 | + { |
| 11 | +"cell_type":"markdown", |
| 12 | +"metadata": {}, |
| 13 | +"source": [ |
| 14 | +"* $\\alpha$ : 二维矩阵掩码控制着在幅度频谱图中有多少尺度的低频信息会被交换,在中心比例为 $\\alpha$ 处为1,其他地方为0\n", |
| 15 | +"* $\\lambda$:该参数控制着调整分布信息的内插比例" |
| 16 | + ] |
| 17 | + }, |
| 18 | + { |
| 19 | +"cell_type":"code", |
| 20 | +"execution_count":68, |
| 21 | +"metadata": {}, |
| 22 | +"outputs": [], |
| 23 | +"source": [ |
| 24 | +"import os\n", |
| 25 | +"from sklearn.manifold import TSNE\n", |
| 26 | +"import cv2\n", |
| 27 | +"import numpy as np\n", |
| 28 | +"import matplotlib.pyplot as plt\n", |
| 29 | +"import os\n", |
| 30 | +"import random\n", |
| 31 | +"\n", |
| 32 | +"alpha =0.2\n", |
| 33 | +"lam = 0.5\n", |
| 34 | +"target_domain = r'Domain5'\n", |
| 35 | +"source_domain_folder = r'G:\\FedICRA\\data\\FAZ\\Domain1\\train\\imgs'" |
| 36 | + ] |
| 37 | + }, |
| 38 | + { |
| 39 | +"cell_type":"code", |
| 40 | +"execution_count":69, |
| 41 | +"metadata": {}, |
| 42 | +"outputs": [], |
| 43 | +"source": [ |
| 44 | +"\n", |
| 45 | +"def domain_generalization(source_domain_folder: str,target_domain: str, alpha:float, lam:float):\n", |
| 46 | +" '''\n", |
| 47 | +" 领域泛化的包装函数,用于实现对某个源图像域的领域泛化。\n", |
| 48 | +"\n", |
| 49 | +" 参数:\n", |
| 50 | +" source_domain_folder: str\n", |
| 51 | +" 源域图像的文件夹路径\n", |
| 52 | +" target_domain: str\n", |
| 53 | +" 表示目标域的字符串,'Domain2','Domain3','Domain4', 'Domain5'\n", |
| 54 | +" 返回值:无返回值\n", |
| 55 | +" '''\n", |
| 56 | +" # 生成新的文件夹\n", |
| 57 | +" generated_image_folder = create_directory(target_domain)\n", |
| 58 | +" source_domain_paths = get_source_domain_paths(source_domain_folder)\n", |
| 59 | +" # 打印所有图像文件的绝对路径\n", |
| 60 | +" for image_path in source_domain_paths:\n", |
| 61 | +" spect_amp_source, spect_pha_source = get_source_domain_image_spec(image_path)\n", |
| 62 | +" spect_amp_target = get_random_target_domain_image_amp(target_domain)\n", |
| 63 | +" spect_amp_generated = get_spect_amp_generated(alpha, lam, spect_amp_source,spect_amp_target)\n", |
| 64 | +"\n", |
| 65 | +" complex_spectrum = spect_amp_generated * np.exp(1j * spect_pha_source)\n", |
| 66 | +" # 逆频谱中心化\n", |
| 67 | +" reconstructed_image = np.fft.ifftshift(complex_spectrum)\n", |
| 68 | +"\n", |
| 69 | +" # 应用逆傅里叶变换\n", |
| 70 | +" reconstructed_image = np.fft.ifft2(reconstructed_image)\n", |
| 71 | +" # 生成图像\n", |
| 72 | +" reconstructed_image = np.abs(reconstructed_image).astype(np.uint8)\n", |
| 73 | +" # save\n", |
| 74 | +" parts = image_path.split('\\\\') # 使用 '\\\\' 进行路径分割,得到每个部分\n", |
| 75 | +"\n", |
| 76 | +" img_name = parts[-1]\n", |
| 77 | +" generated_image_path = os.path.join(generated_image_folder,img_name) # 假设保存在Domain文件夹下\\\n", |
| 78 | +" cv2.imwrite(generated_image_path,reconstructed_image)\n", |
| 79 | +"\n", |
| 80 | +"domain_generalization(source_domain_folder=source_domain_folder,target_domain=target_domain,alpha=alpha,lam=lam)" |
| 81 | + ] |
| 82 | + }, |
| 83 | + { |
| 84 | +"cell_type":"code", |
| 85 | +"execution_count":70, |
| 86 | +"metadata": {}, |
| 87 | +"outputs": [], |
| 88 | +"source": [ |
| 89 | +"## 得到所有源域内的图像路径,在该函数中是domain1\n", |
| 90 | +"def get_source_domain_paths(folder):\n", |
| 91 | +" source_domain_paths = []\n", |
| 92 | +" for root, dirs, files in os.walk(folder):\n", |
| 93 | +" for file in files:\n", |
| 94 | +" file_path = os.path.join(root, file)\n", |
| 95 | +" # 判断文件是否为图像文件(这里以常见的图片格式为例,可以根据实际情况扩展)\n", |
| 96 | +" if file_path.lower().endswith(('.png')):\n", |
| 97 | +" source_domain_paths.append(file_path)\n", |
| 98 | +" return source_domain_paths\n", |
| 99 | +"\n", |
| 100 | +"source_domain_paths = get_source_domain_paths(folder)" |
| 101 | + ] |
| 102 | + }, |
| 103 | + { |
| 104 | +"cell_type":"code", |
| 105 | +"execution_count":71, |
| 106 | +"metadata": {}, |
| 107 | +"outputs": [], |
| 108 | +"source": [ |
| 109 | +"# 从目标域中随机获得一个频谱图,在该例子中是domain2,project要求的是3个域\n", |
| 110 | +"def get_random_target_domain_image_amp(target_domain):\n", |
| 111 | +" target_domain_paths = []\n", |
| 112 | +" target_domain = r'Domain2'\n", |
| 113 | +" current_directory = os.getcwd() # 获取当前工作目录\n", |
| 114 | +"\n", |
| 115 | +" save_directory = os.path.join(current_directory, target_domain) # 假设保存在Domain文件夹\n", |
| 116 | +" for root, dirs, files in os.walk(save_directory):\n", |
| 117 | +" for file in files:\n", |
| 118 | +" file_path = os.path.join(root, file)\n", |
| 119 | +" # 判断文件是否为图像文件\n", |
| 120 | +" if 'amp' in file_path.lower():\n", |
| 121 | +" target_domain_paths.append(file_path)\n", |
| 122 | +"\n", |
| 123 | +"\n", |
| 124 | +" random_target_domain_path = random.choice(target_domain_paths)\n", |
| 125 | +" img = cv2.imread(random_target_domain_path)\n", |
| 126 | +" img_resize = cv2.resize(img, (200,200))\n", |
| 127 | +" gray_image = cv2.cvtColor(img_resize, cv2.COLOR_BGR2GRAY) # 灰度调整\n", |
| 128 | +" spect = np.fft.fft2(gray_image)\n", |
| 129 | +" spect = np.fft.fftshift(spect) # 频谱中心化\n", |
| 130 | +" spect_amp_target = np.abs(spect)\n", |
| 131 | +"\n", |
| 132 | +" return spect_amp_target\n", |
| 133 | +"\n" |
| 134 | + ] |
| 135 | + }, |
| 136 | + { |
| 137 | +"cell_type":"code", |
| 138 | +"execution_count":72, |
| 139 | +"metadata": {}, |
| 140 | +"outputs": [], |
| 141 | +"source": [ |
| 142 | +"def get_source_domain_image_spec(image_path):\n", |
| 143 | +" img = cv2.imread(image_path)\n", |
| 144 | +" img_resize = cv2.resize(img, (200,200))\n", |
| 145 | +" gray_image = cv2.cvtColor(img_resize, cv2.COLOR_BGR2GRAY) # 灰度调整\n", |
| 146 | +" spect = np.fft.fft2(gray_image)\n", |
| 147 | +" spect = np.fft.fftshift(spect) # 频谱中心化\n", |
| 148 | +" spect_amp_source = np.abs(spect)\n", |
| 149 | +" spect_pha_source = np.angle(spect)\n", |
| 150 | +"\n", |
| 151 | +" return spect_amp_source,spect_pha_source\n", |
| 152 | +"\n", |
| 153 | +"# spect_amp_source,spect_pha_source = get_source_domain_image_spec(r'G:\\FedICRA\\data\\FAZ\\Domain1\\train\\imgs\\002_D_10.png')" |
| 154 | + ] |
| 155 | + }, |
| 156 | + { |
| 157 | +"cell_type":"code", |
| 158 | +"execution_count":73, |
| 159 | +"metadata": {}, |
| 160 | +"outputs": [ |
| 161 | + { |
| 162 | +"name":"stdout", |
| 163 | +"output_type":"stream", |
| 164 | +"text": [ |
| 165 | +"[[1. 1. 1. ... 1. 1. 1.]\n", |
| 166 | +" [1. 1. 1. ... 1. 1. 1.]\n", |
| 167 | +" [1. 1. 1. ... 1. 1. 1.]\n", |
| 168 | +" ...\n", |
| 169 | +" [1. 1. 1. ... 1. 1. 1.]\n", |
| 170 | +" [1. 1. 1. ... 1. 1. 1.]\n", |
| 171 | +" [1. 1. 1. ... 1. 1. 1.]]\n" |
| 172 | + ] |
| 173 | + } |
| 174 | + ], |
| 175 | +"source": [ |
| 176 | +"\n", |
| 177 | +"# 生成掩码矩阵M,参数由shape和alpha进行调整,一个控制矩阵的大小,一个控制中心掩码区(全部为1)的比例\n", |
| 178 | +"def generate_binary_mask(shape, alpha):\n", |
| 179 | +" alpha /= 2\n", |
| 180 | +" h, w = shape\n", |
| 181 | +" center_h, center_w = h//2, w//2\n", |
| 182 | +" mask = np.zeros((h,w))\n", |
| 183 | +"\n", |
| 184 | +" alpha_h = int(alpha * h)\n", |
| 185 | +" alpha_w = int(alpha * w)\n", |
| 186 | +"\n", |
| 187 | +" mask[center_h - alpha_h:center_h+alpha_h, center_w - alpha_w:center_w + alpha_w] = 1\n", |
| 188 | +"\n", |
| 189 | +" return mask\n", |
| 190 | +"\n", |
| 191 | +"mask = generate_binary_mask((200,200), 1)\n", |
| 192 | +"print(mask)" |
| 193 | + ] |
| 194 | + }, |
| 195 | + { |
| 196 | +"cell_type":"code", |
| 197 | +"execution_count":74, |
| 198 | +"metadata": {}, |
| 199 | +"outputs": [], |
| 200 | +"source": [ |
| 201 | +"# 由文章公式给出计算生成频谱图的代码\n", |
| 202 | +"def get_spect_amp_generated(alpha: float, lam: float, spect_amp_source, spect_amp_target):\n", |
| 203 | +" mask = generate_binary_mask(spect_amp_source.shape,alpha)\n", |
| 204 | +"\n", |
| 205 | +" spect_amp_generated = mask * ((1-lam)*spect_amp_source+lam*spect_amp_target) + (1-mask) * spect_amp_source\n", |
| 206 | +" return spect_amp_generated\n", |
| 207 | +"\n", |
| 208 | +"\n" |
| 209 | + ] |
| 210 | + }, |
| 211 | + { |
| 212 | +"cell_type":"code", |
| 213 | +"execution_count":75, |
| 214 | +"metadata": {}, |
| 215 | +"outputs": [], |
| 216 | +"source": [ |
| 217 | +"def create_directory(target_domain):\n", |
| 218 | +" current_directory = os.getcwd() # 获取当前工作目录\n", |
| 219 | +"\n", |
| 220 | +" save_directory = os.path.join(current_directory, f'Domain1_to_{target_domain}') # 假设保存在Domain文件夹下\n", |
| 221 | +" if not os.path.exists(save_directory):\n", |
| 222 | +" os.makedirs(save_directory)\n", |
| 223 | +" return save_directory\n" |
| 224 | + ] |
| 225 | + }, |
| 226 | + { |
| 227 | +"cell_type":"markdown", |
| 228 | +"metadata": {}, |
| 229 | +"source": [] |
| 230 | + } |
| 231 | + ], |
| 232 | +"metadata": { |
| 233 | +"kernelspec": { |
| 234 | +"display_name":"pilot", |
| 235 | +"language":"python", |
| 236 | +"name":"python3" |
| 237 | + }, |
| 238 | +"language_info": { |
| 239 | +"codemirror_mode": { |
| 240 | +"name":"ipython", |
| 241 | +"version":3 |
| 242 | + }, |
| 243 | +"file_extension":".py", |
| 244 | +"mimetype":"text/x-python", |
| 245 | +"name":"python", |
| 246 | +"nbconvert_exporter":"python", |
| 247 | +"pygments_lexer":"ipython3", |
| 248 | +"version":"3.10.13" |
| 249 | + } |
| 250 | + }, |
| 251 | +"nbformat":4, |
| 252 | +"nbformat_minor":2 |
| 253 | +} |