使用 Dask 和 PyTorch 进行大规模图像分析
作者:Nicholas Sofroniew, Genevieve Buckley
概述
本文探讨了如何使用 Dask Array 并行应用预训练的 PyTorch 模型。
我们展示了一个简单示例,将预训练的 UNet 应用于图像堆栈,为每个像素生成特征。
一个实际示例
让我们从一个示例开始,将预训练的 UNet 应用于光片显微镜数据堆栈。
在此示例中,我们
- 将图像数据从 Zarr 加载到多分块的 Dask array 中
- 加载一个对图像进行特征提取的预训练 PyTorch 模型
- 构建一个将模型应用于每个分块的函数
- 使用
dask.array.map_blocks
函数将该函数应用于 Dask array。 - 将结果存储回 Zarr 格式
步骤 1. 加载图像数据
首先,我们将图像数据加载到 Dask array 中。
我们此处使用的示例数据集是斑马鱼胚胎尾部区域的点阵光片显微镜数据。它在 这篇 Science 论文 中有描述(见图 4),并经 Srigokul Upadhyayula 许可提供。
Liu et al. 2018 年,“Observing the cell in its native state: Imaging subcellular dynamics in multicellular organisms” Science, Vol. 360, Issue 6386, eaaq1392 DOI: 10.1126/science.aaq1392 (链接)
这是我们在上一篇 关于 Dask 和 ITK 的博文 中分析过的相同数据。尽管我们现在使用的是新的库并执行不同的分析,但您应该注意到与该工作流程的相似之处。
cd '/Users/nicholassofroniew/Github/image-demos/data/LLSM'
# Load our data
import dask.array as da
imgs = da.from_zarr("AOLLSM_m4_560nm.zarr")
imgs
dask.array<from-zarr, shape=(20, 199, 768, 1024), dtype=float32, chunksize=(1, 1, 768, 1024)>
步骤 2. 加载预训练的 PyTorch 模型
接下来,我们加载预训练的 UNet 模型。
这个 UNet 模型接受一个 2D 图像作为输入,并返回一个 2D x 16 的 array,其中每个像素现在都关联一个长度为 16 的特征向量。
我们感谢 Mars Huang 在一个生物图像语料库上训练了这个特定的 UNet 模型,以产生生物学相关的特征向量,这是他关于 交互式生物图像分割 工作的一部分。这些特征随后可用于更下游的图像处理任务,例如图像分割。
# Load our pretrained UNet¶
import torch
from segmentify.model import UNet, layers
def load_unet(path):
"""Load a pretrained UNet model."""
# load in saved model
pth = torch.load(path)
model_args = pth['model_args']
model_state = pth['model_state']
model = UNet(**model_args)
model.load_state_dict(model_state)
# remove last layer and activation
model.segment = layers.Identity()
model.activate = layers.Identity()
model.eval()
return model
model = load_unet("HPA_3.pth")
步骤 3. 构建一个将模型应用于每个分块的函数
我们构建一个函数,将预训练的 UNet 模型应用于 Dask array 的每个分块。
由于 Dask array 由可轻松转换为 Torch array 的 Numpy array 构成,因此我们能够在大规模数据上利用机器学习的力量。
# Apply UNet featurization
import numpy as np
def unet_featurize(image, model):
"""Featurize pixels in an image using pretrained UNet model.
"""
import numpy as np
import torch
# Extract the 2D image data from the Dask array
# Original Dask array dimensions were (time, z-slice, y, x)
img = image[0, 0, ...]
# Put the data into a shape PyTorch expects
# Expected dimensions are (Batch x Channel x Width x Height)
img = img[None, None, ...]
# convert image to torch Tensor
img = torch.Tensor(img).float()
# pass image through model
with torch.no_grad():
features = model(img).numpy()
# generate feature vectors (w,h,f)
features = np.transpose(features, (0,2,3,1))[0]
# Add back the leading length-one dimensions
result = features[None, None, ...]
return result
注意:非常细心的读者可能会注意到,提取 2D 图像数据然后将其转换为 PyTorch 期望的形状的步骤似乎是多余的。对于我们的特定示例来说确实是多余的,但这很容易并非总是如此。
更详细地解释一下,UNet 需要 4D 输入,维度为 (Batch x Channel x Width x Height)
。原始 Dask array 的维度为 (time, z-slice, y, x)
。在我们的示例中,碰巧这些维度匹配的方式使得移除然后再添加前导维度变得多余,但根据原始 Dask array 的形状,情况可能并非如此。
步骤 4. 将该函数应用于整个 Dask array
现在我们使用 dask.array.map_blocks
将该函数应用于 Dask array 中的数据。
# Apply UNet featurization
out = da.map_blocks(unet_featurize, imgs, model, dtype=np.float32, chunks=(1, 1, imgs.shape[2], imgs.shape[3], 16), new_axis=-1)
out
dask.array<unet_featurize, shape=(20, 199, 768, 1024, 16), dtype=float32, chunksize=(1, 1, 768, 1024, 16)>
步骤 5. 将结果存储回 Zarr 格式
最后,我们将 UNet 模型特征提取的结果存储为 zarr array。
# Trigger computation and store
out.to_zarr("AOLLSM_featurized.zarr", overwrite=True)
现在我们已经保存了输出,这些特征可用于更下游的图像处理任务,例如图像分割。
总结
在此,我们展示了如何将预训练的 PyTorch 模型应用于图像数据的 Dask array。
由于我们的 Dask array 分块是 Numpy array,它们可以轻松转换为 Torch array。通过这种方式,我们能够在大规模数据上利用机器学习的力量。
这个工作流程与我们使用 我们的示例 使用 ITK 的 dask.array.map_blocks
函数进行图像反卷积非常相似。这表明您可以轻松调整同类型的工作流程,使用 Dask 实现多种不同类型的分析。
博客评论由 Disqus 提供支持