Py学习  »  Python

mobilenetv2 tflite不需要python3的输出大小

Romzie • 4 年前 • 386 次点击  

我的mobilenetV2固态硬盘模型遇到问题。 我用详细的步骤转换了它 here tflite_convert 对于相关步骤。

这很好,我可以执行一个推断,但输出大小不是我所期望的。

interpreter.get_output_details()

告诉我我要取回10个检测箱:

[{'shape': array([ 1, 10,  4], dtype=int32), 'index': 252, 'name': 'TFLite_Detection_PostProcess', 'quantization': (0.0, 0), 'dtype': <class 'numpy.float32'>}, {'shape': array([ 1, 10], dtype=int32), 'index': 253, 'name': 'TFLite_Detection_PostProcess:1', 'quantization': (0.0, 0), 'dtype': <class 'numpy.float32'>}, {'shape': array([ 1, 10], dtype=int32), 'index': 254, 'name': 'TFLite_Detection_PostProcess:2', 'quantization': (0.0, 0), 'dtype': <class 'numpy.float32'>}, {'shape': array([1], dtype=int32), 'index': 255, 'name': 'TFLite_Detection_PostProcess:3', 'quantization': (0.0, 0), 'dtype': <class 'numpy.float32'>}]

到目前为止还不错,但是 pipeline.config

post_processing {
    batch_non_max_suppression {
        score_threshold: 9.99999993922529e-09
        iou_threshold: 0.6000000238418579
        max_detections_per_class: 100                                                            
        max_total_detections: 100
    }
    score_converter: SIGMOID
}

所以我希望检测的输出数是100,因为在经典的tensorflow中运行相同的模型会得到100个框。

有办法改变输出张量的大小吗?转换或运行时?

我在下面添加经典tensorflow中的张量输出细节:

[<tf.Tensor 'prefix/detection_boxes:0' shape=<unknown> dtype=float32>, <tf.Tensor 'prefix/detection_scores:0' shape=<unknown> dtype=float32>, <tf.Tensor 'prefix/detection_classes:0' shape=<unknown> dtype=float32>, <tf.Tensor 'prefix/num_detections:0' shape=<unknown> dtype=float32>]

任何关于这件事的线索都将非常感激。

如果已经有人问过类似的问题,请原谅,但我显然没有找到。谢谢。

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/53059
 
386 次点击  
文章 [ 1 ]  |  最新文章 4 年前
Romzie
Reply   •   1 楼
Romzie    5 年前

在重读 export_tflite_ssd_graph.py 脚本中,似乎有一个选项来设置保持的最大检测次数。

把这个设为100解决了我的问题。我感觉不好。

python3 object_detection/export_tflite_ssd_graph.py \                                            
    --pipeline_config_path=$model_dir/pipeline.config \                                          
    --trained_checkpoint_prefix=$model_dir/model.ckpt \                                          
    --output_directory=$output_dir \                                                             
    --add_post_processing_op=true

python3 object_detection/export_tflite_ssd_graph.py \                                            
    --pipeline_config_path=$model_dir/pipeline.config \                                          
    --trained_checkpoint_prefix=$model_dir/model.ckpt \                                          
    --output_directory=$output_dir \                                                             
    --add_post_processing_op=true \                                                              
    --max_detections=100