GremlinTemplate.java

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.spring.data.gremlin.query;

import com.azure.spring.data.gremlin.annotation.EdgeFrom;
import com.azure.spring.data.gremlin.annotation.EdgeTo;
import com.azure.spring.data.gremlin.annotation.GeneratedValue;
import com.azure.spring.data.gremlin.common.GremlinEntityType;
import com.azure.spring.data.gremlin.common.GremlinFactory;
import com.azure.spring.data.gremlin.common.GremlinUtils;
import com.azure.spring.data.gremlin.conversion.MappingGremlinConverter;
import com.azure.spring.data.gremlin.conversion.script.GremlinScriptLiteral;
import com.azure.spring.data.gremlin.conversion.script.GremlinScriptLiteralEdge;
import com.azure.spring.data.gremlin.conversion.script.GremlinScriptLiteralGraph;
import com.azure.spring.data.gremlin.conversion.script.GremlinScriptLiteralVertex;
import com.azure.spring.data.gremlin.conversion.source.GremlinSource;
import com.azure.spring.data.gremlin.conversion.source.GremlinSourceEdge;
import com.azure.spring.data.gremlin.conversion.source.GremlinSourceGraph;
import com.azure.spring.data.gremlin.conversion.source.GremlinSourceVertex;
import com.azure.spring.data.gremlin.exception.GremlinEntityInformationException;
import com.azure.spring.data.gremlin.exception.GremlinInvalidEntityIdFieldException;
import com.azure.spring.data.gremlin.exception.GremlinQueryException;
import com.azure.spring.data.gremlin.exception.GremlinUnexpectedEntityTypeException;
import com.azure.spring.data.gremlin.mapping.GremlinPersistentEntity;
import com.azure.spring.data.gremlin.mapping.GremlinPersistentProperty;
import com.azure.spring.data.gremlin.query.query.GremlinQuery;
import com.azure.spring.data.gremlin.query.query.QueryFindScriptGenerator;
import com.azure.spring.data.gremlin.query.query.QueryScriptGenerator;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.tinkerpop.gremlin.driver.Client;
import org.apache.tinkerpop.gremlin.driver.Result;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.data.mapping.PersistentProperty;
import org.springframework.data.mapping.model.ConvertingPropertyAccessor;
import org.springframework.lang.NonNull;
import org.springframework.util.Assert;

import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutionException;

import static java.util.stream.Collectors.toList;

public class GremlinTemplate implements GremlinOperations, ApplicationContextAware {

    private final GremlinFactory factory;
    private final MappingGremlinConverter mappingConverter;

    private Client gremlinClient;
    private ApplicationContext context;

    public GremlinTemplate(@NonNull GremlinFactory factory, @NonNull MappingGremlinConverter converter) {
        this.factory = factory;
        this.mappingConverter = converter;
    }

    @Override
    public MappingGremlinConverter getMappingConverter() {
        return this.mappingConverter;
    }

    public ApplicationContext getApplicationContext() {
        return this.context;
    }

    @Override
    public void setApplicationContext(@NonNull ApplicationContext context) throws BeansException {
        this.context = context;
    }

    public Client getGremlinClient() {
        if (this.gremlinClient == null) {
            this.gremlinClient = this.factory.getGremlinClient();
        }

        return this.gremlinClient;
    }

    @NonNull
    private List<Result> executeQuery(@NonNull List<String> queries) {
        final List<List<String>> parallelQueries = GremlinUtils.toParallelQueryList(queries);

        return parallelQueries.stream().flatMap(q -> executeQueryParallel(q).stream()).collect(toList());
    }

    @NonNull
    private List<Result> executeQueryParallel(@NonNull List<String> queries) {
        return queries.parallelStream()
                .map(q -> getGremlinClient().submit(q).all())
                .collect(toList()).parallelStream().flatMap(f -> {
                    try {
                        return f.get().stream();
                    } catch (InterruptedException | ExecutionException e) {
                        throw new GremlinQueryException("unable to complete query from gremlin", e);
                    }
                })
                .collect(toList());
    }

    @Override
    public void deleteAll() {
        final GremlinScriptLiteral script = new GremlinScriptLiteralGraph();
        final List<String> queryList = script.generateDeleteAllScript();

        executeQuery(queryList);
    }

    @Override
    public void deleteAll(GremlinEntityType type) {
        final GremlinSource<?> source = type.createGremlinSource();

        executeQuery(source.getGremlinScriptLiteral().generateDeleteAllScript());
    }

    @Override
    public <T> void deleteAll(GremlinSource<T> source) {
        executeQuery(source.getGremlinScriptLiteral().generateDeleteAllByClassScript(source));
    }

    private <T> List<Result> insertInternal(@NonNull T object, @NonNull GremlinSource<T> source) {
        this.mappingConverter.write(object, source);

        return executeQuery(source.getGremlinScriptLiteral().generateInsertScript(source));
    }

    @Override
    public <T> T insert(@NonNull T object, GremlinSource<T> source) {
        final boolean entityGraph = source instanceof GremlinSourceGraph;

        if (!entityGraph && source.getIdField().isAnnotationPresent(GeneratedValue.class)
                && source.getId().isPresent()) {
            throw new GremlinInvalidEntityIdFieldException("The entity meant to be created has a non-null id "
                    + "that is marked as @GeneratedValue");
        }

        // The current implementation doesn't support creating graphs that contain both edges
        // and vertices that have null (generated) ids. In this case, vertex and edge creation
        // need to be performed in two consecutive steps.
        // TODO(SOON) Add this verification in the GremlinSourceGraphWriter

        final List<Result> results = insertInternal(object, source);

        if (!results.isEmpty()) {
            if (entityGraph) {
                return recoverGraphDomain((GremlinSourceGraph<T>) source, results);
            } else {
                return recoverDomain(source, results);
            }
        }

        return null;
    }

    @Override
    public <T> T findVertexById(@NonNull Object id, GremlinSource<T> source) {
        if (source instanceof GremlinSourceVertex) {
            source.setId(id);
            return this.findByIdInternal(source);
        }

        throw new GremlinUnexpectedEntityTypeException("should be vertex domain for findVertexById");
    }

    private Object getEdgeAnnotatedFieldValue(@NonNull Field field, @NonNull Object vertexId) {
        if (field.getType() == String.class || field.getType() == Long.class || field.getType() == Integer.class) {
            return vertexId;
        } else if (field.getType().isPrimitive()) {
            throw new GremlinUnexpectedEntityTypeException("only String/Long/Integer type of Id Field is allowed");
        } else {
            return this.findVertexById(vertexId, GremlinUtils.toGremlinSource(field.getType()));
        }
    }

    @NonNull
    private Field getEdgeAnnotatedField(@NonNull Class<?> domainClass,
                                        @NonNull Class<? extends Annotation> annotationClass) {
        final List<Field> fields = FieldUtils.getFieldsListWithAnnotation(domainClass, annotationClass);

        if (fields.size() != 1) {
            throw new GremlinEntityInformationException("should be only one Annotation");
        }

        return fields.get(0);
    }

    /**
     * Find Edge need another two query to obtain edgeFrom and edgeTo.
     * This function will do that and make edge domain completion.
     */
    @SuppressWarnings("unchecked")
    private <T> void completeEdge(@NonNull T domain, @NonNull GremlinSourceEdge<T> source) {
        final ConvertingPropertyAccessor<T> accessor = this.mappingConverter.getPropertyAccessor(domain);
        final GremlinPersistentEntity<T> persistentEntity = (GremlinPersistentEntity<T>) this.mappingConverter
            .getPersistentEntity(domain.getClass());

        final Field fromField = this.getEdgeAnnotatedField(domain.getClass(), EdgeFrom.class);
        final Field toField = this.getEdgeAnnotatedField(domain.getClass(), EdgeTo.class);

        final PersistentProperty<GremlinPersistentProperty> propertyFrom = persistentEntity
            .getPersistentProperty(fromField.getName());
        final PersistentProperty<GremlinPersistentProperty> propertyTo = persistentEntity
            .getPersistentProperty(toField.getName());

        Assert.notNull(propertyFrom, "persistence property should not be null");
        Assert.notNull(propertyTo, "persistence property should not be null");

        accessor.setProperty(propertyFrom, this.getEdgeAnnotatedFieldValue(fromField, source.getVertexIdFrom()));
        accessor.setProperty(propertyTo, this.getEdgeAnnotatedFieldValue(toField, source.getVertexIdTo()));
    }

    @Override
    public <T> T findEdgeById(@NonNull Object id, @NonNull GremlinSource<T> source) {
        if (source instanceof GremlinSourceEdge) {
            return this.findById(id, source);
        }

        throw new GremlinUnexpectedEntityTypeException("should be edge domain for findEdge");
    }

    private <T> T findByIdInternal(@NonNull GremlinSource<T> source) {
        final List<String> queryList = source.getGremlinScriptLiteral().generateFindByIdScript(source);
        final List<Result> results = this.executeQuery(queryList);

        if (results.isEmpty()) {
            return null;
        }

        return recoverDomain(source, results);
    }

    @Override
    public <T> T findById(@NonNull Object id, @NonNull GremlinSource<T> source) {
        if (source instanceof GremlinSourceGraph) {
            throw new UnsupportedOperationException("Gremlin graph cannot be findById.");
        }

        source.setId(id);

        return findByIdInternal(source);
    }

    private <T> T updateInternal(@NonNull T object, @NonNull GremlinSource<T> source) {
        this.mappingConverter.write(object, source);

        final List<String> queryList = source.getGremlinScriptLiteral().generateUpdateScript(source);

        executeQuery(queryList);

        return object;
    }

    @Override
    public <T> T update(@NonNull T object, @NonNull GremlinSource<T> source) {
        final Optional<Object> optional = source.getId();

        if (!(source instanceof GremlinSourceGraph)
                && (!optional.isPresent() || notExistsById(optional.get(), source))) {
            throw new GremlinQueryException("cannot update the object doesn't exist");
        }

        return this.updateInternal(object, source);
    }

    @Override
    public <T> T save(@NonNull T object, @NonNull GremlinSource<T> source) {
        final Optional<Object> optional = source.getId();
        final boolean entityGraph = source instanceof GremlinSourceGraph;

        if (entityGraph && this.isEmptyGraph(source)) {
            return insert(object, source);
        } else if (!entityGraph && (!optional.isPresent() || notExistsById(optional.get(), source))) {
            return insert(object, source);
        } else {
            return updateInternal(object, source);
        }
    }

    @Override
    public <T> List<T> findAll(@NonNull GremlinSource<T> source) {
        if (source instanceof GremlinSourceGraph) {
            throw new UnsupportedOperationException("Gremlin graph cannot be findAll.");
        }

        final List<String> queryList = source.getGremlinScriptLiteral().generateFindAllScript(source);
        final List<Result> results = executeQuery(queryList);

        if (results.isEmpty()) {
            return Collections.emptyList();
        }

        return recoverDomainList(source, results);
    }

    @Override
    public <T> void deleteById(@NonNull Object id, @NonNull GremlinSource<T> source) {
        source.setId(id);

        final List<String> queryList = source.getGremlinScriptLiteral().generateDeleteByIdScript(source);

        executeQuery(queryList);
    }

    @Override
    public <T> boolean isEmptyGraph(@NonNull GremlinSource<T> source) {
        if (source instanceof GremlinSourceGraph) {
            final GremlinScriptLiteralGraph literalGraph = (GremlinScriptLiteralGraph) source.getGremlinScriptLiteral();
            final List<String> queryList = literalGraph.generateIsEmptyScript();
            final List<Result> results = this.executeQuery(queryList);

            return results.isEmpty();
        }

        throw new GremlinQueryException("only graph domain is allowed.");
    }

    @Override
    public long vertexCount() {
        final GremlinScriptLiteral script = new GremlinScriptLiteralVertex();
        final List<String> queryList = script.generateCountScript(new GremlinSourceVertex<>());
        final List<Result> results = this.executeQuery(queryList);

        return results.size();
    }

    @Override
    public long edgeCount() {
        final GremlinScriptLiteral script = new GremlinScriptLiteralEdge();
        final List<String> queryList = script.generateCountScript(new GremlinSourceEdge<>());
        final List<Result> results = this.executeQuery(queryList);

        return results.size();
    }

    private <T> T recoverDomain(@NonNull GremlinSource<T> source, @NonNull List<Result> results) {
        final T domain;
        final Class<T> domainClass = source.getDomainClass();

        source.doGremlinResultRead(results);
        domain = this.mappingConverter.read(domainClass, source);

        if (source instanceof GremlinSourceEdge) {
            this.completeEdge(domain, (GremlinSourceEdge<T>) source);
        }

        return domain;
    }

    private <T> List<T> recoverDomainList(@NonNull GremlinSource<T> source, @NonNull List<Result> results) {
        return results.stream().map(r -> recoverDomain(source, Collections.singletonList(r))).collect(toList());
    }

    private <T> T recoverGraphDomain(@NonNull GremlinSourceGraph<T> source, @NonNull List<Result> results) {
        final T domain;
        final Class<T> domainClass = source.getDomainClass();

        source.getResultsReader().read(results, source);
        domain = source.doGremlinSourceRead(domainClass, mappingConverter);
        return domain;
    }

    private <T> boolean notExistsById(@NonNull Object id, @NonNull GremlinSource<T> source) {
        return !existsById(id, source);
    }

    @Override
    public <T> boolean existsById(@NonNull Object id, @NonNull GremlinSource<T> source) {
        return findById(id, source) != null;
    }

    @Override
    public <T> List<T> find(@NonNull GremlinQuery query, @NonNull GremlinSource<T> source) {
        final QueryScriptGenerator generator = new QueryFindScriptGenerator<>(source);
        final List<String> queryList = generator.generate(query);
        final List<Result> results = this.executeQuery(queryList);

        if (results.isEmpty()) {
            return Collections.emptyList();
        }

        return this.recoverDomainList(source, results);
    }
}