diff --git a/src/main/java/graphql/EngineRunningState.java b/src/main/java/graphql/EngineRunningState.java index 965fdb59fd..43b584805f 100644 --- a/src/main/java/graphql/EngineRunningState.java +++ b/src/main/java/graphql/EngineRunningState.java @@ -14,7 +14,9 @@ import static graphql.Assert.assertTrue; import static graphql.execution.EngineRunningObserver.RunningState.NOT_RUNNING; +import static graphql.execution.EngineRunningObserver.RunningState.NOT_RUNNING_FINISH; import static graphql.execution.EngineRunningObserver.RunningState.RUNNING; +import static graphql.execution.EngineRunningObserver.RunningState.RUNNING_START; @Internal public class EngineRunningState { @@ -26,6 +28,9 @@ public class EngineRunningState { @Nullable private volatile ExecutionId executionId; + // if true the last decrementRunning() call will be ignored + private volatile boolean finished; + private final AtomicInteger isRunning = new AtomicInteger(0); @VisibleForTesting @@ -148,7 +153,7 @@ private void decrementRunning() { return; } assertTrue(isRunning.get() > 0); - if (isRunning.decrementAndGet() == 0) { + if (isRunning.decrementAndGet() == 0 && !finished) { changeOfState(NOT_RUNNING); } } @@ -193,16 +198,20 @@ private void run(Runnable runnable) { /** * Only used once outside of this class: when the execution starts */ - public T call(Supplier supplier) { + public CompletableFuture engineRun(Supplier> engineRun) { if (engineRunningObserver == null) { - return supplier.get(); - } - incrementRunning(); - try { - return supplier.get(); - } finally { - decrementRunning(); + return engineRun.get(); } + isRunning.incrementAndGet(); + changeOfState(RUNNING_START); + + CompletableFuture erCF = engineRun.get(); + erCF = erCF.whenComplete((result, throwable) -> { + finished = true; + changeOfState(NOT_RUNNING_FINISH); + }); + decrementRunning(); + return erCF; } diff --git a/src/main/java/graphql/GraphQL.java b/src/main/java/graphql/GraphQL.java index 8c077a2a53..5d8dfc87d5 100644 --- a/src/main/java/graphql/GraphQL.java +++ b/src/main/java/graphql/GraphQL.java @@ -413,7 +413,7 @@ public CompletableFuture executeAsync(UnaryOperator executeAsync(ExecutionInput executionInput) { EngineRunningState engineRunningState = new EngineRunningState(executionInput); - return engineRunningState.call(() -> { + return engineRunningState.engineRun(() -> { ExecutionInput executionInputWithId = ensureInputHasId(executionInput); engineRunningState.updateExecutionId(executionInputWithId.getExecutionId()); diff --git a/src/main/java/graphql/execution/EngineRunningObserver.java b/src/main/java/graphql/execution/EngineRunningObserver.java index 008623eedc..c75f47706f 100644 --- a/src/main/java/graphql/execution/EngineRunningObserver.java +++ b/src/main/java/graphql/execution/EngineRunningObserver.java @@ -14,6 +14,10 @@ public interface EngineRunningObserver { enum RunningState { + /** + * Represents that the engine is running, for the first time + */ + RUNNING_START, /** * Represents that the engine code is actively running its own code */ @@ -22,6 +26,10 @@ enum RunningState { * Represents that the engine code is asynchronously waiting for fetching to happen */ NOT_RUNNING, + /** + * Represents that the engine is finished + */ + NOT_RUNNING_FINISH } diff --git a/src/test/groovy/graphql/EngineRunningTest.groovy b/src/test/groovy/graphql/EngineRunningTest.groovy index 46104653d4..592f7d2a1d 100644 --- a/src/test/groovy/graphql/EngineRunningTest.groovy +++ b/src/test/groovy/graphql/EngineRunningTest.groovy @@ -23,7 +23,9 @@ import static graphql.ExecutionInput.newExecutionInput import static graphql.execution.EngineRunningObserver.ENGINE_RUNNING_OBSERVER_KEY import static graphql.execution.EngineRunningObserver.RunningState import static graphql.execution.EngineRunningObserver.RunningState.NOT_RUNNING +import static graphql.execution.EngineRunningObserver.RunningState.NOT_RUNNING_FINISH import static graphql.execution.EngineRunningObserver.RunningState.RUNNING +import static graphql.execution.EngineRunningObserver.RunningState.RUNNING_START class EngineRunningTest extends Specification { @@ -70,13 +72,13 @@ class EngineRunningTest extends Specification { when: def er = graphQL.executeAsync(ei) then: - states == [RUNNING, NOT_RUNNING] + states == [RUNNING_START, NOT_RUNNING] when: states.clear() cf.complete(new PreparsedDocumentEntry(document)) then: - states == [RUNNING, NOT_RUNNING] + states == [RUNNING, NOT_RUNNING_FINISH] er.get().data == [hello: "world"] @@ -114,13 +116,13 @@ class EngineRunningTest extends Specification { when: def er = graphQL.executeAsync(ei) then: - states == [RUNNING, NOT_RUNNING] + states == [RUNNING_START, NOT_RUNNING] when: states.clear() cf.complete(new InstrumentationState() {}) then: - states == [RUNNING, NOT_RUNNING] + states == [RUNNING, NOT_RUNNING_FINISH] er.get().data == [hello: "world"] @@ -158,14 +160,14 @@ class EngineRunningTest extends Specification { when: def er = graphQL.executeAsync(ei) then: - states == [RUNNING, NOT_RUNNING] + states == [RUNNING_START, NOT_RUNNING] when: states.clear() cf.complete(ExecutionResultImpl.newExecutionResult().data([hello: "world-modified"]).build()) then: er.get().data == [hello: "world-modified"] - states == [RUNNING, NOT_RUNNING] + states == [RUNNING, NOT_RUNNING_FINISH] } @@ -195,7 +197,7 @@ class EngineRunningTest extends Specification { def er = graphQL.execute(ei) then: er.data == [hello: "world"] - states == [RUNNING, NOT_RUNNING] + states == [RUNNING_START, NOT_RUNNING_FINISH] } def "multiple async DF"() { @@ -251,7 +253,7 @@ class EngineRunningTest extends Specification { when: def er = graphQL.executeAsync(ei) then: - states == [RUNNING, NOT_RUNNING] + states == [RUNNING_START, NOT_RUNNING] when: states.clear(); @@ -270,7 +272,7 @@ class EngineRunningTest extends Specification { states.clear() cf2.complete("world2") then: - states == [RUNNING, NOT_RUNNING] + states == [RUNNING, NOT_RUNNING_FINISH] er.get().data == [hello: "world", hello2: "world2", foo: [name: "FooName"], someStaticValue: [staticValue: "staticValue"]] } @@ -299,14 +301,14 @@ class EngineRunningTest extends Specification { when: def er = graphQL.executeAsync(ei) then: - states == [RUNNING, NOT_RUNNING] + states == [RUNNING_START, NOT_RUNNING] when: states.clear(); cf.complete("world") then: - states == [RUNNING, NOT_RUNNING] + states == [RUNNING, NOT_RUNNING_FINISH] er.get().data == [hello: "world"] } @@ -334,7 +336,7 @@ class EngineRunningTest extends Specification { when: def er = graphQL.executeAsync(ei) then: - states == [RUNNING, NOT_RUNNING] + states == [RUNNING_START, NOT_RUNNING] when: states.clear(); @@ -342,7 +344,7 @@ class EngineRunningTest extends Specification { then: er.get().data == [hello: "world"] - states == [RUNNING, NOT_RUNNING] + states == [RUNNING, NOT_RUNNING_FINISH] } @@ -387,8 +389,7 @@ class EngineRunningTest extends Specification { then: result.errors.collect { it.message } == ["recovered"] - // we expect simply going from running to finshed - states == [RUNNING, NOT_RUNNING] + states == [RUNNING, NOT_RUNNING_FINISH] } def "async datafetcher failing with async exception handler"() { @@ -429,7 +430,7 @@ class EngineRunningTest extends Specification { def er = graphQL.executeAsync(ei) then: - states == [RUNNING, NOT_RUNNING] + states == [RUNNING_START, NOT_RUNNING] when: states.clear() @@ -445,8 +446,7 @@ class EngineRunningTest extends Specification { then: result.errors.collect { it.message } == ["recovered"] - // we expect simply going from running to finshed - new ArrayList<>(states) == [RUNNING, NOT_RUNNING] + states == [RUNNING, NOT_RUNNING_FINISH] } @@ -480,7 +480,7 @@ class EngineRunningTest extends Specification { when: def er = graphQL.executeAsync(ei) then: - states == [RUNNING, NOT_RUNNING] + states == [RUNNING_START, NOT_RUNNING] when: states.clear(); @@ -494,7 +494,7 @@ class EngineRunningTest extends Specification { cf2.complete("world2") then: - states == [RUNNING, NOT_RUNNING] + states == [RUNNING, NOT_RUNNING_FINISH] er.get().data == [hello: "world", hello2: "world2"] } }