# Visitor to *typecheck* MiniC files from typing import List, NoReturn from MiniCVisitor import MiniCVisitor from MiniCParser import MiniCParser from Lib.Errors import MiniCInternalError, MiniCTypeError from enum import Enum class BaseType(Enum): Float, Integer, Boolean, String = range(4) # Basic Type Checking for MiniC programs. class MiniCTypingVisitor(MiniCVisitor): def __init__(self): self._memorytypes = dict() # id -> types # For now, we don't have real functions ... self._current_function = "main" def _raise(self, ctx, for_what, *types): raise MiniCTypeError( 'In function {}: Line {} col {}: invalid type for {}: {}'.format( self._current_function, ctx.start.line, ctx.start.column, for_what, ' and '.join(t.name.lower() for t in types))) def _assertSameType(self, ctx, for_what, *types): if not all(types[0] == t for t in types): raise MiniCTypeError( 'In function {}: Line {} col {}: type mismatch for {}: {}'.format( self._current_function, ctx.start.line, ctx.start.column, for_what, ' and '.join(t.name.lower() for t in types))) def _raiseNonType(self, ctx, message) -> NoReturn: raise MiniCTypeError( 'In function {}: Line {} col {}: {}'.format( self._current_function, ctx.start.line, ctx.start.column, message)) # type declaration def visitVarDecl(self, ctx) -> None: names_str = self.visit(ctx.id_l()) type_minic = self.visit(ctx.typee()) for name in names_str: if name in self._memorytypes: self._raiseNonType(ctx, f"Variable {name} already declared") self._memorytypes[name] = type_minic def visitBasicType(self, ctx): assert ctx.mytype is not None if ctx.mytype.type == MiniCParser.INTTYPE: return BaseType.Integer elif ctx.mytype.type == MiniCParser.FLOATTYPE: return BaseType.Float elif ctx.mytype.type == MiniCParser.STRINGTYPE: return BaseType.String elif ctx.mytype.type == MiniCParser.BOOLTYPE: return BaseType.Boolean def visitIdList(self, ctx) -> List[str]: ids = self.visit(ctx.id_l()) return ids + [ctx.ID().getText()] def visitIdListBase(self, ctx) -> List[str]: return [ctx.ID().getText()] # typing visitors for expressions, statements ! # visitors for atoms --> type def visitParExpr(self, ctx): return self.visit(ctx.expr()) def visitIntAtom(self, ctx): return BaseType.Integer def visitFloatAtom(self, ctx): return BaseType.Float def visitBooleanAtom(self, ctx): return BaseType.Boolean def visitIdAtom(self, ctx): try: return self._memorytypes[ctx.getText()] except KeyError: self._raiseNonType(ctx, "Undefined variable {}".format(ctx.getText())) def visitStringAtom(self, ctx): return BaseType.String # now visit expr def visitAtomExpr(self, ctx): return self.visit(ctx.atom()) def visitOrExpr(self, ctx): ltype = self.visit(ctx.expr(0)) rtype = self.visit(ctx.expr(1)) self._assertSameType(ctx, "|", ltype, rtype) if ltype == BaseType.String: self._raise(ctx, "string not allowed", ltype) return ltype def visitAndExpr(self, ctx): ltype = self.visit(ctx.expr(0)) rtype = self.visit(ctx.expr(1)) self._assertSameType(ctx, "&", ltype, rtype) if ltype == BaseType.String: self._raise(ctx, "string not allowed", ltype) return ltype def visitEqualityExpr(self, ctx): ltype = self.visit(ctx.expr(0)) rtype = self.visit(ctx.expr(1)) self._assertSameType(ctx, "==", ltype, rtype) return BaseType.Boolean def visitRelationalExpr(self, ctx): ltype = self.visit(ctx.expr(0)) rtype = self.visit(ctx.expr(1)) self._assertSameType(ctx, "relational expr need same type", ltype, rtype) if ltype not in [BaseType.Integer, BaseType.Float]: self._raise(ctx, "non numeric type", ltype) return BaseType.Boolean def visitAdditiveExpr(self, ctx): assert ctx.myop is not None ltype = self.visit(ctx.expr(0)) rtype = self.visit(ctx.expr(1)) self._assertSameType(ctx, "Additive", ltype, rtype) if ltype not in [BaseType.String, BaseType.Integer, BaseType.Float]: self._raise(ctx, "additive operands", ltype, rtype) if ltype == BaseType.String and ctx.myop.type != MiniCParser.PLUS: self._raise(ctx, "Minus not compatible with string", ltype) return ltype def visitMultiplicativeExpr(self, ctx): assert ctx.myop is not None ltype = self.visit(ctx.expr(0)) rtype = self.visit(ctx.expr(1)) if ltype == BaseType.String or rtype == BaseType.String: self._raise(ctx, "multiplicative operands", ltype, rtype) self._assertSameType(ctx, "multiplicative operands", ltype, rtype) return ltype def visitNotExpr(self, ctx): ltype = self.visit(ctx.expr()) if ltype != BaseType.Boolean: self._raise(ctx, "NOT: only boolean allowed", ltype) return BaseType.Boolean def visitUnaryMinusExpr(self, ctx): ltype = self.visit(ctx.expr()) if ltype not in [BaseType.Integer, BaseType.Float]: self._raise(ctx, "non integer type", ltype) return ltype # visit statements def visitPrintlnintStat(self, ctx): etype = self.visit(ctx.expr()) if etype != BaseType.Integer: self._raise(ctx, 'println_int statement', etype) def visitPrintlnfloatStat(self, ctx): etype = self.visit(ctx.expr()) if etype != BaseType.Float: self._raise(ctx, 'println_float statement', etype) def visitPrintlnboolStat(self, ctx): etype = self.visit(ctx.expr()) if etype != BaseType.Boolean: self._raise(ctx, 'println_bool statement', etype) def visitPrintlnstringStat(self, ctx): etype = self.visit(ctx.expr()) if etype != BaseType.String: self._raise(ctx, 'println_string statement', etype) def visitAssignStat(self, ctx): etype = self.visit(ctx.expr()) if ctx.ID().getText() not in self._memorytypes: self._raiseNonType(ctx, f"Undefined variable {ctx.ID().getText()}") self._assertSameType(ctx, ctx.ID().getText(), self._memorytypes[ctx.ID().getText()], etype) def visitWhileStat(self, ctx): etype = self.visit(ctx.expr()) if etype != BaseType.Boolean: self._raise(ctx, "non boolean as condition", etype) self.visit(ctx.stat_block()) def visitIfStat(self, ctx): etype = self.visit(ctx.expr()) if etype != BaseType.Boolean: self._raise(ctx, "non boolean as condition", etype) for block in ctx.stat_block(): self.visit(block) def visitForStat(self, ctx) -> None: from_type = self.visit(ctx.expr(0)) to_type = self.visit(ctx.expr(1)) if from_type != BaseType.Integer: self._raise(ctx, "non-integer from value", from_type) if to_type != BaseType.Integer: self._raise(ctx, "non-integer to value", to_type) if ctx.ID().getText() not in self._memorytypes: self._raiseNonType(ctx, f"Undefined variable {ctx.ID().getText()}") if self._memorytypes[ctx.ID().getText()] != BaseType.Integer: self._raise(ctx, "non-integer for loop counter", self._memorytypes[ctx.ID().getText()]) if len(ctx.expr()) > 2: stride_type = self.visit(ctx.expr(2)) if stride_type != BaseType.Integer: self._raise(ctx, "non-integer stride value", stride_type) self.visit(ctx.stat_block())