/* MIT licensed - see LICENSE in the project root directory. */
package org.openzen.packetstreams.qpsp;

import java.io.IOException;
import java.util.Arrays;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Timer;
import java.util.TimerTask;
import org.openzen.packetstreams.NetworkLogger;
import org.openzen.packetstreams.io.BytesDataOutput;
import org.openzen.packetstreams.qpsp.congestion.CongestionController;

/**
 * Implements the standard (TCP/QUIC-based) packet scheduler.
 */
public class StandardPacketScheduler implements PacketScheduler {
	private static final int FAST_RETRANSMIT_PACKETS = 5;
	
	private final QPSPConnection connection;
	private final StreamMultiplexer multiplexer;
	private final Timer timer = new Timer();
	private final TransmissionBuffer inFlight = new TransmissionBuffer();
	private final Queue<QPSPTransmittingPacket> retransmit = new PriorityQueue<>();
	private final CongestionController congestionController;
	
	private int latestRTT = 1000;
	private int srtt = 0;
	private boolean paused = false;
	private long lastSentPacketTimestamp = System.currentTimeMillis();
	private int packetsLostSinceLastReceived = 0;
	
	public StandardPacketScheduler(QPSPConnection connection, StreamMultiplexer multiplexer, CongestionController congestionController) {
		this.connection = connection;
		this.multiplexer = multiplexer;
		this.congestionController = congestionController;
		
		congestionController.setPacketScheduler(this);
	}
	
	@Override
	public void pause() {
		paused = true;
	}
	
	@Override
	public void resume() {
		paused = false;
		packetsLostSinceLastReceived = 0;
		doResume();
	}
	
	@Override
	public void resumeStreams() {
		doResume();
	}
	
	@Override
	public void onPacketReceived() {
		packetsLostSinceLastReceived = 0;
		doResume();
	}

	@Override
	public void onAcknowledged(long fromSeq, long toSeq, NackRange[] nacks) {
		Arrays.sort(nacks, (a, b) -> Long.compare(a.seq, b.seq));
		int nacki = 0;
		for (long ackseq = fromSeq; ackseq <= toSeq; ackseq++) {
			if (nacki < nacks.length && ackseq >= nacks[nacki].seq + nacks[nacki].length)
				nacki++;
			if (nacki < nacks.length && ackseq >= nacks[nacki].seq)
				continue;
			
			ack(inFlight.ack(ackseq));
		}
		
		Long lowest = inFlight.getLowestSeq();
		if (lowest != null)
			connection.stopWaiting(lowest);
		
		//fastRetransmit(streamID, seq);
		doResume();
	}
	
	private void ack(QPSPTransmittingPacket packet) {
		if (packet == null)
			return;
		
		int rtt = (int)(System.currentTimeMillis() - packet.timestamp);
		latestRTT = rtt;
		srtt = srtt == 0 ? rtt : (int)(srtt * 0.9f + rtt * 0.1f);
		
		onAckLossless(packet);
		connection.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, connection.localID, -1, "Ack " + packet.seq + ", rtt = " + latestRTT + ", " + congestionController.getStateInfo());
	}

	@Override
	public void onStreamClosed(int streamID) {
		// TODO
	}
	
	@Override
	public void onConnectionClosed() {
		timer.cancel();
	}
	
	@Override
	public int getEstimatedRTTInMillis() {
		return (srtt + latestRTT) / 2;
	}
	
	@Override
	public long getLastSentPacketTimestamp() {
		return lastSentPacketTimestamp;
	}
	
	public TimerTask retransmit(QPSPTransmittingPacket packet, int timeout) {
		TimerTask task = new TimerTask() {
			@Override
			public void run() {
				connection.runOnNetworkThread(() -> {
					if (packet.packet.acknowledged)
						return;
					
					retransmit.add(packet);
					connection.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, connection.localID, -1, "Scheduling for retransmission: #" + packet.seq + "; " + retransmit.size() + " packets in retransmission");
					onPacketLost(packet);
				});
			}
		};
		timer.schedule(task, timeout);
		return task;
	}
	
	private void doResume() {
		if (paused)
			return;
		
		while (sendPacket()) {}
	}
	
	private boolean sendPacket() {
		if (congestionController.isCongested()) {
			//connection.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, connection.localID, -1, "Stopping transmission; congested");
			return false;
		}
		
		QPSPPacket next = null;
		if (packetsLostSinceLastReceived > 20) {
			if (multiplexer.peekPriority() == QPSPConnection.CONTROL_PRIORITY)
				next = bundle();
			else
				return false;
		}
		
		QPSPTransmittingPacket retransmitted = null;
		while (next == null && !retransmit.isEmpty()) {
			QPSPTransmittingPacket candidate = retransmit.peek();
			if (multiplexer.peekPriority() > candidate.packet.priority)
				break;
			if (candidate.packet.acknowledged) {
				connection.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, connection.localID, -1, "Not retransmitting #" + candidate.seq + " since already acknowledged");
				retransmit.poll();
				continue;
			}
			
			retransmitted = retransmit.poll();
			inFlight.retransmitted(retransmitted);
			next = retransmitted.packet;
			connection.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, connection.localID, -1, "Picked #" + retransmitted.seq + " for retransmission; " + retransmit.size() + " packets left");
			break;
		}
		
		if (next == null)
			next = bundle();
		
		if (next == null) {
			if (!retransmit.isEmpty())
				return true;
			
			connection.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, connection.localID, -1, "Stopping transmission; no data");
			return false;
		}
		if (next.data.length == 0)
			throw new AssertionError();
		
		try {
			lastSentPacketTimestamp = System.currentTimeMillis();
			long seq;
			if (next.lossy)
				seq = connection.sendLossyPacket(next.data);
			else
				seq = connection.sendLosslessPacket(next.data);
			
			QPSPTransmittingPacket packet = new QPSPTransmittingPacket(seq, next, this, retransmitted);
			if (!next.lossy) {
				inFlight.add(packet);
				packet.retransmitAt(getTransmissionTimeout());
			}
			congestionController.onSent(packet);
			
			String active = multiplexer.getActive();
			if (retransmitted == null)
				connection.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, connection.localID, -1, "Transmitted #" + seq + ", " + congestionController.getStateInfo() + ", active=" + active);
			else
				connection.logger.log(NetworkLogger.CATEGORY_RETRANSMISSION, connection.localID, -1, "Retransmitted #" + retransmitted.seq + " as #" + seq + " (originally " + retransmitted.originalSeq + "), " + congestionController.getStateInfo());

			return true;
		} catch (IOException ex) {
			ex.printStackTrace();
			return false;
		}
	}
	
	private QPSPPacket bundle() {
		boolean lossy = true;
		boolean keepalive = false;
		int available = connection.maxUDPPacketSize;
		int priority = Integer.MIN_VALUE;
		
		BytesDataOutput output = new BytesDataOutput();
		while (available > 32) {
			FrameData frame = multiplexer.next(available - 5);
			if (frame != null) {
				lossy &= frame.lossy;
				keepalive |= frame.keepalive;
				priority = Math.max(priority, frame.priority);
				available -= frame.data.length;
				output.writeRawBytes(frame.data);
			} else {
				break;
			}
		}
		
		if (output.length() == 0)
			return null;
		
		return new QPSPPacket(priority, output.toByteArray(), lossy, keepalive);
	}
	
	private void onAckLossless(QPSPTransmittingPacket packet) {
		congestionController.onAckLossless(packet);
		doResume();
	}
	
	private void onPacketLost(QPSPTransmittingPacket packet) {
		packetsLostSinceLastReceived++;
		congestionController.onPacketLost(packet);
		doResume();
	}
	
	private int getTransmissionTimeout() {
		return (int)(Math.max(srtt, latestRTT) * 1.75f);
	}
}
