001/*- 002 ******************************************************************************* 003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd. 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 * 009 * Contributors: 010 * Peter Chang - initial API and implementation and/or initial documentation 011 *******************************************************************************/ 012 013package org.eclipse.january.dataset; 014 015import java.util.Arrays; 016import java.util.List; 017 018/** 019 * Class to run over a single dataset with NumPy broadcasting to promote shapes 020 * which have lower rank and outputs to a second dataset 021 */ 022public class SingleInputBroadcastIterator extends IndexIterator { 023 private int[] maxShape; 024 private int[] aShape; 025 private final Dataset aDataset; 026 private final Dataset oDataset; 027 private int[] aStride; 028 private int[] oStride; 029 030 final private int endrank; 031 032 /** 033 * position in dataset 034 */ 035 private final int[] pos; 036 private final int[] aDelta; 037 private final int[] oDelta; // this being non-null means output is different from inputs 038 private final int aStep, oStep; 039 private int aMax; 040 private int aStart, oStart; 041 private final boolean outputA; 042 043 /** 044 * Index in array 045 */ 046 public int aIndex, oIndex; 047 048 /** 049 * Current value in array 050 */ 051 public double aDouble; 052 053 /** 054 * Current value in array 055 */ 056 public long aLong; 057 058 private boolean asDouble = true; 059 060 /** 061 * @param a 062 * @param o (can be null for new dataset, or a) 063 */ 064 public SingleInputBroadcastIterator(Dataset a, Dataset o) { 065 this(a, o, false); 066 } 067 068 /** 069 * @param a 070 * @param o (can be null for new dataset, or a) 071 * @param createIfNull (by default, can create float or complex datasets) 072 */ 073 public SingleInputBroadcastIterator(Dataset a, Dataset o, boolean createIfNull) { 074 this(a, o, createIfNull, false, true); 075 } 076 077 /** 078 * @param a 079 * @param o (can be null for new dataset, or a) 080 * @param createIfNull 081 * @param allowInteger if true, can create integer datasets 082 * @param allowComplex if true, can create complex datasets 083 */ 084 @SuppressWarnings("deprecation") 085 public SingleInputBroadcastIterator(Dataset a, Dataset o, boolean createIfNull, boolean allowInteger, boolean allowComplex) { 086 List<int[]> fullShapes = BroadcastUtils.broadcastShapes(a.getShapeRef(), o == null ? null : o.getShapeRef()); 087 088 checkItemSize(a, o); 089 090 maxShape = fullShapes.remove(0); 091 092 oStride = null; 093 if (o != null && !Arrays.equals(maxShape, o.getShapeRef())) { 094 throw new IllegalArgumentException("Output does not match broadcasted shape"); 095 } 096 aShape = fullShapes.remove(0); 097 098 int rank = maxShape.length; 099 endrank = rank - 1; 100 101 aDataset = a.reshape(aShape); 102 aStride = BroadcastUtils.createBroadcastStrides(aDataset, maxShape); 103 outputA = o == a; 104 if (outputA) { 105 oStride = aStride; 106 oDelta = null; 107 oStep = 0; 108 oDataset = aDataset; 109 } else if (o != null) { 110 oStride = BroadcastUtils.createBroadcastStrides(o, maxShape); 111 oDelta = new int[rank]; 112 oStep = o.getElementsPerItem(); 113 oDataset = o; 114 } else if (createIfNull) { 115 int is = aDataset.getElementsPerItem(); 116 int dt = aDataset.getDType(); 117 if (aDataset.isComplex() && !allowComplex) { 118 is = 1; 119 dt = DTypeUtils.getBestFloatDType(dt); 120 } else if (!aDataset.hasFloatingPointElements() && !allowInteger) { 121 dt = DTypeUtils.getBestFloatDType(dt); 122 } 123 oDataset = DatasetFactory.zeros(is, maxShape, dt); 124 oStride = BroadcastUtils.createBroadcastStrides(oDataset, maxShape); 125 oDelta = new int[rank]; 126 oStep = oDataset.getElementsPerItem(); 127 } else { 128 oDelta = null; 129 oStep = 0; 130 oDataset = o; 131 } 132 133 pos = new int[rank]; 134 aDelta = new int[rank]; 135 aStep = aDataset.getElementsPerItem(); 136 for (int j = endrank; j >= 0; j--) { 137 aDelta[j] = aStride[j] * aShape[j]; 138 if (oDelta != null) { 139 oDelta[j] = oStride[j] * maxShape[j]; 140 } 141 } 142 if (endrank < 0) { 143 aMax = aStep; 144 } else { 145 aMax = Integer.MIN_VALUE; // use max delta 146 for (int j = endrank; j >= 0; j--) { 147 if (aDelta[j] > aMax) { 148 aMax = aDelta[j]; 149 } 150 } 151 } 152 aStart = aDataset.getOffset(); 153 aMax += aStart; 154 oStart = oDelta == null ? 0 : oDataset.getOffset(); 155 asDouble = aDataset.hasFloatingPointElements(); 156 reset(); 157 } 158 159 /** 160 * @return true if output from iterator is double 161 */ 162 public boolean isOutputDouble() { 163 return asDouble; 164 } 165 166 /** 167 * Set to output doubles 168 * @param asDouble 169 */ 170 public void setOutputDouble(boolean asDouble) { 171 if (this.asDouble != asDouble) { 172 this.asDouble = asDouble; 173 storeCurrentValues(); 174 } 175 } 176 177 private static void checkItemSize(Dataset a, Dataset o) { 178 final int isa = a.getElementsPerItem(); 179 if (o != null) { 180 final int iso = o.getElementsPerItem(); 181 if (isa != 1 && iso != isa) { 182 throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'"); 183 } 184 } 185 } 186 187 @Override 188 public int[] getShape() { 189 return maxShape; 190 } 191 192 @Override 193 public boolean hasNext() { 194 int j = endrank; 195 int oldA = aIndex; 196 for (; j >= 0; j--) { 197 pos[j]++; 198 aIndex += aStride[j]; 199 if (oDelta != null) 200 oIndex += oStride[j]; 201 if (pos[j] >= maxShape[j]) { 202 pos[j] = 0; 203 aIndex -= aDelta[j]; // reset these dimensions 204 if (oDelta != null) 205 oIndex -= oDelta[j]; 206 } else { 207 break; 208 } 209 } 210 if (j == -1) { 211 if (endrank >= 0) { 212 aIndex = aMax; 213 return false; 214 } 215 aIndex += aStep; 216 if (oDelta != null) 217 oIndex += oStep; 218 } 219 if (outputA) { 220 oIndex = aIndex; 221 } 222 223 if (aIndex == aMax) 224 return false; 225 226 if (oldA != aIndex) { 227 if (asDouble) { 228 aDouble = aDataset.getElementDoubleAbs(aIndex); 229 } else { 230 aLong = aDataset.getElementLongAbs(aIndex); 231 } 232 } 233 234 return true; 235 } 236 237 /** 238 * @return output dataset (can be null) 239 */ 240 public Dataset getOutput() { 241 return oDataset; 242 } 243 244 @Override 245 public int[] getPos() { 246 return pos; 247 } 248 249 @Override 250 public void reset() { 251 for (int i = 0; i <= endrank; i++) 252 pos[i] = 0; 253 254 if (endrank >= 0) { 255 pos[endrank] = -1; 256 aIndex = aStart - aStride[endrank]; 257 oIndex = oStart - (oStride == null ? 0 : oStride[endrank]); 258 } else { 259 aIndex = -aStep; 260 oIndex = -oStep; 261 } 262 263 // for zero-ranked datasets 264 if (aIndex == 0) { 265 storeCurrentValues(); 266 if (aMax == aIndex) 267 aMax++; 268 } 269 } 270 271 private void storeCurrentValues() { 272 if (aIndex >= 0) { 273 if (asDouble) { 274 aDouble = aDataset.getElementDoubleAbs(aIndex); 275 } else { 276 aLong = aDataset.getElementLongAbs(aIndex); 277 } 278 } 279 } 280}