[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_preprocessing.hxx
1/************************************************************************/
2/* */
3/* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4/* */
5/* This file is part of the VIGRA computer vision library. */
6/* The VIGRA Website is */
7/* http://hci.iwr.uni-heidelberg.de/vigra/ */
8/* Please direct questions, bug reports, and contributions to */
9/* ullrich.koethe@iwr.uni-heidelberg.de or */
10/* vigra@informatik.uni-hamburg.de */
11/* */
12/* Permission is hereby granted, free of charge, to any person */
13/* obtaining a copy of this software and associated documentation */
14/* files (the "Software"), to deal in the Software without */
15/* restriction, including without limitation the rights to use, */
16/* copy, modify, merge, publish, distribute, sublicense, and/or */
17/* sell copies of the Software, and to permit persons to whom the */
18/* Software is furnished to do so, subject to the following */
19/* conditions: */
20/* */
21/* The above copyright notice and this permission notice shall be */
22/* included in all copies or substantial portions of the */
23/* Software. */
24/* */
25/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27/* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28/* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29/* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30/* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31/* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32/* OTHER DEALINGS IN THE SOFTWARE. */
33/* */
34/************************************************************************/
35
36#ifndef VIGRA_RF_PREPROCESSING_HXX
37#define VIGRA_RF_PREPROCESSING_HXX
38
39#include <limits>
40#include <vigra/mathutil.hxx>
41#include "rf_common.hxx"
42
43namespace vigra
44{
45
46/** Class used while preprocessing (currently used only during learn)
47 *
48 * This class is internally used by the Random Forest learn function.
49 * Different split functors may need to process the data in different manners
50 * (i.e., regression labels that should not be touched and classification
51 * labels that must be converted into a integral format)
52 *
53 * This Class only exists in specialized versions, where the Tag class is
54 * fixed.
55 *
56 * The Tag class is determined by Splitfunctor::Preprocessor_t . Currently
57 * it can either be ClassificationTag or RegressionTag. look At the
58 * RegressionTag specialisation for the basic interface if you ever happen
59 * to care.... - or need some sort of vague new preprocessor.
60 * new preprocessor ( Soft labels or whatever)
61 */
62template<class Tag, class LabelType, class T1, class C1, class T2, class C2>
64
65namespace detail
66{
67
68 /* Common helper function used in all Processors.
69 * This function analyses the options struct and calculates the real
70 * values needed for the current problem (data)
71 */
72 template<class T>
73 void fill_external_parameters(RandomForestOptions const & options,
74 ProblemSpec<T> & ext_param)
75 {
76 // set correct value for mtry
77 switch(options.mtry_switch_)
78 {
79 case RF_SQRT:
80 ext_param.actual_mtry_ =
81 int(std::floor(
82 std::sqrt(double(ext_param.column_count_))
83 + 0.5));
84 break;
85 case RF_LOG:
86 // this is in Breimans original paper
87 ext_param.actual_mtry_ =
88 int(1+(std::log(double(ext_param.column_count_))
89 /std::log(2.0)));
90 break;
91 case RF_FUNCTION:
92 ext_param.actual_mtry_ =
93 options.mtry_func_(ext_param.column_count_);
94 break;
95 case RF_ALL:
96 ext_param.actual_mtry_ = ext_param.column_count_;
97 break;
98 default:
99 ext_param.actual_mtry_ =
100 options.mtry_;
101 }
102 // set correct value for msample
103 switch(options.training_set_calc_switch_)
104 {
105 case RF_CONST:
106 ext_param.actual_msample_ =
107 options.training_set_size_;
108 break;
109 case RF_PROPORTIONAL:
110 ext_param.actual_msample_ =
111 static_cast<int>(std::ceil(options.training_set_proportion_ *
112 ext_param.row_count_));
113 break;
114 case RF_FUNCTION:
115 ext_param.actual_msample_ =
116 options.training_set_func_(ext_param.row_count_);
117 break;
118 default:
119 vigra_precondition(1!= 1, "unexpected error");
120
121 }
122
123 }
124
125 /* Returns true if MultiArray contains NaNs
126 */
127 template<unsigned int N, class T, class C>
128 bool contains_nan(MultiArrayView<N, T, C> const & in)
129 {
130 typedef typename MultiArrayView<N, T, C>::const_iterator Iter;
131 Iter i = in.begin(), end = in.end();
132 for(; i != end; ++i)
133 if(isnan(NumericTraits<T>::toRealPromote(*i)))
134 return true;
135 return false;
136 }
137
138 /* Returns true if MultiArray contains Infs
139 */
140 template<unsigned int N, class T, class C>
141 bool contains_inf(MultiArrayView<N, T, C> const & in)
142 {
143 if(!std::numeric_limits<T>::has_infinity)
144 return false;
145 typedef typename MultiArrayView<N, T, C>::const_iterator Iter;
146 Iter i = in.begin(), end = in.end();
147 for(; i != end; ++i)
148 if(abs(*i) == std::numeric_limits<T>::infinity())
149 return true;
150 return false;
151 }
152} // namespace detail
153
154
155
156/** Preprocessor used during Classification
157 *
158 * This class converts the labels int Integral labels which are used by the
159 * standard split functor to address memory in the node objects.
160 */
161template<class LabelType, class T1, class C1, class T2, class C2>
162class Processor<ClassificationTag, LabelType, T1, C1, T2, C2>
163{
164 public:
165 typedef Int32 LabelInt;
166 typedef MultiArrayView<2, T1, C1> Feature_t;
167 typedef MultiArray<2, T1> FeatureWithMemory_t;
168 typedef MultiArrayView<2,LabelInt> Label_t;
169 MultiArrayView<2, T1, C1>const & features_;
170 MultiArray<2, LabelInt> intLabels_;
172
173 template<class T>
174 Processor(MultiArrayView<2, T1, C1>const & features,
176 RandomForestOptions &options,
177 ProblemSpec<T> &ext_param)
178 :
179 features_( features) // do not touch the features.
180 {
181 vigra_precondition(!detail::contains_nan(features), "RandomForest(): Feature matrix "
182 "contains NaNs");
183 vigra_precondition(!detail::contains_nan(response), "RandomForest(): Response "
184 "contains NaNs");
185 vigra_precondition(!detail::contains_inf(features), "RandomForest(): Feature matrix "
186 "contains inf");
187 vigra_precondition(!detail::contains_inf(response), "RandomForest(): Response "
188 "contains inf");
189 // set some of the problem specific parameters
190 ext_param.column_count_ = features.shape(1);
191 ext_param.row_count_ = features.shape(0);
192 ext_param.problem_type_ = CLASSIFICATION;
193 ext_param.used_ = true;
194 intLabels_.reshape(response.shape());
195
196 //get the class labels
197 if(ext_param.class_count_ == 0)
198 {
199 // fill up a map with the current labels and then create the
200 // integral labels.
201 std::set<T2> labelToInt;
202 for(MultiArrayIndex k = 0; k < features.shape(0); ++k)
203 labelToInt.insert(response(k,0));
204 std::vector<T2> tmp_(labelToInt.begin(), labelToInt.end());
205 ext_param.classes_(tmp_.begin(), tmp_.end());
206 }
207 for(MultiArrayIndex k = 0; k < features.shape(0); ++k)
208 {
209 if(std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0)) == ext_param.classes.end())
210 {
211 throw std::runtime_error("RandomForest(): invalid label in training data.");
212 }
213 else
214 intLabels_(k, 0) = std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0))
215 - ext_param.classes.begin();
216 }
217 // set class weights
218 if(ext_param.class_weights_.size() == 0)
219 {
221 tmp(static_cast<std::size_t>(ext_param.class_count_),
222 NumericTraits<T2>::one());
223 ext_param.class_weights(tmp.begin(), tmp.end());
224 }
225
226 // set mtry and msample
227 detail::fill_external_parameters(options, ext_param);
228
229 // set strata
230 strata_ = intLabels_;
231
232 }
233
234 /** Access the processed features
235 */
237 {
238 return features_;
239 }
240
241 /** Access processed labels
242 */
247
248 /** Access processed strata
249 */
250 ArrayVectorView < LabelInt> strata()
251 {
252 return ArrayVectorView<LabelInt>(intLabels_.size(), intLabels_.data());
253 }
254
255 /** Access strata fraction sized - not used currently
256 */
261};
262
263
264
265/** Regression Preprocessor - This basically does not do anything with the
266 * data.
267 */
268template<class LabelType, class T1, class C1, class T2, class C2>
269class Processor<RegressionTag,LabelType, T1, C1, T2, C2>
270{
271public:
272 // only views are created - no data copied.
275 RandomForestOptions const & options_;
277 ext_param_;
278 // will only be filled if needed
279 MultiArray<2, int> strata_;
280 bool strata_filled;
281
282 // copy the views.
283 template<class T>
286 RandomForestOptions const & options,
287 ProblemSpec<T>& ext_param)
288 :
289 features_(features),
290 response_(response),
291 options_(options),
292 ext_param_(ext_param)
293 {
294 // set some of the problem specific parameters
295 ext_param.column_count_ = features.shape(1);
296 ext_param.row_count_ = features.shape(0);
297 ext_param.problem_type_ = REGRESSION;
298 ext_param.used_ = true;
299 detail::fill_external_parameters(options, ext_param);
300 vigra_precondition(!detail::contains_nan(features), "Processor(): Feature Matrix "
301 "Contains NaNs");
302 vigra_precondition(!detail::contains_nan(response), "Processor(): Response "
303 "Contains NaNs");
304 vigra_precondition(!detail::contains_inf(features), "Processor(): Feature Matrix "
305 "Contains inf");
306 vigra_precondition(!detail::contains_inf(response), "Processor(): Response "
307 "Contains inf");
308 strata_ = MultiArray<2, int> (MultiArrayShape<2>::type(response_.shape(0), 1));
309 ext_param.response_size_ = response.shape(1);
310 ext_param.class_count_ = response_.shape(1);
311 std::vector<T2> tmp_(ext_param.class_count_, 0);
312 ext_param.classes_(tmp_.begin(), tmp_.end());
313 }
314
315 /** access preprocessed features
316 */
318 {
319 return features_;
320 }
321
322 /** access preprocessed response
323 */
325 {
326 return response_;
327 }
328
329 /** access strata - this is not used currently
330 */
332 {
333 return strata_;
334 }
335};
336}
337#endif //VIGRA_RF_PREPROCESSING_HXX
338
339
340
Definition array_vector.hxx:77
const_iterator begin() const
Definition array_vector.hxx:223
size_type size() const
Definition array_vector.hxx:358
const_iterator end() const
Definition array_vector.hxx:237
Definition array_vector.hxx:514
TinyVector< MultiArrayIndex, N > type
Definition multi_shape.hxx:272
Base class for, and view to, MultiArray.
Definition multi_array.hxx:705
StridedScanOrderIterator< actual_dimension, T, T const &, T const * > const_iterator
Definition multi_array.hxx:759
Main MultiArray class containing the memory management.
Definition multi_array.hxx:2479
problem specification class for the random forest.
Definition rf_common.hxx:539
ProblemSpec & classes_(C_Iter begin, C_Iter end)
supply with class labels -
Definition rf_common.hxx:828
ProblemSpec & class_weights(W_Iter begin, W_Iter end)
supply with class weights -
Definition rf_common.hxx:844
MultiArrayView< 2, LabelInt > response()
Definition rf_preprocessing.hxx:243
MultiArrayView< 2, T1, C1 > const & features()
Definition rf_preprocessing.hxx:236
ArrayVectorView< LabelInt > strata()
Definition rf_preprocessing.hxx:250
ArrayVectorView< double > strata_prob()
Definition rf_preprocessing.hxx:257
MultiArray< 2, int > & strata()
Definition rf_preprocessing.hxx:331
MultiArrayView< 2, T1, C1 > & features()
Definition rf_preprocessing.hxx:317
MultiArrayView< 2, T2, C2 > & response()
Definition rf_preprocessing.hxx:324
Definition rf_preprocessing.hxx:63
Options object for the random forest.
Definition rf_common.hxx:171
FFTWComplex< R >::NormType abs(const FFTWComplex< R > &a)
absolute value (= magnitude)
Definition fftw3.hxx:1002
std::ptrdiff_t MultiArrayIndex
Definition multi_fwd.hxx:60
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition sized_int.hxx:175

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.12.1 (Thu Feb 27 2025)