/* MIT licensed - see LICENSE in the project root directory. */
package org.openzen.packetstreams.qpsp.congestion;

import org.openzen.packetstreams.qpsp.QPSPTransmittingPacket;
import org.openzen.packetstreams.qpsp.StandardPacketScheduler;

/**
 * Implements a variant on NewReno as used in TCP or QUIC.
 */
public class StandardCongestionController implements CongestionController {
	// taken from the QUIC spec draft @ https://tools.ietf.org/id/draft-ietf-quic-recovery-12.html#rfc.section.4
	private static final int DEFAULT_MSS = 1460;
	private static final int INITIAL_WINDOW = 10 * DEFAULT_MSS;
	private static final int MINIMUM_WINDOW = /*(int)(.75 * 64000);*/4 * DEFAULT_MSS; // upped from QUIC (was 2 * DEFAULT_MSS), has a very large effect on network with high loss and high RTT
	private static final float LOSS_REDUCTION_FACTOR = 0.6f;
	
	private int bytesInFlight = 0;
	private int congestionWindow = INITIAL_WINDOW;
	private long endOfRecovery = 0;
	private int ssthresh = Integer.MAX_VALUE;
	private StandardPacketScheduler scheduler;
	
	@Override
	public void setPacketScheduler(StandardPacketScheduler scheduler) {
		this.scheduler = scheduler;
	}
	
	@Override
	public boolean isCongested() {
		return bytesInFlight > congestionWindow;
	}
	
	@Override
	public void onSent(QPSPTransmittingPacket packet) {
		if (!packet.packet.lossy)
			bytesInFlight += packet.packet.data.length;
	}
	
	@Override
	public void onAckLossless(QPSPTransmittingPacket packet) {
		if (!packet.lost)
			bytesInFlight -= packet.packet.data.length;
		
		if (!inRecovery(packet.seq)) {
			if (congestionWindow < ssthresh)
				congestionWindow += packet.packet.data.length;
			else
				congestionWindow += DEFAULT_MSS * packet.packet.data.length / congestionWindow;
		}
	}
	
	@Override
	public void onAckDuplicate(long seq) {
		
	}
	
	@Override
	public void onPacketLost(QPSPTransmittingPacket packet) {
		if (!packet.lost) {
			bytesInFlight -= packet.packet.data.length;
			packet.lost = true;
		}
		
		if (!inRecovery(packet.seq)) {
			endOfRecovery = packet.seq;
			congestionWindow *= LOSS_REDUCTION_FACTOR;
			congestionWindow = Math.max(congestionWindow, getMinimumWindow());
			ssthresh = congestionWindow;
		}
	}
	
	@Override
	public String getStateInfo() {
		return "window=" + bytesInFlight + "/" + congestionWindow;
	}
	
	private boolean inRecovery(long seq) {
		return seq <= endOfRecovery;
	}
	
	private int getMinimumWindow() {
		return Math.max(4 * scheduler.getEstimatedRTTInMillis(), MINIMUM_WINDOW);
	}
}
