/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.shuffle.reader;

import java.io.IOException;
import java.nio.ByteBuffer;

import com.esotericsoftware.kryo.io.Input;
import com.google.common.annotations.VisibleForTesting;
import io.netty.buffer.ByteBufInputStream;
import io.netty.buffer.Unpooled;
import org.apache.spark.executor.ShuffleReadMetrics;
import org.apache.spark.serializer.DeserializationStream;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Product2;
import scala.Tuple2;
import scala.collection.AbstractIterator;
import scala.collection.Iterator;
import scala.runtime.BoxedUnit;

import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.response.CompressedShuffleBlock;
import org.apache.uniffle.common.RssShuffleUtils;
import org.apache.uniffle.common.exception.RssException;

public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K, C>> {

  private static final Logger LOG = LoggerFactory.getLogger(RssShuffleDataIterator.class);

  private Iterator<Tuple2<Object, Object>> recordsIterator = null;
  private SerializerInstance serializerInstance;
  private ShuffleReadClient shuffleReadClient;
  private ShuffleReadMetrics shuffleReadMetrics;
  private long readTime = 0;
  private long serializeTime = 0;
  private long decompressTime = 0;
  private Input deserializationInput = null;
  private DeserializationStream deserializationStream = null;
  private ByteBufInputStream byteBufInputStream = null;
  private long compressedBytesLength = 0;
  private long unCompressedBytesLength = 0;
  private ByteBuffer uncompressedData;

  public RssShuffleDataIterator(
      Serializer serializer,
      ShuffleReadClient shuffleReadClient,
      ShuffleReadMetrics shuffleReadMetrics) {
    this.serializerInstance = serializer.newInstance();
    this.shuffleReadClient = shuffleReadClient;
    this.shuffleReadMetrics = shuffleReadMetrics;
  }

  public Iterator<Tuple2<Object, Object>> createKVIterator(ByteBuffer data) {
    clearDeserializationStream();
    byteBufInputStream = new ByteBufInputStream(Unpooled.wrappedBuffer(data), true);
    deserializationStream = serializerInstance.deserializeStream(byteBufInputStream);
    return deserializationStream.asKeyValueIterator();
  }

  private void clearDeserializationStream() {
    if (byteBufInputStream != null) {
      try {
        byteBufInputStream.close();
      } catch (IOException e) {
        LOG.warn("Can't close ByteBufInputStream, memory may be leaked.");
      }
    }
    if (deserializationInput != null) {
      deserializationInput.close();
    }
    if (deserializationStream != null) {
      deserializationStream.close();
    }
    deserializationInput = null;
    deserializationStream = null;
    byteBufInputStream = null;
  }

  @Override
  public boolean hasNext() {
    if (recordsIterator == null || !recordsIterator.hasNext()) {
      // read next segment
      long startFetch = System.currentTimeMillis();
      CompressedShuffleBlock compressedBlock = shuffleReadClient.readShuffleBlockData();
      // If ShuffleServer delete

      ByteBuffer compressedData = null;
      if (compressedBlock != null) {
        compressedData = compressedBlock.getByteBuffer();
      }
      long fetchDuration = System.currentTimeMillis() - startFetch;
      shuffleReadMetrics.incFetchWaitTime(fetchDuration);
      if (compressedData != null) {
        long compressedDataLength = compressedData.limit() - compressedData.position();
        compressedBytesLength += compressedDataLength;
        shuffleReadMetrics.incRemoteBytesRead(compressedDataLength);
        // Directbytebuffers are not collected in time will cause executor easy
        // be killed by cluster managers(such as YARN) for using too much offheap memory
        if (uncompressedData != null && uncompressedData.isDirect()) {
          try {
            RssShuffleUtils.destroyDirectByteBuffer(uncompressedData);
          } catch (Exception e) {
            throw new RssException("Destroy DirectByteBuffer failed!", e);
          }
        }
        long startDecompress = System.currentTimeMillis();
        uncompressedData = RssShuffleUtils.decompressData(
            compressedData, compressedBlock.getUncompressLength());
        unCompressedBytesLength += compressedBlock.getUncompressLength();
        long decompressDuration = System.currentTimeMillis() - startDecompress;
        decompressTime += decompressDuration;
        // create new iterator for shuffle data
        long startSerialization = System.currentTimeMillis();
        recordsIterator = createKVIterator(uncompressedData);
        long serializationDuration = System.currentTimeMillis() - startSerialization;
        readTime += fetchDuration;
        serializeTime += serializationDuration;
      } else {
        // finish reading records, check data consistent
        shuffleReadClient.checkProcessedBlockIds();
        shuffleReadClient.logStatics();
        LOG.info("Fetch " + compressedBytesLength + " bytes cost " + readTime + " ms and "
            + serializeTime + " ms to serialize, " + decompressTime + " ms to decompress with unCompressionLength["
            + unCompressedBytesLength + "]");
        return false;
      }
    }
    return recordsIterator.hasNext();
  }

  @Override
  public Product2<K, C> next() {
    shuffleReadMetrics.incRecordsRead(1L);
    return (Product2<K, C>) recordsIterator.next();
  }

  public BoxedUnit cleanup() {
    clearDeserializationStream();
    if (shuffleReadClient != null) {
      shuffleReadClient.close();
    }
    shuffleReadClient = null;
    return BoxedUnit.UNIT;
  }

  @VisibleForTesting
  protected ShuffleReadMetrics getShuffleReadMetrics() {
    return shuffleReadMetrics;
  }
}

