使用 PyTorch 进行批量预测
目录
实时笔记本
您可以在实时会话中运行此笔记本 ,或在 Github 上查看。
使用 PyTorch 进行批量预测¶
[ ]:
%matplotlib inline
本示例遵循 Torch 的迁移学习教程。我们将
在一个特定任务(蚂蚁 vs. 蜜蜂)上微调一个预训练的卷积神经网络。
使用 Dask 集群进行该模型的批量预测。
注意: Binder 环境默认不安转本例所需的依赖项。您需要在单元格中执行
!conda install torchvision pytorch-cpu
来安装必要的包。
主要关注点是使用 Dask 集群进行批量预测。
下载数据¶
PyTorch 文档托管了一小部分数据。我们将它下载并本地解压。
[ ]:
import urllib.request
import zipfile
[ ]:
filename, _ = urllib.request.urlretrieve("https://download.pytorch.org/tutorial/hymenoptera_data.zip", "data.zip")
zipfile.ZipFile(filename).extractall()
目录结构看起来像
hymenoptera_data/
train/
ants/
0013035.jpg
...
1030023514_aad5c608f9.jpg
bees/
1092977343_cb42b38d62.jpg
...
2486729079_62df0920be.jpg
train/
ants/
0013025.jpg
...
1030023514_aad5c606d9.jpg
bees/
1092977343_cb42b38e62.jpg
...
2486729079_62df0921be.jpg
按照教程,我们将微调模型。
[ ]:
import torchvision
from tutorial_helper import (imshow, train_model, visualize_model,
dataloaders, class_names, finetune_model)
微调模型¶
我们的基础模型是 resnet18。它预测 1,000 个类别,而我们的只预测 2 个(蚂蚁或蜜蜂)。为了让这个模型在 examples.dask.org 上快速训练,我们只使用几个 epoch。
[ ]:
import dask
[ ]:
%%time
model = finetune_model()
对一些随机图像进行检查,结果看起来还不错
[ ]:
visualize_model(model)
使用 Dask 进行批量预测¶
现在进入主要主题:使用预训练模型在 Dask 集群上进行批量预测。主要有两个复杂之处,都与最小化数据移动量有关
在 workers 上加载数据。 我们将使用
dask.delayed
在 workers 上加载数据,而不是在客户端加载后再发送给 workers。PyTorch 神经网络很大。 我们不希望它们出现在 Dask 任务图中,而且只希望它们移动一次。
[ ]:
from distributed import Client
client = Client(n_workers=2, threads_per_worker=2)
client
在 worker 上加载数据¶
首先,我们将定义一些辅助函数来加载数据并为神经网络进行预处理。我们在此使用 dask.delayed
,以便执行是惰性的,并在集群上发生。有关使用 dask.delayed
的更多信息,请参阅delayed 示例。
[ ]:
import glob
import toolz
import dask
import dask.array as da
import torch
from torchvision import transforms
from PIL import Image
@dask.delayed
def load(path, fs=__builtins__):
with fs.open(path, 'rb') as f:
img = Image.open(f).convert("RGB")
return img
@dask.delayed
def transform(img):
trn = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
return trn(img)
[ ]:
objs = [load(x) for x in glob.glob("hymenoptera_data/val/*/*.jpg")]
要从云存储(例如 Amazon S3)加载数据,您可以使用
import s3fs
fs = s3fs.S3FileSystem(...)
objs = [load(x, fs=fs) for x in fs.glob(...)]
PyTorch 模型期望特定形状的张量,所以我们来转换它们。
[ ]:
tensors = [transform(x) for x in objs]
模型期望批量输入,所以我们来堆叠几个。
[ ]:
batches = [dask.delayed(torch.stack)(batch)
for batch in toolz.partition_all(10, tensors)]
batches[:5]
最后,我们将编写一个小的 predict
辅助函数来预测输出类别(0 或 1)。
[ ]:
@dask.delayed
def predict(batch, model):
with torch.no_grad():
out = model(batch)
_, predicted = torch.max(out, 1)
predicted = predicted.numpy()
return predicted
移动模型¶
PyTorch 神经网络很大,因此我们不希望它们在任务图中重复多次(每个批次一次)。
[ ]:
import pickle
dask.utils.format_bytes(len(pickle.dumps(model)))
相反,我们也将模型本身包装在 dask.delayed
中。这意味着模型在 Dask 图中只出现一次。
此外,由于我们上面进行了微调(如果在 GPU 上可用则运行),我们应该将模型移回 CPU。
[ ]:
dmodel = dask.delayed(model.cpu()) # ensuring model is on the CPU
现在我们将使用(延迟的)predict
方法来获取我们的预测结果。
[ ]:
predictions = [predict(batch, dmodel) for batch in batches]
dask.visualize(predictions[:2])
可视化效果有点杂乱,但大型的 PyTorch 模型是那个方框,它是两个 predict
任务的祖先。
现在,我们可以进行计算,使用 Dask 集群来完成所有工作。由于我们使用的数据集很小,直接使用 dask.compute
将结果带回本地客户端是安全的。对于更大的数据集,您会希望写入磁盘或云存储,或在集群上继续处理预测结果。
[ ]:
predictions = dask.compute(*predictions)
predictions
总结¶
本示例展示了如何使用 PyTorch 和 Dask 对一组图像进行批量预测。我们注意在集群上远程加载数据,并只对大型神经网络进行一次序列化。