act/pkg/schema/schema.go

406 lines
10 KiB
Go

package schema
import (
_ "embed"
"encoding/json"
"errors"
"fmt"
"math"
"regexp"
"strconv"
"strings"
"github.com/rhysd/actionlint"
"gopkg.in/yaml.v3"
)
//go:embed workflow_schema.json
var workflowSchema string
//go:embed action_schema.json
var actionSchema string
var functions = regexp.MustCompile(`^([a-zA-Z0-9_]+)\(([0-9]+),([0-9]+|MAX)\)$`)
type Schema struct {
Definitions map[string]Definition
}
func (s *Schema) GetDefinition(name string) Definition {
def, ok := s.Definitions[name]
if !ok {
switch name {
case "any":
return Definition{OneOf: &[]string{"sequence", "mapping", "number", "boolean", "string", "null"}}
case "sequence":
return Definition{Sequence: &SequenceDefinition{ItemType: "any"}}
case "mapping":
return Definition{Mapping: &MappingDefinition{LooseKeyType: "any", LooseValueType: "any"}}
case "number":
return Definition{Number: &NumberDefinition{}}
case "string":
return Definition{String: &StringDefinition{}}
case "boolean":
return Definition{Boolean: &BooleanDefinition{}}
case "null":
return Definition{Null: &NullDefinition{}}
}
}
return def
}
type Definition struct {
Context []string
Mapping *MappingDefinition
Sequence *SequenceDefinition
OneOf *[]string `json:"one-of"`
AllowedValues *[]string `json:"allowed-values"`
String *StringDefinition
Number *NumberDefinition
Boolean *BooleanDefinition
Null *NullDefinition
}
type MappingDefinition struct {
Properties map[string]MappingProperty
LooseKeyType string `json:"loose-key-type"`
LooseValueType string `json:"loose-value-type"`
}
type MappingProperty struct {
Type string
Required bool
}
func (s *MappingProperty) UnmarshalJSON(data []byte) error {
if json.Unmarshal(data, &s.Type) != nil {
type MProp MappingProperty
return json.Unmarshal(data, (*MProp)(s))
}
return nil
}
type SequenceDefinition struct {
ItemType string `json:"item-type"`
}
type StringDefinition struct {
Constant string
IsExpression bool `json:"is-expression"`
}
type NumberDefinition struct {
}
type BooleanDefinition struct {
}
type NullDefinition struct {
}
func GetWorkflowSchema() *Schema {
sh := &Schema{}
_ = json.Unmarshal([]byte(workflowSchema), sh)
return sh
}
func GetActionSchema() *Schema {
sh := &Schema{}
_ = json.Unmarshal([]byte(actionSchema), sh)
return sh
}
type Node struct {
Definition string
Schema *Schema
Context []string
}
type FunctionInfo struct {
name string
min int
max int
}
func (s *Node) checkSingleExpression(exprNode actionlint.ExprNode) error {
if len(s.Context) == 0 {
switch exprNode.Token().Kind {
case actionlint.TokenKindInt:
case actionlint.TokenKindFloat:
case actionlint.TokenKindString:
return nil
default:
return fmt.Errorf("expressions are not allowed here")
}
}
funcs := s.GetFunctions()
var err error
actionlint.VisitExprNode(exprNode, func(node, _ actionlint.ExprNode, entering bool) {
if funcCallNode, ok := node.(*actionlint.FuncCallNode); entering && ok {
for _, v := range *funcs {
if strings.EqualFold(funcCallNode.Callee, v.name) {
if v.min > len(funcCallNode.Args) {
err = errors.Join(err, fmt.Errorf("Missing parameters for %s expected >= %v got %v", funcCallNode.Callee, v.min, len(funcCallNode.Args)))
}
if v.max < len(funcCallNode.Args) {
err = errors.Join(err, fmt.Errorf("Too many parameters for %s expected <= %v got %v", funcCallNode.Callee, v.max, len(funcCallNode.Args)))
}
return
}
}
err = errors.Join(err, fmt.Errorf("Unknown Function Call %s", funcCallNode.Callee))
}
if varNode, ok := node.(*actionlint.VariableNode); entering && ok {
for _, v := range s.Context {
if strings.EqualFold(varNode.Name, v) {
return
}
}
err = errors.Join(err, fmt.Errorf("Unknown Variable Access %s", varNode.Name))
}
})
return err
}
func (s *Node) GetFunctions() *[]FunctionInfo {
funcs := &[]FunctionInfo{}
AddFunction(funcs, "contains", 2, 2)
AddFunction(funcs, "endsWith", 2, 2)
AddFunction(funcs, "format", 1, 255)
AddFunction(funcs, "join", 1, 2)
AddFunction(funcs, "startsWith", 2, 2)
AddFunction(funcs, "toJson", 1, 1)
AddFunction(funcs, "fromJson", 1, 1)
for _, v := range s.Context {
i := strings.Index(v, "(")
if i == -1 {
continue
}
smatch := functions.FindStringSubmatch(v)
if len(smatch) > 0 {
functionName := smatch[1]
minParameters, _ := strconv.ParseInt(smatch[2], 10, 32)
maxParametersRaw := smatch[3]
var maxParameters int64
if strings.EqualFold(maxParametersRaw, "MAX") {
maxParameters = math.MaxInt32
} else {
maxParameters, _ = strconv.ParseInt(maxParametersRaw, 10, 32)
}
*funcs = append(*funcs, FunctionInfo{
name: functionName,
min: int(minParameters),
max: int(maxParameters),
})
}
}
return funcs
}
func (s *Node) checkExpression(node *yaml.Node) (bool, error) {
val := node.Value
hadExpr := false
var err error
for {
if i := strings.Index(val, "${{"); i != -1 {
val = val[i+3:]
} else {
return hadExpr, err
}
hadExpr = true
parser := actionlint.NewExprParser()
lexer := actionlint.NewExprLexer(val)
exprNode, parseErr := parser.Parse(lexer)
if parseErr != nil {
err = errors.Join(err, fmt.Errorf("%sFailed to parse: %s", formatLocation(node), parseErr.Message))
continue
}
val = val[lexer.Offset():]
cerr := s.checkSingleExpression(exprNode)
if cerr != nil {
err = errors.Join(err, fmt.Errorf("%s%w", formatLocation(node), cerr))
}
}
}
func AddFunction(funcs *[]FunctionInfo, s string, i1, i2 int) {
*funcs = append(*funcs, FunctionInfo{
name: s,
min: i1,
max: i2,
})
}
func (s *Node) UnmarshalYAML(node *yaml.Node) error {
if node != nil && node.Kind == yaml.DocumentNode {
return s.UnmarshalYAML(node.Content[0])
}
def := s.Schema.GetDefinition(s.Definition)
if s.Context == nil {
s.Context = def.Context
}
isExpr, err := s.checkExpression(node)
if err != nil {
return err
}
if isExpr {
return nil
}
if def.Mapping != nil {
return s.checkMapping(node, def)
} else if def.Sequence != nil {
return s.checkSequence(node, def)
} else if def.OneOf != nil {
return s.checkOneOf(def, node)
}
if node.Kind != yaml.ScalarNode {
return fmt.Errorf("%sExpected a scalar got %v", formatLocation(node), getStringKind(node.Kind))
}
if def.String != nil {
return s.checkString(node, def)
} else if def.Number != nil {
var num float64
return node.Decode(&num)
} else if def.Boolean != nil {
var b bool
return node.Decode(&b)
} else if def.AllowedValues != nil {
s := node.Value
for _, v := range *def.AllowedValues {
if s == v {
return nil
}
}
return fmt.Errorf("%sExpected one of %s got %s", formatLocation(node), strings.Join(*def.AllowedValues, ","), s)
} else if def.Null != nil {
var myNull *byte
return node.Decode(&myNull)
}
return errors.ErrUnsupported
}
func (s *Node) checkString(node *yaml.Node, def Definition) error {
val := node.Value
if def.String.Constant != "" && def.String.Constant != val {
return fmt.Errorf("%sExpected %s got %s", formatLocation(node), def.String.Constant, val)
}
if def.String.IsExpression {
parser := actionlint.NewExprParser()
lexer := actionlint.NewExprLexer(val + "}}")
exprNode, parseErr := parser.Parse(lexer)
if parseErr != nil {
return fmt.Errorf("%sFailed to parse: %s", formatLocation(node), parseErr.Message)
}
cerr := s.checkSingleExpression(exprNode)
if cerr != nil {
return fmt.Errorf("%s%w", formatLocation(node), cerr)
}
}
return nil
}
func (s *Node) checkOneOf(def Definition, node *yaml.Node) error {
var allErrors error
for _, v := range *def.OneOf {
sub := &Node{
Definition: v,
Schema: s.Schema,
Context: append(append([]string{}, s.Context...), s.Schema.GetDefinition(v).Context...),
}
err := sub.UnmarshalYAML(node)
if err == nil {
return nil
}
allErrors = errors.Join(allErrors, fmt.Errorf("%sFailed to match %s: %w", formatLocation(node), v, err))
}
return allErrors
}
func getStringKind(k yaml.Kind) string {
switch k {
case yaml.DocumentNode:
return "document"
case yaml.SequenceNode:
return "sequence"
case yaml.MappingNode:
return "mapping"
case yaml.ScalarNode:
return "scalar"
case yaml.AliasNode:
return "alias"
default:
return "unknown"
}
}
func (s *Node) checkSequence(node *yaml.Node, def Definition) error {
if node.Kind != yaml.SequenceNode {
return fmt.Errorf("%sExpected a sequence got %v", formatLocation(node), getStringKind(node.Kind))
}
var allErrors error
for _, v := range node.Content {
allErrors = errors.Join(allErrors, (&Node{
Definition: def.Sequence.ItemType,
Schema: s.Schema,
Context: append(append([]string{}, s.Context...), s.Schema.GetDefinition(def.Sequence.ItemType).Context...),
}).UnmarshalYAML(v))
}
return allErrors
}
func formatLocation(node *yaml.Node) string {
return fmt.Sprintf("Line: %v Column %v: ", node.Line, node.Column)
}
func (s *Node) checkMapping(node *yaml.Node, def Definition) error {
if node.Kind != yaml.MappingNode {
return fmt.Errorf("%sExpected a mapping got %v", formatLocation(node), getStringKind(node.Kind))
}
insertDirective := regexp.MustCompile(`\${{\s*insert\s*}}`)
var allErrors error
for i, k := range node.Content {
if i%2 == 0 {
if insertDirective.MatchString(k.Value) {
if len(s.Context) == 0 {
allErrors = errors.Join(allErrors, fmt.Errorf("%sinsert is not allowed here", formatLocation(k)))
}
continue
}
isExpr, err := s.checkExpression(k)
if err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
if isExpr {
continue
}
vdef, ok := def.Mapping.Properties[k.Value]
if !ok {
if def.Mapping.LooseValueType == "" {
allErrors = errors.Join(allErrors, fmt.Errorf("%sUnknown Property %v", formatLocation(k), k.Value))
continue
}
vdef = MappingProperty{Type: def.Mapping.LooseValueType}
}
if err := (&Node{
Definition: vdef.Type,
Schema: s.Schema,
Context: append(append([]string{}, s.Context...), s.Schema.GetDefinition(vdef.Type).Context...),
}).UnmarshalYAML(node.Content[i+1]); err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
}
}
return allErrors
}