Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,34 @@

import com.google.common.collect.ImmutableList;
import graphql.PublicApi;
import graphql.introspection.Introspection;
import graphql.schema.GraphQLEnumType;
import graphql.schema.GraphQLFieldDefinition;
import graphql.schema.GraphQLFieldsContainer;
import graphql.schema.GraphQLImplementingType;
import graphql.schema.GraphQLInputObjectField;
import graphql.schema.GraphQLInputObjectType;
import graphql.schema.GraphQLInterfaceType;
import graphql.schema.GraphQLNamedType;
import graphql.schema.GraphQLObjectType;
import graphql.schema.GraphQLScalarType;
import graphql.schema.GraphQLSchema;
import graphql.schema.GraphQLSchemaElement;
import graphql.schema.GraphQLType;
import graphql.schema.GraphQLTypeVisitorStub;
import graphql.schema.GraphQLUnionType;
import graphql.schema.SchemaTraverser;
import graphql.schema.impl.SchemaUtil;
import graphql.schema.transform.VisibleFieldPredicateEnvironment.VisibleFieldPredicateEnvironmentImpl;
import graphql.util.TraversalControl;
import graphql.util.TraverserContext;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static graphql.schema.SchemaTransformer.transformSchema;
import static graphql.schema.SchemaTransformer.transformSchemaWithDeletes;

/**
Expand Down Expand Up @@ -61,84 +57,103 @@ public FieldVisibilitySchemaTransformation(VisibleFieldPredicate visibleFieldPre
}

public final GraphQLSchema apply(GraphQLSchema schema) {
Set<String> observedBeforeTransform = new LinkedHashSet<>();
Set<String> observedAfterTransform = new LinkedHashSet<>();
Set<GraphQLType> markedForRemovalTypes = new HashSet<>();

// query, mutation, and subscription types should not be removed
final Set<String> protectedTypeNames = new HashSet<>();
for (GraphQLObjectType graphQLObjectType : getOperationTypes(schema)) {
protectedTypeNames.add(graphQLObjectType.getName());
}

beforeTransformationHook.run();

new SchemaTraverser(getChildrenFn(schema)).depthFirst(new TypeObservingVisitor(observedBeforeTransform), getRootTypes(schema));
// Find root unused types BEFORE transformation
// These are types that exist in the schema but are NOT reachable from operation types + directives
Set<String> rootUnusedTypes = findRootUnusedTypes(schema);

// remove fields
// we delete all fields that should be deleted
// this assumes the field remove itself is semantically valid
GraphQLSchema interimSchema = transformSchemaWithDeletes(schema,
new FieldRemovalVisitor(visibleFieldPredicate, markedForRemovalTypes));
new FieldRemovalVisitor(visibleFieldPredicate));


// cleanup schema
// now we want to remove all types which are not reachable via root types, directives and the interface implements relationship
SchemaTraverser schemaTraverser = new SchemaTraverser(childrenWithInterfaceImplementations(interimSchema));

// first we observe all types we don't want to delete
Set<String> observedTypes = new LinkedHashSet<>();
TypeObservingVisitor typeObservingVisitor = new TypeObservingVisitor(observedTypes);
schemaTraverser.depthFirst(typeObservingVisitor, getRootTypes(interimSchema));

// Traverse from root unused types that still exist after transformation
// This preserves originally unused types and their dependencies
List<GraphQLSchemaElement> existingRootUnusedTypes = rootUnusedTypes.stream()
.map(interimSchema::getType)
.filter(Objects::nonNull)
.map(type -> (GraphQLSchemaElement) type)
.collect(Collectors.toList());

new SchemaTraverser(getChildrenFn(interimSchema)).depthFirst(new TypeObservingVisitor(observedAfterTransform), getRootTypes(interimSchema));
if (!existingRootUnusedTypes.isEmpty()) {
schemaTraverser.depthFirst(typeObservingVisitor, existingRootUnusedTypes);
}

// remove types that are not used after removing fields - (connected schema only)
GraphQLSchema connectedSchema = transformSchema(interimSchema,
new TypeVisibilityVisitor(protectedTypeNames, observedBeforeTransform, observedAfterTransform));
// then we delete all the types which are not used anymore
GraphQLSchema finalSchema = transformSchemaWithDeletes(interimSchema,
new TypeRemovalVisitor(observedTypes));

// ensure markedForRemovalTypes are not referenced by other schema elements, and delete from the schema
// the ones that aren't.
GraphQLSchema finalSchema = removeUnreferencedTypes(markedForRemovalTypes, connectedSchema);

afterTransformationHook.run();

return finalSchema;
}

// Creates a getChildrenFn that includes interface
private Function<GraphQLSchemaElement, List<GraphQLSchemaElement>> getChildrenFn(GraphQLSchema schema) {
Map<String, List<GraphQLImplementingType>> interfaceImplementations = new SchemaUtil().groupImplementationsForInterfacesAndObjects(schema);

return graphQLSchemaElement -> {
if (!(graphQLSchemaElement instanceof GraphQLInterfaceType)) {
return graphQLSchemaElement.getChildren();
}
ArrayList<GraphQLSchemaElement> children = new ArrayList<>(graphQLSchemaElement.getChildren());
List<GraphQLImplementingType> implementations = interfaceImplementations.get(((GraphQLInterfaceType) graphQLSchemaElement).getName());
if (implementations != null) {
children.addAll(implementations);
/**
* Finds root unused types - types that exist in additional types but are NOT reachable
* from operation types (Query, Mutation, Subscription) and directives.
*/
private Set<String> findRootUnusedTypes(GraphQLSchema schema) {
// Collect all types reachable from operation roots + directives
// Use a traverser that includes interface implementations
Set<String> typesReachableFromRoots = new LinkedHashSet<>();
SchemaTraverser traverser = new SchemaTraverser(childrenWithInterfaceImplementations(schema));
TypeObservingVisitor visitor = new TypeObservingVisitor(typesReachableFromRoots);
traverser.depthFirst(visitor, getRootTypes(schema));

// Root unused types are additional types that are NOT reachable from roots
Set<String> rootUnusedTypes = new LinkedHashSet<>();
for (GraphQLNamedType type : schema.getAdditionalTypes()) {
String typeName = type.getName();
if (!typesReachableFromRoots.contains(typeName) && !isIntrospectionType(typeName)) {
rootUnusedTypes.add(typeName);
}
return children;
};
}
return rootUnusedTypes;
}

private GraphQLSchema removeUnreferencedTypes(Set<GraphQLType> markedForRemovalTypes, GraphQLSchema connectedSchema) {
GraphQLSchema withoutAdditionalTypes = connectedSchema.transform(builder -> {
Set<GraphQLNamedType> additionalTypes = new HashSet<>(connectedSchema.getAdditionalTypes());
additionalTypes.removeAll(markedForRemovalTypes);
builder.clearAdditionalTypes();
builder.additionalTypes(additionalTypes);
});
/**
* Checks if a type is an introspection type that should be protected from removal.
* This includes standard introspection types (starting with "__") and special types
* like _AppliedDirective (starting with "_") added by IntrospectionWithDirectivesSupport.
*/
private static boolean isIntrospectionType(String typeName) {
return Introspection.isIntrospectionTypes(typeName) || typeName.startsWith("_");
}

// remove from markedForRemovalTypes any type that might still be referenced by other schema elements
transformSchema(withoutAdditionalTypes, new AdditionalTypeVisibilityVisitor(markedForRemovalTypes));
/**
* Creates a function that returns children of a schema element, including interface implementations.
* This ensures that when traversing from an interface, we also visit all types that implement it.
*/
private Function<GraphQLSchemaElement, List<GraphQLSchemaElement>> childrenWithInterfaceImplementations(GraphQLSchema schema) {

// finally remove the types on the schema we are certain aren't referenced by any other node.
return transformSchema(connectedSchema, new GraphQLTypeVisitorStub() {
@Override
protected TraversalControl visitGraphQLType(GraphQLSchemaElement node, TraverserContext<GraphQLSchemaElement> context) {
if (node instanceof GraphQLType && markedForRemovalTypes.contains(node)) {
return deleteNode(context);
}
return super.visitGraphQLType(node, context);
return schemaElement -> {
if (!(schemaElement instanceof GraphQLInterfaceType)) {
return schemaElement.getChildren();
}
});
ArrayList<GraphQLSchemaElement> children = new ArrayList<>(schemaElement.getChildren());
List<GraphQLObjectType> implementations = schema.getImplementations((GraphQLInterfaceType) schemaElement);
children.addAll(implementations);
return children;
};
}

private static class TypeObservingVisitor extends GraphQLTypeVisitorStub {

private final Set<String> observedTypes;


private TypeObservingVisitor(Set<String> observedTypes) {
this.observedTypes = observedTypes;
}
Expand All @@ -150,7 +165,8 @@ protected TraversalControl visitGraphQLType(GraphQLSchemaElement node,
node instanceof GraphQLEnumType ||
node instanceof GraphQLInputObjectType ||
node instanceof GraphQLInterfaceType ||
node instanceof GraphQLUnionType) {
node instanceof GraphQLUnionType ||
node instanceof GraphQLScalarType) {
observedTypes.add(((GraphQLNamedType) node).getName());
}

Expand All @@ -161,15 +177,12 @@ protected TraversalControl visitGraphQLType(GraphQLSchemaElement node,
private static class FieldRemovalVisitor extends GraphQLTypeVisitorStub {

private final VisibleFieldPredicate visibilityPredicate;
private final Set<GraphQLType> removedTypes;

private final Set<GraphQLFieldDefinition> fieldDefinitionsToActuallyRemove = new LinkedHashSet<>();
private final Set<GraphQLInputObjectField> inputObjectFieldsToDelete = new LinkedHashSet<>();

private FieldRemovalVisitor(VisibleFieldPredicate visibilityPredicate,
Set<GraphQLType> removedTypes) {
private FieldRemovalVisitor(VisibleFieldPredicate visibilityPredicate) {
this.visibilityPredicate = visibilityPredicate;
this.removedTypes = removedTypes;
}

@Override
Expand All @@ -189,7 +202,6 @@ private TraversalControl visitFieldsContainer(GraphQLFieldsContainer fieldsConta
fieldDefinition, fieldsContainer);
if (!visibilityPredicate.isVisible(environment)) {
fieldDefinitionsToActuallyRemove.add(fieldDefinition);
removedTypes.add(fieldDefinition.getType());
} else {
allFieldsDeleted = false;
}
Expand All @@ -210,7 +222,6 @@ public TraversalControl visitGraphQLInputObjectType(GraphQLInputObjectType input
inputField, inputObjectType);
if (!visibilityPredicate.isVisible(environment)) {
inputObjectFieldsToDelete.add(inputField);
removedTypes.add(inputField.getType());
} else {
allFieldsDeleted = false;
}
Expand Down Expand Up @@ -245,76 +256,43 @@ public TraversalControl visitGraphQLInputObjectField(GraphQLInputObjectField def
}
}

private static class TypeVisibilityVisitor extends GraphQLTypeVisitorStub {
private static class TypeRemovalVisitor extends GraphQLTypeVisitorStub {

private final Set<String> protectedTypeNames;
private final Set<String> observedBeforeTransform;
private final Set<String> observedAfterTransform;

private TypeVisibilityVisitor(Set<String> protectedTypeNames,
Set<String> observedTypes,
Set<String> observedAfterTransform) {
private TypeRemovalVisitor(Set<String> protectedTypeNames) {
this.protectedTypeNames = protectedTypeNames;
this.observedBeforeTransform = observedTypes;
this.observedAfterTransform = observedAfterTransform;
}

@Override
public TraversalControl visitGraphQLInterfaceType(GraphQLInterfaceType node,
TraverserContext<GraphQLSchemaElement> context) {
return super.visitGraphQLInterfaceType(node, context);
}

@Override
public TraversalControl visitGraphQLType(GraphQLSchemaElement node,
TraverserContext<GraphQLSchemaElement> context) {
if (node instanceof GraphQLNamedType) {
String name = ((GraphQLNamedType) node).getName();
if (isIntrospectionType(name)) {
return TraversalControl.CONTINUE;
}
}
if (node instanceof GraphQLObjectType ||
node instanceof GraphQLEnumType ||
node instanceof GraphQLInputObjectType ||
node instanceof GraphQLInterfaceType ||
node instanceof GraphQLUnionType) {
node instanceof GraphQLUnionType ||
node instanceof GraphQLScalarType) {
String name = ((GraphQLNamedType) node).getName();
if (observedBeforeTransform.contains(name) &&
!observedAfterTransform.contains(name)
&& !protectedTypeNames.contains(name)
) {
if (!protectedTypeNames.contains(name)) {
return deleteNode(context);
}
}
return TraversalControl.CONTINUE;
}
}

private static class AdditionalTypeVisibilityVisitor extends GraphQLTypeVisitorStub {

private final Set<GraphQLType> markedForRemovalTypes;

private AdditionalTypeVisibilityVisitor(Set<GraphQLType> markedForRemovalTypes) {
this.markedForRemovalTypes = markedForRemovalTypes;
}

@Override
public TraversalControl visitGraphQLType(GraphQLSchemaElement node,
TraverserContext<GraphQLSchemaElement> context) {

if (node instanceof GraphQLNamedType) {
GraphQLNamedType namedType = (GraphQLNamedType) node;
// we encountered a node referencing one of the marked types, so it should not be removed.
if (markedForRemovalTypes.contains(node)) {
markedForRemovalTypes.remove(namedType);
}
}

return TraversalControl.CONTINUE;
}
}

private List<GraphQLSchemaElement> getRootTypes(GraphQLSchema schema) {
return ImmutableList.<GraphQLSchemaElement>builder()
.addAll(getOperationTypes(schema))
// Include directive definitions as roots, since they won't be removed in the filtering process.
// Some types (enums, input types, etc.) might be reachable only by directive definitions (and
// not by other types or fields).
.addAll(schema.getDirectives())
.build();
}
Expand Down
Loading
Loading