Skip to content

Commit

Permalink
Feature/configurable thrift (#815)
Browse files Browse the repository at this point in the history
* Tweaking behaviour of optional columns.

* Trigger autocreated query on DB create.

* Adding collection types.

* Simplifying and removing unecessary code.

* Making Thrift serializers compatible and intercheangable.

* Using compact DSL

* Deleting all the codez..

* Removing unused import

* Adding support for all protocols

* Adding models.

* Adding protocol version checks.

* Adding tests for all thrift protocosl

* Separating suites properly.

* Adding e2e for binary protocol.

* Fixing import

* Removing more imports

* Adding twitter future concept back

* Hardcoding protocol to avoid responsibility of doing sensible I/O right now.

* Adding more comprehensive primitive tests.

* Fixing protocol version before bugs are fixed.

* Getting rid of duplicate tests.

* Small correctiong

* Fixing Insert JSON serialisation.

* Correcting tests
  • Loading branch information
alexflav23 authored Mar 21, 2018
1 parent 1dd89b9 commit 0c03b1f
Show file tree
Hide file tree
Showing 50 changed files with 2,239 additions and 394 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ trait TableAliases[T <: CassandraTable[T, R], R] { self: CassandraTable[T, R] =>
) extends Col[Option[RR]] {
override def parse(r: Row): Try[Option[RR]] = ev.fromRow(name, r) match {
case Success(value) => Success(Some(value))
case Failure(_) => Success(None)
case Failure(_) if r.isNull(name) => Success(None)
case Failure(ex) => Failure[Option[RR]](ex)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ class PrimitiveMacro(override val c: blackbox.Context) extends BlackboxToolbelt
}
}



def optionPrimitive(tpe: Type): Tree = {
tpe.typeArgs match {
case head :: Nil => q"$prefix.Primitives.option[$head]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ object Primitives {
1,
"Invalid 8-bits integer value, expecting 1 byte but got " + source.remaining()
) {
case b => source.get(source.position())
case _ => source.get(source.position())
}
}
}
Expand Down Expand Up @@ -359,7 +359,7 @@ object Primitives {
override def serialize(obj: BigDecimal, version: ProtocolVersion): ByteBuffer = {
obj match {
case Primitive.nullValue => Primitive.nullValue
case decimal =>
case _ =>
val bi = obj.bigDecimal.unscaledValue
val bibytes = bi.toByteArray

Expand Down Expand Up @@ -392,7 +392,7 @@ object Primitives {
}

object InetAddressPrimitive extends Primitive[InetAddress] {
override val dataType = CQLSyntax.Types.Inet
override val dataType: String = CQLSyntax.Types.Inet

override def asCql(value: InetAddress): String = CQLQuery.empty.singleQuote(value.getHostAddress)

Expand All @@ -409,14 +409,14 @@ object Primitives {
InetAddress.getByAddress(Bytes.getArray(bytes))
catch {
case e: UnknownHostException =>
throw new InvalidTypeException("Invalid bytes for inet value, got " + bytes.remaining + " bytes")
throw new InvalidTypeException("Invalid bytes for inet value, got " + bytes.remaining + " bytes", e)
}
}
}
}

object BigIntPrimitive extends Primitive[BigInt] {
override val dataType = CQLSyntax.Types.Varint
override val dataType: String = CQLSyntax.Types.Varint

override def asCql(value: BigInt): String = value.toString()

Expand All @@ -428,13 +428,13 @@ object Primitives {
bytes match {
case Primitive.nullValue => Primitive.nullValue
case b if b.remaining() == 0 => Primitive.nullValue
case bt => new BigInteger(Bytes.getArray(bytes))
case _ => new BigInteger(Bytes.getArray(bytes))
}
}
}

object BlobIsPrimitive extends Primitive[ByteBuffer] {
override val dataType = CQLSyntax.Types.Blob
override val dataType: String = CQLSyntax.Types.Blob

override def asCql(value: ByteBuffer): String = Bytes.toHexString(value)

Expand All @@ -444,9 +444,9 @@ object Primitives {
}

object LocalDateIsPrimitive extends Primitive[LocalDate] {
override val dataType = CQLSyntax.Types.Timestamp
override val dataType: String = CQLSyntax.Types.Timestamp

val codec = IntPrimitive
val codec: IntPrimitive.type = IntPrimitive

override def asCql(value: LocalDate): String = ParseUtils.quote(value.toString)

Expand All @@ -461,7 +461,7 @@ object Primitives {
bytes match {
case Primitive.nullValue => Primitive.nullValue
case b if b.remaining() == 0 => Primitive.nullValue
case b @ _ =>
case _ =>
val unsigned = codec.deserialize(bytes, version)
val signed = CodecUtils.fromUnsignedToSignedInt(unsigned)
LocalDate.fromDaysSinceEpoch(signed)
Expand All @@ -474,7 +474,7 @@ object Primitives {
new DateTime(_, DateTimeZone.UTC)
)(LongPrimitive)(CQLSyntax.Types.Timestamp)

val SqlTimestampIsPrimitive = Primitive.manuallyDerive[Timestamp, Long](
val SqlTimestampIsPrimitive: Primitive[Timestamp] = Primitive.manuallyDerive[Timestamp, Long](
ts => ts.getTime,
dt => Timestamp.from(Instant.ofEpochMilli(dt))
)(LongPrimitive)(CQLSyntax.Types.Timestamp)
Expand All @@ -499,7 +499,7 @@ object Primitives {

override def asCql(value: M[RR]): String = converter(value)

override val dataType = cType
override val dataType: String = cType

override def serialize(coll: M[RR], version: ProtocolVersion): ByteBuffer = {
coll match {
Expand Down Expand Up @@ -553,6 +553,7 @@ object Primitives {
)
}


def option[T : Primitive]: Primitive[Option[T]] = {
val ev = implicitly[Primitive[T]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ case class InsertQuery[
final def json(value: PrepareMark): InsertJsonQuery[Table, Record, Status, String :: PS] = {
new InsertJsonQuery[Table, Record, Status, String :: PS](
table = table,
init = QueryBuilder.Insert.json(init, value.qb.queryString),
init = QueryBuilder.Insert.json(init, value),
usingPart = usingPart,
lightweightPart = lightweightPart,
options = options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,6 @@ private[builder] abstract class CollectionModifiers(queryBuilder: QueryBuilder)
.forcePad.append(CQLSyntax.Symbols.`}`)
}

def mapType(keyType: String, valueType: String): CQLQuery = {
diamond(CQLSyntax.Collections.map, CQLQuery(List(keyType, valueType)).queryString)
}

def mapType[K, V](key: Primitive[K], value: Primitive[V]): CQLQuery = {
diamond(CQLSyntax.Collections.map, CQLQuery(
List(frozen(key).queryString, frozen(value).queryString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.outworkers.phantom.builder.serializers

import com.outworkers.phantom.builder.query.engine.CQLQuery
import com.outworkers.phantom.builder.query.prepared.PrepareMark
import com.outworkers.phantom.builder.syntax.CQLSyntax

private[phantom] trait InsertQueryBuilder {
Expand All @@ -39,6 +40,16 @@ private[phantom] trait InsertQueryBuilder {
init.pad.append("JSON").pad.append(CQLQuery.escape(jsonString))
}

/**
* Creates a CQL 2.2 JSON insert clause for a prepared mark.
* @param init The initialization query of the Insert clause, generally comprising the "INSERT INTO tableName" part.
* @param mark The prepared question mark.
* @return A CQL query with the JSON prefix appended to the insert.
*/
def json(init: CQLQuery, mark: PrepareMark): CQLQuery = {
init.pad.append("JSON").pad.append(mark.qb)
}

def columns(seq: Seq[CQLQuery]): CQLQuery = {
CQLQuery.empty.wrapn(seq.map(_.queryString))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class CollectionColumn[

override def asCql(v: M[RR]): String = ev.asCql(v)

override val cassandraType = QueryBuilder.Collections.collectionType(
override val cassandraType: String = QueryBuilder.Collections.collectionType(
collection,
vp.dataType,
shouldFreeze,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ private[phantom] abstract class AbstractMapColumn[
Record,
K,
V
](table: CassandraTable[Owner, Record]) extends Column[Owner, Record, Map[K, V]](table)
with CollectionValueDefinition[V] {
](
table: CassandraTable[Owner, Record]
) extends Column[Owner, Record, Map[K, V]](table) with CollectionValueDefinition[V] {

def keyAsCql(v: K): String

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ abstract class DbOps[
def createAsync()(
implicit ex: ExecutionContextExecutor
): F[Seq[Seq[ResultSet]]] = {
ExecutionHelper.sequencedTraverse(tables.map(_.create.ifNotExists().delegate)) { query =>
ExecutionHelper.sequencedTraverse(tables.map(_.autocreate(db.space).delegate)) { query =>
QueryContext.create(query)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ import java.util.concurrent.TimeUnit

import com.datastax.driver.core.VersionNumber
import com.outworkers.phantom.database.DatabaseProvider
import com.outworkers.phantom.dsl.{DateTime, UUID}
import com.outworkers.phantom.tables.TestDatabase
import com.outworkers.util.samplers._
import io.circe.{Encoder, Json}
import org.joda.time.{DateTime, DateTimeZone, LocalDate}
import org.json4s.Formats
import org.scalatest._
Expand Down Expand Up @@ -79,6 +81,11 @@ trait TestDatabaseProvider extends DatabaseProvider[TestDatabase] {
}

trait PhantomSuite extends FlatSpec with PhantomBaseSuite with TestDatabaseProvider {


implicit val datetimeEncoder: Encoder[DateTime] = Encoder.instance(dt => Json.fromLong(dt.getMillis))
implicit val uuidEncoder: Encoder[UUID] = Encoder.instance(uuid => Json.fromString(uuid.toString))

def requireVersion[T](v: VersionNumber)(fn: => T): Unit = if (cassandraVersion.value.compareTo(v) >= 0) {
val _ = fn
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ class PrimitiveRoundtripTests

implicit val bigDecimalArb: Arbitrary[BigDecimal] = Sample.arbitrary[BigDecimal]

private[this] val protocol = ProtocolVersion.V5
val protocolGen: Gen[ProtocolVersion] = Gen.alphaStr.map(_ => ProtocolVersion.V5)

def roundtrip[T : Primitive](gen: Gen[T]): Assertion = {
val ev = Primitive[T]
forAll(gen) { sample =>
forAll(gen, protocolGen) { (sample, protocol) =>
ev.deserialize(ev.serialize(sample, protocol), protocol) shouldEqual sample
}
}
Expand All @@ -65,10 +65,7 @@ class PrimitiveRoundtripTests
}

def roundtrip[T : Primitive : Arbitrary]: Assertion = {
val ev = Primitive[T]
forAll { sample: T =>
ev.deserialize(ev.serialize(sample, protocol), protocol) shouldEqual sample
}
roundtrip(implicitly[Arbitrary[T]].arbitrary)
}

it should "serialize and deserialize a String primitive" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,24 @@ class MapOperationsTest extends PhantomSuite {
database.scalaPrimitivesTable.createSchema()
}

it should "support a single item map set operation" in {
val recipe = gen[Recipe]
val (key , value) = gen[(String, String)]

val operation = for {
insertDone <- database.recipes.store(recipe).future()
update <- database.recipes.update.where(_.url eqs recipe.url).modify(_.props set (key, value)).future()

select <- database.recipes.select(_.props).where(_.url eqs recipe.url).one
} yield select

whenReady(operation) { items =>
items.value.get(key) shouldBe defined
items.value.get(key).value shouldEqual value
}
}


it should "support a single item map put operation" in {
val recipe = gen[Recipe]
val item = gen[(String, String)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ import com.outworkers.phantom.builder.primitives.{DerivedField, DerivedTupleFiel
import com.outworkers.phantom.dsl.{?, _}
import com.outworkers.phantom.tables.{DerivedRecord, PrimitiveCassandra22, PrimitiveRecord, Recipe}
import com.outworkers.util.samplers._
import io.circe.generic.auto._
import io.circe.syntax._
import io.circe.{Encoder, Json}

class PreparedInsertQueryTest extends PhantomSuite {

Expand Down Expand Up @@ -81,6 +84,38 @@ class PreparedInsertQueryTest extends PhantomSuite {
}
}


it should "execute a non-asynchronous prepared JSON insert query" in {
val sample = gen[Recipe]

val query = database.recipes.insert.json(?).prepare()

val chain = for {
_ <- query.bind(sample.asJson.noSpaces).future()
res <- database.recipes.select.where(_.url eqs sample.url).one()
} yield res

whenReady(chain) { res =>
res shouldBe defined
res.value shouldEqual sample
}
}

it should "execute an asynchronous prepared JSON insert query" in {
val sample = gen[Recipe]

val chain = for {
query <- database.recipes.insert.json(?).prepareAsync()
_ <- query.bind(sample.asJson.noSpaces).future()
res <- database.recipes.select.where(_.url eqs sample.url).one()
} yield res

whenReady(chain) { res =>
res shouldBe defined
res.value shouldEqual sample
}
}

it should "serialize a primitives insert query" in {
val sample = gen[PrimitiveRecord]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ class SASIIntegrationTest extends PhantomSuite {
val samples = genList[MultiSASIRecord]().map(item => item.copy(phoneNumber = pre + item.phoneNumber))

if (cassandraVersion.value >= Version.`3.4.0`) {
val ps = db.multiSasiTable.select.where(_.phoneNumber like prefix(?)).prepareAsync()
val chain = for {
ps <- db.multiSasiTable.select.where(_.phoneNumber like prefix(?)).prepareAsync()
_ <- db.multiSasiTable.storeRecords(samples)
query <- ps.flatMap(_.bind(PrefixValue(pre)).fetch())
query <- ps.bind(PrefixValue(pre)).fetch()
} yield query

whenReady(chain) { results =>
Expand Down Expand Up @@ -95,11 +95,10 @@ class SASIIntegrationTest extends PhantomSuite {
val samples = genList[MultiSASIRecord]().map(item => item.copy(name = item.name + suf))

if (cassandraVersion.value >= Version.`3.4.0`) {

val ps = db.multiSasiTable.select.where(_.name like suffix(?)).prepareAsync()
val chain = for {
ps <- db.multiSasiTable.select.where(_.name like suffix(?)).prepareAsync()
_ <- db.multiSasiTable.storeRecords(samples)
query <- ps.flatMap(_.bind(SuffixValue(suf)).fetch())
query <- ps.bind(SuffixValue(suf)).fetch()
} yield query

whenReady(chain) { results =>
Expand Down Expand Up @@ -180,13 +179,10 @@ class SASIIntegrationTest extends PhantomSuite {
val samples = genList[MultiSASIRecord]().map(item => item.copy(customers = pre))

if (cassandraVersion.value >= Version.`3.4.0`) {

val query = db.multiSasiTable.select.where(_.customers >= ?).prepareAsync()

val chain = for {
bindable <- db.multiSasiTable.select.where(_.customers >= ?).prepareAsync()
_ <- db.multiSasiTable.truncate().future()
_ <- db.multiSasiTable.storeRecords(samples)
bindable <- query
query <- bindable.bind(pre).fetch()
} yield query

Expand Down Expand Up @@ -221,12 +217,12 @@ class SASIIntegrationTest extends PhantomSuite {

if (cassandraVersion.value >= Version.`3.4.0`) {

val ps = db.multiSasiTable.select.where(_.customers <= ?).prepareAsync()

val chain = for {
ps <- db.multiSasiTable.select.where(_.customers <= ?).prepareAsync()
_ <- db.multiSasiTable.truncate().future()
_ <- db.multiSasiTable.storeRecords(samples)
query <- ps.flatMap(_.bind(pre).fetch())
query <- ps.bind(pre).fetch()
} yield query

whenReady(chain) { results =>
Expand Down Expand Up @@ -295,12 +291,12 @@ class SASIIntegrationTest extends PhantomSuite {
val samples = genList[MultiSASIRecord]().map(item => item.copy(customers = pre))

if (cassandraVersion.value >= Version.`3.4.0`) {
val ps = db.multiSasiTable.select.where(_.customers > ?).prepareAsync()

val chain = for {
ps <- db.multiSasiTable.select.where(_.customers > ?).prepareAsync()
_ <- db.multiSasiTable.truncate().future()
_ <- db.multiSasiTable.storeRecords(samples)
query <- ps.flatMap(_.bind(pre).fetch())
query <- ps.bind(pre).fetch()
} yield query

whenReady(chain) { results =>
Expand Down
Loading

0 comments on commit 0c03b1f

Please sign in to comment.