pytorch版本PSEnet训练并部署方式(pytorch训练好的模型如何部署)深度揭秘

随心笔谈2年前发布 编辑
169 0
🌐 经济型:买域名、轻量云服务器、用途:游戏 网站等 《腾讯云》特点:特价机便宜 适合初学者用 点我优惠购买
🚀 拓展型:买域名、轻量云服务器、用途:游戏 网站等 《阿里云》特点:中档服务器便宜 域名备案事多 点我优惠购买
🛡️ 稳定型:买域名、轻量云服务器、用途:游戏 网站等 《西部数码》 特点:比上两家略贵但是稳定性超好事也少 点我优惠购买

import torch
import numpy as np
import argparse
import os
import os.path as osp
import sys
import time
import json
from mmcv import Config
import cv2
from torchvision import transforms
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter
def prepare_image(image, target_size):
“””Do image preprocessing before prediction on any data.
:param image: original image
:param target_size: target image size
:return:
preprocessed image
“””
#assert os.path.exists(img), ‘file is not exists’
#img=cv2.imread(img)
img=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# h, w=image.shape[:2]
# scale=long_size / max(h, w)
img=cv2.resize(img, target_size)
# 将图片由(w,h)变为(1,img_channel,h,w)
tensor=transforms.ToTensor()(img)
tensor=tensor.unsqueeze_(0)
tensor=tensor.to(torch.device(“cuda:0″))
return tensor
def report_speed(outputs, speed_meters):
total_time=0
for key in outputs:
if ‘time’ in key:
total_time +=outputs[key]
speed_meters[key].update(outputs[key])
print(‘%s: %.4f’ % (key, speed_meters[key].avg))
speed_meters[‘total_time’].update(total_time)
print(‘FPS: %.1f’ % (1.0 / speed_meters[‘total_time’].avg))
def load_model(cfg):
model=build_model(cfg.model)
model=model.cuda()
model.eval()
checkpoint=”psenet_r50_ic15_1024_finetune/checkpoint_580ep.pth.tar”
if checkpoint is not None:
if os.path.isfile(checkpoint):
print(“Loading model and optimizer from checkpoint ‘{}'”.format(checkpoint))
sys.stdout.flush()
checkpoint=torch.load(checkpoint)
d=dict()
for key, value in checkpoint[‘state_dict’].items():
tmp=key[7:]
d[tmp]=value
model.load_state_dict(d)
else:
print(“No checkpoint found at”)
raise
# fuse conv and bn
model=fuse_module(model)
return model
if __name__==’__main__’:
src_dir=”testimg/”
save_dir=”test_save/”
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cfg=Config.fromfile(“PSENet/config/psenet/psenet_r50_ic15_1024_finetune.py”)
for d in [cfg, cfg.data.test]:
d.update(dict(
report_speed=False
))
if cfg.report_speed:
speed_meters=dict(
backbone_time=AverageMeter(500),
neck_time=AverageMeter(500),
det_head_time=AverageMeter(500),
det_pse_time=AverageMeter(500),
rec_time=AverageMeter(500),
total_time=AverageMeter(500)
)
model=load_model(cfg)
model.eval()
count=0
for img_name in os.listdir(src_dir):
img=cv2.imread(src_dir + img_name)
tensor=prepare_image(img, target_size=(1376, 1024))
data=dict()
img_metas=dict()
data[‘imgs’]=tensor
img_metas[‘org_img_size’]=torch.tensor([[img.shape[0], img.shape[1]]])
img_metas[‘img_size’]=torch.tensor([[1376, 1024]])
data[‘img_metas’]=img_metas
data.update(dict(
cfg=cfg
))
with torch.no_grad():
outputs=model(**data)
if cfg.report_speed:
report_speed(outputs, speed_meters)
for bboxes in outputs[‘bboxes’]:
x1=bboxes[0]
y1=bboxes[1]
x2=bboxes[4]
y2=bboxes[5]
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 3)
count=count + 1
cv2.imwrite(save_dir + img_name, img)
print(“img test:”, count)

© 版权声明

相关文章