/* MIT licensed - see LICENSE in the project root directory. */
package org.openzen.packetstreams.qpsp;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import org.openzen.packetstreams.NetworkLogger;
import org.openzen.packetstreams.qpsp.scheduler.CongestionController;

/**
 * Implements the standard (TCP/QUIC-based) packet scheduler.
 */
public class StandardPacketScheduler implements PacketScheduler {
	private final QPSPSession session;
	private final StreamMultiplexer multiplexer;
	private final List<InFlightPacket> inFlight = new ArrayList<>();
	private final Queue<QPSPPacket> retransmit = new PriorityQueue<>();
	private final Timer timer = new Timer();
	private final QPSPEndpoint target;
	private final BlockingQueue<Runnable> commandQueue = new LinkedBlockingQueue<>();
	private final CongestionController congestionController;
	
	private int latestRTT = 1000;
	private int srtt = 0;
	private boolean closed = false;
	private long lastSentPacketTimestamp = System.currentTimeMillis();
	
	public StandardPacketScheduler(QPSPSession session, StreamMultiplexer multiplexer, QPSPEndpoint target, CongestionController congestionController) {
		this.session = session;
		this.multiplexer = multiplexer;
		this.target = target;
		this.congestionController = congestionController;
		
		congestionController.setPacketScheduler(this);
		
		new Thread(() -> {
			while (!closed) {
				try {
					Runnable command = commandQueue.take();
					if (command != null)
						command.run();
				} catch (InterruptedException ex) {}
			}
		}).start();
	}
	
	@Override
	public void resume() {
		commandQueue.offer(() -> {
			doResume();
		});
	}

	@Override
	public void onAcknowledged(long streamID, long seq) {
		commandQueue.offer(() -> acknowledge(streamID, seq));
	}

	@Override
	public void onStreamClosed(long streamID) {
		// TODO
	}
	
	@Override
	public void onSessionClosed() {
		commandQueue.offer(() -> {
			timer.cancel();
			closed = true;
		});
	}
	
	@Override
	public int getEstimatedRTTInMillis() {
		return (srtt + latestRTT) / 2;
	}
	
	@Override
	public long getLastSentPacketTimestamp() {
		return lastSentPacketTimestamp;
	}
	
	private void doResume() {
		while (sendPacket()) {}
	}
	
	private boolean sendPacket() {
		if (congestionController.isCongested())
			return false;
		
		QPSPPacket next;
		boolean isRetransmission;
		if (!retransmit.isEmpty() && retransmit.peek().priority >= multiplexer.peekPriority()) {
			next = retransmit.poll();
			isRetransmission = true;
		} else {
			next = multiplexer.next();
			isRetransmission = false;
		}
		if (next == null)
			return !retransmit.isEmpty();
		
		if (!next.isLossy()) {
			InFlightPacket packet = new InFlightPacket(next, isRetransmission);
			inFlight.add(packet);
			
			int timeout = getTransmissionTimeout();
			timer.schedule(packet, timeout);
		}
		if (!isRetransmission)
			session.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, next.localStreamID, "Transmitting #" + next.seq + ", " + congestionController.getStateInfo());
		else
			session.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, next.localStreamID, "Retransmitting #" + next.seq + ", " + congestionController.getStateInfo());
		
		congestionController.onSent(next);
		try {
			lastSentPacketTimestamp = System.currentTimeMillis();
			target.send(session, next.data);
			return true;
		} catch (IOException ex) {
			ex.printStackTrace();
			return false;
		}
	}
	
	private void acknowledge(long streamID, long seq) {
		InFlightPacket toRemove = findInFlight(streamID, seq);
		if (toRemove != null) {
			toRemove.cancel();
			inFlight.remove(toRemove);

			int rtt = (int)(System.currentTimeMillis() - toRemove.packet.timestamp);
			// TODO: something we can do about retransmitted packets? their calculation will be off!
			latestRTT = rtt;
			srtt = srtt == 0 ? rtt : (int)(srtt * 0.9f + rtt * 0.1f);

			onAckLossless(toRemove.packet);

			if (toRemove.isRetransmission) {
				session.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, streamID, "Ack " + seq + ", retransmitted packet, " + congestionController.getStateInfo());
			} else {
				session.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, streamID, "Ack " + seq + ", rtt = " + latestRTT + ", " + congestionController.getStateInfo());
			}
		} else {
			congestionController.onAckDuplicate(streamID, seq);
		}
		
		QPSPPacket toRemoveRetransmit = findRetransmit(streamID, seq);
		retransmit.remove(toRemoveRetransmit);
		doResume();
	}
	
	private void onAckLossless(QPSPPacket packet) {
		congestionController.onAckLossless(packet);
		resume();
	}
	
	private void onPacketLost(QPSPPacket packet) {
		congestionController.onPacketLost(packet);
		doResume();
	}
	
	private InFlightPacket findInFlight(long streamID, long seq) {
		for (InFlightPacket packet : inFlight)
			if (packet.packet.localStreamID == streamID && packet.packet.seq == seq)
				return packet;
		
		return null;
	}
	
	private QPSPPacket findRetransmit(long streamID, long seq) {
		for (QPSPPacket packet : retransmit)
			if (packet.localStreamID == streamID && packet.seq == seq)
				return packet;
		
		return null;
	}
	
	private int getTransmissionTimeout() {
		return (int)(Math.max(srtt, latestRTT) * 1.5f);
	}
	
	private class InFlightPacket extends TimerTask {
		private final QPSPPacket packet;
		private final boolean isRetransmission;
		
		public InFlightPacket(QPSPPacket packet, boolean isRetransmission) {
			this.packet = packet;
			this.isRetransmission = isRetransmission;
		}

		@Override
		public void run() {
			commandQueue.offer(() -> {
				session.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, packet.localStreamID, "Lost packet " + packet.seq);

				retransmit.add(packet);
				inFlight.remove(this);
				onPacketLost(packet);
			});
		}
	}
}
