Skip to content

Commit

Permalink
refactor: sccp lattice manipulation
Browse files Browse the repository at this point in the history
* changed the index type of lattice to `IRVariable` (it should only
  have variables)
* add getter/setter methods that also check types to cache future
  regretions
* add `_eval_lattice_with_op()` to evaluate the lattice with an operant
* revert `_eval()` changes
* unify the handling of `store`
  • Loading branch information
harkal committed May 8, 2024
1 parent d4a0818 commit e39076a
Showing 1 changed file with 31 additions and 20 deletions.
51 changes: 31 additions & 20 deletions vyper/venom/passes/sccp/sccp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class FlowWorkItem:

WorkListItem = Union[FlowWorkItem, SSAWorkListItem]
LatticeItem = Union[LatticeEnum, IRLiteral]
Lattice = dict[IROperand, LatticeItem]
Lattice = dict[IRVariable, LatticeItem]


class SCCP(IRPass):
Expand Down Expand Up @@ -143,15 +143,21 @@ def _handle_SSA_work_item(self, work_item: SSAWorkListItem):
elif len(self.cfg_in_exec[work_item.inst.parent]) > 0:
self._visit_expr(work_item.inst)

def _from_lattice(self, op: IROperand):
def _get_lattice(self, op: IROperand) -> LatticeItem:
assert isinstance(op, IRVariable), "Can't get lattice for non-variable"
lat = self.lattice[op]
assert lat is not None, f"Got undefined var {op}"
return lat

def _set_lattice(self, op: IROperand, value: LatticeItem):
assert isinstance(op, IRVariable), "Can't set lattice for non-variable"
self.lattice[op] = value

def _eval_lattice_with_op(self, op: IROperand) -> IRLiteral | LatticeEnum:
if isinstance(op, IRLiteral):
return op

if isinstance(op, IRLabel):
return LatticeEnum.BOTTOM

assert isinstance(op, IRVariable) # IRLabel would be an error
return self.lattice[op]
return self._get_lattice(op)

def _visit_phi(self, inst: IRInstruction):
assert inst.opcode == "phi", "Can't visit non phi instruction"
Expand All @@ -160,27 +166,25 @@ def _visit_phi(self, inst: IRInstruction):
bb = self.fn.get_basic_block(bb_label.name)
if bb not in self.cfg_in_exec[inst.parent]:
continue
in_vars.append(self.lattice[var])
in_vars.append(self._get_lattice(var))
value = reduce(_meet, in_vars, LatticeEnum.TOP) # type: ignore
assert inst.output in self.lattice, "Got undefined var for phi"

if value != self.lattice[inst.output]:
self.lattice[inst.output] = value
if value != self._get_lattice(inst.output):
self._set_lattice(inst.output, value)
self._add_ssa_work_items(inst)

def _visit_expr(self, inst: IRInstruction):
opcode = inst.opcode
if opcode in ["store", "alloca"]:
if isinstance(inst.operands[0], IRLiteral):
self.lattice[inst.output] = inst.operands[0] # type: ignore
else:
self.lattice[inst.output] = self.lattice[inst.operands[0]] # type: ignore
assert inst.output is not None, "Got store/alloca without output"
self._set_lattice(inst.output, self._eval_lattice_with_op(inst.operands[0]))
self._add_ssa_work_items(inst)
elif opcode == "jmp":
target = self.fn.get_basic_block(inst.operands[0].value)
self.work_list.append(FlowWorkItem(inst.parent, target))
elif opcode == "jnz":
lat = self._from_lattice(inst.operands[0])
lat = self._eval_lattice_with_op(inst.operands[0])

assert lat != LatticeEnum.TOP, f"Got undefined var at jmp at {inst.parent}"
if lat == LatticeEnum.BOTTOM:
Expand All @@ -194,7 +198,7 @@ def _visit_expr(self, inst: IRInstruction):
target = self.fn.get_basic_block(inst.operands[2].name)
self.work_list.append(FlowWorkItem(inst.parent, target))
elif opcode == "djmp":
lat = self._from_lattice(inst.operands[0])
lat = self._eval_lattice_with_op(inst.operands[0])
assert lat != LatticeEnum.TOP, f"Got undefined var at jmp at {inst.parent}"
if lat == LatticeEnum.BOTTOM:
for op in inst.operands[1:]:
Expand All @@ -212,7 +216,7 @@ def _visit_expr(self, inst: IRInstruction):
self._eval(inst)
else:
if inst.output is not None:
self.lattice[inst.output] = LatticeEnum.BOTTOM
self._set_lattice(inst.output, LatticeEnum.BOTTOM)

def _eval(self, inst) -> LatticeItem:
"""
Expand All @@ -223,7 +227,14 @@ def _eval(self, inst) -> LatticeItem:
"""
opcode = inst.opcode

ops = [self._from_lattice(op) for op in inst.operands]
ops = []
for op in inst.operands:
if isinstance(op, IRVariable):
ops.append(self.lattice[op])
elif isinstance(op, IRLabel):
return LatticeEnum.BOTTOM
else:
ops.append(op)

ret = None
if LatticeEnum.BOTTOM in ops:
Expand Down Expand Up @@ -281,7 +292,7 @@ def _replace_constants(self, inst: IRInstruction):
case of jumps and asserts as needed.
"""
if inst.opcode == "jnz":
lat = self._from_lattice(inst.operands[0])
lat = self._eval_lattice_with_op(inst.operands[0])

if isinstance(lat, IRLiteral):
if lat.value == 0:
Expand All @@ -293,7 +304,7 @@ def _replace_constants(self, inst: IRInstruction):
self.cfg_dirty = True

elif inst.opcode in ("assert", "assert_unreachable"):
lat = self._from_lattice(inst.operands[0])
lat = self._eval_lattice_with_op(inst.operands[0])

if isinstance(lat, IRLiteral):
if lat.value > 0:
Expand Down

0 comments on commit e39076a

Please sign in to comment.