list

查看可用模型

import torch
 
torch.hub.list("ultralytics/yolov5")
# 返回: ['custom', 'yolov5l', 'yolov5l6', 'yolov5m', 'yolov5m6', 'yolov5n', 'yolov5n6', 'yolov5s', 'yolov5s6', 'yolov5x', 'yolov5x6']

help

查看模型帮助

import torch
 
torch.hub.help("pytorch/vision", "resnet18")

load

加载模型

import torch
 
# 加载
# github 仓库和模型
model = torch.hub.load("ultralytics/yolov5", "yolov5s")
 
# 本地模型
model = torch.hub.load(
                       "ultralytics/yolov5",
                       "custom",
                       path="path/best.pt"
                       )
 
# 本地版本和模型
model = torch.hub.load(
                       "path/yolov5",
                       "custom",
                       path="path/best.pt",
                       source="local"
                       )
 
 

参数

  • device: GPU
  • _verbose: 静默加载

设置

# 设置
model.conf = 0.25   # 置信度
model.classes = [0] # 分类

使用

res = model("file.jpg")
 
res.print()
res.save()
res.show()  # 预览图片
res.names: 所有 classess
res.xyxy[0]
res.files: 文件名 # 列表形式
 
# pandas 输出
res.pandas().xyxy[0]
res.pandas().xyxy[0].sort_values("xmin") # 排序从左到右
# json 输出
res.pandas().xyxy[0].to_json(orient="records")

pandas 包含:

  • names: 所有 classess#字典
  • files: 文件名#列表
  • xyxy
  • xyxyn: 包含归一化
  • xywhn: 归一化
  • xywhn
  • n
  • t
  • s

加载到设备

model.cpu()  # CPU
model.cuda()  # GPU
model.to(device)  # i.e. device=torch.device(0)

download_url_to_file

import torch
 
# 下载文件到本地
torch.hub.download_url_to_file('http://url/img.jpg', '/save_path/file_name')

参考