summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xcpu_dsl.py153
1 files changed, 96 insertions, 57 deletions
diff --git a/cpu_dsl.py b/cpu_dsl.py
index 1f43203..6401827 100755
--- a/cpu_dsl.py
+++ b/cpu_dsl.py
@@ -20,6 +20,14 @@ class Block:
self.addOp(NormalOp(parts))
return self
+ def processOps(self, prog, fieldVals, output, otype, oplist):
+ for i in range(0, len(oplist)):
+ if i + 1 < len(oplist) and oplist[i+1].op == 'update_flags':
+ flagUpdates, _ = prog.flags.parseFlagUpdate(oplist[i+1].params[0])
+ else:
+ flagUpdates = None
+ oplist[i].generate(prog, self, fieldVals, output, otype, flagUpdates)
+
def resolveLocal(self, name):
return None
@@ -121,8 +129,7 @@ class Instruction(Block):
output.append('\n\tuint{sz}_t {name};'.format(sz=self.locals[var], name=var))
self.newLocals = []
fieldVals,_ = self.getFieldVals(value)
- for op in self.implementation:
- op.generate(prog, self, fieldVals, output, otype)
+ self.processOps(prog, fieldVals, output, otype, self.implementation)
begin = '\nvoid ' + self.generateName(value) + '(' + prog.context_type + ' *context)\n{'
if prog.needFlagCoalesce:
begin += prog.flags.coalesceFlags(prog, otype)
@@ -189,8 +196,7 @@ class SubRoutine(Block):
for name in self.locals:
size = self.locals[name]
output.append('\n\tuint{size}_t {sub}_{local};'.format(size=size, sub=self.name, local=name))
- for op in self.implementation:
- op.generate(prog, self, argValues, output, otype)
+ self.processOps(prog, argValues, output, otype, self.implementation)
prog.popScope()
def __str__(self):
@@ -209,15 +215,36 @@ class Op:
self.impls = {}
self.outOp = ()
def cBinaryOperator(self, op):
- def _impl(prog, params):
+ def _impl(prog, params, rawParams, flagUpdates):
if op == '-':
a = params[1]
b = params[0]
else:
a = params[0]
b = params[1]
- return '\n\t{dst} = {a} {op} {b};'.format(
- dst = params[2], a = a, b = b, op = op
+ needsCarry = needsOflow = needsHalf = False
+ if flagUpdates:
+ for flag in flagUpdates:
+ calc = prog.flags.flagCalc[flag]
+ if calc == 'carry':
+ needsCarry = True
+ elif calc == 'half-carry':
+ needsHalf = True
+ elif calc == 'overflow':
+ needsOflow = True
+ decl = ''
+ if needsCarry or needsOflow or needsHalf:
+ size = prog.paramSize(rawParams[2])
+ if needsCarry:
+ size *= 2
+ decl,name = prog.getTemp(size)
+ dst = prog.carryFlowDst = name
+ prog.lastA = a
+ prog.lastB = b
+ else:
+ dst = params[2]
+ return decl + '\n\t{dst} = {a} {op} {b};'.format(
+ dst = dst, a = a, b = b, op = op
)
self.impls['c'] = _impl
self.outOp = (2,)
@@ -244,11 +271,13 @@ class Op:
return not self.evalFun is None
def numArgs(self):
return self.evalFun.__code__.co_argcount
- def generate(self, otype, prog, params, rawParams):
+ def generate(self, otype, prog, params, rawParams, flagUpdates):
if self.impls[otype].__code__.co_argcount == 2:
return self.impls[otype](prog, params)
- else:
+ elif self.impls[otype].__code__.co_argcount == 3:
return self.impls[otype](prog, params, rawParams)
+ else:
+ return self.impls[otype](prog, params, rawParams, flagUpdates)
def _xchgCImpl(prog, params, rawParams):
@@ -264,36 +293,26 @@ def _dispatchCImpl(prog, params):
return '\n\timpl_{tbl}[{op}](context);'.format(tbl = table, op = params[0])
def _updateFlagsCImpl(prog, params, rawParams):
- i = 0
- last = ''
- autoUpdate = set()
- explicit = {}
- for c in params[0]:
- if c.isdigit():
- if last.isalpha():
- num = int(c)
- if num > 1:
- raise Exception(c + ' is not a valid digit for update_flags')
- explicit[last] = num
- last = c
- else:
- raise Exception('Digit must follow flag letter in update_flags')
- else:
- if last.isalpha():
- autoUpdate.add(last)
- last = c
- if last.isalpha():
- autoUpdate.add(last)
+ autoUpdate, explicit = prog.flags.parseFlagUpdate(params[0])
output = []
#TODO: handle autoUpdate flags
for flag in autoUpdate:
calc = prog.flags.flagCalc[flag]
calc,_,resultBit = calc.partition('-')
- lastDst = prog.resolveParam(prog.lastDst, None, {})
+ if prog.carryFlowDst:
+ lastDst = prog.carryFlowDst
+ else:
+ lastDst = prog.resolveParam(prog.lastDst, None, {})
storage = prog.flags.getStorage(flag)
- if calc == 'bit' or calc == 'sign':
+ if calc == 'bit' or calc == 'sign' or calc == 'carry' or calc == 'half':
+ myRes = lastDst
if calc == 'sign':
resultBit = prog.paramSize(prog.lastDst) - 1
+ elif calc == 'carry':
+ resultBit = prog.paramSize(prog.lastDst)
+ elif calc == 'half':
+ resultBit = 4
+ myRes = '({a} ^ {b} ^ {res})'.format(a = prog.lastA, b = prog.lastB, res = lastDst)
else:
resultBit = int(resultBit)
if type(storage) is tuple:
@@ -302,7 +321,7 @@ def _updateFlagsCImpl(prog, params, rawParams):
if storageBit == resultBit:
#TODO: optimize this case
output.append('\n\t{reg} = ({reg} & ~{mask}U) | ({res} & {mask}U);'.format(
- reg = reg, mask = 1 << resultBit, res = lastDst
+ reg = reg, mask = 1 << resultBit, res = myRes
))
else:
if resultBit > storageBit:
@@ -312,11 +331,11 @@ def _updateFlagsCImpl(prog, params, rawParams):
op = '<<'
shift = storageBit - resultBit
output.append('\n\t{reg} = ({reg} & ~{mask}U) | ({res} {op} {shift}U & {mask}U);'.format(
- reg = reg, mask = 1 << storageBit, res = lastDst, op = op, shift = shift
+ reg = reg, mask = 1 << storageBit, res = myRes, op = op, shift = shift
))
else:
reg = prog.resolveParam(storage, None, {})
- output.append('\n\t{reg} = {res} & {mask}U;'.format(reg=reg, res=lastDst, mask = 1 << resultBit))
+ output.append('\n\t{reg} = {res} & {mask}U;'.format(reg=reg, res=myRes, mask = 1 << resultBit))
elif calc == 'zero':
if type(storage) is tuple:
reg,storageBit = storage
@@ -328,15 +347,16 @@ def _updateFlagsCImpl(prog, params, rawParams):
reg = prog.resolveParam(storage, None, {})
output.append('\n\t{reg} = {res} == 0;'.format(
reg = reg, res = lastDst
- ))
- elif calc == 'half-carry':
- pass
- elif calc == 'carry':
- pass
+ ))
elif calc == 'overflow':
pass
elif calc == 'parity':
pass
+ else:
+ raise Exception('Unknown flag calc type: ' + calc)
+ if prog.carryFlowDst:
+ output.append('\n\t{dst} = {tmpdst};'.format(dst = prog.resolveParam(prog.lastDst, None, {}), tmpdst = prog.carryFlowDst))
+ prog.carryFlowDst = None
#TODO: combine explicit flags targeting the same storage location
for flag in explicit:
location = prog.flags.getStorage(flag)
@@ -458,9 +478,9 @@ class NormalOp:
self.op = parts[0]
self.params = parts[1:]
- def generate(self, prog, parent, fieldVals, output, otype):
+ def generate(self, prog, parent, fieldVals, output, otype, flagUpdates):
procParams = []
- allParamsConst = True
+ allParamsConst = flagUpdates is None
opDef = _opMap.get(self.op)
for param in self.params:
allowConst = (self.op in prog.subroutines or len(procParams) != len(self.params) - 1) and param in parent.regValues
@@ -502,9 +522,9 @@ class NormalOp:
if prog.isReg(dst):
shortProc = (procParams[0], procParams[-1])
shortParams = (self.params[0], self.params[-1])
- output.append(_opMap['mov'].generate(otype, prog, shortProc, shortParams))
+ output.append(_opMap['mov'].generate(otype, prog, shortProc, shortParams, None))
else:
- output.append(opDef.generate(otype, prog, procParams, self.params))
+ output.append(opDef.generate(otype, prog, procParams, self.params, flagUpdates))
elif self.op in prog.subroutines:
prog.subroutines[self.op].inline(prog, procParams, output, otype, parent)
else:
@@ -558,7 +578,7 @@ class Switch(ChildBlock):
return self.current_locals[name]
return self.parent.localSize(name)
- def generate(self, prog, parent, fieldVals, output, otype):
+ def generate(self, prog, parent, fieldVals, output, otype, flagUpdates):
prog.pushScope(self)
param = prog.resolveParam(self.param, parent, fieldVals)
if type(param) is int:
@@ -568,16 +588,14 @@ class Switch(ChildBlock):
output.append('\n\t{')
for local in self.case_locals[param]:
output.append('\n\tuint{0}_t {1};'.format(self.case_locals[param][local], local))
- for op in self.cases[param]:
- op.generate(prog, self, fieldVals, output, otype)
+ self.processOps(prog, fieldVals, output, otype, self.cases[param])
output.append('\n\t}')
elif self.default:
self.current_locals = self.default_locals
output.append('\n\t{')
for local in self.default_locals:
output.append('\n\tuint{0}_t {1};'.format(self.default[local], local))
- for op in self.default:
- op.generate(prog, self, fieldVals, output, otype)
+ self.processOps(prog, fieldVals, output, otype, self.default)
output.append('\n\t}')
else:
output.append('\n\tswitch(' + param + ')')
@@ -588,8 +606,7 @@ class Switch(ChildBlock):
output.append('\n\tcase {0}U: '.format(case) + '{')
for local in self.case_locals[case]:
output.append('\n\tuint{0}_t {1};'.format(self.case_locals[case][local], local))
- for op in self.cases[case]:
- op.generate(prog, self, fieldVals, output, otype)
+ self.processOps(prog, fieldVals, output, otype, self.cases[case])
output.append('\n\tbreak;')
output.append('\n\t}')
if self.default:
@@ -598,8 +615,7 @@ class Switch(ChildBlock):
output.append('\n\tdefault: {')
for local in self.default_locals:
output.append('\n\tuint{0}_t {1};'.format(self.default_locals[local], local))
- for op in self.default:
- op.generate(prog, self, fieldVals, output, otype)
+ self.processOps(prog, fieldVals, output, otype, self.default)
output.append('\n\t}')
prog.popScope()
@@ -666,15 +682,13 @@ class If(ChildBlock):
self.curLocals = self.locals
for local in self.locals:
output.append('\n\tuint{sz}_t {nm};'.format(sz=self.locals[local], nm=local))
- for op in self.body:
- op.generate(prog, self, fieldVals, output, otype)
+ self.processOps(prog, fieldVals, output, otype, self.body)
def _genFalseBody(self, prog, fieldVals, output, otype):
self.curLocals = self.elseLocals
for local in self.elseLocals:
output.append('\n\tuint{sz}_t {nm};'.format(sz=self.elseLocals[local], nm=local))
- for op in self.elseBody:
- op.generate(prog, self, fieldVals, output, otype)
+ self.processOps(prog, fieldVals, output, otype, self.elsebody)
def _genConstParam(self, param, prog, fieldVals, output, otype):
if param:
@@ -682,7 +696,7 @@ class If(ChildBlock):
else:
self._genFalseBody(prog, fieldVals, output, otype)
- def generate(self, prog, parent, fieldVals, output, otype):
+ def generate(self, prog, parent, fieldVals, output, otype, flagUpdates):
self.regValues = parent.regValues
try:
self._genConstParam(prog.checkBool(self.cond), prog, fieldVals, output, otype)
@@ -829,6 +843,28 @@ class Flags:
else:
return loc
+ def parseFlagUpdate(self, flagString):
+ last = ''
+ autoUpdate = set()
+ explicit = {}
+ for c in flagString:
+ if c.isdigit():
+ if last.isalpha():
+ num = int(c)
+ if num > 1:
+ raise Exception(c + ' is not a valid digit for update_flags')
+ explicit[last] = num
+ last = c
+ else:
+ raise Exception('Digit must follow flag letter in update_flags')
+ else:
+ if last.isalpha():
+ autoUpdate.add(last)
+ last = c
+ if last.isalpha():
+ autoUpdate.add(last)
+ return (autoUpdate, explicit)
+
def disperseFlags(self, prog, otype):
bitToFlag = [None] * (self.maxBit+1)
src = prog.resolveReg(self.flagReg, None, {})
@@ -949,6 +985,9 @@ class Program:
self.scopes = []
self.currentScope = None
self.lastOp = None
+ self.carryFlowDst = None
+ self.lastA = None
+ self.lastB = None
def __str__(self):
pieces = []