diff --git a/sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/visitors/inner/ExpressionTranslator.java b/sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/visitors/inner/ExpressionTranslator.java index 18a6dca13fb..cae8f64c233 100644 --- a/sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/visitors/inner/ExpressionTranslator.java +++ b/sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/visitors/inner/ExpressionTranslator.java @@ -80,8 +80,14 @@ protected void clear() { this.translationMap.clear(); } + boolean done(IDBSPInnerNode node) { + return this.maybeGet(node) != null; + } + @Override public void postorder(DBSPApplyExpression node) { + if (this.done(node)) + return; DBSPExpression function = this.getE(node.function); DBSPExpression[] args = this.get(node.arguments); DBSPExpression result = new DBSPApplyExpression(function, node.getType(), args); @@ -90,6 +96,8 @@ public void postorder(DBSPApplyExpression node) { @Override public void postorder(DBSPApplyMethodExpression node) { + if (this.done(node)) + return; DBSPExpression function = this.getE(node.function); DBSPExpression[] args = this.get(node.arguments); DBSPExpression self = this.getE(node.self); @@ -99,6 +107,8 @@ public void postorder(DBSPApplyMethodExpression node) { @Override public void postorder(DBSPArrayExpression node) { + if (this.done(node)) + return; if (node.data == null) { this.map(node, node); } else { @@ -109,6 +119,8 @@ public void postorder(DBSPArrayExpression node) { @Override public void postorder(DBSPAssignmentExpression node) { + if (this.done(node)) + return; DBSPExpression left = this.getE(node.left); DBSPExpression right = this.getE(node.right); this.map(node, new DBSPAssignmentExpression(left, right)); @@ -116,6 +128,8 @@ public void postorder(DBSPAssignmentExpression node) { @Override public void postorder(DBSPBinaryExpression node) { + if (this.done(node)) + return; DBSPExpression left = this.getE(node.left); DBSPExpression right = this.getE(node.right); this.map(node, new DBSPBinaryExpression(node.getNode(), @@ -127,6 +141,8 @@ public void postorder(DBSPBinaryExpression node) { @Override public void postorder(DBSPTimeAddSub node) { + if (this.done(node)) + return; DBSPExpression left = this.getE(node.left); DBSPExpression right = this.getE(node.right); this.map(node, new DBSPTimeAddSub(node.getNode(), @@ -138,6 +154,8 @@ public void postorder(DBSPTimeAddSub node) { @Override public void postorder(DBSPBlockExpression node) { + if (this.done(node)) + return; List statements = Linq.map(node.contents, c -> this.get(c).to(DBSPStatement.class)); DBSPExpression lastExpression = this.getEN(node.lastExpression); @@ -146,36 +164,48 @@ public void postorder(DBSPBlockExpression node) { @Override public VisitDecision preorder(DBSPType type) { + if (this.done(type)) + return VisitDecision.STOP; this.set(type, type); return VisitDecision.STOP; } @Override public void postorder(DBSPBorrowExpression node) { + if (this.done(node)) + return; DBSPExpression expression = this.getE(node.expression); this.map(node, expression.borrow(node.mut)); } @Override public void postorder(DBSPCastExpression node) { + if (this.done(node)) + return; DBSPExpression expression = this.getE(node.source); this.map(node, new DBSPCastExpression(node.getNode(), expression, node.getType(), node.safe)); } @Override public void postorder(DBSPCloneExpression node) { + if (this.done(node)) + return; DBSPExpression expression = this.getE(node.expression); this.map(node, expression.applyClone()); } @Override public void postorder(DBSPClosureExpression node) { + if (this.done(node)) + return; DBSPExpression body = this.getE(node.body); this.map(node, new DBSPClosureExpression(node.getNode(), body, node.parameters)); } @Override public void postorder(DBSPConditionalIncrementExpression node) { + if (this.done(node)) + return; DBSPExpression left = this.getE(node.left); DBSPExpression right = this.getE(node.right); DBSPExpression condition = this.getEN(node.condition); @@ -185,6 +215,8 @@ public void postorder(DBSPConditionalIncrementExpression node) { @Override public void postorder(DBSPFold fold) { + if (this.done(fold)) + return; DBSPExpression zero = this.getE(fold.zero); DBSPClosureExpression increment = this.getE(fold.increment).to(DBSPClosureExpression.class); DBSPClosureExpression postProcessing = this.getE(fold.postProcess).to(DBSPClosureExpression.class); @@ -194,6 +226,8 @@ public void postorder(DBSPFold fold) { @Override public void postorder(DBSPMinMax aggregator) { + if (this.done(aggregator)) + return; DBSPExpression post = this.getEN(aggregator.postProcessing); @Nullable DBSPClosureExpression postClosure = post != null ? post.to(DBSPClosureExpression.class) : null; DBSPMinMax result = new DBSPMinMax(aggregator.getNode(), aggregator.getType(), postClosure, aggregator.aggregation); @@ -202,6 +236,8 @@ public void postorder(DBSPMinMax aggregator) { @Override public void postorder(DBSPConstructorExpression node) { + if (this.done(node)) + return; DBSPExpression function = this.getE(node.function); DBSPExpression[] arguments = this.get(node.arguments); this.map(node, new DBSPConstructorExpression(function, node.getType(), arguments)); @@ -209,6 +245,8 @@ public void postorder(DBSPConstructorExpression node) { @Override public void postorder(DBSPCustomOrdExpression node) { + if (this.done(node)) + return; DBSPExpression source = this.getE(node.source); DBSPExpression comparator = this.getE(node.comparator); this.map(node, new DBSPCustomOrdExpression(node.getNode(), source, comparator.to(DBSPComparatorExpression.class))); @@ -216,18 +254,24 @@ public void postorder(DBSPCustomOrdExpression node) { @Override public void postorder(DBSPCustomOrdField node) { + if (this.done(node)) + return; DBSPExpression expression = this.getE(node.expression); this.map(node, new DBSPCustomOrdField(expression, node.fieldNo)); } @Override public void postorder(DBSPDerefExpression node) { + if (this.done(node)) + return; DBSPExpression expression = this.getE(node.expression); this.map(node, new DBSPDerefExpression(expression)); } @Override public void postorder(DBSPDirectComparatorExpression node) { + if (this.done(node)) + return; DBSPExpression source = this.getE(node.source); this.map(node, new DBSPDirectComparatorExpression( node.getNode(), source.to(DBSPComparatorExpression.class), node.ascending)); @@ -235,22 +279,30 @@ public void postorder(DBSPDirectComparatorExpression node) { @Override public void postorder(DBSPExpression node) { + if (this.done(node)) + return; this.map(node, node); } @Override public void postorder(DBSPExpressionStatement node) { + if (this.done(node)) + return; DBSPExpression expression = this.getE(node.expression); this.map(node, new DBSPExpressionStatement(expression)); } @Override public void postorder(DBSPComment node) { + if (this.done(node)) + return; this.map(node, node); } @Override public void postorder(DBSPFieldComparatorExpression node) { + if (this.done(node)) + return; DBSPExpression source = this.getE(node.source); this.map(node, new DBSPFieldComparatorExpression( node.getNode(), source.to(DBSPComparatorExpression.class), @@ -259,12 +311,16 @@ public void postorder(DBSPFieldComparatorExpression node) { @Override public void postorder(DBSPFieldExpression node) { + if (this.done(node)) + return; DBSPExpression expression = this.getE(node.expression); this.map(node, new DBSPFieldExpression(node.getNode(), expression, node.fieldNo)); } @Override public void postorder(DBSPFlatmap node) { + if (this.done(node)) + return; DBSPExpression collectionExpression = this.getE(node.collectionExpression); List rightProjections = null; if (node.rightProjections != null) @@ -286,6 +342,8 @@ public void postorder(DBSPForExpression node) { @Override public void postorder(DBSPGeoPointConstructor node) { + if (this.done(node)) + return; DBSPExpression left = this.getEN(node.left); DBSPExpression right = this.getEN(node.right); this.map(node, new DBSPGeoPointConstructor(node.getNode(), left, right, node.type)); @@ -293,6 +351,8 @@ public void postorder(DBSPGeoPointConstructor node) { @Override public void postorder(DBSPHandleErrorExpression node) { + if (this.done(node)) + return; DBSPExpression source = this.getE(node.source); this.map(node, new DBSPHandleErrorExpression(node.getNode(), node.index, node.runtimeBehavior, source, node.hasSourcePosition)); @@ -300,11 +360,15 @@ public void postorder(DBSPHandleErrorExpression node) { @Override public void postorder(DBSPStructItem node) { + if (this.done(node)) + return; this.map(node, node); } @Override public void postorder(DBSPIfExpression node) { + if (this.done(node)) + return; DBSPExpression condition = this.getE(node.condition); DBSPExpression positive = this.getE(node.positive); DBSPExpression negative = this.getEN(node.negative); @@ -313,16 +377,22 @@ public void postorder(DBSPIfExpression node) { @Override public void postorder(DBSPIndexedZSetExpression node) { + if (this.done(node)) + return; this.map(node, node); } public void postorder(DBSPIsNullExpression node) { + if (this.done(node)) + return; DBSPExpression expression = this.getE(node.expression); this.map(node, new DBSPIsNullExpression(node.getNode(), expression)); } @Override public VisitDecision preorder(DBSPLetExpression node) { + if (this.done(node)) + return VisitDecision.STOP; // This one is done in preorder node.initializer.accept(this); DBSPExpression initializer = this.getE(node.initializer); @@ -336,6 +406,8 @@ public VisitDecision preorder(DBSPLetExpression node) { @Override public void postorder(DBSPLetStatement node) { + if (this.done(node)) + return; DBSPExpression initializer = this.getEN(node.initializer); DBSPStatement result; if (initializer != null) @@ -347,11 +419,15 @@ public void postorder(DBSPLetStatement node) { @Override public void postorder(DBSPLiteral node) { + if (this.done(node)) + return; this.map(node, node); } @Override public void postorder(DBSPMapExpression node) { + if (this.done(node)) + return; List keys = null; if (node.keys != null) keys = Linq.map(node.keys, this::getE); @@ -363,28 +439,38 @@ public void postorder(DBSPMapExpression node) { @Override public void postorder(DBSPNoComparatorExpression node) { + if (this.done(node)) + return; this.map(node, node); } @Override public void postorder(DBSPPathExpression node) { + if (this.done(node)) + return; this.map(node, node); } @Override public void postorder(DBSPQualifyTypeExpression node) { + if (this.done(node)) + return; DBSPExpression expression = this.getE(node.expression); this.map(node, new DBSPQualifyTypeExpression(expression, node.types)); } @Override public void postorder(DBSPQuestionExpression node) { + if (this.done(node)) + return; DBSPExpression source = this.getE(node.source); this.map(node, source.question()); } @Override public void postorder(DBSPRawTupleExpression node) { + if (this.done(node)) + return; if (node.fields != null) { DBSPExpression[] fields = this.get(node.fields); DBSPExpression result = new DBSPRawTupleExpression( @@ -397,18 +483,24 @@ public void postorder(DBSPRawTupleExpression node) { @Override public void postorder(DBSPReturnExpression node) { + if (this.done(node)) + return; DBSPExpression argument = this.getE(node.argument); this.map(node, new DBSPReturnExpression(node.getNode(), argument)); } @Override public void postorder(DBSPSomeExpression node) { + if (this.done(node)) + return; DBSPExpression expression = this.getE(node.expression); this.map(node, new DBSPSomeExpression(node.getNode(), expression)); } @Override public void postorder(DBSPSortExpression node) { + if (this.done(node)) + return; DBSPExpression comparator = this.getE(node.comparator); this.map(node, new DBSPSortExpression(node.getNode(), node.elementType, comparator.to(DBSPComparatorExpression.class))); @@ -416,12 +508,16 @@ public void postorder(DBSPSortExpression node) { @Override public void postorder(DBSPStaticExpression node) { + if (this.done(node)) + return; DBSPExpression initializer = this.getE(node.initializer); this.map(node, new DBSPStaticExpression(node.getNode(), initializer, node.getName())); } @Override public void postorder(DBSPTupleExpression node) { + if (this.done(node)) + return; if (node.fields != null) { DBSPExpression[] fields = this.get(node.fields); DBSPExpression result = new DBSPTupleExpression(node.getNode(), node.getType().to(DBSPTypeTuple.class), fields); @@ -433,12 +529,16 @@ public void postorder(DBSPTupleExpression node) { @Override public void postorder(DBSPUnaryExpression node) { + if (this.done(node)) + return; DBSPExpression source = this.getE(node.source); this.map(node, new DBSPUnaryExpression(node.getNode(), node.type, node.opcode, source)); } @Override public void postorder(DBSPUnsignedUnwrapExpression node) { + if (this.done(node)) + return; DBSPExpression source = this.getE(node.source); this.map(node, new DBSPUnsignedUnwrapExpression( node.getNode(), source, node.type, node.ascending, node.nullsLast)); @@ -446,6 +546,8 @@ public void postorder(DBSPUnsignedUnwrapExpression node) { @Override public void postorder(DBSPUnsignedWrapExpression node) { + if (this.done(node)) + return; DBSPExpression source = this.getE(node.source); this.map(node, new DBSPUnsignedWrapExpression( node.getNode(), source, node.ascending, node.nullsLast)); @@ -453,46 +555,62 @@ public void postorder(DBSPUnsignedWrapExpression node) { @Override public void postorder(DBSPUnwrapCustomOrdExpression node) { + if (this.done(node)) + return; DBSPExpression expression = this.getE(node.expression); this.map(node, new DBSPUnwrapCustomOrdExpression(expression)); } @Override public void postorder(DBSPUnwrapExpression node) { + if (this.done(node)) + return; DBSPExpression expression = this.getE(node.expression); this.map(node, new DBSPUnwrapExpression(node.message, expression)); } @Override public void postorder(DBSPFailExpression node) { + if (this.done(node)) + return; this.map(node, node); } @Override public void postorder(DBSPVariablePath node) { + if (this.done(node)) + return; this.map(node, node); } @Override public void postorder(DBSPVariantExpression node) { + if (this.done(node)) + return; DBSPExpression value = this.getEN(node.value); this.map(node, new DBSPVariantExpression(value, node.getType().mayBeNull)); } @Override public void postorder(DBSPLazyExpression node) { + if (this.done(node)) + return; DBSPExpression expression = this.getE(node.expression); this.map(node, new DBSPLazyExpression(expression)); } @Override public void postorder(DBSPWindowBoundExpression node) { + if (this.done(node)) + return; DBSPExpression representation = this.getE(node.representation); this.map(node, new DBSPWindowBoundExpression(node.getNode(), node.isPreceding, representation)); } @Override public void postorder(DBSPZSetExpression node) { + if (this.done(node)) + return; Map data = new HashMap<>(); for (Map.Entry e : node.data.entrySet()) { DBSPExpression key = this.getE(e.getKey()); @@ -506,6 +624,8 @@ public void postorder(DBSPZSetExpression node) { @Override public void postorder(LinearAggregate node) { + if (this.done(node)) + return; DBSPExpression map = this.getE(node.map); DBSPExpression postProcess = this.getE(node.postProcess); DBSPExpression emptySetResult = this.getE(node.emptySetResult); @@ -515,11 +635,15 @@ public void postorder(LinearAggregate node) { @Override public void postorder(NoExpression node) { + if (this.done(node)) + return; this.map(node, node); } @Override public void postorder(NonLinearAggregate node) { + if (this.done(node)) + return; DBSPExpression zero = this.getE(node.zero); DBSPExpression increment = this.getE(node.increment); DBSPExpression postProcess = this.getEN(node.postProcess); @@ -539,6 +663,8 @@ public IDBSPInnerNode apply(IDBSPInnerNode node) { @Override public void postorder(DBSPFunction function) { + if (this.done(function)) + return; DBSPExpression body = this.getEN(function.body); DBSPFunction result = new DBSPFunction(function.getNode(), function.name, function.parameters, function.returnType, body, function.annotations); @@ -551,6 +677,8 @@ public void postorder(DBSPFunction function) { @Override public void postorder(DBSPAggregateList aggregate) { + if (this.done(aggregate)) + return; DBSPExpression rowVar = this.getE(aggregate.rowVar); List implementations = Linq.map(aggregate.aggregates, c -> { @@ -568,12 +696,16 @@ public void postorder(DBSPAggregateList aggregate) { @Override public void postorder(DBSPFunctionItem item) { + if (this.done(item)) + return; IDBSPInnerNode result = this.get(item.function); this.map(item, new DBSPFunctionItem(result.to(DBSPFunction.class))); } @Override public void postorder(DBSPStaticItem item) { + if (this.done(item)) + return; DBSPExpression expression = this.getE(item.expression); DBSPItem result = new DBSPStaticItem(expression.to(DBSPStaticExpression.class)); this.map(item, result); diff --git a/sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/visitors/inner/Simplify.java b/sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/visitors/inner/Simplify.java index 1412926206b..966089f6c07 100644 --- a/sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/visitors/inner/Simplify.java +++ b/sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/visitors/inner/Simplify.java @@ -528,6 +528,22 @@ public void postorder(DBSPIfExpression expression) { negative.isCompileTimeConstant() && positive.equivalent(negative)) { result = positive; + } else if (negative != null && + // if (condition) then { true } else { false } == condition + positive.getType().sameType(condition.getType()) && + positive.is(DBSPBoolLiteral.class) && + positive.to(DBSPBoolLiteral.class).hasValue(true) && + negative.is(DBSPBoolLiteral.class) && + negative.to(DBSPBoolLiteral.class).hasValue(false)) { + result = condition; + } else if (negative != null && + // if (condition) then { false } else { true } == !condition + positive.getType().sameType(condition.getType()) && + positive.is(DBSPBoolLiteral.class) && + positive.to(DBSPBoolLiteral.class).hasValue(false) && + negative.is(DBSPBoolLiteral.class) && + negative.to(DBSPBoolLiteral.class).hasValue(true)) { + result = condition.not(); } else if (condition != expression.condition || positive != expression.positive || negative != expression.negative) { @@ -903,7 +919,9 @@ protected void map(DBSPExpression expression, DBSPExpression result) { .appendSupplier(result::toString) .newline(); } - Utilities.enforce(expression.getType().sameType(result.getType())); + Utilities.enforce(expression.getType().sameType(result.getType()), + () -> "Expression with type " + expression.getType() + " has type " + result.getType() + + " after simplification"); super.map(expression, result); } } diff --git a/sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/visitors/outer/ShareIndexes.java b/sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/visitors/outer/ShareIndexes.java index bbdb7e1f14f..cfb229e0438 100644 --- a/sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/visitors/outer/ShareIndexes.java +++ b/sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/compiler/visitors/outer/ShareIndexes.java @@ -21,10 +21,16 @@ import org.dbsp.sqlCompiler.compiler.DBSPCompiler; import org.dbsp.sqlCompiler.compiler.frontend.calciteObject.CalciteRelNode; import org.dbsp.sqlCompiler.compiler.visitors.inner.EquivalenceContext; +import org.dbsp.sqlCompiler.compiler.visitors.inner.ExpressionTranslator; import org.dbsp.sqlCompiler.compiler.visitors.inner.Projection; +import org.dbsp.sqlCompiler.compiler.visitors.inner.ResolveReferences; +import org.dbsp.sqlCompiler.ir.DBSPParameter; +import org.dbsp.sqlCompiler.ir.IDBSPInnerNode; import org.dbsp.sqlCompiler.ir.IDBSPOuterNode; import org.dbsp.sqlCompiler.ir.expression.DBSPClosureExpression; +import org.dbsp.sqlCompiler.ir.expression.DBSPDerefExpression; import org.dbsp.sqlCompiler.ir.expression.DBSPExpression; +import org.dbsp.sqlCompiler.ir.expression.DBSPFieldExpression; import org.dbsp.sqlCompiler.ir.expression.DBSPRawTupleExpression; import org.dbsp.sqlCompiler.ir.expression.DBSPTupleExpression; import org.dbsp.sqlCompiler.ir.expression.DBSPVariablePath; @@ -524,29 +530,100 @@ public void postorder(DBSPMapIndexOperator operator) { } } - record VarAndExpression(DBSPVariablePath var, DBSPExpression expression) {} + record ParameterIndexMap(DBSPVariablePath var, Map indexRemap) {} + + record ParameterIndexMapSet(Map map) { + public ParameterIndexMapSet() { + this(new HashMap<>()); + } + + void add(DBSPParameter param, ParameterIndexMap map) { + Utilities.putNew(this.map, param, map); + } + + @Nullable + ParameterIndexMap get(DBSPParameter param) { + return this.map.get(param); + } + } + + /** For a parameter param this is given a map from integer to integer and a variable. + * If the map[a] = b, this rewrites (*param).a to (*var).b */ + static class ParameterIndexRewriter extends ExpressionTranslator { + final ResolveReferences resolver; + final ParameterIndexMapSet rewriteMap; + + public ParameterIndexRewriter(DBSPCompiler compiler, ParameterIndexMapSet rewriteMap) { + super(compiler); + this.resolver = new ResolveReferences(compiler, false); + this.rewriteMap = rewriteMap; + } + + @Override + public void startVisit(IDBSPInnerNode node) { + super.startVisit(node); + this.resolver.apply(node); + } + + @Override + public void postorder(DBSPVariablePath var) { + if (this.maybeGet(var) != null) { + // Already translated + return; + } + var decl = this.resolver.reference.getDeclaration(var); + if (decl.is(DBSPParameter.class)) { + var map = this.rewriteMap.get(decl.to(DBSPParameter.class)); + if (map != null) { + this.map(var, map.var.deepCopy()); + return; + } + } + super.postorder(var); + } + + @Override + public void postorder(DBSPFieldExpression field) { + if (this.maybeGet(field) != null) { + // Already translated + return; + } + if (field.expression.is(DBSPDerefExpression.class)) { + var deref = field.expression.to(DBSPDerefExpression.class); + if (deref.expression.is(DBSPVariablePath.class)) { + var var = deref.expression.to(DBSPVariablePath.class); + var decl = this.resolver.reference.getDeclaration(var); + if (decl.is(DBSPParameter.class)) { + var map = this.rewriteMap.get(decl.to(DBSPParameter.class)); + if (map != null) { + Integer newField = map.indexRemap.get(field.fieldNo); + if (newField == null) + newField = field.fieldNo; + this.map(field, map.var.deepCopy().deref().field(newField)); + return; + } + } + } + } + super.postorder(field); + } + } record JoinSource(WideMapIndexBuilder builder, int consumerIndex) { - /** Create an expression that represents the value part of the field of the new join input */ - public VarAndExpression createValue(DBSPType expectedType) { - DBSPTypeTuple tuple = expectedType.deref().to(DBSPTypeTuple.class); + public ParameterIndexMap getParameterRemap() { DBSPMapIndexOperator source = this.builder.get(); // This is the new input for this join input var newVar = source.getOutputIndexedZSetType().elementType.ref().var(); - List valueFields = new ArrayList<>(); + // This is the list of fields from the value produced by the MapIndex that this join consumes List outputIndexes = builder.outputIndexes.get(consumerIndex); - int i = 0; - for (int index: outputIndexes) { - DBSPExpression field = newVar.deref().field(index); - if (field.getType().mayBeNull && !tuple.getFieldType(i).mayBeNull) - field = field.neverFailsUnwrap(); - valueFields.add(field.applyCloneIfNeeded()); - i++; + Map remap = new HashMap<>(); + for (int i = 0; i < outputIndexes.size(); i++) { + int index = outputIndexes.get(i); + remap.put(i, index); } - DBSPExpression expr = new DBSPTupleExpression(valueFields, builder.valueNullable).borrow(); - Utilities.enforce(expr.getType().sameType(expectedType)); - return new VarAndExpression(newVar, expr); + + return new ParameterIndexMap(newVar, remap); } @Override @@ -623,24 +700,25 @@ DBSPClosureExpression rewriteJoinClosure( DBSPClosureExpression closure, JoinInputs inputs) { - DBSPVariablePath keyVar = closure.parameters[0].type.var(); - DBSPVariablePath leftVar = closure.parameters[1].type.var(); - DBSPVariablePath rightVar = closure.parameters[2].type.var(); - DBSPExpression leftValue = leftVar; + DBSPVariablePath keyVar = closure.parameters[0].asVariable(); + DBSPVariablePath leftVar = closure.parameters[1].asVariable(); + DBSPVariablePath rightVar = closure.parameters[2].asVariable(); + ParameterIndexMapSet set = new ParameterIndexMapSet(); + if (inputs.left != null) { - var pair = inputs.left.createValue(closure.parameters[1].type); - leftValue = pair.expression; - leftVar = pair.var; + var remap = inputs.left.getParameterRemap(); + leftVar = remap.var; + set.add(closure.parameters[1], remap); } - DBSPExpression rightValue = rightVar; if (inputs.right != null) { - var pair = inputs.right.createValue(closure.parameters[2].type); - rightValue = pair.expression; - rightVar = pair.var; + var remap = inputs.right.getParameterRemap(); + rightVar = remap.var; + set.add(closure.parameters[2], remap); } - DBSPExpression call = closure.call(keyVar, leftValue, rightValue); - DBSPExpression reduced = call.reduce(this.compiler); - return reduced.closure(keyVar, leftVar, rightVar); + + ParameterIndexRewriter rewriter = new ParameterIndexRewriter(this.compiler, set); + DBSPClosureExpression result = rewriter.apply(closure).to(DBSPClosureExpression.class); + return result.body.closure(keyVar, leftVar, rightVar); } @Override diff --git a/sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/ir/expression/literal/DBSPBoolLiteral.java b/sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/ir/expression/literal/DBSPBoolLiteral.java index 5911de4b2b4..89043474575 100644 --- a/sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/ir/expression/literal/DBSPBoolLiteral.java +++ b/sql-to-dbsp-compiler/SQL-compiler/src/main/java/org/dbsp/sqlCompiler/ir/expression/literal/DBSPBoolLiteral.java @@ -63,6 +63,12 @@ public DBSPBoolLiteral(@Nullable Boolean b, boolean nullable) { throw new InternalCompilerError("Null value with non-nullable type", this); } + public boolean hasValue(@Nullable Boolean value) { + if (this.value == null) + return value == null; + return this.value == value; + } + @Override public DBSPExpression deepCopy() { return new DBSPBoolLiteral(this.getNode(), this.type, this.value); diff --git a/sql-to-dbsp-compiler/SQL-compiler/src/test/java/org/dbsp/sqlCompiler/compiler/sql/simple/IncrementalRegression2Tests.java b/sql-to-dbsp-compiler/SQL-compiler/src/test/java/org/dbsp/sqlCompiler/compiler/sql/simple/IncrementalRegression2Tests.java index d374da49d69..4d094cb814b 100644 --- a/sql-to-dbsp-compiler/SQL-compiler/src/test/java/org/dbsp/sqlCompiler/compiler/sql/simple/IncrementalRegression2Tests.java +++ b/sql-to-dbsp-compiler/SQL-compiler/src/test/java/org/dbsp/sqlCompiler/compiler/sql/simple/IncrementalRegression2Tests.java @@ -446,4 +446,60 @@ public void endVisit() { } }); } + + @Test + public void issue5935() { + var ccs = this.getCCS(""" + CREATE TABLE orders ( + order_id INT NOT NULL PRIMARY KEY, + order_date DATE NOT NULL, + billing_customer_id INT, + shipping_customer_id INT + ); + + CREATE TABLE customers ( + customer_id INT NOT NULL PRIMARY KEY, + customer_name VARCHAR(200) NOT NULL, + region TINYINT NOT NULL + ); + + CREATE VIEW V AS SELECT + o.order_id, + o.order_date, + bc.customer_name AS billing_name, + bc.region AS billing_region, + sc.customer_name AS shipping_name, + sc.region AS shipping_region + FROM orders AS o + LEFT JOIN customers AS bc + ON o.billing_customer_id = bc.customer_id + LEFT JOIN customers AS sc + ON o.shipping_customer_id = sc.customer_id + ORDER BY o.order_id;"""); + // Validated on Postgres + ccs.stepWeightOne(""" + INSERT INTO CUSTOMERS VALUES + (1, 'Alice', 0), + (2, 'Bob', 1), + (3, 'Carol', 2), + (4, 'Dave', 3); + INSERT INTO orders (order_id, order_date, billing_customer_id, shipping_customer_id) VALUES + -- both billing + shipping + (101, DATE '2024-01-10', 1, 2), + -- billing only + (102, DATE '2024-01-11', 3, NULL), + -- shipping only + (103, DATE '2024-01-12', NULL, 4), + -- neither + (104, DATE '2024-01-13', NULL, NULL), + -- same customer for both roles + (105, DATE '2024-01-14', 2, 2);""", """ + order_id | order_date | billing_name | billing_region | shipping_name | shipping_region + ------------------------------------------------------------------------------------------ + 101 | 2024-01-10 | Alice| 0 | Bob| 1 + 102 | 2024-01-11 | Carol| 2 |NULL | + 103 | 2024-01-12 |NULL | | Dave| 3 + 104 | 2024-01-13 |NULL | |NULL | + 105 | 2024-01-14 | Bob| 1 | Bob| 1"""); + } } diff --git a/sql-to-dbsp-compiler/SQL-compiler/src/test/java/org/dbsp/sqlCompiler/compiler/sql/tools/CompilerCircuitStream.java b/sql-to-dbsp-compiler/SQL-compiler/src/test/java/org/dbsp/sqlCompiler/compiler/sql/tools/CompilerCircuitStream.java index 121c71b4462..155212f8733 100644 --- a/sql-to-dbsp-compiler/SQL-compiler/src/test/java/org/dbsp/sqlCompiler/compiler/sql/tools/CompilerCircuitStream.java +++ b/sql-to-dbsp-compiler/SQL-compiler/src/test/java/org/dbsp/sqlCompiler/compiler/sql/tools/CompilerCircuitStream.java @@ -72,6 +72,24 @@ public void step(String script, String expected) { this.stream.addPair(input, output); } + /** Like step, but every record in the output has weight one, and the + * weight column is omitted for the expected output. */ + public void stepWeightOne(String script, String expected) { + String[] lines = expected.split("\n"); + boolean inHeader = true; + for (int i = 0; i < lines.length; i++) { + var l = lines[i]; + if (inHeader && l.contains("---")) { + inHeader = false; + continue; + } + if (inHeader) + continue; + lines[i] = l + "| 1"; + } + this.step(script, String.join("\n", lines)); + } + public void step(Change input, Change output) { this.stream.addPair(input, output); }