summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xcpu_dsl.py36
1 files changed, 31 insertions, 5 deletions
diff --git a/cpu_dsl.py b/cpu_dsl.py
index d131a70..859521b 100755
--- a/cpu_dsl.py
+++ b/cpu_dsl.py
@@ -698,7 +698,7 @@ class NormalOp:
def generate(self, prog, parent, fieldVals, output, otype, flagUpdates):
procParams = []
- allParamsConst = flagUpdates is None
+ allParamsConst = flagUpdates is None and not prog.conditional
opDef = _opMap.get(self.op)
for param in self.params:
allowConst = (self.op in prog.subroutines or len(procParams) != len(self.params) - 1) and param in parent.regValues
@@ -835,9 +835,12 @@ class Switch(ChildBlock):
self.processOps(prog, fieldVals, output, otype, self.default)
output.append('\n\t}')
else:
+ oldCond = prog.conditional
+ prog.conditional = True
output.append('\n\tswitch(' + param + ')')
output.append('\n\t{')
for case in self.cases:
+ temp = prog.temp.copy()
self.current_locals = self.case_locals[case]
self.regValues = dict(self.parent.regValues)
output.append('\n\tcase {0}U: '.format(case) + '{')
@@ -846,14 +849,18 @@ class Switch(ChildBlock):
self.processOps(prog, fieldVals, output, otype, self.cases[case])
output.append('\n\tbreak;')
output.append('\n\t}')
+ prog.temp = temp
if self.default:
+ temp = prog.temp.copy()
self.current_locals = self.default_locals
self.regValues = dict(self.parent.regValues)
output.append('\n\tdefault: {')
for local in self.default_locals:
output.append('\n\tuint{0}_t {1};'.format(self.default_locals[local], local))
self.processOps(prog, fieldVals, output, otype, self.default)
+ prog.temp = temp
output.append('\n\t}')
+ prog.conditional = oldCond
prog.popScope()
def __str__(self):
@@ -908,7 +915,7 @@ class If(ChildBlock):
if op.op == 'local':
name = op.params[0]
size = op.params[1]
- self.locals[name] = size
+ self.curLocals[name] = size
elif op.op == 'else':
self.curLocals = self.elseLocals
self.curBody = self.elseBody
@@ -919,21 +926,25 @@ class If(ChildBlock):
return self.curLocals.get(name)
def resolveLocal(self, name):
- if name in self.locals:
+ if name in self.curLocals:
return name
return self.parent.resolveLocal(name)
def _genTrueBody(self, prog, fieldVals, output, otype):
self.curLocals = self.locals
+ subOut = []
+ self.processOps(prog, fieldVals, subOut, otype, self.body)
for local in self.locals:
output.append('\n\tuint{sz}_t {nm};'.format(sz=self.locals[local], nm=local))
- self.processOps(prog, fieldVals, output, otype, self.body)
+ output += subOut
def _genFalseBody(self, prog, fieldVals, output, otype):
self.curLocals = self.elseLocals
+ subOut = []
+ self.processOps(prog, fieldVals, subOut, otype, self.elseBody)
for local in self.elseLocals:
output.append('\n\tuint{sz}_t {nm};'.format(sz=self.elseLocals[local], nm=local))
- self.processOps(prog, fieldVals, output, otype, self.elsebody)
+ output += subOut
def _genConstParam(self, param, prog, fieldVals, output, otype):
if param:
@@ -947,23 +958,37 @@ class If(ChildBlock):
self._genConstParam(prog.checkBool(self.cond), prog, fieldVals, output, otype)
except Exception:
if self.cond in _ifCmpImpl[otype]:
+ oldCond = prog.conditional
+ prog.conditional = True
+ temp = prog.temp.copy()
output.append(_ifCmpImpl[otype][self.cond](prog, parent, fieldVals, output))
self._genTrueBody(prog, fieldVals, output, otype)
+ prog.temp = temp
if self.elseBody:
+ temp = prog.temp.copy()
output.append('\n\t} else {')
self._genFalseBody(prog, fieldVals, output, otype)
+ prog.temp = temp
output.append('\n\t}')
+ prog.conditional = oldCond
else:
cond = prog.resolveParam(self.cond, parent, fieldVals)
if type(cond) is int:
self._genConstParam(cond, prog, fieldVals, output, otype)
else:
+ temp = prog.temp.copy()
output.append('\n\tif ({cond}) '.format(cond=cond) + '{')
+ oldCond = prog.conditional
+ prog.conditional = True
self._genTrueBody(prog, fieldVals, output, otype)
+ prog.temp = temp
if self.elseBody:
+ temp = prog.temp.copy()
output.append('\n\t} else {')
self._genFalseBody(prog, fieldVals, output, otype)
+ prog.temp = temp
output.append('\n\t}')
+ prog.conditional = oldCond
def __str__(self):
@@ -1241,6 +1266,7 @@ class Program:
self.lastA = None
self.lastB = None
self.lastBFlow = None
+ self.conditional = False
def __str__(self):
pieces = []