001/*- 002 * Copyright 2016 Diamond Light Source Ltd. 003 * 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 */ 009 010package org.eclipse.january.dataset; 011 012import java.util.List; 013 014/** 015 * Class to run over a pair of datasets in parallel with NumPy broadcasting of second dataset 016 */ 017public class BroadcastSingleIterator extends BroadcastSelfIterator { 018 private int[] bShape; 019 private int[] aStride; 020 private int[] bStride; 021 022 final private int endrank; 023 024 private final int[] aDelta, bDelta; 025 private final int aStep, bStep; 026 private int aMax, bMax; 027 private int aStart, bStart; 028 029 /** 030 * @param a dataset to iterate over 031 * @param b dataset to iterate over (will broadcast to first) 032 */ 033 public BroadcastSingleIterator(Dataset a, Dataset b) { 034 super(a, b); 035 036 int[] aShape = a.getShapeRef(); 037 maxShape = aShape; 038 List<int[]> fullShapes = BroadcastUtils.broadcastShapesToMax(maxShape, b.getShapeRef()); 039 bShape = fullShapes.remove(0); 040 041 int rank = maxShape.length; 042 endrank = rank - 1; 043 044 bDataset = b.reshape(bShape); 045 int[] aOffset = new int[1]; 046 aStride = AbstractDataset.createStrides(aDataset, aOffset); 047 bStride = BroadcastUtils.createBroadcastStrides(bDataset, maxShape); 048 049 pos = new int[rank]; 050 aDelta = new int[rank]; 051 aStep = aDataset.getElementsPerItem(); 052 bDelta = new int[rank]; 053 bStep = bDataset.getElementsPerItem(); 054 for (int j = endrank; j >= 0; j--) { 055 aDelta[j] = aStride[j] * aShape[j]; 056 bDelta[j] = bStride[j] * bShape[j]; 057 } 058 aStart = aOffset[0]; 059 bStart = bDataset.getOffset(); 060 aMax = endrank < 0 ? aStep + aStart: Integer.MIN_VALUE; 061 bMax = endrank < 0 ? bStep + bStart: Integer.MIN_VALUE; 062 reset(); 063 } 064 065 @Override 066 public boolean hasNext() { 067 int j = endrank; 068 int oldB = bIndex; 069 for (; j >= 0; j--) { 070 pos[j]++; 071 aIndex += aStride[j]; 072 bIndex += bStride[j]; 073 if (pos[j] >= maxShape[j]) { 074 pos[j] = 0; 075 aIndex -= aDelta[j]; // reset these dimensions 076 bIndex -= bDelta[j]; 077 } else { 078 break; 079 } 080 } 081 if (j == -1) { 082 if (endrank >= 0) { 083 return false; 084 } 085 aIndex += aStep; 086 bIndex += bStep; 087 } 088 089 if (aIndex == aMax || bIndex == bMax) { 090 return false; 091 } 092 093 if (read) { 094 if (oldB != bIndex) { 095 if (asDouble) { 096 bDouble = bDataset.getElementDoubleAbs(bIndex); 097 } else { 098 bLong = bDataset.getElementLongAbs(bIndex); 099 } 100 } 101 } 102 103 return true; 104 } 105 106 /** 107 * @return shape of first broadcasted dataset 108 */ 109 public int[] getFirstShape() { 110 return maxShape; 111 } 112 113 /** 114 * @return shape of second broadcasted dataset 115 */ 116 public int[] getSecondShape() { 117 return bShape; 118 } 119 120 @Override 121 public void reset() { 122 for (int i = 0; i <= endrank; i++) { 123 pos[i] = 0; 124 } 125 126 if (endrank >= 0) { 127 pos[endrank] = -1; 128 aIndex = aStart - aStride[endrank]; 129 bIndex = bStart - bStride[endrank]; 130 } else { 131 aIndex = aStart - aStep; 132 bIndex = bStart - bStep; 133 } 134 135 if (aIndex == 0 || bIndex == 0 || (endrank >= 0 && bStride[endrank] == 0)) { // for zero-ranked datasets or extended shape 136 if (read) { 137 storeCurrentValues(); 138 } 139 } 140 } 141}