From 5ddb043ce43e898e327adc08bfe0362c9ac63705 Mon Sep 17 00:00:00 2001 From: "sushrut.ikhar" Date: Tue, 19 Nov 2019 10:35:52 +0530 Subject: [PATCH 1/3] batch leap frame changes --- .../scala/ml/combust/bundle/dsl/Bundle.scala | 1 + .../mleap/runtime/frame/BatchLeapFrame.scala | 156 ++++++++++++++++++ .../src/main/resources/reference.conf | 3 +- .../tensorflow/BatchTensorflowModel.scala | 89 ++++++++++ .../BatchTensorflowTransformer.scala | 30 ++++ .../BatchTensorflowTransformerOp.scala | 74 +++++++++ .../converter/BatchMleapConverter.scala | 53 ++++++ .../converter/BatchTensorflowConverter.scala | 39 +++++ .../BatchTensorflowTransformerSpec.scala | 34 ++++ 9 files changed, 478 insertions(+), 1 deletion(-) create mode 100644 mleap-runtime/src/main/scala/ml/combust/mleap/runtime/frame/BatchLeapFrame.scala create mode 100644 mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowModel.scala create mode 100644 mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformer.scala create mode 100644 mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformerOp.scala create mode 100644 mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/converter/BatchMleapConverter.scala create mode 100644 mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/converter/BatchTensorflowConverter.scala create mode 100644 mleap-tensorflow/src/test/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformerSpec.scala diff --git a/bundle-ml/src/main/scala/ml/combust/bundle/dsl/Bundle.scala b/bundle-ml/src/main/scala/ml/combust/bundle/dsl/Bundle.scala index ed9fe678a..727d0a937 100644 --- a/bundle-ml/src/main/scala/ml/combust/bundle/dsl/Bundle.scala +++ b/bundle-ml/src/main/scala/ml/combust/bundle/dsl/Bundle.scala @@ -101,6 +101,7 @@ object Bundle { val pipeline = "pipeline" val tensorflow = "tensorflow" + val batch_tensorflow = "batch_tensorflow" } def apply[Transformer <: AnyRef](name: String, diff --git a/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/frame/BatchLeapFrame.scala b/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/frame/BatchLeapFrame.scala new file mode 100644 index 000000000..13d2bd099 --- /dev/null +++ b/mleap-runtime/src/main/scala/ml/combust/mleap/runtime/frame/BatchLeapFrame.scala @@ -0,0 +1,156 @@ +package ml.combust.mleap.runtime.frame + +import java.lang.Iterable + +import ml.combust.mleap.core.types.{BasicType, StructField, StructType} +import ml.combust.mleap.runtime.frame.Row.RowSelector +import ml.combust.mleap.runtime.function.{Selector, UserDefinedFunction} + +import scala.collection.JavaConverters._ +import scala.util.{Failure, Try} + +/** Class for storing a leap frame locally in batches of rows. + * + * @param schema schema of leap frame + */ +case class BatchLeapFrame(override val schema: StructType, + dataset: Seq[Row]) extends LeapFrame[BatchLeapFrame] { + def this(schema: StructType, rows: Iterable[Row]) = this(schema, rows.asScala.toSeq) + + /** Try to select fields to create a new LeapFrame. + * + * Returns a Failure if attempting to select any fields that don't exist. + * + * @param fieldNames field names to select + * @return try new LeapFrame with selected fields + */ + override def select(fieldNames: String*): Try[BatchLeapFrame] = { + schema.indicesOf(fieldNames: _*).flatMap { + indices => + schema.selectIndices(indices: _*).map { + schema2 => + val dataset2 = dataset.map(_.selectIndices(indices: _*)) + BatchLeapFrame(schema2, dataset2) + } + } + } + + /** Try to add a column to the LeapFrame. + * + * Returns a Failure if trying to add a field that already exists. + * + * @param name name of column + * @param selectors row selectors used to generate inputs to udf + * @param udf user defined function for calculating column value + * @return LeapFrame with new column + */ + override def withColumn(name: String, selectors: Selector*) + (udf: UserDefinedFunction): Try[BatchLeapFrame] = { + val rowUDF : UserDefinedFunction = UserDefinedFunction( + {(x:Row) => x}, + udf.output, + udf.inputs + ) + RowUtil.createRowSelectors(schema, selectors: _*)(udf).flatMap { + rowSelectors => + val field = StructField(name, udf.outputTypes.head) + schema.withField(field).map(schema2 => { + val results = (udf.f.asInstanceOf[Seq[Row] => Seq[Row]])(dataset.map(r => udfValue(rowSelectors: _*)(rowUDF)(r))) + val dataset2: Seq[Row] = dataset.zip(results).map { + case (r1, r2) => r1.toSeq :+ r2.head + }.map(x => Row(x: _*)) + BatchLeapFrame(schema2, dataset2) + }) + } + } + + /** Try to add multiple columns to the LeapFrame. + * + * Returns a Failure if trying to add a field that already exists. + * + * @param names names of columns + * @param selectors row selectors used to generate inputs to udf + * @param udf user defined function for calculating column values + * @return LeapFrame with new columns + */ + override def withColumns(names: Seq[String], selectors: Selector*) + (udf: UserDefinedFunction): Try[BatchLeapFrame] = { + val rowUDF : UserDefinedFunction = UserDefinedFunction( + {(x: Row) => x}, + udf.output, + udf.inputs + ) + RowUtil.createRowSelectors(schema, selectors: _*)(rowUDF).flatMap { + rowSelectors => + val fields = names.zip(udf.outputTypes).map { + case (name, dt) => StructField(name, dt) + } + + schema.withFields(fields).map( + schema2 => { + val results = (udf.f.asInstanceOf[Seq[Row] => Seq[Row]])(dataset.map(r => udfValue(rowSelectors: _*)(rowUDF)(r))) + val dataset2: Seq[Row] = dataset.zip(results).map { + case (r1, r2) => r1.toSeq ++ r2.toSeq + }.map(x => Row(x: _*)) + BatchLeapFrame(schema2, dataset2) + }) + } + } + + def udfValue(rowSelectors: RowSelector *)(udf : UserDefinedFunction)(row : Row): Row = { + udf.inputs.length match { + case 0 => + Row() + case 1 => + Row(rowSelectors.head (row) ) + case 2 => + Row(rowSelectors.head (row), rowSelectors (1) (row) ) + case 3 => + Row(rowSelectors.head (row), rowSelectors (1) (row), rowSelectors (2) (row) ) + case 4 => + Row(rowSelectors.head (row), rowSelectors (1) (row), rowSelectors (2) (row), rowSelectors (3) (row) ) + case 5 => + Row(rowSelectors.head (row), rowSelectors (1) (row), rowSelectors (2) (row), rowSelectors (3) (row), rowSelectors (4) (row) ) + } + } + + /** Try dropping column(s) from the LeapFrame. + * + * Returns a Failure if the column does not exist. + * + * @param names names of column to drop + * @return LeapFrame with column(s) dropped + */ + override def drop(names: String *): Try[BatchLeapFrame] = { + for(indices <- schema.indicesOf(names: _*); + schema2 <- schema.dropIndices(indices: _*)) yield { + val dataset2 = dataset.map(_.dropIndices(indices: _*)) + BatchLeapFrame(schema = schema2, dataset = dataset2) + } + } + + /** Try filtering the leap frame using the UDF + * + * @param selectors row selectors used as inputs for the filter + * @param udf filter udf, must return a Boolean + * @return LeapFrame with rows filtered + */ + override def filter(selectors: Selector *) + (udf: UserDefinedFunction): Try[BatchLeapFrame] = { + if(udf.outputTypes.length != 1 || udf.outputTypes.head.base != BasicType.Boolean) { + return Failure(new IllegalArgumentException("must provide a UDF that outputs a boolean for filtering")) + } + + RowUtil.createRowSelectors(schema, selectors: _*)(udf).map { + rowSelectors => + val dataset2 = dataset.filter(_.shouldFilter(rowSelectors: _*)(udf)) + BatchLeapFrame(schema, dataset2) + } + } + + /** Collect all rows into a Seq + * + * @return all rows in the leap frame + */ + override def collect(): Seq[Row] = dataset +} \ No newline at end of file diff --git a/mleap-tensorflow/src/main/resources/reference.conf b/mleap-tensorflow/src/main/resources/reference.conf index 70a80611b..c31e6eddc 100644 --- a/mleap-tensorflow/src/main/resources/reference.conf +++ b/mleap-tensorflow/src/main/resources/reference.conf @@ -1,5 +1,6 @@ ml.combust.mleap.tensorflow.ops = [ - "ml.combust.mleap.tensorflow.TensorflowTransformerOp" + "ml.combust.mleap.tensorflow.TensorflowTransformerOp", + "ml.combust.mleap.tensorflow.BatchTensorflowTransformerOp" ] ml.combust.mleap.registry.default.ops += "ml.combust.mleap.tensorflow.ops" diff --git a/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowModel.scala b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowModel.scala new file mode 100644 index 000000000..1063323ae --- /dev/null +++ b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowModel.scala @@ -0,0 +1,89 @@ +package ml.combust.mleap.tensorflow + +import ml.combust.mleap.core.Model +import ml.combust.mleap.core.types.{StructField, StructType, TensorType} +import ml.combust.mleap.tensor.Tensor +import ml.combust.mleap.tensorflow.converter.{BatchMleapConverter, BatchTensorflowConverter} +import org.tensorflow + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.util.Try + +case class BatchTensorflowModel(graph: tensorflow.Graph, + inputs: Seq[(String, TensorType)], + outputs: Seq[(String, TensorType)], + nodes: Option[Seq[String]] = None) extends Model with AutoCloseable { + @transient + private var session: Option[tensorflow.Session] = None + + def apply(values: Seq[Tensor[_]] *): Seq[Seq[Any]] = { + val garbage: mutable.ArrayBuilder[tensorflow.Tensor[_]] = mutable.ArrayBuilder.make[tensorflow.Tensor[_]]() + + val x = values.transpose + val result = Try { + val tensors: Seq[(String, tensorflow.Tensor[_])] = x.zip(inputs).map { + case (v: Seq[Tensor[_]], (name, dataType)) => + val tensor = BatchMleapConverter.convert(v, dataType) + garbage += tensor + (name, tensor) + } + + withSession { + session => + val runner = session.runner() + + tensors.foreach { + case (name, tensor) => runner.feed(name, tensor) + } + + outputs.foreach { + case (name, _) => runner.fetch(name) + } + + nodes.foreach { + _.foreach { + name => runner.addTarget(name) + } + } + + runner.run().asScala.zip(outputs).map { + case (tensor, (_, dataType)) => + garbage += tensor + BatchTensorflowConverter.convert(tensor, dataType) + } + } + } + + garbage.result.foreach(_.close()) + result.get + + } + + private def withSession[T](f: (tensorflow.Session) => T): T = { + val s = session.getOrElse { + session = Some(new tensorflow.Session(graph)) + session.get + } + + f(s) + } + + override def close(): Unit = { + session.foreach(_.close()) + graph.close() + } + + override def finalize(): Unit = { + close() + super.finalize() + } + + override def inputSchema: StructType = StructType(inputs.map { + case (name, dt) => StructField(name, dt) + }).get + + override def outputSchema: StructType = StructType(outputs.map { + case (name, dt) => StructField(name, dt) + }).get +} \ No newline at end of file diff --git a/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformer.scala b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformer.scala new file mode 100644 index 000000000..f6819d52b --- /dev/null +++ b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformer.scala @@ -0,0 +1,30 @@ +package ml.combust.mleap.tensorflow + +import ml.combust.mleap.core.types.{NodeShape, SchemaSpec} +import ml.combust.mleap.runtime.frame.{FrameBuilder, Row, Transformer} +import ml.combust.mleap.runtime.function.{StructSelector, UserDefinedFunction} +import ml.combust.mleap.tensor.Tensor + +import scala.util.Try + +case class BatchTensorflowTransformer(override val uid: String = Transformer.uniqueName("batchTensorflow"), + override val shape: NodeShape, + override val model: BatchTensorflowModel) extends Transformer { + private val f = (tensors: Seq[Row]) => { + model(tensors.map(_.toSeq.map(Tensor.scalar(_))):_*).transpose.map(x=>Row(x:_*)) + } + + val exec: UserDefinedFunction = UserDefinedFunction(f, + outputSchema, + Seq(SchemaSpec(inputSchema))) + + val outputCols: Seq[String] = outputSchema.fields.map(_.name) + val inputCols: Seq[String] = inputSchema.fields.map(_.name) + private val inputSelector: StructSelector = StructSelector(inputCols) + + override def transform[TB <: FrameBuilder[TB]](builder: TB): Try[TB] = { + builder.withColumns(outputCols, inputSelector)(exec) + } + + override def close(): Unit = { model.close() } +} \ No newline at end of file diff --git a/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformerOp.scala b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformerOp.scala new file mode 100644 index 000000000..41958b47c --- /dev/null +++ b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformerOp.scala @@ -0,0 +1,74 @@ +package ml.combust.mleap.tensorflow + +import java.nio.file.Files + +import ml.bundle.{BasicType, DataShape} +import ml.combust.bundle.BundleContext +import ml.combust.bundle.dsl._ +import ml.combust.bundle.op.OpModel +import ml.combust.mleap.bundle.ops.MleapOp +import ml.combust.mleap.core +import ml.combust.mleap.core.types.TensorType +import ml.combust.mleap.runtime.MleapContext +import ml.combust.mleap.runtime.types.BundleTypeConverters._ + +class BatchTensorflowTransformerOp extends MleapOp[BatchTensorflowTransformer, BatchTensorflowModel] { + override val Model: OpModel[MleapContext, BatchTensorflowModel] = new OpModel[MleapContext, BatchTensorflowModel] { + override val klazz: Class[BatchTensorflowModel] = classOf[BatchTensorflowModel] + + override def opName: String = Bundle.BuiltinOps.batchTensorflow + + override def store(model: Model, obj: BatchTensorflowModel) + (implicit context: BundleContext[MleapContext]): Model = { + Files.write(context.file("graph.pb"), obj.graph.toGraphDef) + val (inputNames, inputMleapDataTypes) = obj.inputs.unzip + val (inputBasicTypes, inputShapes) = inputMleapDataTypes.map { + dt => (dt.base: BasicType, dt.shape: DataShape) + }.unzip + + val (outputNames, outputMleapDataTypes) = obj.outputs.unzip + val (outputBasicTypes, outputShapes) = outputMleapDataTypes.map { + dt => (dt.base: BasicType, dt.shape: DataShape) + }.unzip + + model.withValue("input_names", Value.stringList(inputNames)). + withValue("input_types", Value.basicTypeList(inputBasicTypes)). + withValue("input_shapes", Value.dataShapeList(inputShapes)). + withValue("output_names", Value.stringList(outputNames)). + withValue("output_types", Value.basicTypeList(outputBasicTypes)). + withValue("output_shapes", Value.dataShapeList(outputShapes)). + withValue("nodes", obj.nodes.map(Value.stringList)) + } + + override def load(model: Model) + (implicit context: BundleContext[MleapContext]): BatchTensorflowModel = { + val graphBytes = Files.readAllBytes(context.file("graph.pb")) + + val inputNames = model.value("input_names").getStringList + val inputTypes = model.value("input_types").getBasicTypeList.map(v => v: core.types.BasicType) + val inputShapes = model.value("input_shapes").getDataShapeList.map(v => v: core.types.DataShape) + + val outputNames = model.value("output_names").getStringList + val outputTypes = model.value("output_types").getBasicTypeList.map(v => v: core.types.BasicType) + val outputShapes = model.value("output_shapes").getDataShapeList.map(v => v: core.types.DataShape) + + val nodes = model.getValue("nodes").map(_.getStringList) + + val inputs = inputNames.zip(inputTypes.zip(inputShapes).map { + case (b, s) => core.types.DataType(b, s).asInstanceOf[TensorType] + }) + val outputs = outputNames.zip(outputTypes.zip(outputShapes).map { + case (b, s) => core.types.DataType(b, s).asInstanceOf[TensorType] + }) + + val graph = new org.tensorflow.Graph() + graph.importGraphDef(graphBytes) + BatchTensorflowModel(graph, + inputs, + outputs, + nodes) + } + } + + override def model(node: BatchTensorflowTransformer): BatchTensorflowModel = node.model +} \ No newline at end of file diff --git a/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/converter/BatchMleapConverter.scala b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/converter/BatchMleapConverter.scala new file mode 100644 index 000000000..b20328c25 --- /dev/null +++ b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/converter/BatchMleapConverter.scala @@ -0,0 +1,53 @@ +package ml.combust.mleap.tensorflow.converter + +import java.nio._ + +import ml.combust.mleap.core.types.{BasicType, TensorType} +import ml.combust.mleap.tensor.{ByteString, Tensor} +import org.tensorflow + +object BatchMleapConverter { + def convert(value: Seq[Tensor[_]], tt: TensorType): tensorflow.Tensor[_] = { + + val dimensions: Array[Long] = (value.size +: value.head.dimensions).map(_.toLong).toArray + + tt.base match { + case BasicType.ByteString => + val x: Array[Array[Byte]] = value + .flatMap(_.asInstanceOf[Tensor[ByteString]].mapValues(_.bytes).toDense.values) + .toArray + tensorflow.Tensor.create(x) + case BasicType.Byte => + val x: Array[Byte] = value + .flatMap(_.asInstanceOf[Tensor[Byte]].mapValues(_.toByte).toDense.values) + .toArray + tensorflow.Tensor.create(x) + case BasicType.Int => + val x: Array[Int] = value + .flatMap(_.asInstanceOf[Tensor[Int]].mapValues(_.toInt).toDense.values) + .toArray + tensorflow.Tensor.create(dimensions, + IntBuffer.wrap(x)) + case BasicType.Long => + val x: Array[Long] = value + .flatMap(_.asInstanceOf[Tensor[Long]].mapValues(_.toLong).toDense.values) + .toArray + tensorflow.Tensor.create(dimensions, + LongBuffer.wrap(x)) + case BasicType.Float => + val x: Array[Float] = value + .flatMap(_.asInstanceOf[Tensor[Float]].mapValues(_.toFloat).toDense.values) + .toArray + tensorflow.Tensor.create(dimensions, + FloatBuffer.wrap(x)) + case BasicType.Double => + val x: Array[Double] = value + .flatMap(_.asInstanceOf[Tensor[Double]].mapValues(_.toDouble).toDense.values) + .toArray + tensorflow.Tensor.create(dimensions, + DoubleBuffer.wrap(x)) + case _ => + throw new IllegalArgumentException(s"unsupported tensor type $tt") + } + } +} \ No newline at end of file diff --git a/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/converter/BatchTensorflowConverter.scala b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/converter/BatchTensorflowConverter.scala new file mode 100644 index 000000000..256cae686 --- /dev/null +++ b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/converter/BatchTensorflowConverter.scala @@ -0,0 +1,39 @@ +package ml.combust.mleap.tensorflow.converter + +import java.nio._ + +import ml.combust.mleap.core.types.{BasicType, TensorType} +import ml.combust.mleap.tensor.{DenseTensor} +import org.tensorflow + +object BatchTensorflowConverter { + + def convert(tensor: tensorflow.Tensor[_], tt: TensorType): Seq[DenseTensor[_]] = { + val size = tensor.shape().product.toInt + val dimensions: Seq[Int] = tt.dimensions.get + tt.base match { + case BasicType.Byte => + val b = ByteBuffer.allocate(Math.max(1, size)) + tensor.writeTo(b) + b.array().map(x=> DenseTensor(Array(x),dimensions)) + case BasicType.Int => + val b = IntBuffer.allocate(Math.max(1, size)) + tensor.writeTo(b) + b.array().map(x=> DenseTensor(Array(x),dimensions)) + case BasicType.Long => + val b = LongBuffer.allocate(Math.max(1, size)) + tensor.writeTo(b) + b.array().map(x=> DenseTensor(Array(x),dimensions)) + case BasicType.Float => + val b = FloatBuffer.allocate(Math.max(1, size)) + tensor.writeTo(b) + b.array().map(x=> DenseTensor(Array(x),dimensions)) + case BasicType.Double => + val b = DoubleBuffer.allocate(Math.max(1, size)) + tensor.writeTo(b) + b.array().map(x=> DenseTensor(Array(x),dimensions)) + case _ => + throw new RuntimeException(s"unsupported tensorflow type: $tt") + } + } +} \ No newline at end of file diff --git a/mleap-tensorflow/src/test/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformerSpec.scala b/mleap-tensorflow/src/test/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformerSpec.scala new file mode 100644 index 000000000..93bba2be2 --- /dev/null +++ b/mleap-tensorflow/src/test/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformerSpec.scala @@ -0,0 +1,34 @@ +package ml.combust.mleap.tensorflow + +import ml.combust.mleap.core.types.{NodeShape, StructField, StructType, TensorType} +import ml.combust.mleap.runtime.frame.{BatchLeapFrame, Row} +import ml.combust.mleap.tensor.Tensor +import org.scalatest.FunSpec + +class BatchTensorflowTransformerSpec extends FunSpec { + describe("with a scaling tensorflow model") { + it("scales the vector using the model and returns the result") { + val model = BatchTensorflowModel(TestUtil.createAddGraph(), + inputs = Seq(("InputA", TensorType.Float()), ("InputB", TensorType.Float())), + outputs = Seq(("MyResult", TensorType.Float()))) + val shape = NodeShape().withInput("InputA", "input_a"). + withInput("InputB", "input_b"). + withOutput("MyResult", "my_result") + val transformer = BatchTensorflowTransformer(uid = "tensorflow_ab", + shape = shape, + model = model) + val schema = StructType(StructField("input_a", TensorType.Float()), StructField("input_b", TensorType.Float())).get + val dataset = Seq(Row(5.6f, 7.9f), + Row(3.4f, 6.7f), + Row(1.2f, 9.7f)) + val frame = BatchLeapFrame(schema, dataset) + + val data = transformer.transform(frame).get.dataset + assert(data(0)(2) == Tensor.scalar(5.6f + 7.9f)) + assert(data(1)(2) == Tensor.scalar(3.4f + 6.7f)) + assert(data(2)(2) == Tensor.scalar(1.2f + 9.7f)) + + transformer.close() + } + } +} \ No newline at end of file From 3adf0297eed5352e10eee7cf5cab327e1b6ed286 Mon Sep 17 00:00:00 2001 From: "sushrut.ikhar" Date: Tue, 19 Nov 2019 10:42:43 +0530 Subject: [PATCH 2/3] minor fixes --- .../combust/mleap/tensorflow/BatchTensorflowTransformerOp.scala | 2 +- .../mleap/tensorflow/BatchTensorflowTransformerSpec.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformerOp.scala b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformerOp.scala index 41958b47c..2b9e1f8ab 100644 --- a/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformerOp.scala +++ b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformerOp.scala @@ -16,7 +16,7 @@ class BatchTensorflowTransformerOp extends MleapOp[BatchTensorflowTransformer, B override val Model: OpModel[MleapContext, BatchTensorflowModel] = new OpModel[MleapContext, BatchTensorflowModel] { override val klazz: Class[BatchTensorflowModel] = classOf[BatchTensorflowModel] - override def opName: String = Bundle.BuiltinOps.batchTensorflow + override def opName: String = Bundle.BuiltinOps.batch_tensorflow override def store(model: Model, obj: BatchTensorflowModel) (implicit context: BundleContext[MleapContext]): Model = { diff --git a/mleap-tensorflow/src/test/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformerSpec.scala b/mleap-tensorflow/src/test/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformerSpec.scala index 93bba2be2..ac7f71ddc 100644 --- a/mleap-tensorflow/src/test/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformerSpec.scala +++ b/mleap-tensorflow/src/test/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformerSpec.scala @@ -31,4 +31,4 @@ class BatchTensorflowTransformerSpec extends FunSpec { transformer.close() } } -} \ No newline at end of file +} From fc17a3455c00d16d5b35a9b298f8daa0a0d021da Mon Sep 17 00:00:00 2001 From: "sushrut.ikhar" Date: Tue, 19 Nov 2019 18:53:34 +0530 Subject: [PATCH 3/3] fixing test case --- .../tensorflow/BatchTensorflowModel.scala | 3 +-- .../BatchTensorflowTransformer.scala | 24 +++++++++---------- .../converter/BatchMleapConverter.scala | 12 +++++----- 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowModel.scala b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowModel.scala index 1063323ae..92a299830 100644 --- a/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowModel.scala +++ b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowModel.scala @@ -56,8 +56,7 @@ case class BatchTensorflowModel(graph: tensorflow.Graph, } garbage.result.foreach(_.close()) - result.get - + result.get.transpose } private def withSession[T](f: (tensorflow.Session) => T): T = { diff --git a/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformer.scala b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformer.scala index f6819d52b..8e57e1f9e 100644 --- a/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformer.scala +++ b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/BatchTensorflowTransformer.scala @@ -1,29 +1,29 @@ package ml.combust.mleap.tensorflow -import ml.combust.mleap.core.types.{NodeShape, SchemaSpec} -import ml.combust.mleap.runtime.frame.{FrameBuilder, Row, Transformer} -import ml.combust.mleap.runtime.function.{StructSelector, UserDefinedFunction} +import ml.combust.mleap.core.types.NodeShape +import ml.combust.mleap.runtime.frame.{FrameBuilder, Row, SimpleTransformer, Transformer} +import ml.combust.mleap.runtime.function.{FieldSelector, Selector, UserDefinedFunction} import ml.combust.mleap.tensor.Tensor import scala.util.Try -case class BatchTensorflowTransformer(override val uid: String = Transformer.uniqueName("batchTensorflow"), +case class BatchTensorflowTransformer(override val uid: String = Transformer.uniqueName("batch_tensorflow"), override val shape: NodeShape, - override val model: BatchTensorflowModel) extends Transformer { - private val f = (tensors: Seq[Row]) => { - model(tensors.map(_.toSeq.map(Tensor.scalar(_))):_*).transpose.map(x=>Row(x:_*)) + override val model: BatchTensorflowModel) + extends SimpleTransformer { + private val f = (rows: Seq[Row]) => { + model(rows.map(x => x.toSeq.map(Tensor.scalar(_))): _*).map(Row(_: _*)) } - val exec: UserDefinedFunction = UserDefinedFunction(f, - outputSchema, - Seq(SchemaSpec(inputSchema))) + override val exec: UserDefinedFunction = + UserDefinedFunction(f, outputSchema, inputSchema) val outputCols: Seq[String] = outputSchema.fields.map(_.name) val inputCols: Seq[String] = inputSchema.fields.map(_.name) - private val inputSelector: StructSelector = StructSelector(inputCols) + private val inputSelector: Seq[Selector] = inputCols.map(FieldSelector) override def transform[TB <: FrameBuilder[TB]](builder: TB): Try[TB] = { - builder.withColumns(outputCols, inputSelector)(exec) + builder.withColumns(outputCols, inputSelector: _*)(exec) } override def close(): Unit = { model.close() } diff --git a/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/converter/BatchMleapConverter.scala b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/converter/BatchMleapConverter.scala index b20328c25..08d175503 100644 --- a/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/converter/BatchMleapConverter.scala +++ b/mleap-tensorflow/src/main/scala/ml/combust/mleap/tensorflow/converter/BatchMleapConverter.scala @@ -14,35 +14,35 @@ object BatchMleapConverter { tt.base match { case BasicType.ByteString => val x: Array[Array[Byte]] = value - .flatMap(_.asInstanceOf[Tensor[ByteString]].mapValues(_.bytes).toDense.values) + .flatMap(_.mapValues(_.asInstanceOf[ByteString].bytes).toDense.values) .toArray tensorflow.Tensor.create(x) case BasicType.Byte => val x: Array[Byte] = value - .flatMap(_.asInstanceOf[Tensor[Byte]].mapValues(_.toByte).toDense.values) + .flatMap(_.toDense.values.map(_.asInstanceOf[Byte])) .toArray tensorflow.Tensor.create(x) case BasicType.Int => val x: Array[Int] = value - .flatMap(_.asInstanceOf[Tensor[Int]].mapValues(_.toInt).toDense.values) + .flatMap(_.toDense.values.map(_.asInstanceOf[Int])) .toArray tensorflow.Tensor.create(dimensions, IntBuffer.wrap(x)) case BasicType.Long => val x: Array[Long] = value - .flatMap(_.asInstanceOf[Tensor[Long]].mapValues(_.toLong).toDense.values) + .flatMap(_.toDense.values.map(_.asInstanceOf[Long])) .toArray tensorflow.Tensor.create(dimensions, LongBuffer.wrap(x)) case BasicType.Float => val x: Array[Float] = value - .flatMap(_.asInstanceOf[Tensor[Float]].mapValues(_.toFloat).toDense.values) + .flatMap(_.toDense.values.map(_.asInstanceOf[Float])) .toArray tensorflow.Tensor.create(dimensions, FloatBuffer.wrap(x)) case BasicType.Double => val x: Array[Double] = value - .flatMap(_.asInstanceOf[Tensor[Double]].mapValues(_.toDouble).toDense.values) + .flatMap(_.toDense.values.map(_.asInstanceOf[Double])) .toArray tensorflow.Tensor.create(dimensions, DoubleBuffer.wrap(x))