混合知识路由网络(针对大尺度范围的对象检测)
此代码实现参考了jwyang/faster-rcnn.pytorch。
依赖;
PyTorch = 0.3.1(不支持pytorch 0.4或更高版本)
Torchvision >= 0.2.0
cython
pyyaml
easydict
opencv-python
matplotlib
numpy
scipy
tensorboardX
CUDA 8.0
gcc >= 4.9
编译:
sh make.sh
数据准备:创建data文件夹。下载ADE20K 数据集。
还可以使用数据集VG1000,VG3000,或MSCOCO and PACAL VOC graphs。
训练:要使用ResNet101的预训练模型,下载后放到data/pretrained_model/下。
CUDA_VISIBLE_DEVICES=$GPU_ID python trainval_baseline.py \
--dataset vg --bs $BATCH_SIZE --nw $WORKER_NUMBER \
--log_dir $LOG_DIR --save_dir $WHERE_YOU_WANT
训练HKRM:
CUDA_VISIBLE_DEVICES=$GPU_ID python trainval_HKRM.py \
--dataset vg --bs $BATCH_SIZE --nw $WORKER_NUMBER \
--log_dir $LOG_DIR --save_dir $WHERE_YOU_WANT \
--init --net HKRM --attr_size 256 --rela_size 256 --spat_size 256
测试:
python test_net.py --dataset vg --net HKRM \
--load_dir $YOUR_SAVE_DIR \
--checksession $SESSION --checkepoch $EPOCH --checkpoint $CHECKPOINT
本文暂时没有评论,来添加一个吧(●'◡'●)