/* MIT licensed - see LICENSE in the project root directory. */
package org.openzen.packetstreams.qpsp;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import org.openzen.packetstreams.NetworkLogger;
import org.openzen.packetstreams.qpsp.scheduler.CongestionController;

/**
 * Implements the standard (TCP/QUIC-based) packet scheduler.
 */
public class StandardPacketScheduler implements PacketScheduler {
	private static final int FAST_RETRANSMIT_PACKETS = 5;
	
	private final QPSPSession session;
	private final StreamMultiplexer multiplexer;
	private final List<InFlightPacket> inFlight = new ArrayList<>();
	private final Queue<QPSPPacket> retransmit = new PriorityQueue<>();
	private final Timer timer = new Timer();
	private final QPSPEndpoint target;
	private final CongestionController congestionController;
	
	private int latestRTT = 1000;
	private int srtt = 0;
	private boolean closed = false;
	private boolean paused = false;
	private long lastSentPacketTimestamp = System.currentTimeMillis();
	private int packetsLostSinceLastReceived = 0;
	
	public StandardPacketScheduler(QPSPSession session, StreamMultiplexer multiplexer, QPSPEndpoint target, CongestionController congestionController) {
		this.session = session;
		this.multiplexer = multiplexer;
		this.target = target;
		this.congestionController = congestionController;
		
		congestionController.setPacketScheduler(this);
	}
	
	@Override
	public void pause() {
		paused = true;
	}
	
	@Override
	public void resume() {
		paused = false;
		packetsLostSinceLastReceived = 0;
		doResume();
	}
	
	@Override
	public void resumeStreams() {
		doResume();
	}
	
	@Override
	public void onPacketReceived() {
		packetsLostSinceLastReceived = 0;
		doResume();
	}

	@Override
	public void onAcknowledged(long streamID, long seq) {
		acknowledge(streamID, seq);
	}

	@Override
	public void onStreamClosed(long streamID) {
		// TODO
	}
	
	@Override
	public void onSessionClosed() {
		timer.cancel();
		closed = true;
	}
	
	@Override
	public int getEstimatedRTTInMillis() {
		return (srtt + latestRTT) / 2;
	}
	
	@Override
	public long getLastSentPacketTimestamp() {
		return lastSentPacketTimestamp;
	}
	
	private void doResume() {
		if (paused)
			return;
		
		while (sendPacket()) {}
	}
	
	private boolean sendPacket() {
		if (congestionController.isCongested())
			return false;
		
		QPSPPacket next;
		boolean isRetransmission;
		if (!retransmit.isEmpty() && retransmit.peek().priority >= multiplexer.peekPriority()) {
			next = retransmit.poll();
			isRetransmission = true;
		} else {
			next = multiplexer.next();
			isRetransmission = false;
		}
		
		if (next == null)
			return !retransmit.isEmpty();
		
		if (!next.keepalive && packetsLostSinceLastReceived > 20) {
			session.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, -1, "Stopped session transmission due to no response");
			return false;
		}
		if (!next.isLossy()) {
			InFlightPacket packet = new InFlightPacket(next, isRetransmission);
			inFlight.add(packet);
			
			int timeout = getTransmissionTimeout();
			timer.schedule(packet, timeout);
		}
		
		String active = multiplexer.getActive();
		if (!isRetransmission)
			session.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, next.localStreamID, "Transmitting #" + next.seq + ", " + congestionController.getStateInfo() + ", active=" + active);
		else
			session.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, next.localStreamID, "Retransmitting #" + next.seq + ", " + congestionController.getStateInfo());
		
		congestionController.onSent(next);
		try {
			lastSentPacketTimestamp = System.currentTimeMillis();
			target.send(session, next.data);
			return true;
		} catch (IOException ex) {
			ex.printStackTrace();
			return false;
		}
	}
	
	private void acknowledge(long streamID, long seq) {
		packetsLostSinceLastReceived = 0;
		
		InFlightPacket toRemove = findInFlight(streamID, seq);
		if (toRemove != null) {
			toRemove.acknowledged = true;
			toRemove.cancel();
			inFlight.remove(toRemove);

			int rtt = (int)(System.currentTimeMillis() - toRemove.packet.timestamp);
			// TODO: something we can do about retransmitted packets? their calculation will be off!
			latestRTT = rtt;
			srtt = srtt == 0 ? rtt : (int)(srtt * 0.9f + rtt * 0.1f);

			onAckLossless(toRemove.packet);

			if (toRemove.isRetransmission) {
				session.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, streamID, "Ack " + seq + ", retransmitted packet, " + congestionController.getStateInfo());
			} else {
				session.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, streamID, "Ack " + seq + ", rtt = " + latestRTT + ", " + congestionController.getStateInfo());
			}
		} else {
			congestionController.onAckDuplicate(streamID, seq);
		}
		
		QPSPPacket toRemoveRetransmit = findRetransmit(streamID, seq);
		retransmit.remove(toRemoveRetransmit);
		
		fastRetransmit(streamID, seq);
		doResume();
	}
	
	private void fastRetransmit(long streamID, long seq) {
		List<InFlightPacket> fastRetransmit = new ArrayList<>();
		for (InFlightPacket inFlight : this.inFlight) {
			if (inFlight.packet.localStreamID == streamID && inFlight.packet.seq < seq - FAST_RETRANSMIT_PACKETS) {
				inFlight.cancel();
				fastRetransmit.add(inFlight);
			}
		}
		inFlight.removeAll(fastRetransmit);
		for (InFlightPacket packet : fastRetransmit) {
			session.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, packet.packet.localStreamID, "Fast retransmit " + packet.packet.seq);
			retransmit.add(packet.packet);
			onPacketLost(packet.packet);
		}
	}
	
	private void onAckLossless(QPSPPacket packet) {
		congestionController.onAckLossless(packet);
		doResume();
	}
	
	private void onPacketLost(QPSPPacket packet) {
		packetsLostSinceLastReceived++;
		congestionController.onPacketLost(packet);
		doResume();
	}
	
	private InFlightPacket findInFlight(long streamID, long seq) {
		for (InFlightPacket packet : inFlight)
			if (packet.packet.localStreamID == streamID && packet.packet.seq == seq)
				return packet;
		
		return null;
	}
	
	private QPSPPacket findRetransmit(long streamID, long seq) {
		for (QPSPPacket packet : retransmit)
			if (packet.localStreamID == streamID && packet.seq == seq)
				return packet;
		
		return null;
	}
	
	private int getTransmissionTimeout() {
		return (int)(Math.max(srtt, latestRTT) * 1.75f);
	}
	
	private class InFlightPacket extends TimerTask {
		private final QPSPPacket packet;
		private final boolean isRetransmission;
		private boolean acknowledged = false;
		
		public InFlightPacket(QPSPPacket packet, boolean isRetransmission) {
			this.packet = packet;
			this.isRetransmission = isRetransmission;
		}

		@Override
		public void run() {
			session.commandQueue.offer(() -> {
				if (acknowledged)
					return;
				
				session.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, packet.localStreamID, "Lost packet " + packet.seq);

				retransmit.add(packet);
				inFlight.remove(this);
				onPacketLost(packet);
			});
		}
	}
}
