Skip to content

Commit

Permalink
Merge pull request #1567 from oyvindberg/fix-effectful-copy-in
Browse files Browse the repository at this point in the history
Fix effectful `copyIn` (Fixes #1512)
  • Loading branch information
jatcwang authored Dec 10, 2021
2 parents aa68012 + 9414dc5 commit 650c8b9
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package doobie.postgres.syntax

import cats.Foldable
import cats.effect.Ref
import cats.effect.kernel.Resource
import cats.syntax.all._
import doobie._
Expand Down Expand Up @@ -47,20 +48,23 @@ class FragmentOps(f: Fragment) {
val byteStream: Stream[ConnectionIO, Byte] =
stream.chunkMin(minChunkSize).map(foldToString(_)).through(encode)

Stream.bracketCase(
PHC.pgGetCopyAPI(PFCM.copyIn(f.query.sql))
){
case (copyIn, Resource.ExitCase.Succeeded) =>
PHC.embed(copyIn, PFCI.isActive.ifM(PFCI.endCopy.void, PFCI.unit))
case (copyIn, _) =>
PHC.embed(copyIn, PFCI.cancelCopy)
}.flatMap(copyIn =>
byteStream.chunks.evalMap(bytes =>
PHC.embed(copyIn, PFCI.writeToCopy(bytes.toArray, 0, bytes.size))
) *>
Stream.eval(PHC.embed(copyIn, PFCI.endCopy))
).compile.foldMonoid
// use a reference to capture the number of affected rows, as determined by `endCopy`.
// we need to run that in the finalizer of the `bracket`, and the result from that is ignored.
Ref.of[ConnectionIO, Long](-1L).flatMap { numRowsRef =>
val copyAll: ConnectionIO[Unit] =
Stream.bracketCase(PHC.pgGetCopyAPI(PFCM.copyIn(f.query.sql))){
case (copyIn, Resource.ExitCase.Succeeded) =>
PHC.embed(copyIn, PFCI.endCopy).flatMap(numRowsRef.set)
case (copyIn, _) =>
PHC.embed(copyIn, PFCI.cancelCopy)
}.flatMap { copyIn =>
byteStream.chunks.evalMap(bytes =>
PHC.embed(copyIn, PFCI.writeToCopy(bytes.toArray, 0, bytes.size))
)
}.compile.drain

copyAll.flatMap(_ => numRowsRef.get)
}
}

/** Folds given `F` to string, encoding each `A` with `Text` instance and joining resulting strings with `\n` */
Expand Down
89 changes: 89 additions & 0 deletions modules/postgres/src/test/scala/doobie/postgres/Issue1512.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (c) 2013-2020 Rob Norris and Contributors
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

package doobie.postgres

import cats.effect.IO
import cats.implicits.catsSyntaxApplicativeId
import doobie.ConnectionIO
import doobie.implicits._
import doobie.postgres.implicits._
import doobie.util.transactor.Transactor
import munit.CatsEffectSuite
import org.postgresql.ds.PGSimpleDataSource

import javax.sql.DataSource

class Issue1512 extends CatsEffectSuite {

val minChunkSize = 200

val datasource: DataSource = {
val ds = new PGSimpleDataSource
ds.setUser("postgres")
ds.setPassword("")
ds
}
val xa: Transactor[IO] =
Transactor.fromDataSource[IO](datasource, scala.concurrent.ExecutionContext.global)

val setup: IO[Int] =
sql"""
DROP TABLE IF EXISTS demo;
CREATE TABLE demo(id BIGSERIAL PRIMARY KEY NOT NULL, data BIGINT NOT NULL);
""".update.run
.transact(xa)

test("A stream with a Pure effect inserts items properly") {

setup.unsafeRunSync()

// A pure stream is fine - can copy many items
val count = 10000
val stream = fs2.Stream.emits(1 to count)

sql"COPY demo(data) FROM STDIN".copyIn(stream, minChunkSize).transact(xa).unsafeRunSync()

val queryCount =
sql"SELECT count(*) from demo".query[Int].unique.transact(xa).unsafeRunSync()

assertEquals(queryCount, count)
}

test("A stream with a ConnectionIO effect copies <= than minChunkSize items") {

setup.unsafeRunSync()

// Can copy up to minChunkSize just fine with ConnectionIO
val inputs = 1 to minChunkSize
val stream = fs2.Stream.emits[ConnectionIO, Int](inputs)
.evalMap(i => (i + 2).pure[ConnectionIO])

val copiedRows = sql"COPY demo(data) FROM STDIN".copyIn(stream, minChunkSize).transact(xa).unsafeRunSync()

assertEquals(copiedRows, inputs.size.toLong)

val queryCount =
sql"SELECT count(*) from demo".query[Int].unique.transact(xa).unsafeRunSync()

assertEquals(queryCount, minChunkSize)
}

test("A stream with a ConnectionIO effect copies items with count > minChunkSize") {

setup.unsafeRunSync()

val inputs = 1 to minChunkSize + 1
val stream = fs2.Stream.emits[ConnectionIO, Int](inputs)
.evalMap(i => (i + 2).pure[ConnectionIO])

val copiedRows = sql"COPY demo(data) FROM STDIN".copyIn(stream, minChunkSize).transact(xa).unsafeRunSync()
assertEquals(copiedRows, inputs.size.toLong)

val queryCount =
sql"SELECT count(*) from demo".query[Int].unique.transact(xa).unsafeRunSync()

assertEquals(queryCount, minChunkSize + 1)
}
}

0 comments on commit 650c8b9

Please sign in to comment.