package eu.dnetlib.enabling.aas.client;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

import org.apache.log4j.Logger;
import org.opensaml.lite.common.SAMLObject;
import org.opensaml.lite.saml2.core.Assertion;
import org.opensaml.lite.saml2.core.Attribute;
import org.opensaml.lite.saml2.core.AttributeStatement;
import org.opensaml.lite.xacml.XACMLConstants;

import eu.dnetlib.enabling.aas.DNetAAError;
import eu.dnetlib.enabling.aas.DNetAAResponse;
import eu.dnetlib.enabling.aas.DNetAuthenticateRequest;
import eu.dnetlib.enabling.aas.xacml.ctx.ActionType;
import eu.dnetlib.enabling.aas.xacml.ctx.AttributeType;
import eu.dnetlib.enabling.aas.xacml.ctx.AttributeValueType;
import eu.dnetlib.enabling.aas.xacml.ctx.EnvironmentType;
import eu.dnetlib.enabling.aas.xacml.ctx.RequestType;
import eu.dnetlib.enabling.aas.xacml.ctx.ResourceType;
import eu.dnetlib.enabling.aas.xacml.ctx.SubjectType;
import eu.dnetlib.enabling.aas.xacml.profile.saml.XACMLAuthzDecisionQueryType;

/**
 * Assertion refreshing helper class.
 * @author mhorst
 *
 */
public class AssertionRefreshingHelper {
	
	protected final static Logger log = Logger.getLogger(AssertionRefreshingHelper.class);

	public static final String RESOURCE_VALUE_ASSERTION = "assertion";
	public static final String ACTION_VALUE_REFRESH 	= "refresh";
	public static final String SUBJECT_VALUE_ATTR_ID 	= "yadda:authn:subject:id";
	
	/**
	 * Retrieves collection of assertions from given SAML objects.
	 * @param samlObjects array of SAML objects
	 * @param ids assertions identifiers collection, cannot be null
	 * @return collection of assertions
	 */
	public static Collection<Assertion> getAssertions(SAMLObject[] samlObjects, Collection<String> ids) {
		if (samlObjects!=null && samlObjects.length>0) {
			Collection<Assertion> result = new ArrayList<Assertion>();
			for (String id : ids) {
				Assertion foundAssertion = findAssertion(id, samlObjects);
				if (foundAssertion!=null) {
					result.add(foundAssertion);
				} else {
					log.error("couldn't find assertion among requst objects " +
							"for given assertionId: " + id);
				}
			}
			return result;
		} else {
			StringBuffer strBuff = new StringBuffer();
			for (String id : ids) {
				strBuff.append(id);
				strBuff.append("; ");
			}
			log.error("no assertions found in request object, cannot refresh ids: " +
					strBuff.toString());
			return Collections.emptyList();
		}
	}
	
	/**
	 * Finds assertion for given assertionId among array of SAMLObjects
	 * @param assertionId assertion identifier
	 * @param samlObjects not null array of {@link SAMLObject}
	 * @return assertion for given assertionId
	 */
	public static Assertion findAssertion(String assertionId, SAMLObject[] samlObjects) {
		for (SAMLObject samlObj : samlObjects) {
			if (samlObj instanceof Assertion) {
				if (assertionId.equals(((Assertion)samlObj).getID())) {
					return (Assertion) samlObj;
				}
			} else {
				log.debug("not an Assertion instance, got: " + 
						(samlObj!=null?samlObj.getClass().getCanonicalName():"null"));
			}
		}
		return null;
	}
	
	/**
	 * Returns outdated assertion ids.
	 * @param response response object
	 * @return collection of outdated assertions identifiers
	 */
	public static Collection<String> getStaleAssertionsIds(DNetAAResponse response) {
		return getInvalidAssertionsIds(response, DNetAAError.WARN_ASSERTION_OUTDATED);
	}
	
	/**
	 * Returns permanently expired assertion ids which requires reauthentication.
	 * @param response response object
	 * @return collection of permanently expired assertions identifiers
	 */
	public static Collection<String> getPermanentlyExpiredAssertionsIds(DNetAAResponse response) {
		return getInvalidAssertionsIds(response, DNetAAError.WARN_ASSERTION_PERM_EXPIRED);
	}
	
	/**
	 * Returns invalid assertion ids according to the error type.
	 * Never returns null.
	 * @param response response object
	 * @param errorType type of error related to assertion
	 * @return collection of invalid assertions identifiers
	 */
	protected static Collection<String> getInvalidAssertionsIds(DNetAAResponse response, 
			String errorType) {
		Collection<String> result = null;
		for (DNetAAError aaError : response.getErrors()) {
			if (errorType.equals(aaError.getErrorId())) {
				if (result==null) {
					result = new ArrayList<String>();
				}
				result.add((String) aaError.getData());
			}
		}
		return result;
	}
	
	public static boolean containsExpirationRelatedErrors(DNetAAResponse response) {
		if (response.getErrors()!=null && response.getErrors().length>0) {
			for (DNetAAError aaError : response.getErrors()) {
				if (DNetAAError.WARN_ASSERTION_OUTDATED.equals(aaError.getErrorId()) ||
						DNetAAError.WARN_ASSERTION_PERM_EXPIRED.equals(aaError.getErrorId())) {
					return true;
				}
			}
		}
		return false;
	}
	
	/**
	 * Creates assertion refreshing {@link DNetAuthenticateRequest}.
	 * @param subjectId
	 * @param subjectType
	 * @return assertion refreshing authenticate request
	 */
	public static DNetAuthenticateRequest buildAssertionRefreshingRequest(
			String subjectId, String subjectType) {
		XACMLAuthzDecisionQueryType authnQuery = new XACMLAuthzDecisionQueryType();
		DNetAuthenticateRequest request = new DNetAuthenticateRequest(authnQuery);
		
		RequestType samlRequest = new RequestType();
		authnQuery.setRequest(samlRequest);
//		subject
		List<SubjectType> subjects = new ArrayList<SubjectType>();
		if (subjectType!=null) {
			SubjectType samlSubject = new SubjectType();
			samlSubject.setSubjectCategory(XACMLConstants.ACCESS_SUBJECT_CATEGORY);
			AttributeType subjectAttr = new AttributeType();
			subjectAttr.setAttributeID(XACMLConstants.SUBJECT_ID);
			subjectAttr.setDataType(XACMLConstants.DATATYPE_STRING);
			AttributeValueType subjectAttrValue = new AttributeValueType();
			subjectAttrValue.setValue(subjectType);
			subjectAttr.setAttributeValues(new AttributeValueType[] {subjectAttrValue});
			samlSubject.setAttributes(new AttributeType[] {subjectAttr});
			subjects.add(samlSubject);
		}
//		attaching required login parameter if any provided
		if (subjectId!=null) {
			SubjectType samlSubject3 = new SubjectType();
			samlSubject3.setSubjectCategory(XACMLConstants.SUBJECT_PARAM_CATEGORY);
			AttributeType subjectAttr3 = new AttributeType();
			subjectAttr3.setAttributeID(XACMLConstants.SUBJECT_PARAM_ID);
			subjectAttr3.setDataType(XACMLConstants.DATATYPE_STRING);
			AttributeValueType subjectAttr3Value = new AttributeValueType();
			subjectAttr3Value.setValue(subjectId);
			subjectAttr3.setAttributeValues(new AttributeValueType[] {subjectAttr3Value});
			samlSubject3.setAttributes(new AttributeType[] {subjectAttr3});
			subjects.add(samlSubject3);
		}
		samlRequest.setSubjects(subjects.toArray(new SubjectType[subjects.size()]));
		
//		resource
		ResourceType samlResource = new ResourceType();
		AttributeType resourceAttr = new AttributeType();
		resourceAttr.setAttributeID(XACMLConstants.RESOURCE_ID);
		resourceAttr.setDataType(XACMLConstants.DATATYPE_STRING);
		AttributeValueType resourceAttrValue = new AttributeValueType();
		resourceAttrValue.setValue(RESOURCE_VALUE_ASSERTION);
		resourceAttr.setAttributeValues(new AttributeValueType[] {resourceAttrValue});
		samlResource.setAttributes(new AttributeType[] {resourceAttr});
		samlRequest.setResources(new ResourceType[] {samlResource});
//		action
		ActionType samlAction = new ActionType();
		AttributeType actionAttr = new AttributeType();
		actionAttr.setAttributeID(XACMLConstants.ACTION_ID);
		actionAttr.setDataType(XACMLConstants.DATATYPE_STRING);
		AttributeValueType actionAttrValue = new AttributeValueType();
		actionAttrValue.setValue(ACTION_VALUE_REFRESH);
		actionAttr.setAttributeValues(new AttributeValueType[] {actionAttrValue});
		samlAction.setAttributes(new AttributeType[] {actionAttr});
		samlRequest.setAction(samlAction);
//		environment
		EnvironmentType samlEnvironment = new EnvironmentType();
		samlRequest.setEnvironment(samlEnvironment);
		return request;
	}
	
	/**
	 * Extracts subject value for selecting proper refreshing policy.
	 * @param assertion
	 * @return subject value for selecting proper refreshing policy
	 */
	public static String extractSubjectValue(Assertion assertion) {
		for (AttributeStatement currentAS : assertion.getAttributeStatement()) {
			if (currentAS.getAttributes()!=null) {
				for (Attribute attr : currentAS.getAttributes()) {
					if (SUBJECT_VALUE_ATTR_ID.equals(attr.getName())) {
						return (String) attr.getAttributeValues().iterator().next();
					}
				}
			}
		}
//		fallback
		return null;
	}
}
