package eu.dnetlib.data.transform;

import java.io.IOException;
import java.io.InputStream;
import java.io.StringReader;
import java.util.List;
import java.util.Objects;
import java.util.Properties;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import com.google.protobuf.InvalidProtocolBufferException;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoDatabase;
import eu.dnetlib.data.proto.OafProtos.Oaf;
import eu.dnetlib.data.proto.OafProtos.OafEntity;
import eu.dnetlib.data.transform.xml2.DatasetToProto;
import eu.dnetlib.miscutils.collections.Pair;
import eu.dnetlib.miscutils.datetime.HumanTime;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.time.StopWatch;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math.stat.descriptive.DescriptiveStatistics;
import org.bson.Document;
import org.dom4j.DocumentException;
import org.dom4j.io.SAXReader;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;

@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration(classes = { ConfigurationTestConfig.class })
public class ParserToProtoIT {

	private static final Log log = LogFactory.getLog(ParserToProtoIT.class);

	private static final String DATACITE = "datacite";
	private static final String NARCIS = "narcis";

	private static final int BATCH_SIZE = 10000;
	private static final int LOG_FREQ = 5000;
	private static final int LIMIT = 10000;
	public static final String BODY = "body";

	private static String basePathProfiles = "/eu/dnetlib/test/profiles/TransformationRuleDSResources/TransformationRuleDSResourceType/";

	private int batchSize = BATCH_SIZE;
	private int logFreq = LOG_FREQ;
	private int limit = LIMIT;

	@Autowired
	private MongoDatabase db;

	@Autowired()
	private Properties testProperties;

	@Before
	public void setUp() {
		logFreq = Integer.valueOf(testProperties.getProperty("test.logFreq", String.valueOf(LOG_FREQ)));
		batchSize = Integer.valueOf(testProperties.getProperty("test.batchSize", String.valueOf(BATCH_SIZE)));
		limit = Integer.valueOf(testProperties.getProperty("test.limit", String.valueOf(limit)));
	}

	@Test
	@Ignore
	public void testParseDataciteWithVTD() throws IOException {
		doTest(s -> new Pair<>(s, new DatasetToProto().apply(s)), DATACITE);
	}

	@Test
	@Ignore
	public void testParseDataciteWithXSLT() throws IOException {
		final String xslt = IOUtils.toString(loadFromTransformationProfile("odf2hbase.xml"));
		final XsltRowTransformer transformer = XsltRowTransformerFactory.newInstance(xslt);

		doTest(rowToOaf(transformer), DATACITE);
	}

	@Test
	@Ignore
	public void testParseNarcisWithXSLT() throws IOException {
		final String xslt = IOUtils.toString(loadFromTransformationProfile("oaf2hbase.xml"));
		final XsltRowTransformer transformer = XsltRowTransformerFactory.newInstance(xslt);

		doTest(rowToOaf(transformer), NARCIS);
	}

	//// HELPERS

	private void doTest(final Function<String, Pair<String, Oaf>> mapper, final String collectionName) {
		final MongoCollection<Document> collection = db.getCollection(collectionName);

		final long collectionSize = collection.count();
		log.info(String.format("found %s records in collection '%s'", collectionSize, collectionName));

		final AtomicInteger read = new AtomicInteger(0);
		final DescriptiveStatistics stats = new DescriptiveStatistics();

		final StopWatch recordTimer = new StopWatch();
		final StopWatch totalTimer = StopWatch.createStarted();

		StreamSupport.stream(collection.find().batchSize(batchSize).spliterator(), false)
				.limit(limit)
				.peek(d -> {
					if (read.addAndGet(1) % logFreq == 0) {
						log.info(String.format("records read so far %s", read.get()));
						//log.info(String.format("stats so far %s", stats.toString()));
					}
				})
				.map(d -> (String) d.get("body"))
				.filter(Objects::nonNull)
				.collect(Collectors.toList())   // load them in memory first
				.stream()
				.peek(s -> recordTimer.start())
				.map(mapper)
				.forEach(pair -> {
					recordTimer.stop();
					stats.addValue(recordTimer.getTime());
					recordTimer.reset();

					assertNotNull(pair);
					assertTrue(pair.getValue().hasEntity());

					try {
						final org.dom4j.Document doc = new SAXReader().read(new StringReader(pair.getKey()));
						final OafEntity entity = pair.getValue().getEntity();

						//TODO add more asserts
						assertTrue(entity.getId().contains(doc.valueOf("/*[local-name() = 'record']/*[local-name() = 'header']/*[local-name() = 'objIdentifier']/text()")));

					} catch (DocumentException e) {
						throw new IllegalArgumentException("unable to parse record " + pair.getKey(), e);
					}
				});

		totalTimer.stop();
		log.info(String.format("processed %s/%s records in %s", read.get(), collectionSize, HumanTime.exactly(totalTimer.getTime())));
		log.info(stats.toString());
	}

	private Function<String, Pair<String, Oaf>> rowToOaf(final XsltRowTransformer transformer) {
		return xml -> {
			final List<Row> rows = transformer.apply(xml);
			if (rows.isEmpty()) {
				return null;
			}

			return rows.stream()
					.filter(row -> row.getColumn(BODY) != null)
					.map(row -> row.getColumn(BODY))
					.map(c -> c.getValue())
					.map(b -> {
						try {
							return Oaf.parseFrom(b);
						} catch (InvalidProtocolBufferException e) {
							throw new IllegalStateException(e);
						}
					})
					.filter(Objects::nonNull)
					.map(oaf -> new Pair<>(xml, oaf))
					.findFirst()
					.get();
		};
	}

	private InputStream loadFromTransformationProfile(final String profilePath) {
		log.info("Loading xslt from: " + basePathProfiles + profilePath);
		InputStream profile = getClass().getResourceAsStream(basePathProfiles + profilePath);
		SAXReader saxReader = new SAXReader();
		org.dom4j.Document doc = null;
		try {
			doc = saxReader.read(profile);
		} catch (DocumentException e) {
			e.printStackTrace();
			throw new RuntimeException(e);
		}
		String xslt = doc.selectSingleNode("//SCRIPT/CODE/*[local-name()='stylesheet']").asXML();
		//log.info(xslt);
		return IOUtils.toInputStream(xslt);
	}


}
