Type check
This commit is contained in:
parent
8f5b16878f
commit
0cd6040a1f
7 changed files with 625 additions and 256 deletions
334
expression.go
Normal file
334
expression.go
Normal file
|
|
@ -0,0 +1,334 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
)
|
||||
|
||||
type Expression interface {
|
||||
Step() (Expression, error)
|
||||
Substitute(Id, Expression)
|
||||
TypeCheck(TypeEnvironment) (Type, error)
|
||||
}
|
||||
|
||||
type Declaration struct {
|
||||
variables []*Symbol
|
||||
body []Expression
|
||||
}
|
||||
|
||||
func (d *Declaration) Step() (Expression, error) {
|
||||
for _, v := range d.variables {
|
||||
if isValue(v.expr) {
|
||||
continue
|
||||
}
|
||||
n, err := v.expr.Step()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v.expr = n
|
||||
return d, nil
|
||||
}
|
||||
for _, v := range d.variables {
|
||||
for _, b := range d.body {
|
||||
b.Substitute(v.id.value, v.expr)
|
||||
}
|
||||
}
|
||||
if len(d.body) == 0 {
|
||||
return &zero, nil
|
||||
}
|
||||
for i, b := range d.body {
|
||||
if isValue(b) {
|
||||
continue
|
||||
}
|
||||
n, err := b.Step()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.body[i] = n
|
||||
return d, nil
|
||||
}
|
||||
return d.body[len(d.body)-1], nil
|
||||
}
|
||||
|
||||
func (d *Declaration) Substitute(id Id, expr Expression) {
|
||||
for _, v := range d.variables {
|
||||
v.expr.Substitute(id, expr)
|
||||
}
|
||||
for _, b := range d.body {
|
||||
b.Substitute(id, expr)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Declaration) TypeCheck(env TypeEnvironment) (Type, error) {
|
||||
env = maps.Clone(env)
|
||||
for _, v := range d.variables {
|
||||
typ, err := v.expr.TypeCheck(env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
env[v.id.value] = typ
|
||||
}
|
||||
var body []Type = make([]Type, len(d.body))
|
||||
for i, e := range d.body {
|
||||
typ, err := e.TypeCheck(env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body[i] = typ
|
||||
}
|
||||
if len(body) > 0 {
|
||||
return body[len(body)-1], nil
|
||||
} else {
|
||||
return &PrimVoid, nil
|
||||
}
|
||||
}
|
||||
|
||||
type Conditional struct {
|
||||
predicate Expression
|
||||
ifTrue Expression
|
||||
ifFalse Expression
|
||||
}
|
||||
|
||||
func (c *Conditional) Step() (Expression, error) {
|
||||
n, err := c.predicate.Step()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.predicate = n
|
||||
if !isValue(n) {
|
||||
return c, nil
|
||||
}
|
||||
b, ok := n.(*BoolValue)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("boolean value was expected")
|
||||
}
|
||||
if *b {
|
||||
n, err = c.ifTrue.Step()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.ifTrue = n
|
||||
if isValue(n) {
|
||||
return n, nil
|
||||
} else {
|
||||
return c, nil
|
||||
}
|
||||
} else {
|
||||
n, err = c.ifFalse.Step()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.ifTrue = n
|
||||
if isValue(n) {
|
||||
return n, nil
|
||||
} else {
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conditional) Substitute(id Id, expr Expression) {
|
||||
c.predicate.Substitute(id, expr)
|
||||
c.ifTrue.Substitute(id, expr)
|
||||
c.ifFalse.Substitute(id, expr)
|
||||
}
|
||||
|
||||
func (c *Conditional) TypeCheck(env TypeEnvironment) (Type, error) {
|
||||
typ, err := c.predicate.TypeCheck(env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !typ.Equal(&PrimBoolean) { // TODO: truthiness
|
||||
return nil, fmt.Errorf("expected boolean expression as predicate")
|
||||
}
|
||||
typ1, err := c.ifTrue.TypeCheck(env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
typ2, err := c.ifFalse.TypeCheck(env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !typ1.Equal(typ2) {
|
||||
return nil, fmt.Errorf("type mismatch between conditionals. can't coerce %v and %v", typ1, typ2)
|
||||
}
|
||||
return typ1, nil
|
||||
}
|
||||
|
||||
type BuiltInFunction struct {
|
||||
t *FunctionType
|
||||
f func([]Expression) (Expression, error)
|
||||
}
|
||||
|
||||
func (b *BuiltInFunction) HasValue() {}
|
||||
func (b *BuiltInFunction) Substitute(Id, Expression) {}
|
||||
func (b *BuiltInFunction) TypeCheck(TypeEnvironment) (Type, error) { return b.t, nil }
|
||||
func (b *BuiltInFunction) Step() (Expression, error) { panic("not implemented") }
|
||||
func (b BuiltInFunction) String() string { return "(lambda (..) (..))" }
|
||||
|
||||
// func (f *FunctionValue) TypeCheck(TypeEnvironment) (Type, error) { return &PrimBoolean, nil }
|
||||
|
||||
type Lambda struct {
|
||||
t *FunctionType
|
||||
params []*Symbol
|
||||
body []Expression
|
||||
}
|
||||
|
||||
func (l *Lambda) Step() (Expression, error) {
|
||||
for _, arg := range l.params {
|
||||
if arg.expr == nil {
|
||||
panic("!")
|
||||
}
|
||||
for _, e := range l.body {
|
||||
e.Substitute(arg.id.value, arg.expr)
|
||||
}
|
||||
}
|
||||
if len(l.body) == 0 {
|
||||
return &zero, nil
|
||||
}
|
||||
for i, expr := range l.body {
|
||||
if isValue(expr) {
|
||||
continue
|
||||
}
|
||||
r, err := expr.Step()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.body[i] = r
|
||||
return l, nil
|
||||
}
|
||||
return l.body[len(l.body)-1], nil
|
||||
}
|
||||
|
||||
func (l *Lambda) Substitute(id Id, expr Expression) {
|
||||
for _, b := range l.body {
|
||||
b.Substitute(id, expr)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Lambda) TypeCheck(env TypeEnvironment) (Type, error) {
|
||||
env = maps.Clone(env)
|
||||
for _, p := range l.params {
|
||||
typ, err := p.expr.TypeCheck(env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
env[p.id.value] = typ
|
||||
}
|
||||
var body []Type = make([]Type, len(l.body))
|
||||
for i, e := range l.body {
|
||||
typ, err := e.TypeCheck(env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body[i] = typ
|
||||
}
|
||||
if len(body) > 0 {
|
||||
return body[len(body)-1], nil
|
||||
} else {
|
||||
return &PrimVoid, nil
|
||||
}
|
||||
}
|
||||
|
||||
type Symbol struct {
|
||||
id *Identifier
|
||||
expr Expression
|
||||
}
|
||||
|
||||
func (v *Symbol) Step() (Expression, error) {
|
||||
if v.expr == nil {
|
||||
return nil, fmt.Errorf("unbound symbol %s at %s", v.id, v.id.position)
|
||||
}
|
||||
return v.expr, nil
|
||||
}
|
||||
|
||||
func (v *Symbol) Substitute(id Id, expr Expression) {
|
||||
if v.id.value == id {
|
||||
v.expr = expr
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Symbol) TypeCheck(env TypeEnvironment) (Type, error) {
|
||||
if typ, ok := env[s.id.value]; ok {
|
||||
return typ, nil
|
||||
}
|
||||
if s.expr == nil {
|
||||
return nil, fmt.Errorf("undeclared symbol %s at %s", s.id, s.id.position)
|
||||
}
|
||||
return s.expr.TypeCheck(env)
|
||||
}
|
||||
|
||||
type List struct {
|
||||
els []Expression
|
||||
}
|
||||
|
||||
func (v *List) Step() (Expression, error) {
|
||||
for i, expr := range v.els {
|
||||
if !isValue(expr) {
|
||||
e, err := expr.Step()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v.els[i] = e
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
if len(v.els) == 0 {
|
||||
return &zero, nil
|
||||
}
|
||||
args := v.els[1:]
|
||||
switch x := v.els[0].(type) {
|
||||
case *Lambda:
|
||||
for i, p := range x.params {
|
||||
p.expr = args[i]
|
||||
}
|
||||
return x, nil
|
||||
case *BuiltInFunction:
|
||||
return x.f(args)
|
||||
default:
|
||||
return nil, fmt.Errorf("not a function")
|
||||
}
|
||||
}
|
||||
|
||||
func (v *List) Substitute(id Id, expr Expression) {
|
||||
for _, e := range v.els {
|
||||
e.Substitute(id, expr)
|
||||
}
|
||||
}
|
||||
|
||||
func (v *List) TypeCheck(env TypeEnvironment) (Type, error) {
|
||||
var et []Type
|
||||
for _, e := range v.els {
|
||||
typ, err := e.TypeCheck(env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
et = append(et, typ)
|
||||
}
|
||||
if len(et) == 0 {
|
||||
return &PrimAny, nil
|
||||
}
|
||||
tf, ok := et[0].(*FunctionType)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected function call. got %v", et[0])
|
||||
}
|
||||
args := et[1:]
|
||||
// if !isVariadic(tf.params[len(tf.params)-1]) {
|
||||
// }
|
||||
if len(args) != len(tf.params) {
|
||||
return nil, fmt.Errorf("parameter and arguments number mismatch.")
|
||||
}
|
||||
for i := range args {
|
||||
if !tf.params[i].Equal(args[i]) {
|
||||
return nil, fmt.Errorf("parameter type mismatch. expected %v got %v",
|
||||
tf.params[i], args[i])
|
||||
}
|
||||
}
|
||||
return tf.ret, nil
|
||||
// for _, typ := range et {
|
||||
// if !typ.Equal(et[0]) {
|
||||
// return nil, fmt.Errorf("array type mismatch")
|
||||
// }
|
||||
// }
|
||||
// return &ListType{et[0]}, nil
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue