Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
package graphql.schema.validation;

import graphql.Internal;
import graphql.language.ArrayValue;
import graphql.language.ObjectField;
import graphql.language.ObjectValue;
import graphql.language.Value;
import graphql.schema.GraphQLInputObjectField;
import graphql.schema.GraphQLInputObjectType;
import graphql.schema.GraphQLSchemaElement;
import graphql.schema.GraphQLType;
import graphql.schema.GraphQLTypeVisitorStub;
import graphql.schema.InputValueWithState;
import graphql.util.TraversalControl;
import graphql.util.TraverserContext;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static graphql.schema.GraphQLTypeUtil.unwrapAll;

/**
* Validates that {@code InputObjectDefaultValueHasCycle(inputObject)} is {@code false}
* for every input object type, as required by the Input Object type validation rules
* in the GraphQL specification.
* <br>
* For example, consider this type configuration:
* <code>
* input A { b:B = {} }
* input B { a:A = {} }
* </code>
* <br>
* The default values used in these types form a cycle that can create an infinitely large
* value. This validator rejects default values that can create these kinds of cycles.
*
* @see <a href="https://spec.graphql.org/draft/#sec-Input-Objects.Type-Validation">Input Objects Type Validation</a>
*/
@Internal
public class NoDefaultValueCircularRefs extends GraphQLTypeVisitorStub {

// Coordinates already fully traversed without finding a cycle, used to avoid duplicate error reports
// when the same coordinate is reachable from multiple input object types.
private final Set<String> fullyExplored = new LinkedHashSet<>();

// The spec's "visitedFields" set, tracked as coordinate strings ("Type.field").
// The spec creates a new immutable set at each step; this implementation mutates and backtracks
// for the same effect.
private final LinkedHashSet<String> visitedFields = new LinkedHashSet<>();

@Override
public TraversalControl visitGraphQLInputObjectType(GraphQLInputObjectType type, TraverserContext<GraphQLSchemaElement> context) {
SchemaValidationErrorCollector errorCollector = context.getVarFromParents(SchemaValidationErrorCollector.class);

// Implements InputObjectDefaultValueHasCycle(inputObject) from the spec:
// "If defaultValue is not provided, initialize it to an empty unordered map."
inputObjectDefaultValueHasCycle(type, ObjectValue.newObjectValue().build(), errorCollector);

return TraversalControl.CONTINUE;
}

/**
* Implements {@code InputObjectDefaultValueHasCycle(inputObject, defaultValue, visitedFields)}
* from the spec, for literal (AST) default values.
*/
private void inputObjectDefaultValueHasCycle(
GraphQLInputObjectType inputObject,
Value<?> defaultValue,
SchemaValidationErrorCollector errorCollector
) {
// "If defaultValue is a list: for each itemValue in defaultValue..."
if (defaultValue instanceof ArrayValue) {
for (Value<?> itemValue : ((ArrayValue) defaultValue).getValues()) {
inputObjectDefaultValueHasCycle(inputObject, itemValue, errorCollector);
}
return;
}

// "Otherwise, if defaultValue is an unordered map..."
if (!(defaultValue instanceof ObjectValue)) {
return;
}

ObjectValue objectValue = (ObjectValue) defaultValue;
Map<String, Value<?>> defaultValueMap = new LinkedHashMap<>();
for (ObjectField field : objectValue.getObjectFields()) {
defaultValueMap.put(field.getName(), field.getValue());
}

// "For each field in inputObject: if InputFieldDefaultValueHasCycle(...)"
for (GraphQLInputObjectField field : inputObject.getFieldDefinitions()) {
GraphQLType namedFieldType = unwrapAll(field.getType());
if (!(namedFieldType instanceof GraphQLInputObjectType)) {
continue;
}

GraphQLInputObjectType fieldInputObject = (GraphQLInputObjectType) namedFieldType;
String fieldName = field.getName();
if (defaultValueMap.containsKey(fieldName)) {
// "Let fieldDefaultValue be the value for fieldName in defaultValue.
// If fieldDefaultValue exists: InputObjectDefaultValueHasCycle(namedFieldType, fieldDefaultValue, visitedFields)"
inputObjectDefaultValueHasCycle(fieldInputObject, defaultValueMap.get(fieldName), errorCollector);
} else {
// "Otherwise: let fieldDefaultValue be the default value of field..."
inputFieldDefaultValueHasCycle(field, fieldInputObject, inputObject.getName(), errorCollector);
}
}
}

/**
* Implements {@code InputObjectDefaultValueHasCycle(inputObject, defaultValue, visitedFields)}
* from the spec, for external (programmatic Map/List) default values.
*/
private void inputObjectDefaultValueHasCycle(
GraphQLInputObjectType inputObject,
Object defaultValue,
SchemaValidationErrorCollector errorCollector
) {
// "If defaultValue is a list: for each itemValue in defaultValue..."
if (defaultValue instanceof Iterable) {
for (Object itemValue : (Iterable<?>) defaultValue) {
if (itemValue != null) {
inputObjectDefaultValueHasCycle(inputObject, itemValue, errorCollector);
}
}
return;
}

// "Otherwise, if defaultValue is an unordered map..."
if (!(defaultValue instanceof Map)) {
return;
}

@SuppressWarnings("unchecked")
Map<String, Object> defaultValueMap = (Map<String, Object>) defaultValue;

// "For each field in inputObject: if InputFieldDefaultValueHasCycle(...)"
for (GraphQLInputObjectField field : inputObject.getFieldDefinitions()) {
GraphQLType namedFieldType = unwrapAll(field.getType());
if (!(namedFieldType instanceof GraphQLInputObjectType)) {
continue;
}

GraphQLInputObjectType fieldInputObject = (GraphQLInputObjectType) namedFieldType;
String fieldName = field.getName();
if (defaultValueMap.containsKey(fieldName)) {
// "Let fieldDefaultValue be the value for fieldName in defaultValue.
// If fieldDefaultValue exists: InputObjectDefaultValueHasCycle(namedFieldType, fieldDefaultValue, visitedFields)"
Object fieldDefaultValue = defaultValueMap.get(fieldName);
if (fieldDefaultValue != null) {
inputObjectDefaultValueHasCycle(fieldInputObject, fieldDefaultValue, errorCollector);
}
} else {
// "Otherwise: let fieldDefaultValue be the default value of field..."
inputFieldDefaultValueHasCycle(field, fieldInputObject, inputObject.getName(), errorCollector);
}
}
}

/**
* Implements the "Otherwise" branch of {@code InputFieldDefaultValueHasCycle(field, defaultValue, visitedFields)}
* from the spec — called when the field is not present in the parent's default value,
* so the field's own default will be used at runtime.
*/
private void inputFieldDefaultValueHasCycle(
GraphQLInputObjectField field,
GraphQLInputObjectType namedFieldType,
String parentTypeName,
SchemaValidationErrorCollector errorCollector
) {
// "Let fieldDefaultValue be the default value of field.
// If fieldDefaultValue does not exist: return false."
InputValueWithState fieldDefaultValue = field.getInputFieldDefaultValue();
if (fieldDefaultValue.isNotSet()) {
return;
}

String coordinate = parentTypeName + "." + field.getName();

// "If field is within visitedFields: return true."
if (visitedFields.contains(coordinate)) {
// Cycle found — collect intermediate nodes (everything after the coordinate itself)
List<String> intermediaries = new ArrayList<>();
boolean found = false;
for (String entry : visitedFields) {
if (found) {
intermediaries.add(entry);
}
if (entry.equals(coordinate)) {
found = true;
}
}

String message;
if (intermediaries.isEmpty()) {
message = "Invalid circular reference. The default value of Input Object field "
+ coordinate + " references itself.";
} else {
message = "Invalid circular reference. The default value of Input Object field "
+ coordinate + " references itself via the default values of: "
+ String.join(", ", intermediaries) + ".";
}

errorCollector.addError(new SchemaValidationError(
SchemaValidationErrorType.DefaultValueCircularRef, message));
return;
}

if (fullyExplored.contains(coordinate)) {
return;
}
fullyExplored.add(coordinate);

// "Let nextVisitedFields be a new set containing field and everything from visitedFields.
// Return InputObjectDefaultValueHasCycle(namedFieldType, fieldDefaultValue, nextVisitedFields)."
visitedFields.add(coordinate);

if (fieldDefaultValue.isLiteral() && fieldDefaultValue.getValue() instanceof Value) {
inputObjectDefaultValueHasCycle(namedFieldType, (Value<?>) fieldDefaultValue.getValue(), errorCollector);
} else if (fieldDefaultValue.isExternal() && fieldDefaultValue.getValue() != null) {
inputObjectDefaultValueHasCycle(namedFieldType, fieldDefaultValue.getValue(), errorCollector);
}

visitedFields.remove(coordinate);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@ public enum SchemaValidationErrorType implements SchemaValidationErrorClassifica
OneOfNotInhabited,
RequiredInputFieldCannotBeDeprecated,
RequiredFieldArgumentCannotBeDeprecated,
RequiredDirectiveArgumentCannotBeDeprecated
RequiredDirectiveArgumentCannotBeDeprecated,
DefaultValueCircularRef
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public class SchemaValidator {

public SchemaValidator() {
rules.add(new NoUnbrokenInputCycles());
rules.add(new NoDefaultValueCircularRefs());
rules.add(new TypesImplementInterfaces());
rules.add(new TypeAndFieldRule());
rules.add(new DefaultValuesAreValid());
Expand Down
90 changes: 90 additions & 0 deletions src/test/groovy/graphql/CircularInputDefaultValuesTest.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package graphql

import graphql.schema.validation.InvalidSchemaException
import spock.lang.Specification

/**
* Tests for mutually recursive input types with default values.
*
* These schemas are now rejected at build time by NoDefaultValueCircularRefs,
* which detects circular references in input object field default values.
*
* Previously, graphql-java accepted these schemas at build time but hit a
* StackOverflowError at query execution time when the circular defaults were
* expanded in ValuesResolverConversion.defaultValueToInternalValue.
*/
class CircularInputDefaultValuesTest extends Specification {

def "mutually recursive input types with default values - rejected at schema build time"() {
when:
TestUtil.schema('''
type Query {
test(arg: A): String
}
input A { b: B = {} }
input B { a: A = {} }
''')

then:
def e = thrown(InvalidSchemaException)
e.message.contains("Invalid circular reference")
}

def "self-referential input type with default value - rejected at schema build time"() {
when:
TestUtil.schema('''
type Query {
test(arg: A): String
}
input A { a: A = {} }
''')

then:
def e = thrown(InvalidSchemaException)
e.message.contains("Invalid circular reference")
}

def "mutually recursive input types with default values - rejected before query execution"() {
when:
TestUtil.schema('''
type Query {
test(arg: A): String
}
input A { b: B = {} }
input B { a: A = {} }
''')

then:
def e = thrown(InvalidSchemaException)
e.message.contains("Invalid circular reference")
}

def "self-referential input type with default value - rejected before query execution"() {
when:
TestUtil.schema('''
type Query {
test(arg: A): String
}
input A { a: A = {} }
''')

then:
def e = thrown(InvalidSchemaException)
e.message.contains("Invalid circular reference")
}

def "mutually recursive defaults via argument default - rejected at schema build time"() {
when:
TestUtil.schema('''
type Query {
test(arg: A = {}): String
}
input A { b: B = {} }
input B { a: A = {} }
''')

then:
def e = thrown(InvalidSchemaException)
e.message.contains("Invalid circular reference")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1446,20 +1446,20 @@ class SchemaDiffingTest extends Specification {
def schema1 = schema('''
input I {
name: String
field: I = {name: "default name"}
field: I = {name: "default name", field: null}
}
type Query {
foo(arg: I): String
}
}
''')
def schema2 = schema('''
input I {
name: String
field: [I] = [{name: "default name"}]
field: [I] = [{name: "default name", field: null}]
}
type Query {
foo(arg: I): String
}
}
''')

when:
Expand Down
Loading
Loading