sexp/expression.go
2025-08-23 20:02:18 -03:00

339 lines
6.6 KiB
Go

package main
import (
"fmt"
"maps"
)
type Expression interface {
Step() (Expression, error)
Substitute(Id, Expression)
TypeCheck(TypeEnvironment) (Type, error)
}
type Declaration struct {
variables []*Symbol
body []Expression
}
func (d *Declaration) Step() (Expression, error) {
for _, v := range d.variables {
if isValue(v.expr) {
continue
}
n, err := v.expr.Step()
if err != nil {
return nil, err
}
v.expr = n
return d, nil
}
for _, v := range d.variables {
for _, b := range d.body {
b.Substitute(v.id.value, v.expr)
}
}
if len(d.body) == 0 {
return &zero, nil
}
for i, b := range d.body {
if isValue(b) {
continue
}
n, err := b.Step()
if err != nil {
return nil, err
}
d.body[i] = n
return d, nil
}
return d.body[len(d.body)-1], nil
}
func (d *Declaration) Substitute(id Id, expr Expression) {
for _, v := range d.variables {
v.expr.Substitute(id, expr)
}
for _, b := range d.body {
b.Substitute(id, expr)
}
}
func (d *Declaration) TypeCheck(env TypeEnvironment) (Type, error) {
env = maps.Clone(env)
for _, v := range d.variables {
typ, err := v.expr.TypeCheck(env)
if err != nil {
return nil, err
}
env[v.id.value] = typ
}
var body []Type = make([]Type, len(d.body))
for i, e := range d.body {
typ, err := e.TypeCheck(env)
if err != nil {
return nil, err
}
body[i] = typ
}
if len(body) > 0 {
return body[len(body)-1], nil
} else {
return &PrimVoid, nil
}
}
type Conditional struct {
predicate Expression
ifTrue Expression
ifFalse Expression
}
func (c *Conditional) Step() (Expression, error) {
n, err := c.predicate.Step()
if err != nil {
return nil, err
}
c.predicate = n
if !isValue(n) {
return c, nil
}
b, ok := n.(*BoolValue)
if !ok {
return nil, fmt.Errorf("boolean value was expected")
}
if *b {
n, err = c.ifTrue.Step()
if err != nil {
return nil, err
}
c.ifTrue = n
if isValue(n) {
return n, nil
} else {
return c, nil
}
} else {
n, err = c.ifFalse.Step()
if err != nil {
return nil, err
}
c.ifTrue = n
if isValue(n) {
return n, nil
} else {
return c, nil
}
}
}
func (c *Conditional) Substitute(id Id, expr Expression) {
c.predicate.Substitute(id, expr)
c.ifTrue.Substitute(id, expr)
c.ifFalse.Substitute(id, expr)
}
func (c *Conditional) TypeCheck(env TypeEnvironment) (Type, error) {
typ, err := c.predicate.TypeCheck(env)
if err != nil {
return nil, err
}
if !typ.Equal(&PrimBoolean) { // TODO: truthiness
return nil, fmt.Errorf("expected boolean expression as predicate")
}
typ1, err := c.ifTrue.TypeCheck(env)
if err != nil {
return nil, err
}
typ2, err := c.ifFalse.TypeCheck(env)
if err != nil {
return nil, err
}
if !typ1.Equal(typ2) {
return nil, fmt.Errorf("type mismatch between conditionals. can't coerce %v and %v", typ1, typ2)
}
return typ1, nil
}
type BuiltInFunction struct {
t *FunctionType
f func([]Expression) (Expression, error)
}
func (b *BuiltInFunction) HasValue() {}
func (b *BuiltInFunction) Substitute(Id, Expression) {}
func (b *BuiltInFunction) TypeCheck(TypeEnvironment) (Type, error) { return b.t, nil }
func (b *BuiltInFunction) Step() (Expression, error) { panic("not implemented") }
func (b BuiltInFunction) String() string { return "(lambda (..) (..))" }
// func (f *FunctionValue) TypeCheck(TypeEnvironment) (Type, error) { return &PrimBoolean, nil }
type Lambda struct {
t *FunctionType
params []*Symbol
body []Expression
}
func (l *Lambda) Step() (Expression, error) {
for _, arg := range l.params {
if arg.expr == nil {
panic("!")
}
for _, e := range l.body {
e.Substitute(arg.id.value, arg.expr)
}
}
if len(l.body) == 0 {
return &zero, nil
}
for i, expr := range l.body {
if isValue(expr) {
continue
}
r, err := expr.Step()
if err != nil {
return nil, err
}
l.body[i] = r
return l, nil
}
return l.body[len(l.body)-1], nil
}
func (l *Lambda) Substitute(id Id, expr Expression) {
for _, b := range l.body {
b.Substitute(id, expr)
}
}
func (l *Lambda) TypeCheck(env TypeEnvironment) (Type, error) {
env = maps.Clone(env)
for _, p := range l.params {
typ, err := p.expr.TypeCheck(env)
if err != nil {
return nil, err
}
env[p.id.value] = typ
}
var body []Type = make([]Type, len(l.body))
for i, e := range l.body {
typ, err := e.TypeCheck(env)
if err != nil {
return nil, err
}
body[i] = typ
}
if len(body) > 0 {
return body[len(body)-1], nil
} else {
return &PrimVoid, nil
}
}
type Symbol struct {
id *Identifier
expr Expression
}
type Param struct {
id *Identifier
t Type
}
func (v *Symbol) Step() (Expression, error) {
if v.expr == nil {
return nil, fmt.Errorf("unbound symbol %s at %s", v.id, v.id.position)
}
return v.expr, nil
}
func (v *Symbol) Substitute(id Id, expr Expression) {
if v.id.value == id {
v.expr = expr
}
}
func (s *Symbol) TypeCheck(env TypeEnvironment) (Type, error) {
if typ, ok := env[s.id.value]; ok {
return typ, nil
}
if s.expr == nil {
return nil, fmt.Errorf("undeclared symbol %s at %s", s.id, s.id.position)
}
return s.expr.TypeCheck(env)
}
type List struct {
els []Expression
}
func (v *List) Step() (Expression, error) {
for i, expr := range v.els {
if !isValue(expr) {
e, err := expr.Step()
if err != nil {
return nil, err
}
v.els[i] = e
return v, nil
}
}
if len(v.els) == 0 {
return &zero, nil
}
args := v.els[1:]
switch x := v.els[0].(type) {
case *Lambda:
for i, p := range x.params {
p.expr = args[i]
}
return x, nil
case *BuiltInFunction:
return x.f(args)
default:
return nil, fmt.Errorf("not a function")
}
}
func (v *List) Substitute(id Id, expr Expression) {
for _, e := range v.els {
e.Substitute(id, expr)
}
}
func (v *List) TypeCheck(env TypeEnvironment) (Type, error) {
var et []Type
for _, e := range v.els {
typ, err := e.TypeCheck(env)
if err != nil {
return nil, err
}
et = append(et, typ)
}
if len(et) == 0 {
return &PrimAny, nil
}
tf, ok := et[0].(*FunctionType)
if !ok {
return nil, fmt.Errorf("expected function call. got %v", et[0])
}
args := et[1:]
// if !isVariadic(tf.params[len(tf.params)-1]) {
// }
if len(args) != len(tf.params) {
return nil, fmt.Errorf("parameter and arguments number mismatch.")
}
for i := range args {
if !tf.params[i].Equal(args[i]) {
return nil, fmt.Errorf("parameter type mismatch. expected %v got %v",
tf.params[i], args[i])
}
}
return tf.ret, nil
// for _, typ := range et {
// if !typ.Equal(et[0]) {
// return nil, fmt.Errorf("array type mismatch")
// }
// }
// return &ListType{et[0]}, nil
}