[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_algorithm.hxx
1/************************************************************************/
2/* */
3/* Copyright 2008-2009 by Rahul Nair */
4/* */
5/* This file is part of the VIGRA computer vision library. */
6/* The VIGRA Website is */
7/* http://hci.iwr.uni-heidelberg.de/vigra/ */
8/* Please direct questions, bug reports, and contributions to */
9/* ullrich.koethe@iwr.uni-heidelberg.de or */
10/* vigra@informatik.uni-hamburg.de */
11/* */
12/* Permission is hereby granted, free of charge, to any person */
13/* obtaining a copy of this software and associated documentation */
14/* files (the "Software"), to deal in the Software without */
15/* restriction, including without limitation the rights to use, */
16/* copy, modify, merge, publish, distribute, sublicense, and/or */
17/* sell copies of the Software, and to permit persons to whom the */
18/* Software is furnished to do so, subject to the following */
19/* conditions: */
20/* */
21/* The above copyright notice and this permission notice shall be */
22/* included in all copies or substantial portions of the */
23/* Software. */
24/* */
25/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27/* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28/* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29/* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30/* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31/* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32/* OTHER DEALINGS IN THE SOFTWARE. */
33/* */
34/************************************************************************/
35#ifndef VIGRA_RF_ALGORITHM_HXX
36#define VIGRA_RF_ALGORITHM_HXX
37#include <vector>
38#include <random>
39#include "splices.hxx"
40#include <queue>
41#include <fstream>
42namespace vigra
43{
44
45namespace rf
46{
47/** This namespace contains all algorithms developed for feature
48 * selection
49 *
50 */
51namespace algorithms
52{
53
54namespace detail
55{
56 /** create a MultiArray containing only columns supplied between iterators
57 b and e
58 */
59 template<class OrigMultiArray,
60 class Iter,
61 class DestMultiArray>
62 void choose(OrigMultiArray const & in,
63 Iter const & b,
64 Iter const & e,
65 DestMultiArray & out)
66 {
67 int columnCount = std::distance(b, e);
68 int rowCount = in.shape(0);
70 int ii = 0;
71 for(Iter iter = b; iter != e; ++iter, ++ii)
72 {
73 columnVector(out, ii) = columnVector(in, *iter);
74 }
75 }
76}
77
78
79
80/** Standard random forest Errorrate callback functor
81 *
82 * returns the random forest error estimate when invoked.
83 */
85{
86 RandomForestOptions options;
87
88 public:
89 /** Default constructor
90 *
91 * optionally supply options to the random forest classifier
92 * \sa RandomForestOptions
93 */
97
98 /** returns the RF OOB error estimate given features and
99 * labels
100 */
101 template<class Feature_t, class Response_t>
102 double operator() (Feature_t const & features,
103 Response_t const & response)
104 {
105 RandomForest<> rf(options);
107 rf.learn(features,
108 response,
110 return oob.oob_breiman;
111 }
112};
113
114
115/** Structure to hold Variable Selection results
116 */
117class VariableSelectionResult
118{
119 bool initialized;
120
121 public:
122 VariableSelectionResult()
123 : initialized(false)
124 {}
125
126 typedef std::vector<int> FeatureList_t;
127 typedef std::vector<double> ErrorList_t;
128 typedef FeatureList_t::iterator Pivot_t;
129
130 Pivot_t pivot;
131
132 /** list of features.
133 */
134 FeatureList_t selected;
135
136 /** vector of size (number of features)
137 *
138 * the i-th entry encodes the error rate obtained
139 * while using features [0 - i](including i)
140 *
141 * if the i-th entry is -1 then no error rate was obtained
142 * this may happen if more than one feature is added to the
143 * selected list in one step of the algorithm.
144 *
145 * during initialisation error[m+n-1] is always filled
146 */
147 ErrorList_t errors;
148
149
150 /** errorrate using no features
151 */
153
154 template<class FeatureT,
155 class ResponseT,
156 class Iter,
157 class ErrorRateCallBack>
158 bool init(FeatureT const & all_features,
159 ResponseT const & response,
160 Iter b,
161 Iter e,
162 ErrorRateCallBack errorcallback)
163 {
164 bool ret_ = init(all_features, response, errorcallback);
165 if(!ret_)
166 return false;
167 vigra_precondition(std::distance(b, e) == static_cast<std::ptrdiff_t>(selected.size()),
168 "Number of features in ranking != number of features matrix");
169 std::copy(b, e, selected.begin());
170 return true;
171 }
172
173 template<class FeatureT,
174 class ResponseT,
175 class Iter>
176 bool init(FeatureT const & all_features,
177 ResponseT const & response,
178 Iter b,
179 Iter e)
180 {
181 RFErrorCallback ecallback;
182 return init(all_features, response, b, e, ecallback);
183 }
184
185
186 template<class FeatureT,
187 class ResponseT>
188 bool init(FeatureT const & all_features,
189 ResponseT const & response)
190 {
191 return init(all_features, response, RFErrorCallback());
192 }
193 /**initialization routine. Will be called only once in the lifetime
194 * of a VariableSelectionResult. Subsequent calls will not reinitialize
195 * member variables.
196 *
197 * This is intended, to allow continuing variable selection at a point
198 * stopped in an earlier iteration.
199 *
200 * returns true if initialization was successful and false if
201 * the object was already initialized before.
202 */
203 template<class FeatureT,
204 class ResponseT,
205 class ErrorRateCallBack>
206 bool init(FeatureT const & all_features,
207 ResponseT const & response,
208 ErrorRateCallBack errorcallback)
209 {
210 if(initialized)
211 {
212 return false;
213 }
214 initialized = true;
215 // calculate error with all features
216 selected.resize(all_features.shape(1), 0);
217 for(unsigned int ii = 0; ii < selected.size(); ++ii)
218 selected[ii] = ii;
219 errors.resize(all_features.shape(1), -1);
220 errors.back() = errorcallback(all_features, response);
221
222 // calculate error rate if no features are chosen
223 // corresponds to max(prior probability) of the classes
224 std::map<typename ResponseT::value_type, int> res_map;
225 std::vector<int> cts;
226 int counter = 0;
227 for(int ii = 0; ii < response.shape(0); ++ii)
228 {
229 if(res_map.find(response(ii, 0)) == res_map.end())
230 {
231 res_map[response(ii, 0)] = counter;
232 ++counter;
233 cts.push_back(0);
234 }
235 cts[res_map[response(ii,0)]] +=1;
236 }
237 no_features = double(*(std::max_element(cts.begin(),
238 cts.end())))
239 / double(response.shape(0));
240
241 /*init not_selected vector;
242 not_selected.resize(all_features.shape(1), 0);
243 for(int ii = 0; ii < not_selected.size(); ++ii)
244 {
245 not_selected[ii] = ii;
246 }
247 initialized = true;
248 */
249 pivot = selected.begin();
250 return true;
251 }
252};
253
254
255
256/** Perform forward selection
257 *
258 * \param features IN: n x p matrix containing n instances with p attributes/features
259 * used in the variable selection algorithm
260 * \param response IN: n x 1 matrix containing the corresponding response
261 * \param result IN/OUT: VariableSelectionResult struct which will contain the results
262 * of the algorithm.
263 * Features between result.selected.begin() and result.pivot will
264 * be left untouched.
265 * \sa VariableSelectionResult
266 * \param errorcallback
267 * IN, OPTIONAL:
268 * Functor that returns the error rate given a set of
269 * features and labels. Default is the RandomForest OOB Error.
270 *
271 * Forward selection subsequently chooses the next feature that decreases the Error rate most.
272 *
273 * usage:
274 * \code
275 * MultiArray<2, double> features = createSomeFeatures();
276 * MultiArray<2, int> labels = createCorrespondingLabels();
277 * VariableSelectionResult result;
278 * forward_selection(features, labels, result);
279 * \endcode
280 * To use forward selection but ensure that a specific feature e.g. feature 5 is always
281 * included one would do the following
282 *
283 * \code
284 * VariableSelectionResult result;
285 * result.init(features, labels);
286 * std::swap(result.selected[0], result.selected[5]);
287 * result.setPivot(1);
288 * forward_selection(features, labels, result);
289 * \endcode
290 *
291 * \sa VariableSelectionResult
292 *
293 */
294template<class FeatureT, class ResponseT, class ErrorRateCallBack>
295void forward_selection(FeatureT const & features,
296 ResponseT const & response,
298 ErrorRateCallBack errorcallback)
299{
300 VariableSelectionResult::FeatureList_t & selected = result.selected;
301 VariableSelectionResult::ErrorList_t & errors = result.errors;
302 VariableSelectionResult::Pivot_t & pivot = result.pivot;
303 int featureCount = features.shape(1);
304 // initialize result struct if in use for the first time
305 if(!result.init(features, response, errorcallback))
306 {
307 //result is being reused just ensure that the number of features is
308 //the same.
309 vigra_precondition(static_cast<int>(selected.size()) == featureCount,
310 "forward_selection(): Number of features in Feature "
311 "matrix and number of features in previously used "
312 "result struct mismatch!");
313 }
314
315
316 int not_selected_size = std::distance(pivot, selected.end());
317 while(not_selected_size > 1)
318 {
319 std::vector<double> current_errors;
320 VariableSelectionResult::Pivot_t next = pivot;
321 for(int ii = 0; ii < not_selected_size; ++ii, ++next)
322 {
323 std::swap(*pivot, *next);
324 MultiArray<2, double> cur_feats;
325 detail::choose( features,
326 selected.begin(),
327 pivot+1,
328 cur_feats);
329 double error = errorcallback(cur_feats, response);
330 current_errors.push_back(error);
331 std::swap(*pivot, *next);
332 }
333 int pos = std::distance(current_errors.begin(),
334 std::min_element(current_errors.begin(),
335 current_errors.end()));
336 next = pivot;
337 std::advance(next, pos);
338 std::swap(*pivot, *next);
339 errors[std::distance(selected.begin(), pivot)] = current_errors[pos];
340#ifdef RN_VERBOSE
341 std::copy(current_errors.begin(), current_errors.end(), std::ostream_iterator<double>(std::cerr, ", "));
342 std::cerr << "Choosing " << *pivot << " at error of " << current_errors[pos] << std::endl;
343#endif
344 ++pivot;
345 not_selected_size = std::distance(pivot, selected.end());
346 }
347}
348template<class FeatureT, class ResponseT>
349void forward_selection(FeatureT const & features,
350 ResponseT const & response,
351 VariableSelectionResult & result)
352{
353 forward_selection(features, response, result, RFErrorCallback());
354}
355
356
357/** Perform backward elimination
358 *
359 * \param features IN: n x p matrix containing n instances with p attributes/features
360 * used in the variable selection algorithm
361 * \param response IN: n x 1 matrix containing the corresponding response
362 * \param result IN/OUT: VariableSelectionResult struct which will contain the results
363 * of the algorithm.
364 * Features between result.pivot and result.selected.end() will
365 * be left untouched.
366 * \sa VariableSelectionResult
367 * \param errorcallback
368 * IN, OPTIONAL:
369 * Functor that returns the error rate given a set of
370 * features and labels. Default is the RandomForest OOB Error.
371 *
372 * Backward elimination subsequently eliminates features that have the least influence
373 * on the error rate
374 *
375 * usage:
376 * \code
377 * MultiArray<2, double> features = createSomeFeatures();
378 * MultiArray<2, int> labels = createCorrespondingLabels();
379 * VariableSelectionResult result;
380 * backward_elimination(features, labels, result);
381 * \endcode
382 * To use backward elimination but ensure that a specific feature e.g. feature 5 is always
383 * excluded one would do the following:
384 *
385 * \code
386 * VariableSelectionResult result;
387 * result.init(features, labels);
388 * std::swap(result.selected[result.selected.size()-1], result.selected[5]);
389 * result.setPivot(result.selected.size()-1);
390 * backward_elimination(features, labels, result);
391 * \endcode
392 *
393 * \sa VariableSelectionResult
394 *
395 */
396template<class FeatureT, class ResponseT, class ErrorRateCallBack>
397void backward_elimination(FeatureT const & features,
398 ResponseT const & response,
400 ErrorRateCallBack errorcallback)
401{
402 int featureCount = features.shape(1);
403 VariableSelectionResult::FeatureList_t & selected = result.selected;
404 VariableSelectionResult::ErrorList_t & errors = result.errors;
405 VariableSelectionResult::Pivot_t & pivot = result.pivot;
406
407 // initialize result struct if in use for the first time
408 if(!result.init(features, response, errorcallback))
409 {
410 //result is being reused just ensure that the number of features is
411 //the same.
412 vigra_precondition(static_cast<int>(selected.size()) == featureCount,
413 "backward_elimination(): Number of features in Feature "
414 "matrix and number of features in previously used "
415 "result struct mismatch!");
416 }
417 pivot = selected.end() - 1;
418
419 int selected_size = std::distance(selected.begin(), pivot);
420 while(selected_size > 1)
421 {
422 VariableSelectionResult::Pivot_t next = selected.begin();
423 std::vector<double> current_errors;
424 for(int ii = 0; ii < selected_size; ++ii, ++next)
425 {
426 std::swap(*pivot, *next);
427 MultiArray<2, double> cur_feats;
428 detail::choose( features,
429 selected.begin(),
430 pivot+1,
431 cur_feats);
432 double error = errorcallback(cur_feats, response);
433 current_errors.push_back(error);
434 std::swap(*pivot, *next);
435 }
436 int pos = std::distance(current_errors.begin(),
437 std::min_element(current_errors.begin(),
438 current_errors.end()));
439 next = selected.begin();
440 std::advance(next, pos);
441 std::swap(*pivot, *next);
442// std::cerr << std::distance(selected.begin(), pivot) << " " << pos << " " << current_errors.size() << " " << errors.size() << std::endl;
443 errors[std::distance(selected.begin(), pivot)-1] = current_errors[pos];
444 selected_size = std::distance(selected.begin(), pivot);
445#ifdef RN_VERBOSE
446 std::copy(current_errors.begin(), current_errors.end(), std::ostream_iterator<double>(std::cerr, ", "));
447 std::cerr << "Eliminating " << *pivot << " at error of " << current_errors[pos] << std::endl;
448#endif
449 --pivot;
450 }
451}
452
453template<class FeatureT, class ResponseT>
454void backward_elimination(FeatureT const & features,
455 ResponseT const & response,
456 VariableSelectionResult & result)
457{
458 backward_elimination(features, response, result, RFErrorCallback());
459}
460
461/** Perform rank selection using a predefined ranking
462 *
463 * \param features IN: n x p matrix containing n instances with p attributes/features
464 * used in the variable selection algorithm
465 * \param response IN: n x 1 matrix containing the corresponding response
466 * \param result IN/OUT: VariableSelectionResult struct which will contain the results
467 * of the algorithm. The struct should be initialized with the
468 * predefined ranking.
469 *
470 * \sa VariableSelectionResult
471 * \param errorcallback
472 * IN, OPTIONAL:
473 * Functor that returns the error rate given a set of
474 * features and labels. Default is the RandomForest OOB Error.
475 *
476 * Often some variable importance, score measure is used to create the ordering in which
477 * variables have to be selected. This method takes such a ranking and calculates the
478 * corresponding error rates.
479 *
480 * usage:
481 * \code
482 * MultiArray<2, double> features = createSomeFeatures();
483 * MultiArray<2, int> labels = createCorrespondingLabels();
484 * std::vector<int> ranking = createRanking(features);
485 * VariableSelectionResult result;
486 * result.init(features, labels, ranking.begin(), ranking.end());
487 * backward_elimination(features, labels, result);
488 * \endcode
489 *
490 * \sa VariableSelectionResult
491 *
492 */
493template<class FeatureT, class ResponseT, class ErrorRateCallBack>
494void rank_selection (FeatureT const & features,
495 ResponseT const & response,
497 ErrorRateCallBack errorcallback)
498{
499 VariableSelectionResult::FeatureList_t & selected = result.selected;
500 VariableSelectionResult::ErrorList_t & errors = result.errors;
501 VariableSelectionResult::Pivot_t & iter = result.pivot;
502 int featureCount = features.shape(1);
503 // initialize result struct if in use for the first time
504 if(!result.init(features, response, errorcallback))
505 {
506 //result is being reused just ensure that the number of features is
507 //the same.
508 vigra_precondition(static_cast<int>(selected.size()) == featureCount,
509 "forward_selection(): Number of features in Feature "
510 "matrix and number of features in previously used "
511 "result struct mismatch!");
512 }
513
514 int ii = 0;
515 for(; iter != selected.end(); ++iter)
516 {
517 ++ii;
518 MultiArray<2, double> cur_feats;
519 detail::choose( features,
520 selected.begin(),
521 iter+1,
522 cur_feats);
523 double error = errorcallback(cur_feats, response);
524 errors[std::distance(selected.begin(), iter)] = error;
525#ifdef RN_VERBOSE
526 std::copy(selected.begin(), iter+1, std::ostream_iterator<int>(std::cerr, ", "));
527 std::cerr << "Choosing " << *(iter+1) << " at error of " << error << std::endl;
528#endif
529
530 }
531}
532
533template<class FeatureT, class ResponseT>
534void rank_selection (FeatureT const & features,
535 ResponseT const & response,
536 VariableSelectionResult & result)
537{
538 rank_selection(features, response, result, RFErrorCallback());
539}
540
541
542
543enum ClusterLeafTypes{c_Leaf = 95, c_Node = 99};
544
545/* View of a Node in the hierarchical clustering
546 * class
547 * For internal use only -
548 * \sa NodeBase
549 */
550class ClusterNode
551: public NodeBase
552{
553 public:
554
555 typedef NodeBase BT;
556
557 /**constructors **/
558 ClusterNode():NodeBase(){}
559 ClusterNode( int nCol,
560 BT::T_Container_type & topology,
561 BT::P_Container_type & split_param)
562 : BT(nCol + 5, 5,topology, split_param)
563 {
564 status() = 0;
565 BT::column_data()[0] = nCol;
566 if(nCol == 1)
567 BT::typeID() = c_Leaf;
568 else
569 BT::typeID() = c_Node;
570 }
571
572 ClusterNode( BT::T_Container_type const & topology,
573 BT::P_Container_type const & split_param,
574 int n )
575 : NodeBase(5 , 5,topology, split_param, n)
576 {
577 //TODO : is there a more elegant way to do this?
578 BT::topology_size_ += BT::column_data()[0];
579 }
580
581 ClusterNode( BT & node_)
582 : BT(5, 5, node_)
583 {
584 //TODO : is there a more elegant way to do this?
585 BT::topology_size_ += BT::column_data()[0];
586 BT::parameter_size_ += 0;
587 }
588 int index()
589 {
590 return static_cast<int>(BT::parameters_begin()[1]);
591 }
592 void set_index(int in)
593 {
594 BT::parameters_begin()[1] = in;
595 }
596 double& mean()
597 {
598 return BT::parameters_begin()[2];
599 }
600 double& stdev()
601 {
602 return BT::parameters_begin()[3];
603 }
604 double& status()
605 {
606 return BT::parameters_begin()[4];
607 }
608};
609
610/** Stackentry class for HClustering class
611 */
612struct HC_Entry
613{
614 int parent;
615 int level;
616 int addr;
617 bool infm;
618 HC_Entry(int p, int l, int a, bool in)
619 : parent(p), level(l), addr(a), infm(in)
620 {}
621};
622
623
624/** Hierarchical Clustering class.
625 * Performs single linkage clustering
626 * \code
627 * Matrix<double> distance = get_distance_matrix();
628 * linkage.cluster(distance);
629 * // Draw clustering tree.
630 * Draw<double, int> draw(features, labels, "linkagetree.graph");
631 * linkage.breadth_first_traversal(draw);
632 * \endcode
633 * \sa ClusterImportanceVisitor
634 *
635 * once the clustering has taken place. Information queries can be made
636 * using the breadth_first_traversal() method and iterate() method
637 *
638 */
640{
641public:
642 typedef MultiArrayShape<2>::type Shp;
643 ArrayVector<int> topology_;
644 ArrayVector<double> parameters_;
645 int begin_addr;
646
647 // Calculates the distance between two
648 double dist_func(double a, double b)
649 {
650 return std::min(a, b);
651 }
652
653 /** Visit each node with a Functor
654 * in creation order (should be depth first)
655 */
656 template<class Functor>
657 void iterate(Functor & tester)
658 {
659
660 std::vector<int> stack;
661 stack.push_back(begin_addr);
662 while(!stack.empty())
663 {
664 ClusterNode node(topology_, parameters_, stack.back());
665 stack.pop_back();
666 if(!tester(node))
667 {
668 if(node.columns_size() != 1)
669 {
670 stack.push_back(node.child(0));
671 stack.push_back(node.child(1));
672 }
673 }
674 }
675 }
676
677 /** Perform breadth first traversal of hierarchical cluster tree
678 */
679 template<class Functor>
680 void breadth_first_traversal(Functor & tester)
681 {
682
683 std::queue<HC_Entry> queue;
684 int level = 0;
685 int parent = -1;
686 int addr = -1;
687 bool infm = false;
688 queue.push(HC_Entry(parent,level,begin_addr, infm));
689 while(!queue.empty())
690 {
691 level = queue.front().level;
692 parent = queue.front().parent;
693 addr = queue.front().addr;
694 infm = queue.front().infm;
695 ClusterNode node(topology_, parameters_, queue.front().addr);
696 ClusterNode parnt;
697 if(parent != -1)
698 {
699 parnt = ClusterNode(topology_, parameters_, parent);
700 }
701 queue.pop();
702 bool istrue = tester(node, level, parnt, infm);
703 if(node.columns_size() != 1)
704 {
705 queue.push(HC_Entry(addr, level +1,node.child(0),istrue));
706 queue.push(HC_Entry(addr, level +1,node.child(1),istrue));
707 }
708 }
709 }
710 /**save to HDF5 - defunct - has to be updated to new HDF5 interface
711 */
712#ifdef HasHDF5
713 void save(std::string file, std::string prefix)
714 {
715
716 vigra::writeHDF5(file.c_str(), (prefix + "topology").c_str(),
718 Shp(topology_.size(),1),
719 topology_.data()));
720 vigra::writeHDF5(file.c_str(), (prefix + "parameters").c_str(),
722 Shp(parameters_.size(), 1),
723 parameters_.data()));
724 vigra::writeHDF5(file.c_str(), (prefix + "begin_addr").c_str(),
725 MultiArrayView<2, int>(Shp(1,1), &begin_addr));
726
727 }
728#endif
729
730 /**Perform single linkage clustering
731 * \param distance distance matrix used. \sa CorrelationVisitor
732 */
733 template<class T, class C>
735 {
736 MultiArray<2, T> dist(distance);
737 std::vector<std::pair<int, int> > addr;
738 int index = 0;
739 for(int ii = 0; ii < distance.shape(0); ++ii)
740 {
741 addr.push_back(std::make_pair(topology_.size(), ii));
742 ClusterNode leaf(1, topology_, parameters_);
743 leaf.set_index(index);
744 ++index;
745 leaf.columns_begin()[0] = ii;
746 }
747
748 while(addr.size() != 1)
749 {
750 //find the two nodes with the smallest distance
751 int ii_min = 0;
752 int jj_min = 1;
753 double min_dist = dist((addr.begin()+ii_min)->second,
754 (addr.begin()+jj_min)->second);
755 for(unsigned int ii = 0; ii < addr.size(); ++ii)
756 {
757 for(unsigned int jj = ii+1; jj < addr.size(); ++jj)
758 {
759 if( dist((addr.begin()+ii_min)->second,
760 (addr.begin()+jj_min)->second)
761 > dist((addr.begin()+ii)->second,
762 (addr.begin()+jj)->second))
763 {
764 min_dist = dist((addr.begin()+ii)->second,
765 (addr.begin()+jj)->second);
766 ii_min = ii;
767 jj_min = jj;
768 }
769 }
770 }
771
772 //merge two nodes
773 int col_size = 0;
774 // The problem is that creating a new node invalidates the iterators stored
775 // in firstChild and secondChild.
776 {
777 ClusterNode firstChild(topology_,
778 parameters_,
779 (addr.begin() +ii_min)->first);
780 ClusterNode secondChild(topology_,
781 parameters_,
782 (addr.begin() +jj_min)->first);
783 col_size = firstChild.columns_size() + secondChild.columns_size();
784 }
785 int cur_addr = topology_.size();
786 begin_addr = cur_addr;
787// std::cerr << col_size << std::endl;
788 ClusterNode parent(col_size,
789 topology_,
790 parameters_);
791 ClusterNode firstChild(topology_,
792 parameters_,
793 (addr.begin() +ii_min)->first);
794 ClusterNode secondChild(topology_,
795 parameters_,
796 (addr.begin() +jj_min)->first);
797 parent.parameters_begin()[0] = min_dist;
798 parent.set_index(index);
799 ++index;
800 std::merge(firstChild.columns_begin(), firstChild.columns_end(),
801 secondChild.columns_begin(),secondChild.columns_end(),
802 parent.columns_begin());
803 //merge nodes in addr
804 int to_desc;
805 int ii_keep;
806 if(*parent.columns_begin() == *firstChild.columns_begin())
807 {
808 parent.child(0) = (addr.begin()+ii_min)->first;
809 parent.child(1) = (addr.begin()+jj_min)->first;
810 (addr.begin()+ii_min)->first = cur_addr;
811 ii_keep = ii_min;
812 to_desc = (addr.begin()+jj_min)->second;
813 addr.erase(addr.begin()+jj_min);
814 }
815 else
816 {
817 parent.child(1) = (addr.begin()+ii_min)->first;
818 parent.child(0) = (addr.begin()+jj_min)->first;
819 (addr.begin()+jj_min)->first = cur_addr;
820 ii_keep = jj_min;
821 to_desc = (addr.begin()+ii_min)->second;
822 addr.erase(addr.begin()+ii_min);
823 }
824 //update distances;
825
826 for(int jj = 0 ; jj < static_cast<int>(addr.size()); ++jj)
827 {
828 if(jj == ii_keep)
829 continue;
830 double bla = dist_func(
831 dist(to_desc, (addr.begin()+jj)->second),
832 dist((addr.begin()+ii_keep)->second,
833 (addr.begin()+jj)->second));
834
835 dist((addr.begin()+ii_keep)->second,
836 (addr.begin()+jj)->second) = bla;
837 dist((addr.begin()+jj)->second,
838 (addr.begin()+ii_keep)->second) = bla;
839 }
840 }
841 }
842
843};
844
845
846/** Normalize the status value in the HClustering tree (HClustering Visitor)
847 */
849{
850public:
851 double n;
852 /** Constructor
853 * \param m normalize status() by m
854 */
856 :n(m)
857 {}
858 template<class Node>
859 bool operator()(Node& node)
860 {
861 node.status()/=n;
862 return false;
863 }
864};
865
866
867/** Perform Permutation importance on HClustering clusters
868 * (See visit_after_tree() method of visitors::VariableImportance to
869 * see the basic idea. (Just that we apply the permutation not only to
870 * variables but also to clusters))
871 */
872template<class Iter, class DT>
873class PermuteCluster
874{
875public:
876 typedef MultiArrayShape<2>::type Shp;
877 Matrix<double> tmp_mem_;
880 Matrix<double> feats_;
881 Matrix<int> labels_;
882 const int nPerm;
883 DT const & dt;
884 int index;
885 int oob_size;
886
887 template<class Feat_T, class Label_T>
888 PermuteCluster(Iter a,
889 Iter b,
890 Feat_T const & feats,
891 Label_T const & labls,
894 int np,
895 DT const & dt_)
896 :tmp_mem_(_spl(a, b).size(), feats.shape(1)),
897 perm_imp(p_imp),
898 orig_imp(o_imp),
899 feats_(_spl(a,b).size(), feats.shape(1)),
900 labels_(_spl(a,b).size(),1),
901 nPerm(np),
902 dt(dt_),
903 index(0),
904 oob_size(b-a)
905 {
906 copy_splice(_spl(a,b),
907 _spl(feats.shape(1)),
908 feats,
909 feats_);
910 copy_splice(_spl(a,b),
911 _spl(labls.shape(1)),
912 labls,
913 labels_);
914 }
915
916 template<class Node>
917 bool operator()(Node& node)
918 {
919 tmp_mem_ = feats_;
920 RandomMT19937 random;
921 int class_count = perm_imp.shape(1) - 1;
922 //permute columns together
923 for(int kk = 0; kk < nPerm; ++kk)
924 {
925 tmp_mem_ = feats_;
926 for(int ii = 0; ii < rowCount(feats_); ++ii)
927 {
928 int index = random.uniformInt(rowCount(feats_) - ii) +ii;
929 for(int jj = 0; jj < node.columns_size(); ++jj)
930 {
931 if(node.columns_begin()[jj] != feats_.shape(1))
932 tmp_mem_(ii, node.columns_begin()[jj])
933 = tmp_mem_(index, node.columns_begin()[jj]);
934 }
935 }
936
937 for(int ii = 0; ii < rowCount(tmp_mem_); ++ii)
938 {
939 if(dt
940 .predictLabel(rowVector(tmp_mem_, ii))
941 == labels_(ii, 0))
942 {
943 //per class
944 ++perm_imp(index,labels_(ii, 0));
945 //total
946 ++perm_imp(index, class_count);
947 }
948 }
949 }
950 double node_status = perm_imp(index, class_count);
951 node_status /= nPerm;
952 node_status -= orig_imp(0, class_count);
953 node_status *= -1;
954 node_status /= oob_size;
955 node.status() += node_status;
956 ++index;
957
958 return false;
959 }
960};
961
962/** Convert ClusteringTree into a list (HClustering visitor)
963 */
964class GetClusterVariables
965{
966public:
967 /** NumberOfClusters x NumberOfVariables MultiArrayView containing
968 * in each row the variable belonging to a cluster
969 */
971 int index;
972 GetClusterVariables(MultiArrayView<2, int> vars)
973 :variables(vars), index(0)
974 {}
975#ifdef HasHDF5
976 void save(std::string file, std::string prefix)
977 {
978 vigra::writeHDF5(file.c_str(), (prefix + "_variables").c_str(),
979 variables);
980 }
981#endif
982
983 template<class Node>
984 bool operator()(Node& node)
985 {
986 for(int ii = 0; ii < node.columns_size(); ++ii)
987 variables(index, ii) = node.columns_begin()[ii];
988 ++index;
989 return false;
990 }
991};
992/** corrects the status fields of a linkage Clustering (HClustering Visitor)
993 *
994 * such that status(currentNode) = min(status(parent), status(currentNode))
995 * \sa cluster_permutation_importance()
996 */
998{
999public:
1000 template<class Nde>
1001 bool operator()(Nde & cur, int /*level*/, Nde parent, bool /*infm*/)
1002 {
1003 if(parent.hasData_)
1004 cur.status() = std::min(parent.status(), cur.status());
1005 return true;
1006 }
1007};
1008
1009
1010/** draw current linkage Clustering (HClustering Visitor)
1011 *
1012 * create a graphviz .dot file
1013 * usage:
1014 * \code
1015 * Matrix<double> distance = get_distance_matrix();
1016 * linkage.cluster(distance);
1017 * Draw<double, int> draw(features, labels, "linkagetree.graph");
1018 * linkage.breadth_first_traversal(draw);
1019 * \endcode
1020 */
1021template<class T1,
1022 class T2,
1023 class C1 = UnstridedArrayTag,
1024 class C2 = UnstridedArrayTag>
1025class Draw
1026{
1027public:
1028 typedef MultiArrayShape<2>::type Shp;
1029 MultiArrayView<2, T1, C1> const & features_;
1030 MultiArrayView<2, T2, C2> const & labels_;
1031 std::ofstream graphviz;
1032
1033
1034 Draw(MultiArrayView<2, T1, C1> const & features,
1035 MultiArrayView<2, T2, C2> const& labels,
1036 std::string const gz)
1037 :features_(features), labels_(labels),
1038 graphviz(gz.c_str(), std::ios::out)
1039 {
1040 graphviz << "digraph G\n{\n node [shape=\"record\"]";
1041 }
1042 ~Draw()
1043 {
1044 graphviz << "\n}\n";
1045 graphviz.close();
1046 }
1047
1048 template<class Nde>
1049 bool operator()(Nde & cur, int /*level*/, Nde parent, bool /*infm*/)
1050 {
1051 graphviz << "node" << cur.index() << " [style=\"filled\"][label = \" #Feats: "<< cur.columns_size() << "\\n";
1052 graphviz << " status: " << cur.status() << "\\n";
1053 for(int kk = 0; kk < cur.columns_size(); ++kk)
1054 {
1055 graphviz << cur.columns_begin()[kk] << " ";
1056 if(kk % 15 == 14)
1057 graphviz << "\\n";
1058 }
1059 graphviz << "\"] [color = \"" <<cur.status() << " 1.000 1.000\"];\n";
1060 if(parent.hasData_)
1061 graphviz << "\"node" << parent.index() << "\" -> \"node" << cur.index() <<"\";\n";
1062 return true;
1063 }
1064};
1065
1066/** calculate Cluster based permutation importance while learning. (RandomForestVisitor)
1067 */
1068class ClusterImportanceVisitor : public visitors::VisitorBase
1069{
1070 public:
1071
1072 /** List of variables as produced by GetClusterVariables
1073 */
1075 /** Corresponding importance measures
1076 */
1078 /** Corresponding error
1079 */
1081 int repetition_count_;
1082 bool in_place_;
1083 HClustering & clustering;
1084
1085
1086#ifdef HasHDF5
1087 void save(std::string filename, std::string prefix)
1088 {
1089 std::string prefix1 = "cluster_importance_" + prefix;
1090 writeHDF5(filename.c_str(),
1091 prefix1.c_str(),
1093 prefix1 = "vars_" + prefix;
1094 writeHDF5(filename.c_str(),
1095 prefix1.c_str(),
1096 variables);
1097 }
1098#endif
1099
1100 ClusterImportanceVisitor(HClustering & clst, int rep_cnt = 10)
1101 : repetition_count_(rep_cnt), clustering(clst)
1102
1103 {}
1104
1105 /** Allocate enough memory
1106 */
1107 template<class RF, class PR>
1108 void visit_at_beginning(RF const & rf, PR const & /*pr*/)
1109 {
1110 Int32 const class_count = rf.ext_param_.class_count_;
1111 Int32 const column_count = rf.ext_param_.column_count_+1;
1113 .reshape(MultiArrayShape<2>::type(2*column_count-1,
1114 class_count+1));
1116 .reshape(MultiArrayShape<2>::type(2*column_count-1,
1117 class_count+1));
1118 variables
1119 .reshape(MultiArrayShape<2>::type(2*column_count-1,
1120 column_count), -1);
1122 clustering.iterate(gcv);
1123
1124 }
1125
1126 /**compute permutation based var imp.
1127 * (Only an Array of size oob_sample_count x 1 is created.
1128 * - apposed to oob_sample_count x feature_count in the other method.
1129 *
1130 * \sa FieldProxy
1131 */
1132 template<class RF, class PR, class SM, class ST>
1133 void after_tree_ip_impl(RF& rf, PR & pr, SM & sm, ST & /*st*/, int index)
1134 {
1135 typedef MultiArrayShape<2>::type Shp_t;
1136 Int32 column_count = rf.ext_param_.column_count_ +1;
1137 Int32 class_count = rf.ext_param_.class_count_;
1138
1139 // remove the const cast on the features (yep , I know what I am
1140 // doing here.) data is not destroyed.
1141 typename PR::Feature_t & features
1142 = const_cast<typename PR::Feature_t &>(pr.features());
1143
1144 //find the oob indices of current tree.
1145 ArrayVector<Int32> oob_indices;
1146 ArrayVector<Int32>::iterator
1147 iter;
1148
1149 if(rf.ext_param_.actual_msample_ < pr.features().shape(0)- 10000)
1150 {
1151 ArrayVector<int> cts(2, 0);
1152 ArrayVector<Int32> indices(pr.features().shape(0));
1153 for(int ii = 0; ii < pr.features().shape(0); ++ii)
1154 indices.push_back(ii); ;
1155 std::random_device rd;
1156 std::mt19937 g(rd());
1157 std::shuffle(indices.begin(), indices.end(), g);
1158 for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1159 {
1160 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 3000)
1161 {
1162 oob_indices.push_back(indices[ii]);
1163 ++cts[pr.response()(indices[ii], 0)];
1164 }
1165 }
1166 }
1167 else
1168 {
1169 for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
1170 if(!sm.is_used()[ii])
1171 oob_indices.push_back(ii);
1172 }
1173
1174 // Random foo
1175 RandomMT19937 random(RandomSeed);
1177 randint(random);
1178
1179 //make some space for the results
1181 oob_right(Shp_t(1, class_count + 1));
1182
1183 // get the oob success rate with the original samples
1184 for(iter = oob_indices.begin();
1185 iter != oob_indices.end();
1186 ++iter)
1187 {
1188 if(rf.tree(index)
1189 .predictLabel(rowVector(features, *iter))
1190 == pr.response()(*iter, 0))
1191 {
1192 //per class
1193 ++oob_right[pr.response()(*iter,0)];
1194 //total
1195 ++oob_right[class_count];
1196 }
1197 }
1198
1200 perm_oob_right (Shp_t(2* column_count-1, class_count + 1));
1201
1202 PermuteCluster<ArrayVector<Int32>::iterator,typename RF::DecisionTree_t>
1203 pc(oob_indices.begin(), oob_indices.end(),
1204 pr.features(),
1205 pr.response(),
1206 perm_oob_right,
1207 oob_right,
1208 repetition_count_,
1209 rf.tree(index));
1210 clustering.iterate(pc);
1211
1212 perm_oob_right /= repetition_count_;
1213 for(int ii = 0; ii < rowCount(perm_oob_right); ++ii)
1214 rowVector(perm_oob_right, ii) -= oob_right;
1215
1216 perm_oob_right *= -1;
1217 perm_oob_right /= oob_indices.size();
1218 cluster_importance_ += perm_oob_right;
1219 }
1220
1221 /** calculate permutation based impurity after every tree has been
1222 * learned default behaviour is that this happens out of place.
1223 * If you have very big data sets and want to avoid copying of data
1224 * set the in_place_ flag to true.
1225 */
1226 template<class RF, class PR, class SM, class ST>
1227 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index)
1228 {
1229 after_tree_ip_impl(rf, pr, sm, st, index);
1230 }
1231
1232 /** Normalise variable importance after the number of trees is known.
1233 */
1234 template<class RF, class PR>
1235 void visit_at_end(RF & rf, PR & /*pr*/)
1236 {
1237 NormalizeStatus nrm(rf.tree_count());
1238 clustering.iterate(nrm);
1239 cluster_importance_ /= rf.trees_.size();
1240 }
1241};
1242
1243/** Perform hierarchical clustering of variables and assess importance of clusters
1244 *
1245 * \param features IN: n x p matrix containing n instances with p attributes/features
1246 * used in the variable selection algorithm
1247 * \param response IN: n x 1 matrix containing the corresponding response
1248 * \param linkage OUT: Hierarchical grouping of variables.
1249 * \param distance OUT: distance matrix used for creating the linkage
1250 *
1251 * Performs Hierarchical clustering of variables. And calculates the permutation importance
1252 * measures of each of the clusters. Use the Draw functor to create human readable output
1253 * The cluster-permutation importance measure corresponds to the normal permutation importance
1254 * measure with all columns corresponding to a cluster permuted.
1255 * The importance measure for each cluster is stored as the status() field of each clusternode
1256 * \sa HClustering
1257 *
1258 * usage:
1259 * \code
1260 * MultiArray<2, double> features = createSomeFeatures();
1261 * MultiArray<2, int> labels = createCorrespondingLabels();
1262 * HClustering linkage;
1263 * MultiArray<2, double> distance;
1264 * cluster_permutation_importance(features, labels, linkage, distance)
1265 * // create graphviz output
1266 *
1267 * Draw<double, int> draw(features, labels, "linkagetree.graph");
1268 * linkage.breadth_first_traversal(draw);
1269 *
1270 * \endcode
1271 *
1272 *
1273 */
1274template<class FeatureT, class ResponseT>
1275void cluster_permutation_importance(FeatureT const & features,
1276 ResponseT const & response,
1277 HClustering & linkage,
1278 MultiArray<2, double> & distance)
1279{
1280
1282 opt.tree_count(100);
1283 if(features.shape(0) > 40000)
1284 opt.samples_per_tree(20000).use_stratification(RF_EQUAL);
1285
1286
1287 vigra::RandomForest<int> RF(opt);
1290 RF.learn(features, response,
1291 create_visitor(missc, progress));
1292 distance = missc.distance;
1293 /*
1294 missc.save(exp_dir + dset.name() + "_result.h5", dset.name()+"MACH");
1295 */
1296
1297
1298 // Produce linkage
1299 linkage.cluster(distance);
1300
1301 //linkage.save(exp_dir + dset.name() + "_result.h5", "_linkage_CC/");
1302 vigra::RandomForest<int> RF2(opt);
1303 ClusterImportanceVisitor ci(linkage);
1304 RF2.learn(features,
1305 response,
1306 create_visitor(progress, ci));
1307
1308
1309 CorrectStatus cs;
1310 linkage.breadth_first_traversal(cs);
1311
1312 //ci.save(exp_dir + dset.name() + "_result.h5", dset.name());
1313 //Draw<double, int> draw(dset.features(), dset.response(), exp_dir+ dset.name() + ".graph");
1314 //linkage.breadth_first_traversal(draw);
1315
1316}
1317
1318
1319template<class FeatureT, class ResponseT>
1320void cluster_permutation_importance(FeatureT const & features,
1321 ResponseT const & response,
1322 HClustering & linkage)
1323{
1324 MultiArray<2, double> distance;
1325 cluster_permutation_importance(features, response, linkage, distance);
1326}
1327
1328
1329template<class Array1, class Vector1>
1330void get_ranking(Array1 const & in, Vector1 & out)
1331{
1332 std::map<double, int> mymap;
1333 for(int ii = 0; ii < in.size(); ++ii)
1334 mymap[in[ii]] = ii;
1335 for(std::map<double, int>::reverse_iterator iter = mymap.rbegin(); iter!= mymap.rend(); ++iter)
1336 {
1337 out.push_back(iter->second);
1338 }
1339}
1340}//namespace algorithms
1341}//namespace rf
1342}//namespace vigra
1343#endif //VIGRA_RF_ALGORITHM_HXX
const_iterator begin() const
Definition array_vector.hxx:223
const_pointer data() const
Definition array_vector.hxx:209
size_type size() const
Definition array_vector.hxx:358
const_iterator end() const
Definition array_vector.hxx:237
Definition array_vector.hxx:514
TinyVector< MultiArrayIndex, N > type
Definition multi_shape.hxx:272
Base class for, and view to, MultiArray.
Definition multi_array.hxx:705
const difference_type & shape() const
Definition multi_array.hxx:1650
Main MultiArray class containing the memory management.
Definition multi_array.hxx:2479
Topology_type column_data() const
Definition rf_nodeproxy.hxx:159
INT & typeID()
Definition rf_nodeproxy.hxx:136
NodeBase()
Definition rf_nodeproxy.hxx:237
Parameter_type parameters_begin() const
Definition rf_nodeproxy.hxx:207
Options object for the random forest.
Definition rf_common.hxx:171
RandomForestOptions & use_stratification(RF_OptionTag in)
specify stratification strategy
Definition rf_common.hxx:374
RandomForestOptions & samples_per_tree(double in)
specify the fraction of the total number of samples used per tree for learning.
Definition rf_common.hxx:411
RandomForestOptions & tree_count(unsigned int in)
Definition rf_common.hxx:500
Random forest version 2 (see also RandomForest for version 3)
Definition random_forest.hxx:148
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random)
learn on data with custom config and random number generator
Definition random_forest.hxx:941
UInt32 uniformInt() const
Definition random.hxx:471
Definition random.hxx:677
Definition matrix.hxx:125
Definition rf_algorithm.hxx:1069
MultiArray< 2, double > cluster_importance_
Definition rf_algorithm.hxx:1077
MultiArray< 2, int > variables
Definition rf_algorithm.hxx:1074
void visit_at_end(RF &rf, PR &)
Definition rf_algorithm.hxx:1235
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition rf_algorithm.hxx:1227
MultiArray< 2, double > cluster_stdev_
Definition rf_algorithm.hxx:1080
void after_tree_ip_impl(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition rf_algorithm.hxx:1133
void visit_at_beginning(RF const &rf, PR const &)
Definition rf_algorithm.hxx:1108
Definition rf_algorithm.hxx:998
Definition rf_algorithm.hxx:965
MultiArrayView< 2, int > variables
Definition rf_algorithm.hxx:970
Definition rf_algorithm.hxx:640
void iterate(Functor &tester)
Definition rf_algorithm.hxx:657
void cluster(MultiArrayView< 2, T, C > distance)
Definition rf_algorithm.hxx:734
void breadth_first_traversal(Functor &tester)
Definition rf_algorithm.hxx:680
Definition rf_algorithm.hxx:849
NormalizeStatus(double m)
Definition rf_algorithm.hxx:855
Definition rf_algorithm.hxx:874
Definition rf_algorithm.hxx:85
double operator()(Feature_t const &features, Response_t const &response)
Definition rf_algorithm.hxx:102
RFErrorCallback(RandomForestOptions opt=RandomForestOptions())
Definition rf_algorithm.hxx:94
Definition rf_algorithm.hxx:118
double no_features
Definition rf_algorithm.hxx:152
ErrorList_t errors
Definition rf_algorithm.hxx:147
FeatureList_t selected
Definition rf_algorithm.hxx:134
bool init(FeatureT const &all_features, ResponseT const &response, ErrorRateCallBack errorcallback)
Definition rf_algorithm.hxx:206
Definition rf_visitors.hxx:1499
MultiArray< 2, double > distance
Definition rf_visitors.hxx:1527
Definition rf_visitors.hxx:865
double oob_breiman
Definition rf_visitors.hxx:875
Definition rf_visitors.hxx:103
Definition rf_algorithm.hxx:52
void backward_elimination(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition rf_algorithm.hxx:397
void rank_selection(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition rf_algorithm.hxx:494
void forward_selection(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition rf_algorithm.hxx:295
void cluster_permutation_importance(FeatureT const &features, ResponseT const &response, HClustering &linkage, MultiArray< 2, double > &distance)
Definition rf_algorithm.hxx:1275
detail::VisitorNode< A > create_visitor(A &a)
Definition rf_visitors.hxx:345
void writeHDF5(...)
Store array data in an HDF5 file.
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition matrix.hxx:684
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition matrix.hxx:697
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition matrix.hxx:671
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition matrix.hxx:727
RandomNumberGenerator< detail::RandomState< detail::MT19937 > > RandomMT19937
Definition random.hxx:630
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition sized_int.hxx:175
Definition metaprogramming.hxx:123
Definition rf_algorithm.hxx:613

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.12.1 (Thu Feb 27 2025)