Skip to content

Commit

Permalink
Inlining and vector access rewrites (#427)
Browse files Browse the repository at this point in the history
  • Loading branch information
eldritchconundrum authored Jun 2, 2024
1 parent a00caa9 commit 0e40701
Show file tree
Hide file tree
Showing 15 changed files with 213 additions and 147 deletions.
28 changes: 21 additions & 7 deletions src/inlining.fs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,12 @@ type VariableInlining(options: Options.Options) =
// Mark variables as inlinable when possible.
// Variables are always safe to inline when all of:
// - the variable is used only once in the current block
// - the variable is not used in a sub-block (e.g. inside a loop), for runtime performance
// - the variable is not used in a loop sub-block, for runtime performance
// - the init value refers only to variables that are never written to, and functions that are builtin and pure
let markSafelyInlinableLocals block =
// Variables that are defined in this scope.
// The boolean indicate if the variable initialization is const.
let localDefs = Dictionary<string, (Ident * bool)>()
// List of all expressions in the current block. Do not look in sub-blocks.
let mutable localExpr = []
for stmt: Stmt in block do
match stmt with
| Decl (_, declElts) ->
Expand All @@ -51,14 +49,30 @@ type VariableInlining(options: Options.Options) =
| None ->
localDefs.[def.name.Name] <- (def.name, true)
| Some init ->
localExpr <- init :: localExpr
let isConst = Analyzer(options).varUsesInStmt (Expr init) |> Seq.forall isEffectivelyConst
localDefs.[def.name.Name] <- (def.name, isConst)
| _ -> ()
// List of all expressions under the current block, but do not look in loops.
let mutable localExprs = []
let rec addLocalExprs stmt =
match stmt with
| Decl (_, declElts) ->
for def in declElts do
match def.init with
| None -> ()
| Some init -> localExprs <- init :: localExprs
| Expr e
| Jump (_, Some e) -> localExpr <- e :: localExpr
| Directive _ | Verbatim _ | Jump (_, None) | Block _ | If _| ForE _ | ForD _ | While _ | DoWhile _ | Switch _ -> ()
| Jump (_, Some e) -> localExprs <- e :: localExprs
| If (cond, th, el) ->
localExprs <- cond :: localExprs
addLocalExprs th
el |> Option.iter addLocalExprs
| Block stmts -> stmts |> Seq.iter addLocalExprs
| Directive _ | Verbatim _ | Jump (_, None) | ForE _ | ForD _ | While _ | DoWhile _ | Switch _ -> ()
for stmt: Stmt in block do
addLocalExprs stmt

let localReferences = countReferences [for e in localExpr -> Expr e]
let localReferences = countReferences [for e in localExprs -> Expr e]
let allReferences = countReferences block

for def in localDefs do
Expand Down
39 changes: 38 additions & 1 deletion src/rewriter.fs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ open Analyzer

let private commaSeparatedExprs = List.reduce (fun a b -> FunCall(Op ",", [a; b]))

let private isKnownToHaveTypeFloat = function
| Float _ -> true
| ResolvedVariableUse (_, vd) -> vd.decl.name.Name = "float"
| _ -> false

[<RequireQualifiedAccess>]
type private OptimizationPass =
| First
Expand Down Expand Up @@ -252,7 +257,8 @@ type private RewriterImpl(options: Options.Options, optimizationPass: Optimizati
// basic type of the object being constructed, the scalar construction rules (above) are used to convert
// the parameters."
let useInts = function
| Float (f, _) as e when Decimal.Round(f) = f ->
| Float (f, _) as e when optimizationPass = OptimizationPass.Second // only do this after other transforms that can apply only to floats.
&& Decimal.Round(f) = f ->
try
let candidate = Int (int f, "")
if (Printer.exprToS candidate).Length <= (Printer.exprToS e).Length then
Expand Down Expand Up @@ -290,6 +296,32 @@ type private RewriterImpl(options: Options.Options, optimizationPass: Optimizati
let args = dropLastSwizzle vecSize args
FunCall (Var constr, args)

let simplifyVecDot (vecName: string) args (field: string) e =
let vecSize = vecName.ToCharArray() |> Array.last |> string |> int
if not (
args |> Seq.forall Effects.isPure && // check that arguments can be reordered
args |> List.length = vecSize // check that the Nth swizzle index maps to the Nth arg
)
then e
else
let indexes = field.ToCharArray() |> Seq.map Builtin.swizzleIndex |> Seq.toList
match indexes |> List.length with
| 1 -> // vec3(a,b,c).y -> b
match List.tryItem indexes.Head args with
| Some arg when isKnownToHaveTypeFloat arg -> arg
| _ -> e
| 2 | 3 | 4 -> // vec3(a,b,c).yx -> vec2(b,a)
// find whether the repeated fields are repeatable exprs (don't repeat function calls or long exprs)
let repeatedIndexes = indexes |> List.countBy id |> List.filter (fun (_, count) -> count > 1) |> List.map (fun (key, _) -> key)
let isRepeatableExpr = function
| Var _ | Int _ | Float _ -> true
| _ -> false
if repeatedIndexes |> List.forall (fun index -> isRepeatableExpr args[index]) then
let constructor = Ident("vec" + string (indexes |> List.length))
FunCall(Var constructor, indexes |> List.map (fun index -> args[index]))
else e
| _ -> e

let simplifyExpr (didInline: bool ref) env = function
| FunCall(Var v, passedArgs) as e when v.ToBeInlined ->
match env.fns.TryFind (v.Name, passedArgs.Length) with
Expand Down Expand Up @@ -317,6 +349,11 @@ type private RewriterImpl(options: Options.Options, optimizationPass: Optimizati
| FunCall(Var constr, args) when constr.Name = "vec2" || constr.Name = "vec3" || constr.Name = "vec4" ->
simplifyVec constr args

| Dot(FunCall(Var constr, args), field) as e when (constr.Name = "vec2" || constr.Name = "vec3" || constr.Name = "vec4") ->
let e = simplifyVecDot constr.Name args field e
match e with
| Dot(e, field) when options.canonicalFieldNames <> "" -> Dot(e, options.renameField field)
| _ -> e
| Dot(e, field) when options.canonicalFieldNames <> "" -> Dot(e, options.renameField field)

| ResolvedVariableUse (_, vd) as e when vd.decl.name.ToBeInlined ->
Expand Down
14 changes: 7 additions & 7 deletions tests/compression_results.log
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
clod.frag (and others) 8706 => 1475.856
mouton/mouton.vert (and others) 16601 => 2399.791
clod.frag (and others) 8696 => 1476.293
mouton/mouton.vert (and others) 16589 => 2402.445
audio-flight-v2.frag 4471 => 874.321
buoy.frag 4005 => 599.099
controllable-machinery.frag 7671 => 1220.647
ed-209.frag 7677 => 1339.532
ed-209.frag 7665 => 1345.543
elevated.hlsl 3406 => 603.219
endeavour.frag 2567 => 530.460
from-the-seas-to-the-stars.frag 14173 => 2278.764
frozen-wasteland.frag 4511 => 809.183
kinder_painter.frag 2832 => 442.132
leizex.frag 2252 => 506.309
lunaquatic.frag 5227 => 1044.673
mandelbulb.frag 2322 => 532.387
mandelbulb.frag 2325 => 539.673
ohanami.frag 3179 => 703.261
orchard.frag 5384 => 1002.066
oscars_chair.frag 4648 => 986.069
robin.frag 6199 => 1039.306
slisesix.frag 4443 => 886.946
terrarium.frag 3575 => 747.677
slisesix.frag 4433 => 884.925
terrarium.frag 3571 => 744.738
the_real_party_is_in_your_pocket.frag 11974 => 1767.561
valley_ball.glsl 4307 => 881.820
yx_long_way_from_home.frag 2925 => 598.425
Total: 133055 => 23269.505
Total: 133010 => 23280.933
7 changes: 2 additions & 5 deletions tests/real/ed-209.frag.expected
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,8 @@ MarchData room(vec3 p)
float doorHole=sdBox(p,frameInner+vec3(0,0,1)),backWall=length(p.z-8.);
r.d=min(backWall,max(length(p.z),-doorHole+.1));
if(r.d==backWall)
{
float ocp=min(max(min(abs(sdOctogon(xy,2.6)),abs(sdOctogon(xy,1.9))),min(.7-abs(xy.x+1.2),-xy.y)),max(abs(sdOctogon(xy,1.2)),min(xy.x,.7-abs(xy.y))));
if(ocp<.3)
r.mat=vec3(.39,.57,.71);
}
if(min(max(min(abs(sdOctogon(xy,2.6)),abs(sdOctogon(xy,1.9))),min(.7-abs(xy.x+1.2),-xy.y)),max(abs(sdOctogon(xy,1.2)),min(xy.x,.7-abs(xy.y))))<.3)
r.mat=vec3(.39,.57,.71);
doorHole=max(sdBox(p,frameInner+vec3(.4,.4,.1)),-doorHole);
backWall=frameInner.x*.5;
p.x-=frameInner.x;
Expand Down
126 changes: 64 additions & 62 deletions tests/real/mandelbulb.expected
Original file line number Diff line number Diff line change
Expand Up @@ -5,114 +5,116 @@
const char *mandelbulb_frag =
"uniform vec2 resolution;"
"uniform float time;"
"bool f(vec4 v,vec3 f,vec3 o,out vec2 i)"
"bool f(vec3 v,vec3 f,out vec2 i)"
"{"
"vec3 y=f-v.xyz;"
"float c=dot(y,o),e=c*c-dot(y,y)+v.w*v.w;"
"vec4 t=vec4(0,0,0,1.25);"
"vec3 y=v-t.xyz;"
"float c=dot(y,f),e=c*c-dot(y,y)+t.w*t.w;"
"if(e<0.)"
"return false;"
"e=sqrt(e);"
"i.x=-c-e;"
"i.y=-c+e;"
"return true;"
"}"
"bool f(vec3 v,out float f,out vec4 y)"
"bool v(vec3 v,out float f,out vec4 y)"
"{"
"vec4 e=vec4(100);"
"vec3 i=v;"
"float o=dot(i,i);"
"if(o>1e2)"
"return f=.5*log(o)/pow(8.,0.),y=vec4(1),false;"
"float t=dot(i,i);"
"if(t>1e2)"
"return f=.5*log(t)/pow(8.,0.),y=vec4(1),false;"
"for(int x=1;x<7;x++)"
"{"
"\n#if 0\n"
"float z=sqrt(dot(i,i)),c=acos(i.y/z),s=atan(i.x,i.z);"
"float z=sqrt(dot(i,i)),o=acos(i.y/z),c=atan(i.x,i.z);"
"z=pow(z,8.);"
"o*=8.;"
"c*=8.;"
"s*=8.;"
"i=v+z*vec3(sin(c)*sin(s),cos(c),sin(c)*cos(s));"
"i=v+z*vec3(sin(o)*sin(c),cos(o),sin(o)*cos(c));"
"\n#else\n"
"float t=i.x,d=t*t,n=d*d,a=i.y,l=a*a,m=i.z,w=m*m,g=w*w,p=d+w,r=inversesqrt(p*p*p*p*p*p*p),C=n+l*l+g-6.*l*w-6.*d*l+2.*w*d,F=d-l+w;"
"i.x=v.x+64.*t*a*m*(d-w)*F*(n-6.*d*w+g)*C*r;"
"float n=i.x,s=n*n,d=s*s,a=i.y,l=a*a,m=i.z,w=m*m,g=w*w,p=s+w,r=inversesqrt(p*p*p*p*p*p*p),C=d+l*l+g-6.*l*w-6.*s*l+2.*w*s,F=s-l+w;"
"i.x=v.x+64.*n*a*m*(s-w)*F*(d-6.*s*w+g)*C*r;"
"i.y=v.y+-16.*l*p*F*F+C*C;"
"i.z=v.z+-8.*a*F*(n*n-28.*n*d*w+70.*n*g-28.*d*w*g+g*g)*C*r;"
"i.z=v.z+-8.*a*F*(d*d-28.*d*s*w+70.*d*g-28.*s*w*g+g*g)*C*r;"
"\n#endif\n"
"o=dot(i,i);"
"e=min(e,vec4(i.xyz*i.xyz,o));"
"if(o>1e2)"
"return y=e,f=.5*log(o)/pow(8.,float(x)),false;"
"t=dot(i,i);"
"e=min(e,vec4(i.xyz*i.xyz,t));"
"if(t>1e2)"
"return y=e,f=.5*log(t)/pow(8.,float(x)),false;"
"}"
"y=e;"
"f=0.;"
"return true;"
"}"
"bool f(vec3 v,vec3 o,out float y,out vec3 i,out vec4 c)"
"bool f(vec3 i,vec3 s,out float y,out vec3 t,out vec4 o)"
"{"
"vec4 e=vec4(0,0,0,1.25);"
"vec2 s;"
"if(!f(e,v,o,s))"
"float p=1.;"
"vec2 e;"
"if(!f(i,s,e))"
"return false;"
"if(s.y<.001)"
"if(e.y<.001)"
"return false;"
"if(s.x<.001)"
"s.x=.001;"
"if(s.y>1e20)"
"s.y=1e20;"
"float x;"
"vec3 z;"
"float t=1./sqrt(2.);"
"for(float d=s.x;d<s.y;)"
"if(e.x<.001)"
"e.x=.001;"
"if(e.y>1e20)"
"e.y=1e20;"
"float c;"
"vec3 x;"
"vec4 n;"
"p=1./sqrt(1.+p*p);"
"for(float f=e.x;f<e.y;)"
"{"
"vec3 m=v+o*d;"
"float n=clamp(.001*d*t,1e-6,.005),p=n*.1;"
"vec3 z=i+s*f;"
"float d=clamp(.001*f*p,1e-6,.005),w=d*.1;"
"vec4 r;"
"float w;"
"if(f(m,w,e))"
"return y=d,i=normalize(z),c=e,true;"
"float l;"
"if(v(z,l,n))"
"return y=f,t=normalize(x),o=n,true;"
"float m;"
"v(z+vec3(w,0,0),m,r);"
"float g;"
"f(m+vec3(p,0,0),g,r);"
"v(z+vec3(0,w,0),g,r);"
"float C;"
"f(m+vec3(0,p,0),C,r);"
"float a;"
"f(m+vec3(0,0,p),a,r);"
"z=vec3(g-w,C-w,a-w);"
"x=.5*w*p/length(z);"
"if(x<n)"
"return c=e,i=normalize(z),y=d,true;"
"d+=x;"
"v(z+vec3(0,0,w),C,r);"
"x=vec3(m-l,g-l,C-l);"
"c=.5*l*w/length(x);"
"if(c<d)"
"return o=n,t=normalize(x),y=f,true;"
"f+=c;"
"}"
"return false;"
"}"
"void main()"
"{"
"vec2 v=-1.+2.*gl_FragCoord.xy/resolution.xy,i=v*vec2(1.33,1);"
"vec3 y=vec3(.577),o=vec3(-.707,0,.707);"
"float e=1.4+.2*cos(6.28318*time/20.);"
"vec3 c=vec3(e*sin(6.28318*time/20.),.3-.4*sin(6.28318*time/20.),e*cos(6.28318*time/20.)),s=normalize(vec3(0,.1,0)-c),d=normalize(cross(s,vec3(0,1,0)));"
"d=normalize(i.x*d+i.y*normalize(cross(d,s))+1.5*s);"
"vec3 z;"
"vec4 p;"
"if(f(c,d,e,s,p))"
"vec3 y=vec3(.577);"
"float t=1.4+.2*cos(6.28318*time/20.);"
"vec3 c=vec3(t*sin(6.28318*time/20.),.3-.4*sin(6.28318*time/20.),t*cos(6.28318*time/20.)),e=normalize(vec3(0,.1,0)-c),s=normalize(cross(e,vec3(0,1,0)));"
"s=normalize(i.x*s+i.y*normalize(cross(s,e))+1.5*e);"
"vec3 x;"
"vec4 o;"
"if(f(c,s,t,e,o))"
"{"
"vec3 i=c+e*d;"
"float v=clamp(.2+.8*dot(y,s),0.,1.);"
"vec3 i=c+t*s;"
"float v=clamp(.2+.8*dot(y,e),0.,1.);"
"v*=v;"
"float t=clamp(.3+.7*dot(o,s),0.,1.),w=clamp(1.25*p.w-.4,0.,1.);"
"float p=clamp(.3+.7*dot(vec3(-.707,0,.707),e),0.,1.),w=clamp(1.25*o.w-.4,0.,1.);"
"w=w*w*.5+.5*w;"
"float r;"
"vec3 C;"
"vec4 x;"
"vec3 n;"
"vec4 g;"
"if(v>.001)"
"if(f(i,y,r,C,x))"
"if(f(i,y,r,n,g))"
"v=.1;"
"z=mix(mix(mix(vec3(1),vec3(.8,.6,.2),sqrt(p.x)*1.25),vec3(.8,.3,.3),sqrt(p.y)*1.25),vec3(.7,.4,.3),sqrt(p.z)*1.25)*((.5+.5*s.y)*vec3(.14,.15,.16)*.8+v*vec3(1,.85,.4)+.5*t*vec3(.08,.1,.14))*vec3(pow(w,.8),w,pow(w,1.1));"
"z=1.5*(z*.15+.85*sqrt(z));"
"x=mix(mix(mix(vec3(1),vec3(.8,.6,.2),sqrt(o.x)*1.25),vec3(.8,.3,.3),sqrt(o.y)*1.25),vec3(.7,.4,.3),sqrt(o.z)*1.25)*((.5+.5*e.y)*vec3(.14,.15,.16)*.8+v*vec3(1,.85,.4)+.5*p*vec3(.08,.1,.14))*vec3(pow(w,.8),w,pow(w,1.1));"
"x=1.5*(x*.15+.85*sqrt(x));"
"}"
"else"
" z=1.3*vec3(1,.98,.9)*(.7+.3*d.y);"
" x=1.3*vec3(1,.98,.9)*(.7+.3*s.y);"
"v=v*.5+.5;"
"z=clamp(z*(.7+4.8*v.x*v.y*(1.-v.x)*(1.-v.y)),0.,1.);"
"gl_FragColor=vec4(z,1);"
"x=clamp(x*(.7+4.8*v.x*v.y*(1.-v.x)*(1.-v.y)),0.,1.);"
"gl_FragColor=vec4(x,1);"
"}";

#endif // MANDELBULB_EXPECTED_
Loading

0 comments on commit 0e40701

Please sign in to comment.