package eu.dnetlib.lbs.metrics;

import java.lang.reflect.Method;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import io.prometheus.client.Summary;
import org.springframework.stereotype.Component;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;

@Component
public class RequestTimingInterceptor extends HandlerInterceptorAdapter implements MetricInterceptor {

	private static final String REQ_PARAM_TIMING = "timing";

	private static final Summary responseTimeInMs = Summary
			.build()
			.name("http_response_time_milliseconds")
			.labelNames("method", "handler", "status")
			.help("Request completed time in milliseconds")
			.register();

	@Override
	public boolean preHandle(final HttpServletRequest request, final HttpServletResponse response, final Object handler) throws Exception {
		request.setAttribute(REQ_PARAM_TIMING, System.currentTimeMillis());
		return true;
	}

	@Override
	public void afterCompletion(final HttpServletRequest request, final HttpServletResponse response, final Object handler, final Exception ex) throws Exception {
		final Long timingAttr = (Long) request.getAttribute(REQ_PARAM_TIMING);
		final long completedTime = System.currentTimeMillis() - timingAttr;

		String handlerLabel = handler.toString();
		// get short form of handler method name
		if (handler instanceof HandlerMethod) {
			Method method = ((HandlerMethod) handler).getMethod();
			handlerLabel = method.getDeclaringClass().getSimpleName() + "." + method.getName();
		}
		responseTimeInMs.labels(
				request.getMethod(),
				handlerLabel,
				Integer.toString(response.getStatus())).observe(completedTime);
	}

}
