Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1 Allow more heterogenous types in numeric operations #36

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 41 additions & 15 deletions src/common/NumericOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,45 @@ import java.io.PrintWriter
import scala.reflect.SourceContext

trait LiftNumeric {
this: Base =>
this: Base with NumericOps =>

// HACK The Numeric context bound is not *required* but it is useful to reduce the applicability of this implicit conversion
implicit def numericToNumericRep[T:Numeric:Manifest](x: T) = unit(x)

// Explicit `1 + unit(1)` (partial) support because it needs two implicit conversions
import NumericOpsTypes._
implicit class NumericOpsCls[A : Numeric : Manifest](lhs: A) {
def + (rhs: Rep[A])(implicit op: (A ~ A) ~> A, sc: SourceContext) = numeric_plus(op.lhs(unit(lhs)), op.rhs(rhs))
def - (rhs: Rep[A])(implicit op: (A ~ A) ~> A, sc: SourceContext) = numeric_minus(op.lhs(unit(lhs)), op.rhs(rhs))
def * (rhs: Rep[A])(implicit op: (A ~ A) ~> A, sc: SourceContext) = numeric_times(op.lhs(unit(lhs)), op.rhs(rhs))
def / (rhs: Rep[A])(implicit op: (A ~ A) ~> A, sc: SourceContext) = numeric_divide(op.lhs(unit(lhs)), op.rhs(rhs))
}
}

trait NumericOps extends Variables {

// workaround for infix not working with manifests
implicit def numericToNumericOps[T:Numeric:Manifest](n: T) = new NumericOpsCls(unit(n))
implicit def repNumericToNumericOps[T:Numeric:Manifest](n: Rep[T]) = new NumericOpsCls(n)
implicit def varNumericToNumericOps[T:Numeric:Manifest](n: Var[T]) = new NumericOpsCls(readVar(n))

class NumericOpsCls[T:Numeric:Manifest](lhs: Rep[T]){
def +[A](rhs: A)(implicit c: A => T, pos: SourceContext) = numeric_plus(lhs,unit(c(rhs)))
def +(rhs: Rep[T])(implicit pos: SourceContext) = numeric_plus(lhs,rhs)
def -(rhs: Rep[T])(implicit pos: SourceContext) = numeric_minus(lhs,rhs)
def *(rhs: Rep[T])(implicit pos: SourceContext) = numeric_times(lhs,rhs)
def /(rhs: Rep[T])(implicit pos: SourceContext) = numeric_divide(lhs,rhs)
// Type constraints allowing an eventual type promotion (e.g. Int to Float) before performing the numeric operation
object NumericOpsTypes {
trait Args { type Lhs; type Rhs }
trait ~[A, B] extends Args { type Lhs = A; type Rhs = B }
class ~>[A <: Args, B](val lhs: Rep[A#Lhs] => Rep[B], val rhs: Rep[A#Rhs] => Rep[B])(implicit val Numeric: Numeric[B])
}
import NumericOpsTypes._
implicit def numericSameArgs[A : Numeric] = new ((A ~ A) ~> A) (identity, identity)

/* FIXME
* I’d like to define numeric operators as follows:
* def infix_+[A, B](lhs: A, rhs: B)(implicit someAdditionnalConstraints...)
* But this signature leads to an ambiguous reference to overloaded definition with an infix_+(s: String, a: Any) method defined in EmbeddedControls (?)
*/
implicit class RepNumericOpsCls[A](lhs: Rep[A]) {
def + [B, C](rhs: Rep[B])(implicit op: (A ~ B) ~> C, mC: Manifest[C], sc: SourceContext) = numeric_plus(op.lhs(lhs), op.rhs(rhs))(op.Numeric, mC, sc)
def - [B, C](rhs: Rep[B])(implicit op: (A ~ B) ~> C, mC: Manifest[C], sc: SourceContext) = numeric_minus(op.lhs(lhs), op.rhs(rhs))(op.Numeric, mC, sc)
def * [B, C](rhs: Rep[B])(implicit op: (A ~ B) ~> C, mC: Manifest[C], sc: SourceContext) = numeric_times(op.lhs(lhs), op.rhs(rhs))(op.Numeric, mC, sc)
def / [B, C](rhs: Rep[B])(implicit op: (A ~ B) ~> C, mC: Manifest[C], sc: SourceContext) = numeric_divide(op.lhs(lhs), op.rhs(rhs))(op.Numeric, mC, sc)
}

//def infix_+[T:Numeric:Manifest](lhs: Rep[T], rhs: Rep[T]) = numeric_plus(lhs,rhs)
//def infix_-[T:Numeric:Manifest](lhs: Rep[T], rhs: Rep[T]) = numeric_minus(lhs,rhs)
//def infix_*[T:Numeric:Manifest](lhs: Rep[T], rhs: Rep[T]) = numeric_times(lhs,rhs)
implicit def varNumericToNumericOps[T : Numeric : Manifest](n: Var[T]) = new RepNumericOpsCls(readVar(n))

def numeric_plus[T:Numeric:Manifest](lhs: Rep[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[T]
def numeric_minus[T:Numeric:Manifest](lhs: Rep[T], rhs: Rep[T])(implicit pos: SourceContext): Rep[T]
Expand All @@ -38,6 +54,16 @@ trait NumericOps extends Variables {
//def numeric_signum[T:Numeric](x: T): Rep[Int]
}

/*
* Enable promotion of arguments involved in a numeric operation provided there exists an implicit conversion to perform the promotion.
* For instance, it allows to mix Int values and Double values in a numeric operation.
*/
trait NumericPromotions { this: ImplicitOps with NumericOps =>
import NumericOpsTypes._
implicit def numericPromoteLhs[A : Manifest, B : Numeric : Manifest](implicit aToB: A => B) = new ((A ~ B) ~> B) (lhs = implicit_convert[A, B](_), rhs = identity)
implicit def numericPromoteRhs[A : Manifest, B : Numeric : Manifest](implicit aToB: A => B) = new ((B ~ A) ~> B) (lhs = identity, rhs = implicit_convert[A, B](_))
}

trait NumericOpsExp extends NumericOps with VariablesExp with BaseFatExp {
abstract class DefMN[A:Manifest:Numeric] extends Def[A] {
def mev = manifest[A]
Expand Down
60 changes: 60 additions & 0 deletions test-src/epfl/test13-numeric/TestNumeric.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package scala.virtualization.lms
package epfl
package test13

import common._

// Does nothing but checks the code compiles
trait TestNumeric {

def typed[A](a: => A) {}

trait Usage { this: Base with NumericOps =>

val a = unit(1) + unit(1)
typed[Rep[Int]](a)

val b = unit(1.0) + unit(1.0)
typed[Rep[Double]](b)
}

trait UsageWithLift { this: Base with NumericOps with LiftNumeric =>

val a = 1 + unit(1)
typed[Rep[Int]](a)

val b = 1.0 + unit(1.0)
typed[Rep[Double]](b)

val c = unit(1) + 1
typed[Rep[Int]](c)

val d = unit(1.0) + 1.0
typed[Rep[Double]](d)
}

trait UsageWithPromotions { this: Base with NumericOps with NumericPromotions =>

val a = unit(1) + unit(1.0)
typed[Rep[Double]](a)

val b = unit(1.0) + unit(1)
typed[Rep[Double]](b)
}

trait UsageWithPromotionsAndLift { this: Base with NumericOps with NumericPromotions with LiftNumeric =>

val a = unit(1) + 1.0
typed[Rep[Double]](a)

// val b = 1 + unit(1.0)
// typed[Rep[Double]](b)

val c = unit(1.0) + 1
typed[Rep[Double]](c)

// val d = 1.0 + unit(1)
// typed[Rep[Double]](d)
}

}