From e669d760015048d5a9a68914622337e977e45b6a Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Thu, 19 May 2016 08:10:30 -0700 Subject: [PATCH] start work on catalyst implementation --- project/Dependencies.scala | 2 +- .../org/apache/spark/LoggingWrapper.scala | 14 + .../org/tensorframes/ColumnInformation.scala | 2 +- .../tensorframes/ExperimentalOperations.scala | 2 +- .../scala/org/tensorframes/Operations.scala | 17 +- .../catalyst/CatalystOperations.scala | 80 ++++ .../catalyst/TestMapBlockPlan.scala | 72 ++++ .../scala/org/tensorframes/dsl/DslImpl.scala | 2 +- .../org/tensorframes/dsl/Implicits.scala | 4 +- .../org/tensorframes/dsl/Operation.scala | 2 +- src/main/scala/org/tensorframes/dsl/Ops.scala | 4 +- .../scala/org/tensorframes/dsl/Paths.scala | 2 +- .../scala/org/tensorframes/impl/DataOps.scala | 2 +- .../org/tensorframes/impl/DebugRowOps.scala | 337 +---------------- .../tensorframes/impl/PythonInterface.scala | 6 +- .../tensorframes/impl/SchemaTransforms.scala | 347 ++++++++++++++++++ .../org/tensorframes/impl/TensorFlowOps.scala | 2 +- .../org/tensorframes/impl/datatypes.scala | 2 +- .../org/tensorframes/test/DslOperations.scala | 4 +- .../scala/org/tensorframes/test/dsl.scala | 2 +- .../tensorframes/BasicOperationsSuite.scala | 5 +- .../tensorframes/CommonOperationsSuite.scala | 2 +- .../org/tensorframes/DSLOperationsSuite.scala | 2 +- .../org/tensorframes/DebugRowOpsSuite.scala | 2 +- .../tensorframes/ExtraOperationsSuite.scala | 2 +- .../scala/org/tensorframes/SlicingSuite.scala | 2 +- .../tensorframes/TFInitializationSuite.scala | 2 +- .../TrimmingOperationsSuite.scala | 2 +- .../org/tensorframes/dsl/BasicOpsSuite.scala | 2 +- .../org/tensorframes/dsl/BasicSuite.scala | 2 +- .../org/tensorframes/dsl/ExtractNodes.scala | 2 +- .../perf/ConvertBackPerformanceSuite.scala | 2 +- .../perf/ConvertPerformanceSuite.scala | 2 +- .../tensorframes/perf/PerformanceSuite.scala | 2 +- .../scala/org/tensorframes/type_suites.scala | 3 +- 35 files changed, 576 insertions(+), 363 deletions(-) create mode 100644 src/main/scala/org/apache/spark/LoggingWrapper.scala create mode 100644 src/main/scala/org/tensorframes/catalyst/CatalystOperations.scala create mode 100644 src/main/scala/org/tensorframes/catalyst/TestMapBlockPlan.scala create mode 100644 src/main/scala/org/tensorframes/impl/SchemaTransforms.scala diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 5c8a90c..6db3c23 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -12,7 +12,7 @@ import xml.transform.{RuleTransformer, RewriteRule} object Dependencies { // The spark version - val targetSparkVersion = "1.6.1" + val targetSparkVersion = "2.0.0-SNAPSHOT" val targetJCPPVersion = "1.2" diff --git a/src/main/scala/org/apache/spark/LoggingWrapper.scala b/src/main/scala/org/apache/spark/LoggingWrapper.scala new file mode 100644 index 0000000..7d62e7a --- /dev/null +++ b/src/main/scala/org/apache/spark/LoggingWrapper.scala @@ -0,0 +1,14 @@ +package org.apache.spark + +import org.apache.spark.internal.Logging + +trait LoggingWrapper extends Logging { + + override protected def logInfo(msg: => String): Unit = { + super.logInfo(msg) + } + + override protected def logTrace(msg: => String): Unit = { + super.logTrace(msg) + } +} diff --git a/src/main/scala/org/tensorframes/ColumnInformation.scala b/src/main/scala/org/tensorframes/ColumnInformation.scala index 3e0dac5..b390b53 100644 --- a/src/main/scala/org/tensorframes/ColumnInformation.scala +++ b/src/main/scala/org/tensorframes/ColumnInformation.scala @@ -1,6 +1,6 @@ package org.tensorframes -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.types._ diff --git a/src/main/scala/org/tensorframes/ExperimentalOperations.scala b/src/main/scala/org/tensorframes/ExperimentalOperations.scala index 05d4548..86171c4 100644 --- a/src/main/scala/org/tensorframes/ExperimentalOperations.scala +++ b/src/main/scala/org/tensorframes/ExperimentalOperations.scala @@ -1,6 +1,6 @@ package org.tensorframes -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.{ArrayType, DataType, NumericType} diff --git a/src/main/scala/org/tensorframes/Operations.scala b/src/main/scala/org/tensorframes/Operations.scala index 089fadc..46e1049 100644 --- a/src/main/scala/org/tensorframes/Operations.scala +++ b/src/main/scala/org/tensorframes/Operations.scala @@ -1,6 +1,6 @@ package org.tensorframes -import org.apache.spark.sql.{GroupedData, Row, DataFrame} +import org.apache.spark.sql.{RelationalGroupedDataset, Row, DataFrame} import org.tensorflow.framework.GraphDef @@ -123,13 +123,26 @@ trait OperationsInterface { * @param shapeHints some hints for the shape. * @return */ - def aggregate(data: GroupedData, graph: GraphDef, shapeHints: ShapeDescription): DataFrame + def aggregate(data: RelationalGroupedDataset, graph: GraphDef, shapeHints: ShapeDescription): DataFrame /** * A string that contains detailed information about a dataframe, in particular relevant information * with respect to TensorFlow. + * * @param df * @return */ def explain(df: DataFrame): String +} + +object OperationsInterface { + // Developer API that may get removed in the future. + // If true, will try to use Catalyst hooks to accelerate data transfers between Spark and + // TensorFlow. This is highly experimental and may not work for all tensor shapes. + // Only use it if you are developing the Catalyst integration. + def enableCatalystOptimizations(set: Boolean): Unit = { + enableCatalystHook = set + } + + private[tensorframes] var enableCatalystHook: Boolean = false } \ No newline at end of file diff --git a/src/main/scala/org/tensorframes/catalyst/CatalystOperations.scala b/src/main/scala/org/tensorframes/catalyst/CatalystOperations.scala new file mode 100644 index 0000000..a2fcbb4 --- /dev/null +++ b/src/main/scala/org/tensorframes/catalyst/CatalystOperations.scala @@ -0,0 +1,80 @@ +package org.tensorframes.catalyst + +import org.apache.spark.{LoggingWrapper => Logging} +import org.apache.spark.sql._ +import org.tensorflow.framework.GraphDef +import org.tensorframes.impl.{TensorFlowOps, DebugRowOps} +import org.tensorframes.impl.SchemaTransforms._ +import org.tensorframes.{ShapeDescription, OperationsInterface} + +/** + * Optimized implementation of the TensorFrames operation that hooks directly into the catalyst + * compiler. + */ +object CatalystOperations extends OperationsInterface with Logging { + + override def mapRows( + dataframe: DataFrame, + graph: GraphDef, + shapeHints: ShapeDescription): DataFrame = { + DebugRowOps.mapRows(dataframe, graph, shapeHints) + } + + override def mapBlocks( + dataframe: DataFrame, + graph: GraphDef, + shapeHints: ShapeDescription): DataFrame = { + DebugRowOps.mapBlocks(dataframe, graph, shapeHints) + } + + override def mapBlocksTrimmed( + dataframe: DataFrame, + graph: GraphDef, + shapeHints: ShapeDescription): DataFrame = { + DebugRowOps.mapBlocksTrimmed(dataframe, graph, shapeHints) + } + + override def reduceRows( + dataframe: DataFrame, + graph: GraphDef, + shapeHints: ShapeDescription): Row = { + DebugRowOps.reduceRows(dataframe, graph, shapeHints) + } + + override def reduceBlocks( + dataframe: DataFrame, + graph: GraphDef, + shapeHints: ShapeDescription): Row = { + DebugRowOps.reduceBlocks(dataframe, graph, shapeHints) + } + + override def aggregate( + data: RelationalGroupedDataset, + graph: GraphDef, + shapeHints: ShapeDescription): DataFrame = { + DebugRowOps.aggregate(data, graph, shapeHints) + } + + override def explain(df: DataFrame): String = { + DebugRowOps.explain(df) + } + + private def mapBlocks( + dataframe: DataFrame, + graph: GraphDef, + shapeHints: ShapeDescription, + appendInput: Boolean): DataFrame = { + val sc = dataframe.sqlContext.sparkContext + val transform = mapBlocksSchema(dataframe.schema, graph, shapeHints, appendInput) + + logDebug(s"mapBlocks: TF input schema = ${transform.inputSchema}," + + s" complete output schema = ${transform.outputSchema}") + + val gProto = sc.broadcast(TensorFlowOps.graphSerial(graph)) + val child = TFHooks.logicalPlan(dataframe) + val plan = TestMapBlockPlan(child, gProto, transform) + TFHooks.ofRows(dataframe.sparkSession, plan) + } + +} + diff --git a/src/main/scala/org/tensorframes/catalyst/TestMapBlockPlan.scala b/src/main/scala/org/tensorframes/catalyst/TestMapBlockPlan.scala new file mode 100644 index 0000000..0f76bb4 --- /dev/null +++ b/src/main/scala/org/tensorframes/catalyst/TestMapBlockPlan.scala @@ -0,0 +1,72 @@ +package org.tensorframes.catalyst + +import org.tensorframes.impl.MapBlocksSchema + +import org.apache.spark.SparkContext +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, Strategy, TFHooks} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.DataType + +case class TestMapBlockPlan private[tensorframes]( + override val child: LogicalPlan, + graphDefSerial: Broadcast[Array[Byte]], + transform: MapBlocksSchema) extends UnaryNode { + + def output: Seq[Attribute] = { + val dt: DataType = transform.outputSchema + val attr = AttributeReference("obj", dt, nullable = false)() + attr :: Nil + } + + // TODO: this could be cleaned, based on the requested inputs + override def references: AttributeSet = child.outputSet +} + +case class TestMapBlockExec( + child: SparkPlan, + logicalPlan: TestMapBlockPlan) extends SparkPlan { // UnaryExecNode + + override def children: Seq[SparkPlan] = child :: Nil + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { it => + // This is where stuff happens + ??? + } + } + + override def output: Seq[Attribute] = logicalPlan.output + +} + +object TestStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case p: TestMapBlockPlan => + val sparkContext = SparkContext.getOrCreate() + val session = SQLContext.getOrCreate(sparkContext).sparkSession + val childPlan = TFHooks.planLater(session, p.child) + TestMapBlockExec(childPlan, p) :: Nil + case _ => + Nil + } + + def ensureLoaded(): Unit = { + assert(loaded_) + } + + private lazy val loaded_ : Boolean = { + val spark = SparkContext.getOrCreate() + val sql = SQLContext.getOrCreate(spark) + sql.experimental.extraStrategies ++= Seq(TestStrategy) + true + } +} + diff --git a/src/main/scala/org/tensorframes/dsl/DslImpl.scala b/src/main/scala/org/tensorframes/dsl/DslImpl.scala index 0af0aa8..269934d 100644 --- a/src/main/scala/org/tensorframes/dsl/DslImpl.scala +++ b/src/main/scala/org/tensorframes/dsl/DslImpl.scala @@ -3,7 +3,7 @@ package org.tensorframes.dsl import javax.annotation.Nullable import org.tensorflow.framework.{AttrValue, DataType, GraphDef, TensorShapeProto} -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.NumericType diff --git a/src/main/scala/org/tensorframes/dsl/Implicits.scala b/src/main/scala/org/tensorframes/dsl/Implicits.scala index 6004e9d..637b3da 100644 --- a/src/main/scala/org/tensorframes/dsl/Implicits.scala +++ b/src/main/scala/org/tensorframes/dsl/Implicits.scala @@ -2,7 +2,7 @@ package org.tensorframes.dsl import scala.languageFeature.implicitConversions -import org.apache.spark.sql.{GroupedData, Row, DataFrame} +import org.apache.spark.sql.{RelationalGroupedDataset, Row, DataFrame} import org.tensorflow.framework.GraphDef import org.tensorframes.{ExperimentalOperations, OperationsInterface, ShapeDescription, dsl} @@ -102,7 +102,7 @@ trait DFImplicits { * * This is useful for aggregation. */ - implicit class RichGroupedData(dg: GroupedData) { + implicit class RichGroupedData(dg: RelationalGroupedDataset) { def aggregate(graphDef: GraphDef, shapeDescription: ShapeDescription): DataFrame = { ops.aggregate(dg, graphDef, shapeDescription) } diff --git a/src/main/scala/org/tensorframes/dsl/Operation.scala b/src/main/scala/org/tensorframes/dsl/Operation.scala index 20a2916..c57bbcb 100644 --- a/src/main/scala/org/tensorframes/dsl/Operation.scala +++ b/src/main/scala/org/tensorframes/dsl/Operation.scala @@ -1,6 +1,6 @@ package org.tensorframes.dsl -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.types.NumericType import org.tensorflow.framework.{NodeDef, AttrValue} import org.tensorframes.{dsl => tf, ShapeDescription, Shape} diff --git a/src/main/scala/org/tensorframes/dsl/Ops.scala b/src/main/scala/org/tensorframes/dsl/Ops.scala index e650bff..beaa41a 100644 --- a/src/main/scala/org/tensorframes/dsl/Ops.scala +++ b/src/main/scala/org/tensorframes/dsl/Ops.scala @@ -1,6 +1,6 @@ package org.tensorframes.dsl -import org.apache.spark.sql.{GroupedData, Row, DataFrame} +import org.apache.spark.sql.{RelationalGroupedDataset, Row, DataFrame} import org.tensorflow.framework.GraphDef import org.tensorframes.{ExperimentalOperations, ShapeDescription, OperationsInterface} import org.tensorframes.impl.DebugRowOps @@ -40,7 +40,7 @@ object Ops extends OperationsInterface with DslOperations with ExperimentalOpera } override def aggregate( - data: GroupedData, + data: RelationalGroupedDataset, graph: GraphDef, shapeHints: ShapeDescription): DataFrame = { ops.aggregate(data, graph, shapeHints) diff --git a/src/main/scala/org/tensorframes/dsl/Paths.scala b/src/main/scala/org/tensorframes/dsl/Paths.scala index 6af62f5..9820447 100644 --- a/src/main/scala/org/tensorframes/dsl/Paths.scala +++ b/src/main/scala/org/tensorframes/dsl/Paths.scala @@ -2,7 +2,7 @@ package org.tensorframes.dsl import scala.collection.mutable -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} /** * Operations that try to give a convenient way to express paths in expressions. diff --git a/src/main/scala/org/tensorframes/impl/DataOps.scala b/src/main/scala/org/tensorframes/impl/DataOps.scala index e299d48..3bc3122 100644 --- a/src/main/scala/org/tensorframes/impl/DataOps.scala +++ b/src/main/scala/org/tensorframes/impl/DataOps.scala @@ -5,7 +5,7 @@ import scala.reflect.ClassTag import org.bytedeco.javacpp.{tensorflow => jtf} -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema} import org.apache.spark.sql.types.{NumericType, StructType} diff --git a/src/main/scala/org/tensorframes/impl/DebugRowOps.scala b/src/main/scala/org/tensorframes/impl/DebugRowOps.scala index c37748f..ec35103 100644 --- a/src/main/scala/org/tensorframes/impl/DebugRowOps.scala +++ b/src/main/scala/org/tensorframes/impl/DebugRowOps.scala @@ -1,13 +1,13 @@ package org.tensorframes.impl import org.apache.commons.lang3.SerializationUtils -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.expressions.{MutableRow, GenericRowWithSchema} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ import org.apache.spark.sql.functions.col -import org.apache.spark.sql.{GroupedData, DataFrame, Row} +import org.apache.spark.sql.{RelationalGroupedDataset, DataFrame, Row} import org.bytedeco.javacpp.{tensorflow => jtf} import org.tensorflow.framework.GraphDef import org.tensorframes._ @@ -16,261 +16,6 @@ import org.tensorframes.test.DslOperations import scala.collection.mutable import scala.util.{Failure, Success, Try} -/** - * The different schemas required for the block reduction. - * - * Here is the order (block reduction in a map phase followed by pair-wise aggregation in a - * reduce phase): - * mapInput -> output -> reduceInput -> output - * - * Call 'x' a variable required by the transform, and 'y' an extra column - * - * @param mapInput the schema of the column block. Contains 'x_input' and 'y' - * @param mapTFCols the indexes of the columns required by the transform - * @param output contains 'x' only - * @param reduceInput contains 'x_input' only ('y' column has been dropped in the map phase). - */ -case class ReduceBlockSchema( - mapInput: StructType, - mapTFCols: List[Int], - output: StructType, - reduceInput: StructType) extends Serializable - - -/** - * All the schema transformations that are done by the basic TF operations. - * - * These methods describe the schema transforms performed on a DataFrame. They include - * all the validation steps that should be performed before passing data to TensorFlow. - * - * After calling these methods, the implementation can assume the schemas are valid and complete enough. - */ -// Implementation is separated for python accessors -// TODO: these methods are pretty complicated, add more documentation! -private[impl] trait SchemaTransforms extends Logging { - def get[A](x: Option[A], msg: String) = x.getOrElse { - throw new Exception(msg) - } - - def check(b: Boolean, msg: String): Unit = if (! b) { - throw new Exception(msg) - } - - - /** - * Validates and computes the transformation schema under reducing. - * - * A graph may access a subset of the rows. All the schemas are for blocks of data. - * - * For each output x of the graph, there must be: - * - a placeholder called x_input with one extra dimension (left unknown) - * - a corresponding column labeled 'x' with the same row structure as the output - * - * @param schema the schema of the dataframe - * @param graph the graph - * @param shapeHints the shape hints obtained for this graph - * @return a triplet containing the input block schema, the output block schema, and the - * requested inputs, which may be a subset of the input. - */ - // TODO(tjh) all the inputs and outputs are created by hashmaps, which makes their order not - // deterministic. Change that. - def reduceBlocksSchema( - schema: StructType, - graph: GraphDef, - shapeHints: ShapeDescription): ReduceBlockSchema = { - val summary = TensorFlowOps.analyzeGraph(graph, shapeHints) - .map(x => x.name -> x).toMap - val fieldsByName = schema.fields.map(f => f.name -> f).toMap - val fieldNameList = fieldsByName.keySet.toSeq.sorted.mkString(", ") - val outputNameList = summary.filter(_._2.isOutput).keySet.toSeq.sorted.mkString(", ") - val suffix = "_input" - - // Initial check: all the fields are here: - val outputs = summary.filter(_._2.isOutput) - - // Check that the outputs of the graph are a subset of the columns of the dataframe. - val missingColInputs = (outputs.keySet -- fieldsByName.keySet).toSeq.sorted - check(missingColInputs.isEmpty, { - val missing = missingColInputs.mkString(", ") - s"Based on the TF graph, some inputs are missing: $missing. " + - s"Dataframe columns: $fieldNameList; Outputs: $outputNameList" }) - - // Initial check: the inputs are all there, and they are the only ones. - val inputs = summary.filter(_._2.isInput) - val expectedInputs = outputs.keySet.map(_ + suffix) - val extraInputs = (inputs.keySet -- expectedInputs).toSeq.sorted - logDebug(s"reduceRows: expectedInputs=$expectedInputs") - check(extraInputs.isEmpty, - s"Extra graph inputs have been found: ${extraInputs.mkString(", ")}. " + - s"Dataframe columns: $fieldNameList") - - val missingInputs = (expectedInputs -- inputs.keySet).toSeq.sorted - check(missingInputs.isEmpty, - s"Some inputs are missing in the graph: ${missingInputs.mkString(", ")}. " + - s"Dataframe columns: $fieldNameList") - - // Check that for each output, the field is present with the right schema. - // WARNING: keeping the order of the fields is important -> do not iterate over the outputs. - val fields = schema.filter(f => outputs.contains(f.name)).map { f => - val ci = ColumnInformation(f) - val stf = get(ci.stf, - s"Data column '${f.name}' has not been analyzed yet, cannot run TF on this dataframe") - // Check that the output is compatible (its presence has already been checked. - val out = summary(f.name) - check(out.isOutput, s"Graph node '${out.name}' should be an output") - - check(stf.dataType == out.scalarType, s"Output '${f.name}' has type ${out.scalarType}" + - s" but the column type " + - s"is ${stf.dataType}") - - // Take the tail, we only compare cells - val cellShape = stf.shape.tail - check(out.shape.checkMorePreciseThan(cellShape), - s"Output '${f.name}' has shape ${out.shape}, not compatible with the shape " + - s"of field elements $cellShape") - // The input block may be too precise with respect to the lead dimension (number of rows), - // which is usually incorrect when pairwise reductions are performed. - // Always assume they are unknown for now. - val shape = cellShape.prepend(Shape.Unknown) - val inputStf = stf.copy(shape = shape) - - val inputName = f.name + suffix - val in = get(summary.get(inputName), - s"The graph needs to have a placeholder input called $inputName.") - assert(in.isPlaceholder, s"Node $inputName should be a placeholder") - assert(in.isInput, s"Node $inputName should be an input") - check(inputStf.shape.checkMorePreciseThan(in.shape), - s"The data column '${f.name}' has shape ${inputStf.shape}, not compatible with shape" + - s" ${in.shape} requested by the TF graph") - check(inputStf.dataType == in.scalarType, - s"The type of node '${in.name}' (${inputStf.dataType}) is not compatible with the data" + - s" type of the column (${in.scalarType})") - val m = ColumnInformation(f, inputStf).merged - logDebug(s">>> $m -> ${ColumnInformation(m).stf}") - m - } - val outputSchema = StructType(fields.toArray) - // The input schema is simply the block schema, with a different name for the variables. - // We still pass all the variables because the filtering is done on the indices selected. - val inputSchema = StructType(schema.map { f => - if (outputs.contains(f.name)) { - widenLeadDim(f.copy(name = f.name + "_input")) - } else { f } - }) - val inputReduceSchema = StructType(schema - .filter(f => outputs.contains(f.name)) - .map(f => widenLeadDim(f.copy(name=f.name + "_input")))) - val requestedIndexes = schema.zipWithIndex - .filter { case (f, idx) => outputs.contains(f.name)} - .map(_._2) .toList - ReduceBlockSchema(inputSchema, requestedIndexes, outputSchema, inputReduceSchema) - } - - def reduceRowsSchema( - schema: StructType, - graph: GraphDef, - shapeHints: ShapeDescription): StructType = { - val summary = TensorFlowOps.analyzeGraph(graph, shapeHints) - .map(x => x.name -> x).toMap - val fieldsByName = schema.fields.map(f => f.name -> f).toMap - val fieldNameList = fieldsByName.keySet.toSeq.sorted.mkString(", ") - val outputNameList = summary.filter(_._2.isOutput).keySet.toSeq.sorted.mkString(", ") - val suffixes = Seq("_1", "_2") - - // Initial check: all the fields are here: - val outputs = summary.filter(_._2.isOutput) - // Check that there are precisely as many outputs as columns: - if ((outputs.keySet -- fieldsByName.keySet).nonEmpty) { - val extra = (outputs.keySet -- fieldsByName.keySet).toSeq.sorted.mkString(", ") - val s = s"Some extra outputs were found in the reducer: $extra. " + - s"Dataframe columns: $fieldNameList; Outputs: $outputNameList" - throw new Exception(s) - } - if ((fieldsByName.keySet -- outputs.keySet).nonEmpty) { - val extra = (fieldsByName.keySet -- outputs.keySet).toSeq.sorted.mkString(", ") - val s = s"Some outputs are missing in the reducer: $extra. " + - s"Dataframe columns: $fieldNameList; Outputs: $outputNameList" - throw new Exception(s) - } - - // Initial check: the inputs are all there: - val inputs = summary.filter(_._2.isInput) - val expectedInputs = suffixes.flatMap(suff => fieldsByName.keys.map(_ + suff)).toSet - logDebug(s"reduceRows: expectedInputs=$expectedInputs") - if ((inputs.keySet -- expectedInputs).nonEmpty) { - val extra = (inputs.keySet -- expectedInputs).toSeq.sorted.mkString(", ") - throw new Exception( - s"Extra graph inputs have been found: $extra. Dataframe columns: $fieldNameList") - } - if ((expectedInputs -- inputs.keySet).nonEmpty) { - val extra = (expectedInputs -- inputs.keySet).toSeq.sorted.mkString(", ") - throw new Exception( - s"Some inputs are missing in th graph: $extra. Dataframe columns: $fieldNameList") - } - - // Check that all the fields are here - for { - f <- fieldsByName.values - suffix <- suffixes - } { - val stf = ColumnInformation(f).stf.getOrElse { throw new Exception( - s"Data column '${f.name}' has not been analyzed yet, cannot run TF on this dataframe") - } - // Check that the output is compatible (its presence has already been checked. - val out = summary(f.name) - if (!out.isOutput) { - throw new Exception( - s"Graph node '${out.name}' should be an output") - } - if (stf.dataType != out.scalarType) { - val s = s"Output '${f.name}' has type ${out.scalarType} but the column type " + - s"is ${stf.dataType}" - throw new Exception(s) - } - // Take the tail, we only compare cells - val cellShape = stf.shape.tail - if (! out.shape.checkMorePreciseThan(cellShape)) { - throw new Exception( - s"Output '${f.name}' has shape ${out.shape}, not compatible with the shapes" + - s"of field elements ${cellShape}") - } - - // Check that the 2 inputs are compatible: - for (suffix <- Seq("_1", "_2")) { - val inputName = f.name + suffix - val in = summary.getOrElse(inputName, throw new Exception( - s"The graph needs to have a placeholder input called $inputName.")) - assert(in.isPlaceholder, s"Node $inputName should be a placeholder") - assert(in.isInput, s"Node $inputName should be an input") - if (! cellShape.checkMorePreciseThan(in.shape)) { - throw new Exception( - s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" + - s" ${in.shape} requested by the TF graph") - } - if (stf.dataType != in.scalarType) { - throw new Exception( - s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data" + - s" type of the column (${in.scalarType})") - } - } - } - // Same schema as output - schema - } - - // Sets the lead column to Unknown - private def widenLeadDim(f: StructField): StructField = { - ColumnInformation(f).stf match { - case Some(ci) if ci.shape.numDims >= 1 => - val s = ci.shape.tail.prepend(Shape.Unknown) - ColumnInformation(f, ci.copy(shape = s)).merged - case _ => f // Nothing to do - } - } -} - -object SchemaTransforms extends SchemaTransforms - /** * A simple and slow implementation of the basic operations that maximizes correctness of * implementation and works with older versions of Spark (based on RDDs). @@ -304,70 +49,10 @@ class DebugRowOps shapeHints: ShapeDescription, appendInput: Boolean): DataFrame = { val sc = dataframe.sqlContext.sparkContext - val summary = TensorFlowOps.analyzeGraph(graph, shapeHints) - .map(x => x.name -> x).toMap - val inputs = summary.filter(_._2.isInput) - val outputs = summary.filter(_._2.isOutput) - val fieldsByName = dataframe.schema.fields.map(f => f.name -> f).toMap - val cols = dataframe.schema.fieldNames.mkString(", ") - // The input schema, after validation with TF that it contains all the spark TF info. - val inputSchema: StructType = { - - inputs.values.foreach { in => - val f = get(fieldsByName.get(in.name), - s"Graph input ${in.name} found, but no column to match it. Dataframe columns: $cols") - - val stf = ColumnInformation(f).stf.getOrElse { - throw new Exception( - s"Data column ${f.name} has not been analyzed yet, cannot run TF on this dataframe") - } - if (! stf.shape.checkMorePreciseThan(in.shape)) { - throw new Exception( - s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" + - s" ${in.shape} requested by the TF graph") - } - // We do not support autocasting for now. - if (stf.dataType != in.scalarType) { - throw new Exception( - s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data type " + - s"of the column (${in.scalarType})") - } - // The input has to be either a constant or a placeholder - if (! in.isPlaceholder) { - throw new Exception( - s"Invalid type for input node ${in.name}. It has to be a placeholder") - } - } - dataframe.schema - } - - // The output schema from the data generated by TF. - val outputTFSchema: StructType = { - // The order of the output columns is decided for now by their names. - val fields = outputs.values.toSeq.sortBy(_.name).map { out => - if (fieldsByName.contains(out.name)) { - throw new Exception(s"TF graph has an output node called '${out.name}'," + - s" but this column already exists. Input columns: ${cols}") - } - logInfo(s"mapBlocks: out = $out") - ColumnInformation.structField(out.name, out.scalarType, out.shape) - } - StructType(fields.toArray) - } - // The column indices requested by TF - val requestedTFInput: Array[Int] = { - val colIdxs = dataframe.schema.fieldNames.zipWithIndex.toMap - inputs.keys.toArray.map { name => colIdxs(name) } - } - // Full output schema, including data being passed through and validated for duplicates. - // The first columns are the TF columns, followed by all the other columns. - val outputSchema: StructType = if (appendInput) { - StructType(outputTFSchema ++ dataframe.schema.fields) - } else { - StructType(outputTFSchema) - } + val transform = mapBlocksSchema(dataframe.schema, graph, shapeHints, appendInput) - logDebug(s"mapBlocks: TF input schema = $inputSchema, complete output schema = $outputSchema") + logDebug(s"mapBlocks: TF input schema = ${transform.inputSchema}," + + s" complete output schema = ${transform.outputSchema}") val gProto = sc.broadcast(TensorFlowOps.graphSerial(graph)) val transformRdd = dataframe.rdd.mapPartitions { it => @@ -376,16 +61,16 @@ class DebugRowOps if (it.hasNext) { DebugRowOpsImpl.performMap( it.toArray, - inputSchema, - requestedTFInput, + transform.inputSchema, + transform.mapTFCols, gProto.value, - outputTFSchema, + transform.outputTFSchema, appendInput) } else { mutable.Iterable.empty.iterator } } - dataframe.sqlContext.createDataFrame(transformRdd, outputSchema) + dataframe.sqlContext.createDataFrame(transformRdd, transform.outputSchema) } @@ -532,7 +217,7 @@ class DebugRowOps } override def aggregate( - data: GroupedData, + data: RelationalGroupedDataset, graph: GraphDef, shapeHints: ShapeDescription): DataFrame = { // The constraints on the graph are the same as blocked data. @@ -691,7 +376,7 @@ object DebugRowOpsImpl extends Logging { * @param groupedData the grouped data * @return the dataframe, if it succeeded. */ - def backingDF(groupedData: GroupedData): Try[DataFrame] = { + def backingDF(groupedData: RelationalGroupedDataset): Try[DataFrame] = { Try { groupedData.getClass.getDeclaredMethods.foreach { m => logDebug(s"method: ${m.getName}") diff --git a/src/main/scala/org/tensorframes/impl/PythonInterface.scala b/src/main/scala/org/tensorframes/impl/PythonInterface.scala index de31551..44ec240 100644 --- a/src/main/scala/org/tensorframes/impl/PythonInterface.scala +++ b/src/main/scala/org/tensorframes/impl/PythonInterface.scala @@ -5,7 +5,7 @@ import java.util import scala.collection.JavaConverters._ import org.apache.log4j.PropertyConfigurator -import org.apache.spark.sql.{GroupedData, DataFrame, Row} +import org.apache.spark.sql.{RelationalGroupedDataset, DataFrame, Row} import org.apache.spark.sql.types.StructType import org.tensorflow.framework.GraphDef import org.tensorframes._ @@ -60,7 +60,7 @@ private[tensorframes] trait PythonInterface { self: OperationsInterface with Exp new PythonOpBuilder(this, ReduceRow, dataFrame) } - def aggregate_blocks(groupedData: GroupedData): PythonOpBuilder = { + def aggregate_blocks(groupedData: RelationalGroupedDataset): PythonOpBuilder = { new PythonOpBuilder(this, AggregateBlock, null, groupedData) } @@ -84,7 +84,7 @@ class PythonOpBuilder( interface: OperationsInterface with ExperimentalOperations, op: PythonInterface.Operation, df: DataFrame = null, - groupedData: GroupedData = null) { + groupedData: RelationalGroupedDataset = null) { import PythonInterface._ private var _shapeHints: ShapeDescription = ShapeDescription.empty private var _graph: GraphDef = null diff --git a/src/main/scala/org/tensorframes/impl/SchemaTransforms.scala b/src/main/scala/org/tensorframes/impl/SchemaTransforms.scala new file mode 100644 index 0000000..7b7fda7 --- /dev/null +++ b/src/main/scala/org/tensorframes/impl/SchemaTransforms.scala @@ -0,0 +1,347 @@ +package org.tensorframes.impl + +import org.apache.spark.{LoggingWrapper => Logging} +import org.apache.spark.sql.types.{StructField, StructType} +import org.tensorflow.framework.GraphDef +import org.tensorframes.{Shape, ColumnInformation, ShapeDescription} + + +/** + * The different schemas required for the block reduction. + * + * Here is the order (block reduction in a map phase followed by pair-wise aggregation in a + * reduce phase): + * mapInput -> output -> reduceInput -> output + * + * Call 'x' a variable required by the transform, and 'y' an extra column + * + * @param mapInput the schema of the column block. Contains 'x_input' and 'y' + * @param mapTFCols the indexes of the columns required by the transform + * @param output contains 'x' only + * @param reduceInput contains 'x_input' only ('y' column has been dropped in the map phase). + */ +case class ReduceBlockSchema( + mapInput: StructType, + mapTFCols: List[Int], + output: StructType, + reduceInput: StructType) extends Serializable + +/** + * The schemas required by the block mapping. + * + * Trimming can be inferred from the condition outputTFSchema == outputSchema + * + * @param inputSchema the schema of the input dataframe + * @param mapTFCols the list of indexes in the previous schema of the columns required by the TF + * mapping. + * @param outputTFSchema the schema of the columns created by TF + * @param outputSchema the complete schema of the final dataframe. + */ +case class MapBlocksSchema( + inputSchema: StructType, + mapTFCols: Array[Int], + outputTFSchema: StructType, + outputSchema: StructType) extends Serializable + +/** + * All the schema transformations that are done by the basic TF operations. + * + * These methods describe the schema transforms performed on a DataFrame. They include + * all the validation steps that should be performed before passing data to TensorFlow. + * + * After calling these methods, the implementation can assume the schemas are valid and complete + * enough. + */ +// Implementation is separated for python accessors +// TODO: these methods are pretty complicated, add more documentation! +private[impl] trait SchemaTransforms extends Logging { + def get[A](x: Option[A], msg: String) = x.getOrElse { + throw new Exception(msg) + } + + def check(b: Boolean, msg: String): Unit = if (! b) { + throw new Exception(msg) + } + + + /** + * Validates and computes the transformation schema under reducing. + * + * A graph may access a subset of the rows. All the schemas are for blocks of data. + * + * For each output x of the graph, there must be: + * - a placeholder called x_input with one extra dimension (left unknown) + * - a corresponding column labeled 'x' with the same row structure as the output + * + * @param schema the schema of the dataframe + * @param graph the graph + * @param shapeHints the shape hints obtained for this graph + * @return a triplet containing the input block schema, the output block schema, and the + * requested inputs, which may be a subset of the input. + */ + // TODO(tjh) all the inputs and outputs are created by hashmaps, which makes their order not + // deterministic. Change that. + def reduceBlocksSchema( + schema: StructType, + graph: GraphDef, + shapeHints: ShapeDescription): ReduceBlockSchema = { + val summary = TensorFlowOps.analyzeGraph(graph, shapeHints) + .map(x => x.name -> x).toMap + val fieldsByName = schema.fields.map(f => f.name -> f).toMap + val fieldNameList = fieldsByName.keySet.toSeq.sorted.mkString(", ") + val outputNameList = summary.filter(_._2.isOutput).keySet.toSeq.sorted.mkString(", ") + val suffix = "_input" + + // Initial check: all the fields are here: + val outputs = summary.filter(_._2.isOutput) + + // Check that the outputs of the graph are a subset of the columns of the dataframe. + val missingColInputs = (outputs.keySet -- fieldsByName.keySet).toSeq.sorted + check(missingColInputs.isEmpty, { + val missing = missingColInputs.mkString(", ") + s"Based on the TF graph, some inputs are missing: $missing. " + + s"Dataframe columns: $fieldNameList; Outputs: $outputNameList" }) + + // Initial check: the inputs are all there, and they are the only ones. + val inputs = summary.filter(_._2.isInput) + val expectedInputs = outputs.keySet.map(_ + suffix) + val extraInputs = (inputs.keySet -- expectedInputs).toSeq.sorted + logDebug(s"reduceRows: expectedInputs=$expectedInputs") + check(extraInputs.isEmpty, + s"Extra graph inputs have been found: ${extraInputs.mkString(", ")}. " + + s"Dataframe columns: $fieldNameList") + + val missingInputs = (expectedInputs -- inputs.keySet).toSeq.sorted + check(missingInputs.isEmpty, + s"Some inputs are missing in the graph: ${missingInputs.mkString(", ")}. " + + s"Dataframe columns: $fieldNameList") + + // Check that for each output, the field is present with the right schema. + // WARNING: keeping the order of the fields is important -> do not iterate over the outputs. + val fields = schema.filter(f => outputs.contains(f.name)).map { f => + val ci = ColumnInformation(f) + val stf = get(ci.stf, + s"Data column '${f.name}' has not been analyzed yet, cannot run TF on this dataframe") + // Check that the output is compatible (its presence has already been checked. + val out = summary(f.name) + check(out.isOutput, s"Graph node '${out.name}' should be an output") + + check(stf.dataType == out.scalarType, s"Output '${f.name}' has type ${out.scalarType}" + + s" but the column type " + + s"is ${stf.dataType}") + + // Take the tail, we only compare cells + val cellShape = stf.shape.tail + check(out.shape.checkMorePreciseThan(cellShape), + s"Output '${f.name}' has shape ${out.shape}, not compatible with the shape " + + s"of field elements $cellShape") + // The input block may be too precise with respect to the lead dimension (number of rows), + // which is usually incorrect when pairwise reductions are performed. + // Always assume they are unknown for now. + val shape = cellShape.prepend(Shape.Unknown) + val inputStf = stf.copy(shape = shape) + + val inputName = f.name + suffix + val in = get(summary.get(inputName), + s"The graph needs to have a placeholder input called $inputName.") + assert(in.isPlaceholder, s"Node $inputName should be a placeholder") + assert(in.isInput, s"Node $inputName should be an input") + check(inputStf.shape.checkMorePreciseThan(in.shape), + s"The data column '${f.name}' has shape ${inputStf.shape}, not compatible with shape" + + s" ${in.shape} requested by the TF graph") + check(inputStf.dataType == in.scalarType, + s"The type of node '${in.name}' (${inputStf.dataType}) is not compatible with the data" + + s" type of the column (${in.scalarType})") + val m = ColumnInformation(f, inputStf).merged + logDebug(s">>> $m -> ${ColumnInformation(m).stf}") + m + } + val outputSchema = StructType(fields.toArray) + // The input schema is simply the block schema, with a different name for the variables. + // We still pass all the variables because the filtering is done on the indices selected. + val inputSchema = StructType(schema.map { f => + if (outputs.contains(f.name)) { + widenLeadDim(f.copy(name = f.name + "_input")) + } else { f } + }) + val inputReduceSchema = StructType(schema + .filter(f => outputs.contains(f.name)) + .map(f => widenLeadDim(f.copy(name=f.name + "_input")))) + val requestedIndexes = schema.zipWithIndex + .filter { case (f, idx) => outputs.contains(f.name)} + .map(_._2) .toList + ReduceBlockSchema(inputSchema, requestedIndexes, outputSchema, inputReduceSchema) + } + + def reduceRowsSchema( + schema: StructType, + graph: GraphDef, + shapeHints: ShapeDescription): StructType = { + val summary = TensorFlowOps.analyzeGraph(graph, shapeHints) + .map(x => x.name -> x).toMap + val fieldsByName = schema.fields.map(f => f.name -> f).toMap + val fieldNameList = fieldsByName.keySet.toSeq.sorted.mkString(", ") + val outputNameList = summary.filter(_._2.isOutput).keySet.toSeq.sorted.mkString(", ") + val suffixes = Seq("_1", "_2") + + // Initial check: all the fields are here: + val outputs = summary.filter(_._2.isOutput) + // Check that there are precisely as many outputs as columns: + if ((outputs.keySet -- fieldsByName.keySet).nonEmpty) { + val extra = (outputs.keySet -- fieldsByName.keySet).toSeq.sorted.mkString(", ") + val s = s"Some extra outputs were found in the reducer: $extra. " + + s"Dataframe columns: $fieldNameList; Outputs: $outputNameList" + throw new Exception(s) + } + if ((fieldsByName.keySet -- outputs.keySet).nonEmpty) { + val extra = (fieldsByName.keySet -- outputs.keySet).toSeq.sorted.mkString(", ") + val s = s"Some outputs are missing in the reducer: $extra. " + + s"Dataframe columns: $fieldNameList; Outputs: $outputNameList" + throw new Exception(s) + } + + // Initial check: the inputs are all there: + val inputs = summary.filter(_._2.isInput) + val expectedInputs = suffixes.flatMap(suff => fieldsByName.keys.map(_ + suff)).toSet + logDebug(s"reduceRows: expectedInputs=$expectedInputs") + if ((inputs.keySet -- expectedInputs).nonEmpty) { + val extra = (inputs.keySet -- expectedInputs).toSeq.sorted.mkString(", ") + throw new Exception( + s"Extra graph inputs have been found: $extra. Dataframe columns: $fieldNameList") + } + if ((expectedInputs -- inputs.keySet).nonEmpty) { + val extra = (expectedInputs -- inputs.keySet).toSeq.sorted.mkString(", ") + throw new Exception( + s"Some inputs are missing in th graph: $extra. Dataframe columns: $fieldNameList") + } + + // Check that all the fields are here + for { + f <- fieldsByName.values + suffix <- suffixes + } { + val stf = ColumnInformation(f).stf.getOrElse { throw new Exception( + s"Data column '${f.name}' has not been analyzed yet, cannot run TF on this dataframe") + } + // Check that the output is compatible (its presence has already been checked. + val out = summary(f.name) + if (!out.isOutput) { + throw new Exception( + s"Graph node '${out.name}' should be an output") + } + if (stf.dataType != out.scalarType) { + val s = s"Output '${f.name}' has type ${out.scalarType} but the column type " + + s"is ${stf.dataType}" + throw new Exception(s) + } + // Take the tail, we only compare cells + val cellShape = stf.shape.tail + if (! out.shape.checkMorePreciseThan(cellShape)) { + throw new Exception( + s"Output '${f.name}' has shape ${out.shape}, not compatible with the shapes" + + s"of field elements ${cellShape}") + } + + // Check that the 2 inputs are compatible: + for (suffix <- Seq("_1", "_2")) { + val inputName = f.name + suffix + val in = summary.getOrElse(inputName, throw new Exception( + s"The graph needs to have a placeholder input called $inputName.")) + assert(in.isPlaceholder, s"Node $inputName should be a placeholder") + assert(in.isInput, s"Node $inputName should be an input") + if (! cellShape.checkMorePreciseThan(in.shape)) { + throw new Exception( + s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" + + s" ${in.shape} requested by the TF graph") + } + if (stf.dataType != in.scalarType) { + throw new Exception( + s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data" + + s" type of the column (${in.scalarType})") + } + } + } + // Same schema as output + schema + } + + def mapBlocksSchema( + schema: StructType, + graph: GraphDef, + shapeHints: ShapeDescription, + appendInput: Boolean): MapBlocksSchema = { + val summary = TensorFlowOps.analyzeGraph(graph, shapeHints) + .map(x => x.name -> x).toMap + val inputs = summary.filter(_._2.isInput) + val outputs = summary.filter(_._2.isOutput) + val fieldsByName = schema.fields.map(f => f.name -> f).toMap + val cols = schema.fieldNames.mkString(", ") + + inputs.values.foreach { in => + val f = get(fieldsByName.get(in.name), + s"Graph input ${in.name} found, but no column to match it. Dataframe columns: $cols") + + val stf = ColumnInformation(f).stf.getOrElse { + throw new Exception( + s"Data column ${f.name} has not been analyzed yet, cannot run TF on this dataframe") + } + if (! stf.shape.checkMorePreciseThan(in.shape)) { + throw new Exception( + s"The data column '${f.name}' has shape ${stf.shape} (not compatible) with shape" + + s" ${in.shape} requested by the TF graph") + } + // We do not support autocasting for now. + if (stf.dataType != in.scalarType) { + throw new Exception( + s"The type of node '${in.name}' (${stf.dataType}) is not compatible with the data type " + + s"of the column (${in.scalarType})") + } + // The input has to be either a constant or a placeholder + if (! in.isPlaceholder) { + throw new Exception( + s"Invalid type for input node ${in.name}. It has to be a placeholder") + } + } + + // The output schema from the data generated by TF. + val outputTFSchema: StructType = { + // The order of the output columns is decided for now by their names. + val fields = outputs.values.toSeq.sortBy(_.name).map { out => + if (fieldsByName.contains(out.name)) { + throw new Exception(s"TF graph has an output node called '${out.name}'," + + s" but this column already exists. Input columns: ${cols}") + } + logInfo(s"mapBlocks: out = $out") + ColumnInformation.structField(out.name, out.scalarType, out.shape) + } + StructType(fields.toArray) + } + // The column indices requested by TF + val requestedTFInput: Array[Int] = { + val colIdxs = schema.fieldNames.zipWithIndex.toMap + inputs.keys.toArray.map { name => colIdxs(name) } + } + + // Full output schema, including data being passed through and validated for duplicates. + // The first columns are the TF columns, followed by all the other columns. + val outputSchema: StructType = if (appendInput) { + StructType(outputTFSchema ++ schema.fields) + } else { + StructType(outputTFSchema) + } + + MapBlocksSchema(schema, requestedTFInput, outputTFSchema, outputSchema) + } + + // Sets the lead column to Unknown + private def widenLeadDim(f: StructField): StructField = { + ColumnInformation(f).stf match { + case Some(ci) if ci.shape.numDims >= 1 => + val s = ci.shape.tail.prepend(Shape.Unknown) + ColumnInformation(f, ci.copy(shape = s)).merged + case _ => f // Nothing to do + } + } +} + +object SchemaTransforms extends SchemaTransforms diff --git a/src/main/scala/org/tensorframes/impl/TensorFlowOps.scala b/src/main/scala/org/tensorframes/impl/TensorFlowOps.scala index 26bc410..0458b4e 100644 --- a/src/main/scala/org/tensorframes/impl/TensorFlowOps.scala +++ b/src/main/scala/org/tensorframes/impl/TensorFlowOps.scala @@ -1,6 +1,6 @@ package org.tensorframes.impl -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.types.NumericType import org.bytedeco.javacpp.{BytePointer, tensorflow => jtf} import org.tensorflow.framework.GraphDef diff --git a/src/main/scala/org/tensorframes/impl/datatypes.scala b/src/main/scala/org/tensorframes/impl/datatypes.scala index 4594225..f0117d4 100644 --- a/src/main/scala/org/tensorframes/impl/datatypes.scala +++ b/src/main/scala/org/tensorframes/impl/datatypes.scala @@ -2,7 +2,7 @@ package org.tensorframes.impl import java.nio._ -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.Row import org.apache.spark.sql.types.{LongType, DoubleType, IntegerType, NumericType} import org.bytedeco.javacpp.{tensorflow => jtf} diff --git a/src/main/scala/org/tensorframes/test/DslOperations.scala b/src/main/scala/org/tensorframes/test/DslOperations.scala index 5b51296..7915332 100644 --- a/src/main/scala/org/tensorframes/test/DslOperations.scala +++ b/src/main/scala/org/tensorframes/test/DslOperations.scala @@ -1,6 +1,6 @@ package org.tensorframes.test -import org.apache.spark.sql.{GroupedData, DataFrame, Row} +import org.apache.spark.sql.{RelationalGroupedDataset, DataFrame, Row} import org.tensorflow.framework.GraphDef import org.tensorframes.{OperationsInterface, ShapeDescription} import org.tensorframes.impl.{GraphNodeSummary, TensorFlowOps} @@ -44,7 +44,7 @@ trait DslOperations extends OperationsInterface { make(ns, df, reduceBlocks) } - def aggregate(gdf: GroupedData, node1: dsl.Node, nodes: dsl.Node*): DataFrame = { + def aggregate(gdf: RelationalGroupedDataset, node1: dsl.Node, nodes: dsl.Node*): DataFrame = { val ns = node1 +: nodes val g = dsl.buildGraph(ns: _*) val info = extraInfo(ns) diff --git a/src/main/scala/org/tensorframes/test/dsl.scala b/src/main/scala/org/tensorframes/test/dsl.scala index 9bc9a91..72666a7 100644 --- a/src/main/scala/org/tensorframes/test/dsl.scala +++ b/src/main/scala/org/tensorframes/test/dsl.scala @@ -2,7 +2,7 @@ package org.tensorframes.test import java.nio.file.{Files, Paths} -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.types.{DoubleType, NumericType} import org.tensorflow.framework._ import org.tensorframes.Shape diff --git a/src/test/scala/org/tensorframes/BasicOperationsSuite.scala b/src/test/scala/org/tensorframes/BasicOperationsSuite.scala index 3d60d75..f49189f 100644 --- a/src/test/scala/org/tensorframes/BasicOperationsSuite.scala +++ b/src/test/scala/org/tensorframes/BasicOperationsSuite.scala @@ -6,7 +6,7 @@ import org.tensorframes.dsl._ import org.tensorframes.dsl.Implicits._ import org.tensorframes.impl.DebugRowOps -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.Row // Some basic operations that stress shape transforms mostly. @@ -207,7 +207,8 @@ class BasicOperationsSuite val x = reduce_sum(x1, Seq(0)) named "x" val df2 = df.groupBy("key").aggregate(x).select("key", "x") df2.printSchema() - assert(df2.collect() === Array(Row(1, 2.1), Row(2, 2.0))) + val rows = df2.collect().sortBy { case Row(x: Int, _) => x } + assert(rows === Array(Row(1, 2.1), Row(2, 2.0))) } testGraph("2-tensors - 1") { diff --git a/src/test/scala/org/tensorframes/CommonOperationsSuite.scala b/src/test/scala/org/tensorframes/CommonOperationsSuite.scala index 0d608ec..48eaf7c 100644 --- a/src/test/scala/org/tensorframes/CommonOperationsSuite.scala +++ b/src/test/scala/org/tensorframes/CommonOperationsSuite.scala @@ -1,6 +1,6 @@ package org.tensorframes -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.Row import org.scalatest.FunSuite import org.tensorframes.impl.{SupportedOperations, ScalarTypeOperation, DebugRowOps} diff --git a/src/test/scala/org/tensorframes/DSLOperationsSuite.scala b/src/test/scala/org/tensorframes/DSLOperationsSuite.scala index e14890a..68ef6c7 100644 --- a/src/test/scala/org/tensorframes/DSLOperationsSuite.scala +++ b/src/test/scala/org/tensorframes/DSLOperationsSuite.scala @@ -4,7 +4,7 @@ import org.scalatest.FunSuite import org.tensorframes.dsl._ import org.tensorframes.dsl.Implicits._ -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.Row class DSLOperationsSuite diff --git a/src/test/scala/org/tensorframes/DebugRowOpsSuite.scala b/src/test/scala/org/tensorframes/DebugRowOpsSuite.scala index e4ddda9..bfb40a2 100644 --- a/src/test/scala/org/tensorframes/DebugRowOpsSuite.scala +++ b/src/test/scala/org/tensorframes/DebugRowOpsSuite.scala @@ -1,6 +1,6 @@ package org.tensorframes -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.Row import org.apache.spark.sql.types.{DoubleType, StructType} import org.scalatest.FunSuite diff --git a/src/test/scala/org/tensorframes/ExtraOperationsSuite.scala b/src/test/scala/org/tensorframes/ExtraOperationsSuite.scala index e2f7cdc..3001c38 100644 --- a/src/test/scala/org/tensorframes/ExtraOperationsSuite.scala +++ b/src/test/scala/org/tensorframes/ExtraOperationsSuite.scala @@ -1,6 +1,6 @@ package org.tensorframes -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.types.{DoubleType, IntegerType} import org.scalatest.FunSuite diff --git a/src/test/scala/org/tensorframes/SlicingSuite.scala b/src/test/scala/org/tensorframes/SlicingSuite.scala index d525d6f..c5023c6 100644 --- a/src/test/scala/org/tensorframes/SlicingSuite.scala +++ b/src/test/scala/org/tensorframes/SlicingSuite.scala @@ -6,7 +6,7 @@ import org.tensorframes.impl.DebugRowOps import org.tensorframes.{dsl => tf} import org.tensorframes.dsl.Implicits._ -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.Row import org.apache.spark.sql.types.{DoubleType, IntegerType} diff --git a/src/test/scala/org/tensorframes/TFInitializationSuite.scala b/src/test/scala/org/tensorframes/TFInitializationSuite.scala index cf104c9..7597891 100644 --- a/src/test/scala/org/tensorframes/TFInitializationSuite.scala +++ b/src/test/scala/org/tensorframes/TFInitializationSuite.scala @@ -1,6 +1,6 @@ package org.tensorframes -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.scalatest.FunSuite import org.tensorframes.impl.TensorFlowOps diff --git a/src/test/scala/org/tensorframes/TrimmingOperationsSuite.scala b/src/test/scala/org/tensorframes/TrimmingOperationsSuite.scala index bb21232..6a4cb7a 100644 --- a/src/test/scala/org/tensorframes/TrimmingOperationsSuite.scala +++ b/src/test/scala/org/tensorframes/TrimmingOperationsSuite.scala @@ -1,6 +1,6 @@ package org.tensorframes -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.Row import org.scalatest.FunSuite import org.tensorframes.dsl.GraphScoping diff --git a/src/test/scala/org/tensorframes/dsl/BasicOpsSuite.scala b/src/test/scala/org/tensorframes/dsl/BasicOpsSuite.scala index 04ea6eb..544c58e 100644 --- a/src/test/scala/org/tensorframes/dsl/BasicOpsSuite.scala +++ b/src/test/scala/org/tensorframes/dsl/BasicOpsSuite.scala @@ -1,6 +1,6 @@ package org.tensorframes.dsl -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.scalatest.FunSuite import org.tensorframes.dsl.ExtractNodes._ import org.tensorframes.{dsl => tf} diff --git a/src/test/scala/org/tensorframes/dsl/BasicSuite.scala b/src/test/scala/org/tensorframes/dsl/BasicSuite.scala index 3dc24ea..40b9fa0 100644 --- a/src/test/scala/org/tensorframes/dsl/BasicSuite.scala +++ b/src/test/scala/org/tensorframes/dsl/BasicSuite.scala @@ -1,6 +1,6 @@ package org.tensorframes.dsl -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.scalatest.FunSuite import org.tensorframes.{dsl => tf} import org.tensorframes.dsl.Implicits._ diff --git a/src/test/scala/org/tensorframes/dsl/ExtractNodes.scala b/src/test/scala/org/tensorframes/dsl/ExtractNodes.scala index 84a080d..17cb0be 100644 --- a/src/test/scala/org/tensorframes/dsl/ExtractNodes.scala +++ b/src/test/scala/org/tensorframes/dsl/ExtractNodes.scala @@ -3,7 +3,7 @@ package org.tensorframes.dsl import java.io.{BufferedReader, InputStreamReader, File} import java.nio.file.Files import java.nio.charset.StandardCharsets -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.scalatest.ShouldMatchers import scala.collection.JavaConverters._ diff --git a/src/test/scala/org/tensorframes/perf/ConvertBackPerformanceSuite.scala b/src/test/scala/org/tensorframes/perf/ConvertBackPerformanceSuite.scala index fb48d0b..e8226c9 100644 --- a/src/test/scala/org/tensorframes/perf/ConvertBackPerformanceSuite.scala +++ b/src/test/scala/org/tensorframes/perf/ConvertBackPerformanceSuite.scala @@ -5,7 +5,7 @@ import org.scalatest.FunSuite import org.tensorframes.{ColumnInformation, Shape, TensorFramesTestSparkContext} import org.tensorframes.impl.{DataOps, SupportedOperations} -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.Row import org.apache.spark.sql.types._ diff --git a/src/test/scala/org/tensorframes/perf/ConvertPerformanceSuite.scala b/src/test/scala/org/tensorframes/perf/ConvertPerformanceSuite.scala index 53b4390..ad1a791 100644 --- a/src/test/scala/org/tensorframes/perf/ConvertPerformanceSuite.scala +++ b/src/test/scala/org/tensorframes/perf/ConvertPerformanceSuite.scala @@ -3,7 +3,7 @@ package org.tensorframes.perf import org.bytedeco.javacpp.{tensorflow => jtf} import org.scalatest.FunSuite -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.Row import org.apache.spark.sql.types._ diff --git a/src/test/scala/org/tensorframes/perf/PerformanceSuite.scala b/src/test/scala/org/tensorframes/perf/PerformanceSuite.scala index 1e1be56..3e83f01 100644 --- a/src/test/scala/org/tensorframes/perf/PerformanceSuite.scala +++ b/src/test/scala/org/tensorframes/perf/PerformanceSuite.scala @@ -1,6 +1,6 @@ package org.tensorframes.perf -import org.apache.spark.Logging +import org.apache.spark.{LoggingWrapper => Logging} import org.apache.spark.sql.functions._ import org.scalatest.FunSuite import org.tensorframes.TensorFramesTestSparkContext diff --git a/src/test/scala/org/tensorframes/type_suites.scala b/src/test/scala/org/tensorframes/type_suites.scala index 36e7372..523ad41 100644 --- a/src/test/scala/org/tensorframes/type_suites.scala +++ b/src/test/scala/org/tensorframes/type_suites.scala @@ -182,7 +182,8 @@ trait BasicMonoidTests[T] { self: CommonOperationsSuite[T] => val x1 = placeholder(dtype, Shape(Shape.Unknown)) named "x_input" val x = reduce_sum(x1, Seq(0)) named "x" val df2 = ops.aggregate(df.groupBy("key"), x).select("key", "x") - assert(df2.collect() === Seq(Row("a", 21.0), Row("b", 20.0)).u) + val rows = df2.collect().sortBy { case Row(s: String, _) => s } + assert(rows === Seq(Row("a", 21.0), Row("b", 20.0)).u) } }