下面的plugin中configurePlugin函数仅仅是简单地确认了下输入和输出以及类型。
void MyCustomPluginDynamic::configurePlugin( const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) { // Validate input arguments assert(nbOutputs == 1); assert(nbInputs == 2); assert(mType == inputs[0].desc.type); } clone这玩意儿干嘛的,顾名思义,就是克隆嘛,将这个plugin对象克隆一份给TensorRT的builder、network或者engine。这个成员函数会调用上述说到的第二个构造函数:
MyCustomPlugin(float in_channel, const std::vector<float>& weight, const std::vector<float>& bias);将要克隆的plugin的权重和参数传递给这个构造函数。
IPluginV2DynamicExt* MyCustomPlugin::clone() const { // auto plugin = new MyCustomPlugin{_in_channel, _weight, _bias}; plugin->setPluginNamespace(mPluginNamespace); return plugin; }clone成员函数主要用于传递不变的权重和参数,将plugin复制n多份,从而可以被不同engine或者builder或者network使用。
getSerializationSize返回序列化时需要写多少字节到buffer中。
size_t MyCustomPlugin::getSerializationSize() const { return (serialized_size(_in_channel) + serialized_size(_weight) + serialized_size(_bias) ); } supportsFormatCombinationTensorRT调用此方法以判断pos索引的输入/输出是否支持inOut[pos].format和inOut[pos].type指定的格式/数据类型。
如果插件支持inOut[pos]处的格式/数据类型,则返回true。 如果是否支持取决于其他的输入/输出格式/数据类型,则插件可以使其结果取决于inOut[0..pos-1]中的格式/数据类型,该格式/数据类型将设置为插件支持的值。 这个函数不需要检查inOut[pos + 1..nbInputs + nbOutputs-1],pos的决定必须仅基于inOut[0..pos]。
bool MyCustomPlugin::supportsFormatCombination( int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) { // 假设有一个输入一个输出 assert(0 <= pos && pos < 2); const auto *in = inOut; const auto *out = inOut + nbInputs; switch (pos) { case 0: return in[0].type == DataType::kFLOAT && in[0].format == nvinfer1::TensorFormat::kLINEAR; case 1: return out[0].type == in[0].type && out[0].format == nvinfer1::TensorFormat::kLINEAR; } } serialize把需要用的数据按照顺序序列化到buffer里头。
void MyCustomPlugin::serialize(void *buffer) const { serialize_value(&buffer, _in_channel); serialize_value(&buffer, _weight); serialize_value(&buffer, _bias); } attachToContext如果这个op使用到了一些其他东西,例如cublas handle,可以直接借助TensorRT内部提供的cublas handle:
void MyCustomPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) { mCublas = cublasContext; } MyCustomPluginCreator 插件工厂类总览:
class MyCustomPluginCreator : public BaseCreator { public: MyCustomPluginCreator(); ~MyCustomPluginCreator() override = default; const char* getPluginName() const override; // 不介绍 const char* getPluginVersion() const override; // 不介绍 const PluginFieldCollection* getFieldNames() override; // 不介绍 IPluginV2DynamicExt* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) override; IPluginV2DynamicExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override; private: static PluginFieldCollection mFC; static std::vector<PluginField> mPluginAttributes; std::string mNamespace; }; 构造函数创建一个空的mPluginAttributes初始化mFC。
MyCustomPluginCreator::MyCustomPluginCreator() { mPluginAttributes.emplace_back(PluginField("in_channel", nullptr, PluginFieldType::kFLOAT32, 1)); mPluginAttributes.emplace_back(PluginField("weight", nullptr, PluginFieldType::kFLOAT32, 1)); mPluginAttributes.emplace_back(PluginField("bias", nullptr, PluginFieldType::kFLOAT32, 1)); mFC.nbFields = mPluginAttributes.size(); mFC.fields = mPluginAttributes.data(); } createPlugin这个成员函数作用是通过PluginFieldCollection去创建plugin,将op需要的权重和参数一个一个取出来,然后调用上文提到的第一个构造函数:
MyCustomPlugin(int in_channel, nvinfer1::Weights const& weight, nvinfer1::Weights const& bias);去创建plugin。
MyCustomPlugin示例:
IPluginV2DynamicExt* MyCustomPlugin::createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) { int in_channel; std::vector<float> weight; std::vector<float> bias; const PluginField* fields = fc->fields; for (int i = 0; i < fc->nbFields; ++i) { const char* attrName = fields[i].name; if (!strcmp(attrName, "in_channel")) { ASSERT(fields[i].type == PluginFieldType::kINT32); in_channel= *(static_cast<const int32_t*>(fields[i].data)); } else if (!strcmp(attrName, "weight")) { ASSERT(fields[i].type == PluginFieldType::kFLOAT32); int size = fields[i].length; h_weight.reserve(size); const auto* w = static_cast<const float*>(fields[i].data); for (int j = 0; j < size; j++) { h_weight.push_back(*w); w++; } } else if (!strcmp(attrName, "bias")) { ASSERT(fields[i].type == PluginFieldType::kFLOAT32); int size = fields[i].length; h_bias.reserve(size); const auto* w = static_cast<const float*>(fields[i].data); for (int j = 0; j < size; j++) { h_bias.push_back(*w); w++; } } } Weights weightWeights{DataType::kFLOAT, weight.data(), (int64_t) weight.size()}; Weights biasWeights{DataType::kFLOAT, bias.data(), (int64_t)_bias.size()}; MyCustomPlugin* obj = new MyCustomPlugin(in_channel, weightWeights, biasWeights); obj->setPluginNamespace(mNamespace.c_str()); return obj; } deserializePlugin这个函数会被onnx-tensorrt的一个叫做TRT_PluginV2的转换op调用,这个op会读取onnx模型的data数据将其反序列化到network中。
一些官方插件的注意事项使用官方插件会遇到些小问题。
topk问题官方的topk插件最多支持k<=3840。否则会报:
[TensorRT] ERROR: Parameter check failed at: ../builder/Layers.cpp::TopKLayer::3137, condition: k > 0 && k <= MAX_TOPK_K
相关问题:https://github.com/tensorflow/tensorflow/issues/31671
batchednms问题