深度学习异常检测Anomalib算法训练+推理+转化+onnx

张开发
2026/6/30 15:36:16 15 分钟阅读
深度学习异常检测Anomalib算法训练+推理+转化+onnx
目录一、环境安装配置python 版本 必须大于3.10二、数据集三、训练源码 GitHub - open-edge-platform/anomalib: An anomaly detection library comprising state-of-the-art algorithms and features such as experiment management, hyper-parameter optimization, and edge inference. · GitHubtrain.pydetect.py 推理代码模型转化 export.py官方的转化代码一、环境安装配置python 版本 必须大于3.10cuda和cudnn各个版本的Pytorch下载网页版onnxncnnpt模型转化工具_cuda国内镜像下载网站-CSDN博客你需要的所有东西都可以下载如果有链接打不开 就直接去网上搜官网 从官网进入下载标准环境cuda 11.8 cudnn 对应11.8的即可torch 2.4.0cu118torchaudio 2.4.0cu118torchmetrics 1.8.0torchvision 0.19.0cu118下面是我的环境torch 2.1.0torchaudio 2.1.0torchmetrics 1.9.0torchvision 0.16.0因为我的是11.7的cuda 不想换了 所以就用了 但是注意 我的环境版本只能训练Patchcore 模型 其他的模型都无法训练 版本不够所以如果你们想训练其他的Anomalib模型需要标准的高版本环境二、数据集数据集中创建这俩文件夹 放入你的正样本负样本是用来验证的也需要放。两边的数据集需要最少10张图片。图片命名不能用中文三、训练源码 GitHub - open-edge-platform/anomalib: An anomaly detection library comprising state-of-the-art algorithms and features such as experiment management, hyper-parameter optimization, and edge inference. · GitHub你可以在连接中下载 也可以直接pip 安装库pip install anomalib -i https://pypi.tuna.tsinghua.edu.cn/simple在你的工程中创建三个py文件 train.py detect.py export.pytrain.pyimport multiprocessing from anomalib.data import Folder from anomalib.models import Patchcore from anomalib.engine import Engine def main(): datamodule Folder( nameflat_enameled_wire, #训练模型后存放的位置 rootrC:\Users\Administrator\Desktop\tscc\200t, #数据集 normal_dirgood, #数据集的good abnormal_dirbad, #数据集中的bad # # val_split_modenone, # 无验证集 # test_split_modenone, # 无自动测试集 normal_split_ratio1.0, # 全部正常图用于训练 train_batch_size1, num_workers0, # Windows 必须 0 ) datamodule.setup() model Patchcore( backbonewide_resnet50_2, #预训练模型 pre_trainedTrue, # 必须开否则没法训 开启就是下载预训练模型 并使用 coreset_sampling_ratio0.05, num_neighbors5 ) engine Engine( devices1, max_epochs1, ) engine.train(datamoduledatamodule, modelmodel) if __name__ __main__: multiprocessing.freeze_support() main()预训练模型如果网络不好是无法下载的 可以直接用下方链接下载 预训练模型下载好后将其放到 这个路径中 如果没有就创建 C:\Users\Administrator\.cache\torch\hub\checkpointshttps://download.pytorch.org/models/wide_resnet50_2-95faca4d.pthdetect.py 推理代码import multiprocessing from anomalib.data import Folder # Import the model and engine from anomalib.models import Patchcore from anomalib.engine import Engine def main(): # Create the datamodule datamodule Folder( nameflat_enameled_wire_infer4, # rootrH:\anomalib-main\ttt\train\11, rootrC:\Users\Administrator\Desktop\tscc\test22, normal_dirgood, abnormal_dirbad, # taskclassification, ) # Setup the datamodule datamodule.setup() model Patchcore(pre_trainedFalse) # engine Engine(taskclassification) engine Engine() engine.predict( datamoduledatamodule, modelmodel, ckpt_pathrH:\anomalib-main\results\Patchcore\flat_enameled_wire\v7\weights\lightning\model.ckpt, ) if __name__ __main__: multiprocessing.freeze_support() # Optional, if your script might be frozen into an executable main()模型转化 export.py我用的是torch 转化的 如果是高版本的可以直接用官方的import os import torch os.environ[HF_HUB_OFFLINE] 1 os.environ[TRANSFORMERS_OFFLINE] 1 os.environ[KMP_DUPLICATE_LIB_OK] TRUE CKPT_PATH rH:\anomalib-main\results\Patchcore\flat_enameled_wire\v5\weights\lightning\model.ckpt INPUT_SIZE (3, 224, 224) ONNX_SAVE_PATH rH:\anomalib-main\results\Patchcore\flat_enameled_wire\v5\weights\lightning\model_final.onnx from anomalib.models import Patchcore # 加载模型 model Patchcore(backbonewide_resnet50_2, pre_trainedFalse) checkpoint torch.load(CKPT_PATH, map_locationcpu) model.load_state_dict(checkpoint[state_dict]) model.eval() dummy_input torch.randn(1, *INPUT_SIZE) torch.onnx.export( model, dummy_input, ONNX_SAVE_PATH, opset_version14, do_constant_foldingTrue, input_names[input], output_names[output], # dynamo 参数彻底删掉 ) print(f导出成功{ONNX_SAVE_PATH})官方的转化代码from anomalib.models import Patchcore from anomalib.engine import Engine model Patchcore() engine Engine(taskclassification) onnx_model engine.export( modelmodel, export_typeonnx, export_rootNone, input_size[244, 244], transformNone, compression_typeNone, datamoduleNone, metricNone, ov_argsNone, ckpt_pathE:\\proj\\anomalib\\myProj\\model.ckpt, # 存放model.ckpt ) print(onnx_model)

更多文章