OpenCV的softcascade代码解读(2)

if( !_update || !data ) //准备好训练用的数据,并确定只包含正负样本两类,分配保存弱分类器的存储空间
    {
        clear();
        data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,  //准备训练数据,但是这里怎么还要用到boost的参数_params呢?
            _sample_idx, _var_type, _missing_mask, _params, true, true );

if( data->get_num_classes() != 2 )
            CV_ERROR( CV_StsNotImplemented,
            "Boosted trees can only be used for 2-class classification." );
        CV_CALL( storage = cvCreateMemStorage() );
        weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage ); //这是CvBoost类中保存弱分类器的向量?
        storage = 0;
    }
    else
    {
        data->set_data( _train_data, _tflag, _responses, _var_idx,
            _sample_idx, _var_type, _missing_mask, _params, true, true, true );
    }

if ( (_params.boost_type == LOGIT) || (_params.boost_type == GENTLE) )
        data->do_responses_copy();

update_weights( 0 ); //将各样本权重平均分配

for( i = 0; i < params.weak_count; i++ ) //训练weak_count个弱分类器
    {
        CvBoostTree* tree = new CvBoostTree;
        if( !tree->train( data, subsample_mask, this ) ) //主要的训练函数,subsample_mask似乎是一个输出参数,查了其初始值是值为0的指针,记录弱分类器正确分类的样本,也许初始值是全0的向量?
               //第三个参数是训练出的弱分类器要连接的‘宿主’分类器
        {
            delete tree;
            break;
        }
        //cvCheckArr( get_weak_response());
        cvSeqPush( weak, &tree );
        update_weights( tree ); //这里是不是根据训练出的弱分类器的分类情况调整各样本的权重?
        trim_weights();
        if( cvCountNonZero(subsample_mask) == 0 )
            break;
    }
 
    if(weak->total > 0)//释放存储空间
    {
        get_active_vars(); // recompute active_vars* maps and condensed_idx's in the splits.
        data->is_classifier = true;
        data->free_train_data();
        ok = true;
    }
    else
        clear();

__END__;

return ok;
}

//CvBoostTree::train()函数定义如下,它用来训练单个弱分类器,它进一步调用了CvDTree::do_train()函数:
CvBoostTree::train( CvDTreeTrainData* _train_data,
                    const CvMat* _subsample_idx, CvBoost* _ensemble )
{
    clear();
    ensemble = _ensemble;
    data = _train_data;
    data->shared = true;
    return do_train( _subsample_idx );
}
//CvDTree::do_train()函数定义如下(在文件tree.cpp中,头文件为ml.hpp):

bool CvDTree::do_train( const CvMat* _subsample_idx )
{
    bool result = false;

CV_FUNCNAME( "CvDTree::do_train" );

__BEGIN__;

root = data->subsample_data( _subsample_idx ); //明显是选择参与训练的样本

CV_CALL( try_split_node(root));

if( root->split )
    {
        CV_Assert( root->left );
        CV_Assert( root->right );

if( data->params.cv_folds > 0 )
            CV_CALL( prune_cv() );

if( !data->shared )
            data->free_train_data();

result = true;
    }

__END__;

return result;
}

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:http://www.heiqu.com/b28fd9f04c615ef47ca5ae9f988b73c3.html