diff --git a/src/main/java/graphql/schema/validation/NoDefaultValueCircularRefs.java b/src/main/java/graphql/schema/validation/NoDefaultValueCircularRefs.java new file mode 100644 index 0000000000..70822426da --- /dev/null +++ b/src/main/java/graphql/schema/validation/NoDefaultValueCircularRefs.java @@ -0,0 +1,229 @@ +package graphql.schema.validation; + +import graphql.Internal; +import graphql.language.ArrayValue; +import graphql.language.ObjectField; +import graphql.language.ObjectValue; +import graphql.language.Value; +import graphql.schema.GraphQLInputObjectField; +import graphql.schema.GraphQLInputObjectType; +import graphql.schema.GraphQLSchemaElement; +import graphql.schema.GraphQLType; +import graphql.schema.GraphQLTypeVisitorStub; +import graphql.schema.InputValueWithState; +import graphql.util.TraversalControl; +import graphql.util.TraverserContext; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static graphql.schema.GraphQLTypeUtil.unwrapAll; + +/** + * Validates that {@code InputObjectDefaultValueHasCycle(inputObject)} is {@code false} + * for every input object type, as required by the Input Object type validation rules + * in the GraphQL specification. + *
+ * For example, consider this type configuration: + * + * input A { b:B = {} } + * input B { a:A = {} } + * + *
+ * The default values used in these types form a cycle that can create an infinitely large + * value. This validator rejects default values that can create these kinds of cycles. + * + * @see Input Objects Type Validation + */ +@Internal +public class NoDefaultValueCircularRefs extends GraphQLTypeVisitorStub { + + // Coordinates already fully traversed without finding a cycle, used to avoid duplicate error reports + // when the same coordinate is reachable from multiple input object types. + private final Set fullyExplored = new LinkedHashSet<>(); + + // The spec's "visitedFields" set, tracked as coordinate strings ("Type.field"). + // The spec creates a new immutable set at each step; this implementation mutates and backtracks + // for the same effect. + private final LinkedHashSet visitedFields = new LinkedHashSet<>(); + + @Override + public TraversalControl visitGraphQLInputObjectType(GraphQLInputObjectType type, TraverserContext context) { + SchemaValidationErrorCollector errorCollector = context.getVarFromParents(SchemaValidationErrorCollector.class); + + // Implements InputObjectDefaultValueHasCycle(inputObject) from the spec: + // "If defaultValue is not provided, initialize it to an empty unordered map." + inputObjectDefaultValueHasCycle(type, ObjectValue.newObjectValue().build(), errorCollector); + + return TraversalControl.CONTINUE; + } + + /** + * Implements {@code InputObjectDefaultValueHasCycle(inputObject, defaultValue, visitedFields)} + * from the spec, for literal (AST) default values. + */ + private void inputObjectDefaultValueHasCycle( + GraphQLInputObjectType inputObject, + Value defaultValue, + SchemaValidationErrorCollector errorCollector + ) { + // "If defaultValue is a list: for each itemValue in defaultValue..." + if (defaultValue instanceof ArrayValue) { + for (Value itemValue : ((ArrayValue) defaultValue).getValues()) { + inputObjectDefaultValueHasCycle(inputObject, itemValue, errorCollector); + } + return; + } + + // "Otherwise, if defaultValue is an unordered map..." + if (!(defaultValue instanceof ObjectValue)) { + return; + } + + ObjectValue objectValue = (ObjectValue) defaultValue; + Map> defaultValueMap = new LinkedHashMap<>(); + for (ObjectField field : objectValue.getObjectFields()) { + defaultValueMap.put(field.getName(), field.getValue()); + } + + // "For each field in inputObject: if InputFieldDefaultValueHasCycle(...)" + for (GraphQLInputObjectField field : inputObject.getFieldDefinitions()) { + GraphQLType namedFieldType = unwrapAll(field.getType()); + if (!(namedFieldType instanceof GraphQLInputObjectType)) { + continue; + } + + GraphQLInputObjectType fieldInputObject = (GraphQLInputObjectType) namedFieldType; + String fieldName = field.getName(); + if (defaultValueMap.containsKey(fieldName)) { + // "Let fieldDefaultValue be the value for fieldName in defaultValue. + // If fieldDefaultValue exists: InputObjectDefaultValueHasCycle(namedFieldType, fieldDefaultValue, visitedFields)" + inputObjectDefaultValueHasCycle(fieldInputObject, defaultValueMap.get(fieldName), errorCollector); + } else { + // "Otherwise: let fieldDefaultValue be the default value of field..." + inputFieldDefaultValueHasCycle(field, fieldInputObject, inputObject.getName(), errorCollector); + } + } + } + + /** + * Implements {@code InputObjectDefaultValueHasCycle(inputObject, defaultValue, visitedFields)} + * from the spec, for external (programmatic Map/List) default values. + */ + private void inputObjectDefaultValueHasCycle( + GraphQLInputObjectType inputObject, + Object defaultValue, + SchemaValidationErrorCollector errorCollector + ) { + // "If defaultValue is a list: for each itemValue in defaultValue..." + if (defaultValue instanceof Iterable) { + for (Object itemValue : (Iterable) defaultValue) { + if (itemValue != null) { + inputObjectDefaultValueHasCycle(inputObject, itemValue, errorCollector); + } + } + return; + } + + // "Otherwise, if defaultValue is an unordered map..." + if (!(defaultValue instanceof Map)) { + return; + } + + @SuppressWarnings("unchecked") + Map defaultValueMap = (Map) defaultValue; + + // "For each field in inputObject: if InputFieldDefaultValueHasCycle(...)" + for (GraphQLInputObjectField field : inputObject.getFieldDefinitions()) { + GraphQLType namedFieldType = unwrapAll(field.getType()); + if (!(namedFieldType instanceof GraphQLInputObjectType)) { + continue; + } + + GraphQLInputObjectType fieldInputObject = (GraphQLInputObjectType) namedFieldType; + String fieldName = field.getName(); + if (defaultValueMap.containsKey(fieldName)) { + // "Let fieldDefaultValue be the value for fieldName in defaultValue. + // If fieldDefaultValue exists: InputObjectDefaultValueHasCycle(namedFieldType, fieldDefaultValue, visitedFields)" + Object fieldDefaultValue = defaultValueMap.get(fieldName); + if (fieldDefaultValue != null) { + inputObjectDefaultValueHasCycle(fieldInputObject, fieldDefaultValue, errorCollector); + } + } else { + // "Otherwise: let fieldDefaultValue be the default value of field..." + inputFieldDefaultValueHasCycle(field, fieldInputObject, inputObject.getName(), errorCollector); + } + } + } + + /** + * Implements the "Otherwise" branch of {@code InputFieldDefaultValueHasCycle(field, defaultValue, visitedFields)} + * from the spec — called when the field is not present in the parent's default value, + * so the field's own default will be used at runtime. + */ + private void inputFieldDefaultValueHasCycle( + GraphQLInputObjectField field, + GraphQLInputObjectType namedFieldType, + String parentTypeName, + SchemaValidationErrorCollector errorCollector + ) { + // "Let fieldDefaultValue be the default value of field. + // If fieldDefaultValue does not exist: return false." + InputValueWithState fieldDefaultValue = field.getInputFieldDefaultValue(); + if (fieldDefaultValue.isNotSet()) { + return; + } + + String coordinate = parentTypeName + "." + field.getName(); + + // "If field is within visitedFields: return true." + if (visitedFields.contains(coordinate)) { + // Cycle found — collect intermediate nodes (everything after the coordinate itself) + List intermediaries = new ArrayList<>(); + boolean found = false; + for (String entry : visitedFields) { + if (found) { + intermediaries.add(entry); + } + if (entry.equals(coordinate)) { + found = true; + } + } + + String message; + if (intermediaries.isEmpty()) { + message = "Invalid circular reference. The default value of Input Object field " + + coordinate + " references itself."; + } else { + message = "Invalid circular reference. The default value of Input Object field " + + coordinate + " references itself via the default values of: " + + String.join(", ", intermediaries) + "."; + } + + errorCollector.addError(new SchemaValidationError( + SchemaValidationErrorType.DefaultValueCircularRef, message)); + return; + } + + if (fullyExplored.contains(coordinate)) { + return; + } + fullyExplored.add(coordinate); + + // "Let nextVisitedFields be a new set containing field and everything from visitedFields. + // Return InputObjectDefaultValueHasCycle(namedFieldType, fieldDefaultValue, nextVisitedFields)." + visitedFields.add(coordinate); + + if (fieldDefaultValue.isLiteral() && fieldDefaultValue.getValue() instanceof Value) { + inputObjectDefaultValueHasCycle(namedFieldType, (Value) fieldDefaultValue.getValue(), errorCollector); + } else if (fieldDefaultValue.isExternal() && fieldDefaultValue.getValue() != null) { + inputObjectDefaultValueHasCycle(namedFieldType, fieldDefaultValue.getValue(), errorCollector); + } + + visitedFields.remove(coordinate); + } +} diff --git a/src/main/java/graphql/schema/validation/SchemaValidationErrorType.java b/src/main/java/graphql/schema/validation/SchemaValidationErrorType.java index b1caecc8e4..b7ce1cc631 100644 --- a/src/main/java/graphql/schema/validation/SchemaValidationErrorType.java +++ b/src/main/java/graphql/schema/validation/SchemaValidationErrorType.java @@ -25,5 +25,6 @@ public enum SchemaValidationErrorType implements SchemaValidationErrorClassifica OneOfNotInhabited, RequiredInputFieldCannotBeDeprecated, RequiredFieldArgumentCannotBeDeprecated, - RequiredDirectiveArgumentCannotBeDeprecated + RequiredDirectiveArgumentCannotBeDeprecated, + DefaultValueCircularRef } diff --git a/src/main/java/graphql/schema/validation/SchemaValidator.java b/src/main/java/graphql/schema/validation/SchemaValidator.java index 1f676fb77e..fac409378a 100644 --- a/src/main/java/graphql/schema/validation/SchemaValidator.java +++ b/src/main/java/graphql/schema/validation/SchemaValidator.java @@ -19,6 +19,7 @@ public class SchemaValidator { public SchemaValidator() { rules.add(new NoUnbrokenInputCycles()); + rules.add(new NoDefaultValueCircularRefs()); rules.add(new TypesImplementInterfaces()); rules.add(new TypeAndFieldRule()); rules.add(new DefaultValuesAreValid()); diff --git a/src/test/groovy/graphql/CircularInputDefaultValuesTest.groovy b/src/test/groovy/graphql/CircularInputDefaultValuesTest.groovy new file mode 100644 index 0000000000..43a92975e3 --- /dev/null +++ b/src/test/groovy/graphql/CircularInputDefaultValuesTest.groovy @@ -0,0 +1,90 @@ +package graphql + +import graphql.schema.validation.InvalidSchemaException +import spock.lang.Specification + +/** + * Tests for mutually recursive input types with default values. + * + * These schemas are now rejected at build time by NoDefaultValueCircularRefs, + * which detects circular references in input object field default values. + * + * Previously, graphql-java accepted these schemas at build time but hit a + * StackOverflowError at query execution time when the circular defaults were + * expanded in ValuesResolverConversion.defaultValueToInternalValue. + */ +class CircularInputDefaultValuesTest extends Specification { + + def "mutually recursive input types with default values - rejected at schema build time"() { + when: + TestUtil.schema(''' + type Query { + test(arg: A): String + } + input A { b: B = {} } + input B { a: A = {} } + ''') + + then: + def e = thrown(InvalidSchemaException) + e.message.contains("Invalid circular reference") + } + + def "self-referential input type with default value - rejected at schema build time"() { + when: + TestUtil.schema(''' + type Query { + test(arg: A): String + } + input A { a: A = {} } + ''') + + then: + def e = thrown(InvalidSchemaException) + e.message.contains("Invalid circular reference") + } + + def "mutually recursive input types with default values - rejected before query execution"() { + when: + TestUtil.schema(''' + type Query { + test(arg: A): String + } + input A { b: B = {} } + input B { a: A = {} } + ''') + + then: + def e = thrown(InvalidSchemaException) + e.message.contains("Invalid circular reference") + } + + def "self-referential input type with default value - rejected before query execution"() { + when: + TestUtil.schema(''' + type Query { + test(arg: A): String + } + input A { a: A = {} } + ''') + + then: + def e = thrown(InvalidSchemaException) + e.message.contains("Invalid circular reference") + } + + def "mutually recursive defaults via argument default - rejected at schema build time"() { + when: + TestUtil.schema(''' + type Query { + test(arg: A = {}): String + } + input A { b: B = {} } + input B { a: A = {} } + ''') + + then: + def e = thrown(InvalidSchemaException) + e.message.contains("Invalid circular reference") + } +} diff --git a/src/test/groovy/graphql/schema/diffing/SchemaDiffingTest.groovy b/src/test/groovy/graphql/schema/diffing/SchemaDiffingTest.groovy index 46481fa7a5..bd263631ab 100644 --- a/src/test/groovy/graphql/schema/diffing/SchemaDiffingTest.groovy +++ b/src/test/groovy/graphql/schema/diffing/SchemaDiffingTest.groovy @@ -1446,20 +1446,20 @@ class SchemaDiffingTest extends Specification { def schema1 = schema(''' input I { name: String - field: I = {name: "default name"} + field: I = {name: "default name", field: null} } type Query { foo(arg: I): String - } + } ''') def schema2 = schema(''' input I { name: String - field: [I] = [{name: "default name"}] + field: [I] = [{name: "default name", field: null}] } type Query { foo(arg: I): String - } + } ''') when: diff --git a/src/test/groovy/graphql/schema/validation/NoDefaultValueCircularRefsTest.groovy b/src/test/groovy/graphql/schema/validation/NoDefaultValueCircularRefsTest.groovy new file mode 100644 index 0000000000..bd35006c25 --- /dev/null +++ b/src/test/groovy/graphql/schema/validation/NoDefaultValueCircularRefsTest.groovy @@ -0,0 +1,202 @@ +package graphql.schema.validation + +import graphql.TestUtil +import spock.lang.Specification + +class NoDefaultValueCircularRefsTest extends Specification { + def "self-referential default value is rejected"() { + when: + TestUtil.schema(''' + type Query { test(arg: A): String } + input A { x: A = {} } + ''') + + then: + def e = thrown(InvalidSchemaException) + e.message.contains("Invalid circular reference. The default value of Input Object field A.x references itself.") + } + + def "mutual recursion through defaults is rejected"() { + when: + TestUtil.schema(''' + type Query { test(arg: A): String } + input A { b: B = {} } + input B { a: A = {} } + ''') + + then: + def e = thrown(InvalidSchemaException) + e.message.contains("Invalid circular reference") + e.message.contains("A.b") + } + + def "transitive cycle through three types is rejected"() { + when: + TestUtil.schema(''' + type Query { test(arg: B): String } + input B { x: B2 = {} } + input B2 { x: B3 = {} } + input B3 { x: B = {} } + ''') + + then: + def e = thrown(InvalidSchemaException) + e.message.contains("Invalid circular reference. The default value of Input Object field B.x references itself via the default values of: B2.x, B3.x.") + } + + def "self-reference through list wrapping"() { + when: + TestUtil.schema(''' + type Query { test(arg: C): String } + input C { x: [C] = [{}] } + ''') + + then: + def e = thrown(InvalidSchemaException) + e.message.contains("Invalid circular reference. The default value of Input Object field C.x references itself.") + } + + def "nested default value that eventually cycles"() { + when: + TestUtil.schema(''' + type Query { test(arg: D): String } + input D { x: D = { x: { x: {} } } } + ''') + + then: + def e = thrown(InvalidSchemaException) + e.message.contains("Invalid circular reference. The default value of Input Object field D.x references itself.") + } + + def "cross-field cycle through defaults"() { + when: + TestUtil.schema(''' + type Query { test(arg: E): String } + input E { + x: E = { x: null } + y: E = { y: null } + } + ''') + + then: + def e = thrown(InvalidSchemaException) + e.message.contains("Invalid circular reference. The default value of Input Object field E.x references itself via the default values of: E.y.") + } + + def "cycle through non-null wrapping"() { + when: + TestUtil.schema(''' + type Query { test(arg: F): String } + input F { x: F2! = {} } + input F2 { x: F = { x: {} } } + ''') + + then: + def e = thrown(InvalidSchemaException) + e.message.contains("Invalid circular reference. The default value of Input Object field F2.x references itself.") + } + + def "partial default with non-provided recursive field"() { + when: + TestUtil.schema(''' + type Query { test(arg: A): String } + input A { x: B = {name: "hi"} } + input B { + name: String + a: A = {} + } + ''') + + then: + def e = thrown(InvalidSchemaException) + e.message.contains("Invalid circular reference. The default value of Input Object field A.x references itself via the default values of: B.a.") + } + + def "multiple independent cycles are reported"() { + when: + TestUtil.schema(''' + type Query { test(a: A, b: P): String } + input A { x: A = {} } + input P { x: P = {} } + ''') + + then: + def e = thrown(InvalidSchemaException) + e.message.contains("A.x references itself") + e.message.contains("P.x references itself") + } + + def "explicit field in default breaks cycle"() { + when: + def schema = TestUtil.schema(''' + type Query { test(arg: A): String } + input A { b: B = {a: null} } + input B { a: A = {} } + ''') + + then: + noExceptionThrown() + schema.getType("A") != null + } + + def "recursive field without default does not cycle"() { + when: + def schema = TestUtil.schema(''' + type Query { test(arg: A): String } + input A { b: B = {} } + input B { a: A } + ''') + + then: + noExceptionThrown() + schema.getType("A") != null + } + + def "scalar default value does not cycle"() { + when: + def schema = TestUtil.schema(''' + type Query { test(arg: A): String } + input A { name: String = "hi" } + ''') + + then: + noExceptionThrown() + schema.getType("A") != null + } + + def "null literal default does not cycle"() { + when: + def schema = TestUtil.schema(''' + type Query { test(arg: A): String } + input A { x: A = null } + ''') + + then: + noExceptionThrown() + schema.getType("A") != null + } + + def "empty list default does not cycle"() { + when: + def schema = TestUtil.schema(''' + type Query { test(arg: A): String } + input A { x: [A] = [] } + ''') + + then: + noExceptionThrown() + schema.getType("A") != null + } + + def "explicit null on recursive field breaks self-reference"() { + when: + def schema = TestUtil.schema(''' + type Query { test(arg: A): String } + input A { x: A = {x: null} } + ''') + + then: + noExceptionThrown() + schema.getType("A") != null + } +} diff --git a/src/test/groovy/graphql/schema/validation/SchemaValidatorTest.groovy b/src/test/groovy/graphql/schema/validation/SchemaValidatorTest.groovy index 706542df8d..a854129d41 100644 --- a/src/test/groovy/graphql/schema/validation/SchemaValidatorTest.groovy +++ b/src/test/groovy/graphql/schema/validation/SchemaValidatorTest.groovy @@ -11,15 +11,16 @@ class SchemaValidatorTest extends Specification { def validator = new SchemaValidator() def rules = validator.rules then: - rules.size() == 9 + rules.size() == 10 rules[0] instanceof NoUnbrokenInputCycles - rules[1] instanceof TypesImplementInterfaces - rules[2] instanceof TypeAndFieldRule - rules[3] instanceof DefaultValuesAreValid - rules[4] instanceof AppliedDirectivesAreValid - rules[5] instanceof AppliedDirectiveArgumentsAreValid - rules[6] instanceof InputAndOutputTypesUsedAppropriately - rules[7] instanceof OneOfInputObjectRules - rules[8] instanceof DeprecatedInputObjectAndArgumentsAreValid + rules[1] instanceof NoDefaultValueCircularRefs + rules[2] instanceof TypesImplementInterfaces + rules[3] instanceof TypeAndFieldRule + rules[4] instanceof DefaultValuesAreValid + rules[5] instanceof AppliedDirectivesAreValid + rules[6] instanceof AppliedDirectiveArgumentsAreValid + rules[7] instanceof InputAndOutputTypesUsedAppropriately + rules[8] instanceof OneOfInputObjectRules + rules[9] instanceof DeprecatedInputObjectAndArgumentsAreValid } }