CAP/MiniC/TPoptim/OptimSSA.py

428 lines
16 KiB
Python
Raw Permalink Normal View History

2024-10-06 19:58:11 +02:00
"""
CAP, SSA Intro, Elimination and Optimisations
Optimisations on SSA.
"""
from enum import Enum
from typing import List, Dict, Tuple, cast
from Lib.Errors import MiniCInternalError
from Lib.Operands import (Operand, Temporary, Immediate, A, ZERO)
from Lib.Statement import (Statement, Instruction, Label, AbsoluteJump)
from Lib.CFG import (BlockInstr, Terminator, Block, CFG)
from Lib.Terminator import (Return, BranchingTerminator)
from Lib.PhiNode import PhiNode
from Lib import RiscV
def div_rd_0(a: int, b: int) -> int:
"""Division rounded towards 0 (integer division in Python rounds down)."""
return -(-a // b) if (a < 0) ^ (b < 0) else a // b
def mod_rd_0(a: int, b: int) -> int:
"""Modulo rounded towards 0 (integer division in Python rounds down)."""
return -(-a % b) if (a < 0) ^ (b < 0) else a % b
class Lattice(Enum):
Bottom = 0
Top = 1
LATTICE_VALUE = int | Lattice # Type for our values: Bottom < int < Top
def join(v1: LATTICE_VALUE, v2: LATTICE_VALUE) -> LATTICE_VALUE:
"""Compute the join of the two lattice values."""
match v1, v2:
case Lattice.Top, _:
return Lattice.Top
case _, Lattice.Top:
return Lattice.Top
case Lattice.Bottom, _:
return v2
case _, Lattice.Bottom:
return v1
case _, _: # both int
if v1 == v2:
return v1
else:
return Lattice.Top
def joinl(values: List[LATTICE_VALUE]) -> LATTICE_VALUE:
"""Compute the join of the list of lattice values."""
res = Lattice.Bottom
for v in values:
res = join(res, v)
return res
class CondConstantPropagation:
"""
Class that optimises a CFG under SSA form
following the algorithm "Sparse Conditionnal Constant Propagation".
"""
cfg: CFG
# CFG under SSA form to optimise
valueness: Dict[Operand, LATTICE_VALUE]
# Values of each variable v:
# valueness[v] = Lattice.Bottom if no evidence that v is assigned
# valueness[v] = n if we found evidence that only n is assigned to v
# valueness[v] = Lattice.Top if we found evidence that v is assigned
# to at least two different values
executability: Dict[Tuple[Block | None, Block], bool]
# Exectuability of an edge (B, C):
# executability[B, C] = False if no evidence that the edge (B, C) can ever be executed
# executability[B, C] = True if (B, C) may be executed (over-approximation)
# There is an initial edge from None to the start block
modified_flag: bool
# Flag to check if we reach the fixpoint
debug: bool
# Print valueness and executability at each step if True
all_vars: List[Operand]
# All the variables of the CFG
all_blocks: List[Block]
# All the blocks of the CFG
def __init__(self, cfg: CFG, debug: bool):
self.cfg = cfg
self.valueness = dict()
self.executability = dict()
self.debug = debug
self.all_vars = list(cfg.gather_defs().keys())
self.all_blocks = cfg.get_blocks()
# Initialisation of valueness and executability
for var in self.all_vars:
self.valueness[var] = Lattice.Bottom
for block in self.all_blocks:
for succ in cfg.out_blocks(block):
self.executability[block, succ] = False
# Add an initial edge from None to the start block
start_blk = self.cfg.get_block(self.cfg.get_start())
self.executability[None, start_blk] = False
def dump(self) -> None: # pragma: no cover
"""
For debug purposes: print valueness and executability.
"""
print("Valueness:")
for x, v in self.valueness.items():
print("{0}: {1}".format(x, v))
print("Executability:")
for (B, C), v in self.executability.items():
print("{0} -> {1}: {2}".format(B.get_label() if B is not None else "",
C.get_label(), v))
def set_valueness(self, v: Operand, x: LATTICE_VALUE) -> None:
"""
Update the valueness of a variable `v` by performing a join with
its current value.
"""
old_x = self.valueness[v]
new_x = join(x, old_x)
if new_x != old_x:
self.modified_flag = True
self.valueness[v] = new_x
def set_executability(self, B: Block | None, C: Block) -> None:
"""
Mark the edge from `B` to `C` as executable.
"""
old_x = self.executability[B, C]
if not old_x:
self.modified_flag = True
self.executability[B, C] = True
def is_constant(self, op: Operand) -> bool:
"""True if the value of `op` is constant."""
return isinstance(self.valueness.get(op, None), int)
def is_executable(self, B: Block) -> bool:
"""True if the block `B` may be executed."""
return B in (C for ((_, C), b) in self.executability.items() if b)
def compute(self) -> None:
"""
Compute executability for all edges and valueness for all variables
using a fixpoint algorithm.
"""
# 1. For any v coming from outside the CFG (parameters, function calls),
# set valueness[v] = Top. These are exactly the registers of A.
for var in A:
self.valueness[var] = Lattice.Top
# 2. The start block is executable, with an initial edge coming from None.
start_blk = self.cfg.get_block(self.cfg.get_start())
self.set_executability(None, start_blk)
# Start the fixpoint.
self.modified_flag = True
while self.modified_flag:
# Whenever executability or valueness is modified,
# modified_flag is set to True (see set_executability and set_valueness)
# so that the fixpoint continues.
self.modified_flag = False
if self.debug:
self.dump()
# 3. For any executable block B with only one successor C,
# set executability[B, C] = True.
for B in self.all_blocks:
nexts = self.cfg.out_blocks(B)
if self.is_executable(B) and len(nexts) == 1:
C = nexts[0]
self.set_executability(B, C)
for B in self.all_blocks:
if self.is_executable(B):
for stat in B.get_all_statements():
self.propagate_in(B, stat)
def propagate_in(self, B: Block, stat: Statement) -> None:
"""
Propagate valueness and executability to the given statement `stat`
located in the given executable block `B`.
See the `compute` function for more context.
"""
# 4. For any executable assignment v <- op (x, y),
# set valueness[v] = eval (op, x, y)
# TODO (Exercise 4)
# 5. For any executable assignment v <- phi (x1, ..., xn),
# set valueness[v] = join(x1, .., xn)
# TODO (Exercise 6)
# 6. For any executable conditional branch to blocks B1 and B2,
# set executability[B1] = True and/or executability[B2] = True
# depending on the valueness of its condition
# TODO (Exercise 6)
def get_executable_srcs(self, B: Block, phi: PhiNode) -> List[Operand]:
"""
Given a phi node `phi` belonging to the block `B`,
return its operands coming from an executable edge.
"""
return [x for lbl, x in phi.get_srcs().items()
if self.executability[self.cfg.get_block(lbl), B]]
def get_operands(self, ins: Instruction) -> List[LATTICE_VALUE]:
"""
Returns the valueness of the operands of the given instruction `ins`.
Also takes into account immediate values and the zero register.
"""
args: List[LATTICE_VALUE] = []
for x in ins.used():
if isinstance(x, Temporary):
args.append(self.valueness[x])
elif isinstance(x, Immediate):
args.append(x._val)
elif (x == ZERO):
args.append(0)
elif isinstance(x, Label):
continue
else:
args.append(Lattice.Top)
return args
def eval_arith_instr(self, ins: Instruction) -> LATTICE_VALUE:
"""
Computes the result of an arithmetic instruction in the valueness lattice,
from the valueness of its operands.
"""
args = self.get_operands(ins)
name: str = ins.ins
if Lattice.Top in args:
return Lattice.Top
elif Lattice.Bottom in args:
return Lattice.Bottom
args = cast(List[int], args)
if name == "add" or name == "addi":
return args[0] + args[1]
elif name == "mul":
return args[0] * args[1]
elif name == "div":
return div_rd_0(args[0], args[1])
elif name == "rem":
return mod_rd_0(args[0], args[1])
elif name == "sub":
return args[0] - args[1]
elif name == "and":
return args[0] & args[1]
elif name == "or":
return args[0] | args[1]
elif name == "xor" or name == "xori":
return args[0] ^ args[1]
elif name == "li":
assert (isinstance(ins.used()[0], Immediate))
return args[0]
elif name == "mv":
return args[0]
raise MiniCInternalError("Instruction modifying a temporary not in\
['add', 'addi', 'mul', 'div', 'rem',\
'sub', 'and', 'or', 'xor', 'xori', 'li', 'mv']")
def eval_bool_instr(self, ins: BranchingTerminator) -> LATTICE_VALUE:
"""
Computes the result of the comparison of a branching instruction
in the valueness lattice, from the valueness of its operands.
"""
args = self.get_operands(ins)
name: str = ins.ins
if Lattice.Top in args:
return Lattice.Top
elif Lattice.Bottom in args:
return Lattice.Bottom
args = cast(List[int], args)
if name == "blt":
return args[0] < args[1]
elif name == "bgt":
return args[0] > args[1]
elif name == "beq":
return args[0] == args[1]
elif name == "bne":
return args[0] != args[1]
elif name == "ble":
return args[0] <= args[1]
elif name == "bge":
return args[0] >= args[1]
elif name == "beqz":
return args[0] == 0
elif name == "bnez":
return args[0] != 0
raise MiniCInternalError("Condition of a CondJump not in ['blt',\
'bgt', 'beq', 'bne', 'ble', 'bge',\
'beqz', 'bnez']")
def replacePhi(self, B: Block, ins: PhiNode) -> PhiNode:
"""
Replace a phi node that has constant operands
according to the valueness computation.
"""
to_remove: List[Label] = [] # List of block's labels with no executable edge to B
for Bi_label, xi in ins.get_srcs().items():
Bi = self.cfg.get_block(Bi_label)
if self.executability[Bi, B]:
if self.is_constant(xi):
# Add a LI instruction in the block from where xi comes,
# at the end of its body (i.e. just before its Terminator),
# and replace xi by this new temporary
new_xi = self.cfg.fdata.fresh_tmp()
ins.srcs[Bi_label] = new_xi
imm = Immediate(self.valueness[xi])
li_ins = RiscV.li(new_xi, imm)
Bi.add_instruction(li_ins)
else:
to_remove.append(Bi_label)
for Bi_label in to_remove:
del ins.srcs[Bi_label]
return ins
def replaceInstruction(self, ins: BlockInstr) -> List[BlockInstr]:
"""
Replace an instruction that has constant operands
according to the valueness computation.
"""
# Add some LI instructions before the instruction `ins`
li_instrs: List[BlockInstr] = []
# Replace the constant variables with the temporaries defined by the new LI instructions
subst: Dict[Operand, Operand] = {}
# Compute `li_instrs` and `subst`
# TODO (Exercise 5)
new_ins = ins.substitute(subst)
return li_instrs + [new_ins]
def replaceTerminator(self, ins: Terminator) -> Tuple[List[BlockInstr], Terminator]:
"""
Replace a terminator that has constant operands
according to the valueness computation.
Return the list of LI instructions to do before,
and the new terminator.
"""
# Add some LI instructions at the end of the body of the block
li_instrs: List[BlockInstr] = []
# Replace the constant variables with the temporaries defined by the new LI instructions
subst: Dict[Operand, Operand] = {}
# Compute `li_instrs` and `subst`
# TODO (Exercise 5)
new_ins = ins.substitute(subst)
return li_instrs, new_ins
def rewriteCFG(self) -> None:
"""Update the CFG."""
# a. Whenever executability[B, C] = False, delete this edge
for (B, C) in [(B, C) for ((B, C), b) in self.executability.items()
if not b and B is not None]:
# Remove the edge
self.cfg.remove_edge(B, C)
# Update the corresponding terminator
targets = B.get_terminator().targets()
targets.remove(C.get_label())
if len(targets) == 0:
B.set_terminator(Return())
elif len(targets) == 1:
B.set_terminator(AbsoluteJump(targets[0]))
else:
raise MiniCInternalError(
"rewriteCFG: A terminator has more than 2 targets: {}"
.format(targets + [C.get_label()]))
# b. Whenever valueness[x] = c, substitute c for x and delete assignment to x
for block in self.all_blocks:
if self.is_executable(block):
new_phis: List[PhiNode] = []
for ins in block.get_phis():
assert (isinstance(ins, PhiNode))
v = ins.defined()[0]
if self.is_constant(v):
# We do not keep instructions defining operands of constant values
continue
else:
new_phis.append(self.replacePhi(block, ins))
new_instrs: List[BlockInstr] = []
for ins in block.get_body():
defs = ins.defined()
if len(defs) == 1 and self.is_constant(defs[0]):
# We do not keep instructions defining operands of constant values
continue
elif isinstance(ins, Instruction):
# We replace others instructions
new_instrs.extend(self.replaceInstruction(ins))
else:
# We do nothing for comments
new_instrs.append(ins)
term_instrs, new_term = self.replaceTerminator(block.get_terminator())
block.set_phis(cast(List[Statement], new_phis))
block._instructions = new_instrs + term_instrs
block.set_terminator(new_term)
# c. Whenever a block B is not executable, delete B
# There are no edge implicating B, for such an edge would not be
# executable, whence would have been deleted beforehand
for block in self.all_blocks:
if not self.is_executable(block):
del self.cfg._blocks[block.get_label()]
def OptimSSA(cfg: CFG, debug: bool) -> None:
"""Optimise a CFG under SSA form."""
optim = CondConstantPropagation(cfg, debug)
optim.compute()
optim.rewriteCFG()