339 lines
6.6 KiB
Go
339 lines
6.6 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"maps"
|
|
)
|
|
|
|
type Expression interface {
|
|
Step() (Expression, error)
|
|
Substitute(Id, Expression)
|
|
TypeCheck(TypeEnvironment) (Type, error)
|
|
}
|
|
|
|
type Declaration struct {
|
|
variables []*Symbol
|
|
body []Expression
|
|
}
|
|
|
|
func (d *Declaration) Step() (Expression, error) {
|
|
for _, v := range d.variables {
|
|
if isValue(v.expr) {
|
|
continue
|
|
}
|
|
n, err := v.expr.Step()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
v.expr = n
|
|
return d, nil
|
|
}
|
|
for _, v := range d.variables {
|
|
for _, b := range d.body {
|
|
b.Substitute(v.id.value, v.expr)
|
|
}
|
|
}
|
|
if len(d.body) == 0 {
|
|
return &zero, nil
|
|
}
|
|
for i, b := range d.body {
|
|
if isValue(b) {
|
|
continue
|
|
}
|
|
n, err := b.Step()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
d.body[i] = n
|
|
return d, nil
|
|
}
|
|
return d.body[len(d.body)-1], nil
|
|
}
|
|
|
|
func (d *Declaration) Substitute(id Id, expr Expression) {
|
|
for _, v := range d.variables {
|
|
v.expr.Substitute(id, expr)
|
|
}
|
|
for _, b := range d.body {
|
|
b.Substitute(id, expr)
|
|
}
|
|
}
|
|
|
|
func (d *Declaration) TypeCheck(env TypeEnvironment) (Type, error) {
|
|
env = maps.Clone(env)
|
|
for _, v := range d.variables {
|
|
typ, err := v.expr.TypeCheck(env)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
env[v.id.value] = typ
|
|
}
|
|
var body []Type = make([]Type, len(d.body))
|
|
for i, e := range d.body {
|
|
typ, err := e.TypeCheck(env)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
body[i] = typ
|
|
}
|
|
if len(body) > 0 {
|
|
return body[len(body)-1], nil
|
|
} else {
|
|
return &PrimVoid, nil
|
|
}
|
|
}
|
|
|
|
type Conditional struct {
|
|
predicate Expression
|
|
ifTrue Expression
|
|
ifFalse Expression
|
|
}
|
|
|
|
func (c *Conditional) Step() (Expression, error) {
|
|
n, err := c.predicate.Step()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.predicate = n
|
|
if !isValue(n) {
|
|
return c, nil
|
|
}
|
|
b, ok := n.(*BoolValue)
|
|
if !ok {
|
|
return nil, fmt.Errorf("boolean value was expected")
|
|
}
|
|
if *b {
|
|
n, err = c.ifTrue.Step()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.ifTrue = n
|
|
if isValue(n) {
|
|
return n, nil
|
|
} else {
|
|
return c, nil
|
|
}
|
|
} else {
|
|
n, err = c.ifFalse.Step()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.ifTrue = n
|
|
if isValue(n) {
|
|
return n, nil
|
|
} else {
|
|
return c, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Conditional) Substitute(id Id, expr Expression) {
|
|
c.predicate.Substitute(id, expr)
|
|
c.ifTrue.Substitute(id, expr)
|
|
c.ifFalse.Substitute(id, expr)
|
|
}
|
|
|
|
func (c *Conditional) TypeCheck(env TypeEnvironment) (Type, error) {
|
|
typ, err := c.predicate.TypeCheck(env)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !typ.Equal(&PrimBoolean) { // TODO: truthiness
|
|
return nil, fmt.Errorf("expected boolean expression as predicate")
|
|
}
|
|
typ1, err := c.ifTrue.TypeCheck(env)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
typ2, err := c.ifFalse.TypeCheck(env)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !typ1.Equal(typ2) {
|
|
return nil, fmt.Errorf("type mismatch between conditionals. can't coerce %v and %v", typ1, typ2)
|
|
}
|
|
return typ1, nil
|
|
}
|
|
|
|
type BuiltInFunction struct {
|
|
t *FunctionType
|
|
f func([]Expression) (Expression, error)
|
|
}
|
|
|
|
func (b *BuiltInFunction) HasValue() {}
|
|
func (b *BuiltInFunction) Substitute(Id, Expression) {}
|
|
func (b *BuiltInFunction) TypeCheck(TypeEnvironment) (Type, error) { return b.t, nil }
|
|
func (b *BuiltInFunction) Step() (Expression, error) { panic("not implemented") }
|
|
func (b BuiltInFunction) String() string { return "(lambda (..) (..))" }
|
|
|
|
// func (f *FunctionValue) TypeCheck(TypeEnvironment) (Type, error) { return &PrimBoolean, nil }
|
|
|
|
type Lambda struct {
|
|
t *FunctionType
|
|
params []*Symbol
|
|
body []Expression
|
|
}
|
|
|
|
func (l *Lambda) Step() (Expression, error) {
|
|
for _, arg := range l.params {
|
|
if arg.expr == nil {
|
|
panic("!")
|
|
}
|
|
for _, e := range l.body {
|
|
e.Substitute(arg.id.value, arg.expr)
|
|
}
|
|
}
|
|
if len(l.body) == 0 {
|
|
return &zero, nil
|
|
}
|
|
for i, expr := range l.body {
|
|
if isValue(expr) {
|
|
continue
|
|
}
|
|
r, err := expr.Step()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
l.body[i] = r
|
|
return l, nil
|
|
}
|
|
return l.body[len(l.body)-1], nil
|
|
}
|
|
|
|
func (l *Lambda) Substitute(id Id, expr Expression) {
|
|
for _, b := range l.body {
|
|
b.Substitute(id, expr)
|
|
}
|
|
}
|
|
|
|
func (l *Lambda) TypeCheck(env TypeEnvironment) (Type, error) {
|
|
env = maps.Clone(env)
|
|
for _, p := range l.params {
|
|
typ, err := p.expr.TypeCheck(env)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
env[p.id.value] = typ
|
|
}
|
|
var body []Type = make([]Type, len(l.body))
|
|
for i, e := range l.body {
|
|
typ, err := e.TypeCheck(env)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
body[i] = typ
|
|
}
|
|
if len(body) > 0 {
|
|
return body[len(body)-1], nil
|
|
} else {
|
|
return &PrimVoid, nil
|
|
}
|
|
}
|
|
|
|
type Symbol struct {
|
|
id *Identifier
|
|
expr Expression
|
|
}
|
|
|
|
type Param struct {
|
|
id *Identifier
|
|
t Type
|
|
}
|
|
|
|
func (v *Symbol) Step() (Expression, error) {
|
|
if v.expr == nil {
|
|
return nil, fmt.Errorf("unbound symbol %s at %s", v.id, v.id.position)
|
|
}
|
|
return v.expr, nil
|
|
}
|
|
|
|
func (v *Symbol) Substitute(id Id, expr Expression) {
|
|
if v.id.value == id {
|
|
v.expr = expr
|
|
}
|
|
}
|
|
|
|
func (s *Symbol) TypeCheck(env TypeEnvironment) (Type, error) {
|
|
if typ, ok := env[s.id.value]; ok {
|
|
return typ, nil
|
|
}
|
|
if s.expr == nil {
|
|
return nil, fmt.Errorf("undeclared symbol %s at %s", s.id, s.id.position)
|
|
}
|
|
return s.expr.TypeCheck(env)
|
|
}
|
|
|
|
type List struct {
|
|
els []Expression
|
|
}
|
|
|
|
func (v *List) Step() (Expression, error) {
|
|
for i, expr := range v.els {
|
|
if !isValue(expr) {
|
|
e, err := expr.Step()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
v.els[i] = e
|
|
return v, nil
|
|
}
|
|
}
|
|
if len(v.els) == 0 {
|
|
return &zero, nil
|
|
}
|
|
args := v.els[1:]
|
|
switch x := v.els[0].(type) {
|
|
case *Lambda:
|
|
for i, p := range x.params {
|
|
p.expr = args[i]
|
|
}
|
|
return x, nil
|
|
case *BuiltInFunction:
|
|
return x.f(args)
|
|
default:
|
|
return nil, fmt.Errorf("not a function")
|
|
}
|
|
}
|
|
|
|
func (v *List) Substitute(id Id, expr Expression) {
|
|
for _, e := range v.els {
|
|
e.Substitute(id, expr)
|
|
}
|
|
}
|
|
|
|
func (v *List) TypeCheck(env TypeEnvironment) (Type, error) {
|
|
var et []Type
|
|
for _, e := range v.els {
|
|
typ, err := e.TypeCheck(env)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
et = append(et, typ)
|
|
}
|
|
if len(et) == 0 {
|
|
return &PrimAny, nil
|
|
}
|
|
tf, ok := et[0].(*FunctionType)
|
|
if !ok {
|
|
return nil, fmt.Errorf("expected function call. got %v", et[0])
|
|
}
|
|
args := et[1:]
|
|
// if !isVariadic(tf.params[len(tf.params)-1]) {
|
|
// }
|
|
if len(args) != len(tf.params) {
|
|
return nil, fmt.Errorf("parameter and arguments number mismatch.")
|
|
}
|
|
for i := range args {
|
|
if !tf.params[i].Equal(args[i]) {
|
|
return nil, fmt.Errorf("parameter type mismatch. expected %v got %v",
|
|
tf.params[i], args[i])
|
|
}
|
|
}
|
|
return tf.ret, nil
|
|
// for _, typ := range et {
|
|
// if !typ.Equal(et[0]) {
|
|
// return nil, fmt.Errorf("array type mismatch")
|
|
// }
|
|
// }
|
|
// return &ListType{et[0]}, nil
|
|
}
|