001/*-
002 *******************************************************************************
003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd.
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 *
009 * Contributors:
010 *    Peter Chang - initial API and implementation and/or initial documentation
011 *******************************************************************************/
012
013package org.eclipse.january.dataset;
014
015import java.util.Arrays;
016import java.util.List;
017
018/**
019 * Class to run over a pair of datasets in parallel with NumPy broadcasting to promote shapes
020 * which have lower rank and outputs to a third dataset
021 */
022public class BroadcastPairIterator extends BroadcastIterator {
023        private int[] aShape;
024        private int[] bShape;
025        private int[] aStride;
026        private int[] bStride;
027        private int[] oStride;
028
029        final private int endrank;
030
031        private final int[] aDelta, bDelta;
032        private final int[] oDelta; // this being non-null means output is different from inputs
033        private final int aStep, bStep, oStep;
034        private int aMax, bMax;
035        private int aStart, bStart, oStart;
036
037        /**
038         * 
039         * @param a dataset to iterate over
040         * @param b dataset to iterate over
041         * @param o output (can be null for new dataset, or a)
042         * @param createIfNull if true, create new dataset if o is null
043         */
044        public BroadcastPairIterator(Dataset a, Dataset b, Dataset o, boolean createIfNull) {
045                super(a, b, o);
046                List<int[]> fullShapes = BroadcastUtils.broadcastShapes(a.getShapeRef(), b.getShapeRef(), o == null ? null : o.getShapeRef());
047
048                maxShape = fullShapes.remove(0);
049
050                oStride = null;
051                if (o != null && !Arrays.equals(maxShape, o.getShapeRef())) {
052                        throw new IllegalArgumentException("Output does not match broadcasted shape");
053                }
054                aShape = fullShapes.remove(0);
055                bShape = fullShapes.remove(0);
056
057                int rank = maxShape.length;
058                endrank = rank - 1;
059
060                aDataset = a.reshape(aShape);
061                bDataset = b.reshape(bShape);
062                aStride = BroadcastUtils.createBroadcastStrides(aDataset, maxShape);
063                bStride = BroadcastUtils.createBroadcastStrides(bDataset, maxShape);
064                if (outputA) {
065                        oStride = aStride;
066                        oDelta = null;
067                        oStep = 0;
068                } else if (outputB) {
069                        oStride = bStride;
070                        oDelta = null;
071                        oStep = 0;
072                } else if (o != null) {
073                        oStride = BroadcastUtils.createBroadcastStrides(o, maxShape);
074                        oDelta = new int[rank];
075                        oStep = o.getElementsPerItem();
076                } else if (createIfNull) {
077                        oDataset = BroadcastUtils.createDataset(a, b, maxShape);
078                        oStride = BroadcastUtils.createBroadcastStrides(oDataset, maxShape);
079                        oDelta = new int[rank];
080                        oStep = oDataset.getElementsPerItem();
081                } else {
082                        oDelta = null;
083                        oStep = 0;
084                }
085
086                pos = new int[rank];
087                aDelta = new int[rank];
088                aStep = aDataset.getElementsPerItem();
089                bDelta = new int[rank];
090                bStep = bDataset.getElementsPerItem();
091                for (int j = endrank; j >= 0; j--) {
092                        aDelta[j] = aStride[j] * aShape[j];
093                        bDelta[j] = bStride[j] * bShape[j];
094                        if (oDelta != null) {
095                                oDelta[j] = oStride[j] * maxShape[j];
096                        }
097                }
098                aStart = aDataset.getOffset();
099                bStart = bDataset.getOffset();
100                aMax = endrank < 0 ? aStep + aStart: Integer.MIN_VALUE;
101                bMax = endrank < 0 ? bStep + bStart: Integer.MIN_VALUE;
102                oStart = oDelta == null ? 0 : oDataset.getOffset();
103                reset();
104        }
105
106        @Override
107        public boolean hasNext() {
108                int j = endrank;
109                int oldA = aIndex;
110                int oldB = bIndex;
111                for (; j >= 0; j--) {
112                        pos[j]++;
113                        aIndex += aStride[j];
114                        bIndex += bStride[j];
115                        if (oDelta != null) {
116                                oIndex += oStride[j];
117                        }
118                        if (pos[j] >= maxShape[j]) {
119                                pos[j] = 0;
120                                aIndex -= aDelta[j]; // reset these dimensions
121                                bIndex -= bDelta[j];
122                                if (oDelta != null) {
123                                        oIndex -= oDelta[j];
124                                }
125                        } else {
126                                break;
127                        }
128                }
129                if (j == -1) {
130                        if (endrank >= 0) {
131                                return false;
132                        }
133                        aIndex += aStep;
134                        bIndex += bStep;
135                        if (oDelta != null) {
136                                oIndex += oStep;
137                        }
138                }
139                if (outputA) {
140                        oIndex = aIndex;
141                } else if (outputB) {
142                        oIndex = bIndex;
143                }
144
145                if (aIndex == aMax || bIndex == bMax) {
146                        return false;
147                }
148
149                if (read) {
150                        if (oldA != aIndex) {
151                                if (asDouble) {
152                                        aDouble = aDataset.getElementDoubleAbs(aIndex);
153                                } else {
154                                        aLong = aDataset.getElementLongAbs(aIndex);
155                                }
156                        }
157                        if (oldB != bIndex) {
158                                if (asDouble) {
159                                        bDouble = bDataset.getElementDoubleAbs(bIndex);
160                                } else {
161                                        bLong = bDataset.getElementLongAbs(bIndex);
162                                }
163                        }
164                }
165
166                return true;
167        }
168
169        /**
170         * @return shape of first broadcasted dataset
171         */
172        public int[] getFirstShape() {
173                return aShape;
174        }
175
176        /**
177         * @return shape of second broadcasted dataset
178         */
179        public int[] getSecondShape() {
180                return bShape;
181        }
182
183        @Override
184        public void reset() {
185                for (int i = 0; i <= endrank; i++) {
186                        pos[i] = 0;
187                }
188
189                if (endrank >= 0) {
190                        pos[endrank] = -1;
191                        aIndex = aStart - aStride[endrank];
192                        bIndex = bStart - bStride[endrank];
193                        oIndex = oStart - (oStride == null ? 0 : oStride[endrank]);
194                } else {
195                        aIndex = aStart - aStep;
196                        bIndex = bStart - bStep;
197                        oIndex = oStart - oStep;
198                }
199
200                if (aIndex == 0 || bIndex == 0 || (endrank >= 0 && (aStride[endrank] == 0 || bStride[endrank] == 0))) { // for zero-ranked datasets or extended shape
201                        if (read) {
202                                storeCurrentValues();
203                        }
204                }
205        }
206}