if( !_update || !data ) //准备好训练用的数据,并确定只包含正负样本两类,分配保存弱分类器的存储空间
{
clear();
data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx, //准备训练数据,但是这里怎么还要用到boost的参数_params呢?
_sample_idx, _var_type, _missing_mask, _params, true, true );
if( data->get_num_classes() != 2 )
CV_ERROR( CV_StsNotImplemented,
"Boosted trees can only be used for 2-class classification." );
CV_CALL( storage = cvCreateMemStorage() );
weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage ); //这是CvBoost类中保存弱分类器的向量?
storage = 0;
}
else
{
data->set_data( _train_data, _tflag, _responses, _var_idx,
_sample_idx, _var_type, _missing_mask, _params, true, true, true );
}
if ( (_params.boost_type == LOGIT) || (_params.boost_type == GENTLE) )
data->do_responses_copy();
update_weights( 0 ); //将各样本权重平均分配
for( i = 0; i < params.weak_count; i++ ) //训练weak_count个弱分类器
{
CvBoostTree* tree = new CvBoostTree;
if( !tree->train( data, subsample_mask, this ) ) //主要的训练函数,subsample_mask似乎是一个输出参数,查了其初始值是值为0的指针,记录弱分类器正确分类的样本,也许初始值是全0的向量?
//第三个参数是训练出的弱分类器要连接的‘宿主’分类器
{
delete tree;
break;
}
//cvCheckArr( get_weak_response());
cvSeqPush( weak, &tree );
update_weights( tree ); //这里是不是根据训练出的弱分类器的分类情况调整各样本的权重?
trim_weights();
if( cvCountNonZero(subsample_mask) == 0 )
break;
}
if(weak->total > 0)//释放存储空间
{
get_active_vars(); // recompute active_vars* maps and condensed_idx's in the splits.
data->is_classifier = true;
data->free_train_data();
ok = true;
}
else
clear();
__END__;
return ok;
}
//CvBoostTree::train()函数定义如下,它用来训练单个弱分类器,它进一步调用了CvDTree::do_train()函数:
CvBoostTree::train( CvDTreeTrainData* _train_data,
const CvMat* _subsample_idx, CvBoost* _ensemble )
{
clear();
ensemble = _ensemble;
data = _train_data;
data->shared = true;
return do_train( _subsample_idx );
}
//CvDTree::do_train()函数定义如下(在文件tree.cpp中,头文件为ml.hpp):
bool CvDTree::do_train( const CvMat* _subsample_idx )
{
bool result = false;
CV_FUNCNAME( "CvDTree::do_train" );
__BEGIN__;
root = data->subsample_data( _subsample_idx ); //明显是选择参与训练的样本
CV_CALL( try_split_node(root));
if( root->split )
{
CV_Assert( root->left );
CV_Assert( root->right );
if( data->params.cv_folds > 0 )
CV_CALL( prune_cv() );
if( !data->shared )
data->free_train_data();
result = true;
}
__END__;
return result;
}