35#ifndef VIGRA_RF_ALGORITHM_HXX
36#define VIGRA_RF_ALGORITHM_HXX
59 template<
class OrigMultiArray,
62 void choose(OrigMultiArray
const & in,
71 for(Iter iter = b; iter != e; ++iter, ++ii)
101 template<
class Feature_t,
class Response_t>
103 Response_t
const & response)
117class VariableSelectionResult
122 VariableSelectionResult()
126 typedef std::vector<int> FeatureList_t;
127 typedef std::vector<double> ErrorList_t;
128 typedef FeatureList_t::iterator Pivot_t;
154 template<
class FeatureT,
157 class ErrorRateCallBack>
158 bool init(FeatureT
const & all_features,
159 ResponseT
const & response,
162 ErrorRateCallBack errorcallback)
164 bool ret_ = init(all_features, response, errorcallback);
167 vigra_precondition(std::distance(b, e) ==
static_cast<std::ptrdiff_t
>(
selected.size()),
168 "Number of features in ranking != number of features matrix");
173 template<
class FeatureT,
176 bool init(FeatureT
const & all_features,
177 ResponseT
const & response,
182 return init(all_features, response, b, e, ecallback);
186 template<
class FeatureT,
188 bool init(FeatureT
const & all_features,
189 ResponseT
const & response)
191 return init(all_features, response, RFErrorCallback());
203 template<
class FeatureT,
205 class ErrorRateCallBack>
206 bool init(FeatureT
const & all_features,
207 ResponseT
const & response,
208 ErrorRateCallBack errorcallback)
216 selected.resize(all_features.shape(1), 0);
217 for(
unsigned int ii = 0; ii <
selected.size(); ++ii)
219 errors.resize(all_features.shape(1), -1);
220 errors.back() = errorcallback(all_features, response);
224 std::map<typename ResponseT::value_type, int> res_map;
225 std::vector<int> cts;
227 for(
int ii = 0; ii < response.shape(0); ++ii)
229 if(res_map.find(response(ii, 0)) == res_map.end())
231 res_map[response(ii, 0)] = counter;
235 cts[res_map[response(ii,0)]] +=1;
237 no_features = double(*(std::max_element(cts.begin(),
239 / double(response.shape(0));
294template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
296 ResponseT
const & response,
298 ErrorRateCallBack errorcallback)
300 VariableSelectionResult::FeatureList_t & selected = result.selected;
301 VariableSelectionResult::ErrorList_t & errors = result.errors;
302 VariableSelectionResult::Pivot_t & pivot = result.pivot;
303 int featureCount = features.shape(1);
305 if(!result.init(features, response, errorcallback))
309 vigra_precondition(
static_cast<int>(selected.size()) == featureCount,
310 "forward_selection(): Number of features in Feature "
311 "matrix and number of features in previously used "
312 "result struct mismatch!");
316 int not_selected_size = std::distance(pivot, selected.end());
317 while(not_selected_size > 1)
319 std::vector<double> current_errors;
320 VariableSelectionResult::Pivot_t next = pivot;
321 for(
int ii = 0; ii < not_selected_size; ++ii, ++next)
323 std::swap(*pivot, *next);
325 detail::choose( features,
329 double error = errorcallback(cur_feats, response);
330 current_errors.push_back(error);
331 std::swap(*pivot, *next);
333 int pos = std::distance(current_errors.begin(),
334 std::min_element(current_errors.begin(),
335 current_errors.end()));
337 std::advance(next, pos);
338 std::swap(*pivot, *next);
339 errors[std::distance(selected.begin(), pivot)] = current_errors[pos];
341 std::copy(current_errors.begin(), current_errors.end(), std::ostream_iterator<double>(std::cerr,
", "));
342 std::cerr <<
"Choosing " << *pivot <<
" at error of " << current_errors[pos] << std::endl;
345 not_selected_size = std::distance(pivot, selected.end());
348template<
class FeatureT,
class ResponseT>
350 ResponseT
const & response,
351 VariableSelectionResult & result)
396template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
398 ResponseT
const & response,
400 ErrorRateCallBack errorcallback)
402 int featureCount = features.shape(1);
403 VariableSelectionResult::FeatureList_t & selected = result.selected;
404 VariableSelectionResult::ErrorList_t & errors = result.errors;
405 VariableSelectionResult::Pivot_t & pivot = result.pivot;
408 if(!result.init(features, response, errorcallback))
412 vigra_precondition(
static_cast<int>(selected.size()) == featureCount,
413 "backward_elimination(): Number of features in Feature "
414 "matrix and number of features in previously used "
415 "result struct mismatch!");
417 pivot = selected.end() - 1;
419 int selected_size = std::distance(selected.begin(), pivot);
420 while(selected_size > 1)
422 VariableSelectionResult::Pivot_t next = selected.begin();
423 std::vector<double> current_errors;
424 for(
int ii = 0; ii < selected_size; ++ii, ++next)
426 std::swap(*pivot, *next);
428 detail::choose( features,
432 double error = errorcallback(cur_feats, response);
433 current_errors.push_back(error);
434 std::swap(*pivot, *next);
436 int pos = std::distance(current_errors.begin(),
437 std::min_element(current_errors.begin(),
438 current_errors.end()));
439 next = selected.begin();
440 std::advance(next, pos);
441 std::swap(*pivot, *next);
443 errors[std::distance(selected.begin(), pivot)-1] = current_errors[pos];
444 selected_size = std::distance(selected.begin(), pivot);
446 std::copy(current_errors.begin(), current_errors.end(), std::ostream_iterator<double>(std::cerr,
", "));
447 std::cerr <<
"Eliminating " << *pivot <<
" at error of " << current_errors[pos] << std::endl;
453template<
class FeatureT,
class ResponseT>
455 ResponseT
const & response,
456 VariableSelectionResult & result)
493template<
class FeatureT,
class ResponseT,
class ErrorRateCallBack>
495 ResponseT
const & response,
497 ErrorRateCallBack errorcallback)
499 VariableSelectionResult::FeatureList_t & selected = result.selected;
500 VariableSelectionResult::ErrorList_t & errors = result.errors;
501 VariableSelectionResult::Pivot_t & iter = result.pivot;
502 int featureCount = features.shape(1);
504 if(!result.init(features, response, errorcallback))
508 vigra_precondition(
static_cast<int>(selected.size()) == featureCount,
509 "forward_selection(): Number of features in Feature "
510 "matrix and number of features in previously used "
511 "result struct mismatch!");
515 for(; iter != selected.end(); ++iter)
519 detail::choose( features,
523 double error = errorcallback(cur_feats, response);
524 errors[std::distance(selected.begin(), iter)] = error;
526 std::copy(selected.begin(), iter+1, std::ostream_iterator<int>(std::cerr,
", "));
527 std::cerr <<
"Choosing " << *(iter+1) <<
" at error of " << error << std::endl;
533template<
class FeatureT,
class ResponseT>
535 ResponseT
const & response,
536 VariableSelectionResult & result)
543enum ClusterLeafTypes{c_Leaf = 95, c_Node = 99};
559 ClusterNode(
int nCol,
560 BT::T_Container_type & topology,
561 BT::P_Container_type & split_param)
562 : BT(nCol + 5, 5,topology, split_param)
572 ClusterNode( BT::T_Container_type
const & topology,
573 BT::P_Container_type
const & split_param,
575 :
NodeBase(5 , 5,topology, split_param, n)
581 ClusterNode( BT & node_)
586 BT::parameter_size_ += 0;
592 void set_index(
int in)
618 HC_Entry(
int p,
int l,
int a,
bool in)
619 : parent(p), level(l), addr(a), infm(in)
648 double dist_func(
double a,
double b)
650 return std::min(a, b);
656 template<
class Functor>
660 std::vector<int> stack;
661 stack.push_back(begin_addr);
662 while(!stack.empty())
664 ClusterNode node(topology_, parameters_, stack.back());
668 if(node.columns_size() != 1)
670 stack.push_back(node.child(0));
671 stack.push_back(node.child(1));
679 template<
class Functor>
683 std::queue<HC_Entry> queue;
688 queue.push(
HC_Entry(parent,level,begin_addr, infm));
689 while(!queue.empty())
691 level = queue.front().level;
692 parent = queue.front().parent;
693 addr = queue.front().addr;
694 infm = queue.front().infm;
695 ClusterNode node(topology_, parameters_, queue.front().addr);
699 parnt = ClusterNode(topology_, parameters_, parent);
702 bool istrue = tester(node, level, parnt, infm);
703 if(node.columns_size() != 1)
705 queue.push(
HC_Entry(addr, level +1,node.child(0),istrue));
706 queue.push(
HC_Entry(addr, level +1,node.child(1),istrue));
713 void save(std::string file, std::string prefix)
718 Shp(topology_.
size(),1),
722 Shp(parameters_.
size(), 1),
723 parameters_.
data()));
733 template<
class T,
class C>
737 std::vector<std::pair<int, int> > addr;
739 for(
int ii = 0; ii < distance.
shape(0); ++ii)
741 addr.push_back(std::make_pair(topology_.size(), ii));
742 ClusterNode leaf(1, topology_, parameters_);
743 leaf.set_index(index);
745 leaf.columns_begin()[0] = ii;
748 while(addr.size() != 1)
753 double min_dist = dist((addr.begin()+ii_min)->second,
754 (addr.begin()+jj_min)->second);
755 for(
unsigned int ii = 0; ii < addr.size(); ++ii)
757 for(
unsigned int jj = ii+1; jj < addr.size(); ++jj)
759 if( dist((addr.begin()+ii_min)->second,
760 (addr.begin()+jj_min)->second)
761 > dist((addr.begin()+ii)->second,
762 (addr.begin()+jj)->second))
764 min_dist = dist((addr.begin()+ii)->second,
765 (addr.begin()+jj)->second);
777 ClusterNode firstChild(topology_,
779 (addr.begin() +ii_min)->first);
780 ClusterNode secondChild(topology_,
782 (addr.begin() +jj_min)->first);
783 col_size = firstChild.columns_size() + secondChild.columns_size();
785 int cur_addr = topology_.size();
786 begin_addr = cur_addr;
788 ClusterNode parent(col_size,
791 ClusterNode firstChild(topology_,
793 (addr.begin() +ii_min)->first);
794 ClusterNode secondChild(topology_,
796 (addr.begin() +jj_min)->first);
797 parent.parameters_begin()[0] = min_dist;
798 parent.set_index(index);
800 std::merge(firstChild.columns_begin(), firstChild.columns_end(),
801 secondChild.columns_begin(),secondChild.columns_end(),
802 parent.columns_begin());
806 if(*parent.columns_begin() == *firstChild.columns_begin())
808 parent.child(0) = (addr.begin()+ii_min)->first;
809 parent.child(1) = (addr.begin()+jj_min)->first;
810 (addr.begin()+ii_min)->first = cur_addr;
812 to_desc = (addr.begin()+jj_min)->second;
813 addr.erase(addr.begin()+jj_min);
817 parent.child(1) = (addr.begin()+ii_min)->first;
818 parent.child(0) = (addr.begin()+jj_min)->first;
819 (addr.begin()+jj_min)->first = cur_addr;
821 to_desc = (addr.begin()+ii_min)->second;
822 addr.erase(addr.begin()+ii_min);
826 for(
int jj = 0 ; jj < static_cast<int>(addr.size()); ++jj)
830 double bla = dist_func(
831 dist(to_desc, (addr.begin()+jj)->second),
832 dist((addr.begin()+ii_keep)->second,
833 (addr.begin()+jj)->second));
835 dist((addr.begin()+ii_keep)->second,
836 (addr.begin()+jj)->second) = bla;
837 dist((addr.begin()+jj)->second,
838 (addr.begin()+ii_keep)->second) = bla;
859 bool operator()(Node& node)
872template<
class Iter,
class DT>
887 template<
class Feat_T,
class Label_T>
888 PermuteCluster(Iter a,
890 Feat_T
const & feats,
891 Label_T
const & labls,
896 :tmp_mem_(_spl(a, b).size(), feats.shape(1)),
899 feats_(_spl(a,b).size(), feats.shape(1)),
900 labels_(_spl(a,b).size(),1),
906 copy_splice(_spl(a,b),
907 _spl(feats.shape(1)),
910 copy_splice(_spl(a,b),
911 _spl(labls.shape(1)),
917 bool operator()(Node& node)
921 int class_count = perm_imp.shape(1) - 1;
923 for(
int kk = 0; kk < nPerm; ++kk)
926 for(
int ii = 0; ii <
rowCount(feats_); ++ii)
929 for(
int jj = 0; jj < node.columns_size(); ++jj)
931 if(node.columns_begin()[jj] != feats_.shape(1))
932 tmp_mem_(ii, node.columns_begin()[jj])
933 = tmp_mem_(index, node.columns_begin()[jj]);
937 for(
int ii = 0; ii <
rowCount(tmp_mem_); ++ii)
944 ++perm_imp(index,labels_(ii, 0));
946 ++perm_imp(index, class_count);
950 double node_status = perm_imp(index, class_count);
951 node_status /= nPerm;
952 node_status -= orig_imp(0, class_count);
954 node_status /= oob_size;
955 node.status() += node_status;
964class GetClusterVariables
976 void save(std::string file, std::string prefix)
984 bool operator()(Node& node)
986 for(
int ii = 0; ii < node.columns_size(); ++ii)
987 variables(index, ii) = node.columns_begin()[ii];
1001 bool operator()(Nde & cur,
int , Nde parent,
bool )
1004 cur.status() = std::min(parent.status(), cur.status());
1031 std::ofstream graphviz;
1036 std::string
const gz)
1037 :features_(features), labels_(labels),
1038 graphviz(gz.c_str(), std::ios::out)
1040 graphviz <<
"digraph G\n{\n node [shape=\"record\"]";
1044 graphviz <<
"\n}\n";
1049 bool operator()(Nde & cur,
int , Nde parent,
bool )
1051 graphviz <<
"node" << cur.index() <<
" [style=\"filled\"][label = \" #Feats: "<< cur.columns_size() <<
"\\n";
1052 graphviz <<
" status: " << cur.status() <<
"\\n";
1053 for(
int kk = 0; kk < cur.columns_size(); ++kk)
1055 graphviz << cur.columns_begin()[kk] <<
" ";
1059 graphviz <<
"\"] [color = \"" <<cur.status() <<
" 1.000 1.000\"];\n";
1061 graphviz <<
"\"node" << parent.index() <<
"\" -> \"node" << cur.index() <<
"\";\n";
1081 int repetition_count_;
1087 void save(std::string filename, std::string prefix)
1089 std::string prefix1 =
"cluster_importance_" + prefix;
1093 prefix1 =
"vars_" + prefix;
1101 : repetition_count_(rep_cnt), clustering(clst)
1107 template<
class RF,
class PR>
1110 Int32 const class_count = rf.ext_param_.class_count_;
1111 Int32 const column_count = rf.ext_param_.column_count_+1;
1122 clustering.iterate(gcv);
1132 template<
class RF,
class PR,
class SM,
class ST>
1136 Int32 column_count = rf.ext_param_.column_count_ +1;
1137 Int32 class_count = rf.ext_param_.class_count_;
1141 typename PR::Feature_t & features
1142 =
const_cast<typename PR::Feature_t &
>(pr.features());
1146 ArrayVector<Int32>::iterator
1149 if(rf.ext_param_.actual_msample_ < pr.features().shape(0)- 10000)
1153 for(
int ii = 0; ii < pr.features().shape(0); ++ii)
1154 indices.push_back(ii); ;
1155 std::random_device rd;
1156 std::mt19937 g(rd());
1157 std::shuffle(indices.
begin(), indices.
end(), g);
1158 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1160 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 3000)
1162 oob_indices.push_back(indices[ii]);
1163 ++cts[pr.response()(indices[ii], 0)];
1169 for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1170 if(!sm.is_used()[ii])
1171 oob_indices.push_back(ii);
1181 oob_right(Shp_t(1, class_count + 1));
1184 for(iter = oob_indices.
begin();
1185 iter != oob_indices.
end();
1189 .predictLabel(
rowVector(features, *iter))
1190 == pr.response()(*iter, 0))
1193 ++oob_right[pr.response()(*iter,0)];
1195 ++oob_right[class_count];
1200 perm_oob_right (Shp_t(2* column_count-1, class_count + 1));
1203 pc(oob_indices.
begin(), oob_indices.
end(),
1210 clustering.iterate(pc);
1212 perm_oob_right /= repetition_count_;
1213 for(
int ii = 0; ii <
rowCount(perm_oob_right); ++ii)
1214 rowVector(perm_oob_right, ii) -= oob_right;
1216 perm_oob_right *= -1;
1217 perm_oob_right /= oob_indices.
size();
1226 template<
class RF,
class PR,
class SM,
class ST>
1234 template<
class RF,
class PR>
1238 clustering.iterate(nrm);
1274template<
class FeatureT,
class ResponseT>
1276 ResponseT
const & response,
1283 if(features.shape(0) > 40000)
1290 RF.
learn(features, response,
1291 create_visitor(missc, progress));
1306 create_visitor(progress, ci));
1319template<
class FeatureT,
class ResponseT>
1321 ResponseT
const & response,
1322 HClustering & linkage)
1329template<
class Array1,
class Vector1>
1330void get_ranking(Array1
const & in, Vector1 & out)
1332 std::map<double, int> mymap;
1333 for(
int ii = 0; ii < in.size(); ++ii)
1335 for(std::map<double, int>::reverse_iterator iter = mymap.rbegin(); iter!= mymap.rend(); ++iter)
1337 out.push_back(iter->second);
const_iterator begin() const
Definition array_vector.hxx:223
const_pointer data() const
Definition array_vector.hxx:209
size_type size() const
Definition array_vector.hxx:358
const_iterator end() const
Definition array_vector.hxx:237
Definition array_vector.hxx:514
TinyVector< MultiArrayIndex, N > type
Definition multi_shape.hxx:272
Base class for, and view to, MultiArray.
Definition multi_array.hxx:705
const difference_type & shape() const
Definition multi_array.hxx:1650
Main MultiArray class containing the memory management.
Definition multi_array.hxx:2479
Topology_type column_data() const
Definition rf_nodeproxy.hxx:159
INT & typeID()
Definition rf_nodeproxy.hxx:136
NodeBase()
Definition rf_nodeproxy.hxx:237
Parameter_type parameters_begin() const
Definition rf_nodeproxy.hxx:207
Options object for the random forest.
Definition rf_common.hxx:171
RandomForestOptions & use_stratification(RF_OptionTag in)
specify stratification strategy
Definition rf_common.hxx:374
RandomForestOptions & samples_per_tree(double in)
specify the fraction of the total number of samples used per tree for learning.
Definition rf_common.hxx:411
RandomForestOptions & tree_count(unsigned int in)
Definition rf_common.hxx:500
Random forest version 2 (see also RandomForest for version 3)
Definition random_forest.hxx:148
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random)
learn on data with custom config and random number generator
Definition random_forest.hxx:941
UInt32 uniformInt() const
Definition random.hxx:471
Definition matrix.hxx:125
Definition rf_algorithm.hxx:1069
MultiArray< 2, double > cluster_importance_
Definition rf_algorithm.hxx:1077
MultiArray< 2, int > variables
Definition rf_algorithm.hxx:1074
void visit_at_end(RF &rf, PR &)
Definition rf_algorithm.hxx:1235
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition rf_algorithm.hxx:1227
MultiArray< 2, double > cluster_stdev_
Definition rf_algorithm.hxx:1080
void after_tree_ip_impl(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition rf_algorithm.hxx:1133
void visit_at_beginning(RF const &rf, PR const &)
Definition rf_algorithm.hxx:1108
Definition rf_algorithm.hxx:998
Definition rf_algorithm.hxx:965
MultiArrayView< 2, int > variables
Definition rf_algorithm.hxx:970
Definition rf_algorithm.hxx:640
void iterate(Functor &tester)
Definition rf_algorithm.hxx:657
void cluster(MultiArrayView< 2, T, C > distance)
Definition rf_algorithm.hxx:734
void breadth_first_traversal(Functor &tester)
Definition rf_algorithm.hxx:680
Definition rf_algorithm.hxx:849
NormalizeStatus(double m)
Definition rf_algorithm.hxx:855
Definition rf_algorithm.hxx:874
Definition rf_algorithm.hxx:85
double operator()(Feature_t const &features, Response_t const &response)
Definition rf_algorithm.hxx:102
RFErrorCallback(RandomForestOptions opt=RandomForestOptions())
Definition rf_algorithm.hxx:94
Definition rf_algorithm.hxx:118
double no_features
Definition rf_algorithm.hxx:152
ErrorList_t errors
Definition rf_algorithm.hxx:147
FeatureList_t selected
Definition rf_algorithm.hxx:134
bool init(FeatureT const &all_features, ResponseT const &response, ErrorRateCallBack errorcallback)
Definition rf_algorithm.hxx:206
Definition rf_visitors.hxx:1499
MultiArray< 2, double > distance
Definition rf_visitors.hxx:1527
Definition rf_visitors.hxx:865
double oob_breiman
Definition rf_visitors.hxx:875
Definition rf_visitors.hxx:1463
Definition rf_visitors.hxx:103
Definition rf_algorithm.hxx:52
void backward_elimination(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition rf_algorithm.hxx:397
void rank_selection(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition rf_algorithm.hxx:494
void forward_selection(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition rf_algorithm.hxx:295
void cluster_permutation_importance(FeatureT const &features, ResponseT const &response, HClustering &linkage, MultiArray< 2, double > &distance)
Definition rf_algorithm.hxx:1275
detail::VisitorNode< A > create_visitor(A &a)
Definition rf_visitors.hxx:345
void writeHDF5(...)
Store array data in an HDF5 file.
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition matrix.hxx:684
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition matrix.hxx:697
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition matrix.hxx:671
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition matrix.hxx:727
RandomNumberGenerator< detail::RandomState< detail::MT19937 > > RandomMT19937
Definition random.hxx:630
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition sized_int.hxx:175
Definition metaprogramming.hxx:123
Definition rf_algorithm.hxx:613