refactor
This commit is contained in:
parent
687364767a
commit
b7f547839a
BIN
example/golang/simple
Executable file
BIN
example/golang/simple
Executable file
Binary file not shown.
@ -7,3 +7,10 @@ require (
|
|||||||
golang.org/x/sync v0.17.0
|
golang.org/x/sync v0.17.0
|
||||||
golang.org/x/text v0.30.0
|
golang.org/x/text v0.30.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/stretchr/testify v1.11.1
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
)
|
||||||
|
|||||||
@ -1,6 +1,16 @@
|
|||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e h1:Lf/gRkoycfOBPa42vU2bbgPurFong6zXeFtPoxholzU=
|
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e h1:Lf/gRkoycfOBPa42vU2bbgPurFong6zXeFtPoxholzU=
|
||||||
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e/go.mod h1:uNVvRXArCGbZ508SxYYTC5v1JWoz2voff5pm25jU1Ok=
|
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e/go.mod h1:uNVvRXArCGbZ508SxYYTC5v1JWoz2voff5pm25jU1Ok=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|||||||
@ -1,14 +1,18 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import "efprojects.com/kitten-ipc/types"
|
type ValType string
|
||||||
|
|
||||||
// todo check TInt size < 64
|
const (
|
||||||
// todo check not float
|
TNoType ValType = ""
|
||||||
|
TInt ValType = "int"
|
||||||
|
TString ValType = "string"
|
||||||
|
TBool ValType = "bool"
|
||||||
|
TBlob ValType = "blob"
|
||||||
|
)
|
||||||
|
|
||||||
type Val struct {
|
type Val struct {
|
||||||
Name string
|
Name string
|
||||||
Type types.ValType
|
Type ValType
|
||||||
Children []Val
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Method struct {
|
type Method struct {
|
||||||
|
|||||||
19
kitcom/internal/common/write.go
Normal file
19
kitcom/internal/common/write.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
func WriteFile(destFile string, content []byte) error {
|
||||||
|
f, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open destination file: %w", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if _, err := f.Write(content); err != nil {
|
||||||
|
return fmt.Errorf("write file: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@ -4,18 +4,15 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"go/format"
|
"go/format"
|
||||||
"os"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
_ "embed"
|
_ "embed"
|
||||||
|
|
||||||
"efprojects.com/kitten-ipc/kitcom/internal/api"
|
"efprojects.com/kitten-ipc/kitcom/internal/api"
|
||||||
"efprojects.com/kitten-ipc/types"
|
"efprojects.com/kitten-ipc/kitcom/internal/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
// todo: check int overflow
|
|
||||||
// todo: check float is whole
|
|
||||||
|
|
||||||
//go:embed gogen.tmpl
|
//go:embed gogen.tmpl
|
||||||
var templateString string
|
var templateString string
|
||||||
|
|
||||||
@ -34,43 +31,41 @@ func (g *GoApiGenerator) Generate(apis *api.Api, destFile string) error {
|
|||||||
Api: apis,
|
Api: apis,
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultReceiver = "self"
|
|
||||||
|
|
||||||
tpl := template.New("gogen")
|
tpl := template.New("gogen")
|
||||||
tpl = tpl.Funcs(map[string]any{
|
tpl = tpl.Funcs(map[string]any{
|
||||||
"receiver": func(name string) string {
|
"receiver": func(name string) string {
|
||||||
return defaultReceiver
|
return strings.ToLower(name[:1])
|
||||||
},
|
},
|
||||||
"typedef": func(t types.ValType) (string, error) {
|
"typedef": func(t api.ValType) (string, error) {
|
||||||
td, ok := map[types.ValType]string{
|
td, ok := map[api.ValType]string{
|
||||||
types.TInt: "int",
|
api.TInt: "int",
|
||||||
types.TString: "string",
|
api.TString: "string",
|
||||||
types.TBool: "bool",
|
api.TBool: "bool",
|
||||||
types.TBlob: "[]byte",
|
api.TBlob: "[]byte",
|
||||||
}[t]
|
}[t]
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("cannot generate type %v", t)
|
return "", fmt.Errorf("cannot generate type %v", t)
|
||||||
}
|
}
|
||||||
return td, nil
|
return td, nil
|
||||||
},
|
},
|
||||||
"convtype": func(valDef string, t types.ValType) (string, error) {
|
"convtype": func(valDef string, t api.ValType) (string, error) {
|
||||||
td, ok := map[types.ValType]string{
|
td, ok := map[api.ValType]string{
|
||||||
types.TInt: fmt.Sprintf("int(%s.(float64))", valDef),
|
api.TInt: fmt.Sprintf("int(%s.(float64))", valDef),
|
||||||
types.TString: fmt.Sprintf("%s.(string)", valDef),
|
api.TString: fmt.Sprintf("%s.(string)", valDef),
|
||||||
types.TBool: fmt.Sprintf("%s.(bool)", valDef),
|
api.TBool: fmt.Sprintf("%s.(bool)", valDef),
|
||||||
types.TBlob: fmt.Sprintf("%s.([]byte)", valDef),
|
api.TBlob: fmt.Sprintf("%s.([]byte)", valDef),
|
||||||
}[t]
|
}[t]
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("cannot convert type %v for val %s", t, valDef)
|
return "", fmt.Errorf("cannot convert type %v for val %s", t, valDef)
|
||||||
}
|
}
|
||||||
return td, nil
|
return td, nil
|
||||||
},
|
},
|
||||||
"zerovalue": func(t types.ValType) (string, error) {
|
"zerovalue": func(t api.ValType) (string, error) {
|
||||||
v, ok := map[types.ValType]string{
|
v, ok := map[api.ValType]string{
|
||||||
types.TInt: "0",
|
api.TInt: "0",
|
||||||
types.TString: `""`,
|
api.TString: `""`,
|
||||||
types.TBool: "false",
|
api.TBool: "false",
|
||||||
types.TBlob: "[]byte{}",
|
api.TBlob: "[]byte{}",
|
||||||
}[t]
|
}[t]
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("cannot generate zero value for type %v", t)
|
return "", fmt.Errorf("cannot generate zero value for type %v", t)
|
||||||
@ -86,27 +81,14 @@ func (g *GoApiGenerator) Generate(apis *api.Api, destFile string) error {
|
|||||||
return fmt.Errorf("execute template: %w", err)
|
return fmt.Errorf("execute template: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.writeDest(destFile, buf.Bytes()); err != nil {
|
formatted, err := format.Source(buf.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("format source: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := common.WriteFile(destFile, formatted); err != nil {
|
||||||
return fmt.Errorf("write file: %w", err)
|
return fmt.Errorf("write file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GoApiGenerator) writeDest(destFile string, bytes []byte) error {
|
|
||||||
f, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("open destination file: %w", err)
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
formatted, err := format.Source(bytes)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("format source: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := f.Write(formatted); err != nil {
|
|
||||||
return fmt.Errorf("write formatted source: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@ -28,7 +28,9 @@ func ({{ $e.Name | receiver }} *{{ $e.Name }}) {{ $mtd.Name }}(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return {{ range $mtd.Ret }}{{ .Type | zerovalue }}, {{ end }} fmt.Errorf("call to {{ $e.Name }}.{{ $mtd.Name }} failed: %w", err)
|
return {{ range $mtd.Ret }}{{ .Type | zerovalue }}, {{ end }} fmt.Errorf("call to {{ $e.Name }}.{{ $mtd.Name }} failed: %w", err)
|
||||||
}
|
}
|
||||||
_ = results
|
if len(results) < {{ len $mtd.Ret }} {
|
||||||
|
return {{ range $mtd.Ret }}{{ .Type | zerovalue }}, {{ end }} fmt.Errorf("call to {{ $e.Name }}.{{ $mtd.Name }}: expected {{ len $mtd.Ret }} results, got %d", len(results))
|
||||||
|
}
|
||||||
{{ range $i, $ret := $mtd.Ret }}
|
{{ range $i, $ret := $mtd.Ret }}
|
||||||
{{ if eq $ret.Type "blob" }}
|
{{ if eq $ret.Type "blob" }}
|
||||||
results[{{ $i }}], err = base64.StdEncoding.DecodeString(results[{{ $i }}].(string))
|
results[{{ $i }}], err = base64.StdEncoding.DecodeString(results[{{ $i }}].(string))
|
||||||
|
|||||||
@ -9,7 +9,6 @@ import (
|
|||||||
|
|
||||||
"efprojects.com/kitten-ipc/kitcom/internal/api"
|
"efprojects.com/kitten-ipc/kitcom/internal/api"
|
||||||
"efprojects.com/kitten-ipc/kitcom/internal/common"
|
"efprojects.com/kitten-ipc/kitcom/internal/common"
|
||||||
"efprojects.com/kitten-ipc/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var decorComment = regexp.MustCompile(`^//\s?kittenipc:api$`)
|
var decorComment = regexp.MustCompile(`^//\s?kittenipc:api$`)
|
||||||
@ -142,11 +141,11 @@ func fieldToVal(param *ast.Field, returning bool) (*api.Val, error) {
|
|||||||
case *ast.Ident:
|
case *ast.Ident:
|
||||||
switch paramType.Name {
|
switch paramType.Name {
|
||||||
case "int":
|
case "int":
|
||||||
val.Type = types.TInt
|
val.Type = api.TInt
|
||||||
case "string":
|
case "string":
|
||||||
val.Type = types.TString
|
val.Type = api.TString
|
||||||
case "bool":
|
case "bool":
|
||||||
val.Type = types.TBool
|
val.Type = api.TBool
|
||||||
case "error":
|
case "error":
|
||||||
if returning {
|
if returning {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -161,7 +160,7 @@ func fieldToVal(param *ast.Field, returning bool) (*api.Val, error) {
|
|||||||
case *ast.Ident:
|
case *ast.Ident:
|
||||||
switch elementType.Name {
|
switch elementType.Name {
|
||||||
case "byte":
|
case "byte":
|
||||||
val.Type = types.TBlob
|
val.Type = api.TBlob
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("parameter type %s is not supported yet", elementType.Name)
|
return nil, fmt.Errorf("parameter type %s is not supported yet", elementType.Name)
|
||||||
}
|
}
|
||||||
|
|||||||
43
kitcom/internal/golang/goparser_test.go
Normal file
43
kitcom/internal/golang/goparser_test.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package golang
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"efprojects.com/kitten-ipc/kitcom/internal/api"
|
||||||
|
"efprojects.com/kitten-ipc/kitcom/internal/common"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGoParser(t *testing.T) {
|
||||||
|
parser := &GoApiParser{Parser: &common.Parser{}}
|
||||||
|
parser.AddFile("../../../example/golang/main.go")
|
||||||
|
|
||||||
|
result, err := parser.Parse()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, result.Endpoints, 1)
|
||||||
|
|
||||||
|
ep := result.Endpoints[0]
|
||||||
|
assert.Equal(t, "GoIpcApi", ep.Name)
|
||||||
|
require.Len(t, ep.Methods, 2)
|
||||||
|
|
||||||
|
// Div method
|
||||||
|
div := ep.Methods[0]
|
||||||
|
assert.Equal(t, "Div", div.Name)
|
||||||
|
require.Len(t, div.Params, 2)
|
||||||
|
assert.Equal(t, api.TInt, div.Params[0].Type)
|
||||||
|
assert.Equal(t, "a", div.Params[0].Name)
|
||||||
|
assert.Equal(t, api.TInt, div.Params[1].Type)
|
||||||
|
assert.Equal(t, "b", div.Params[1].Name)
|
||||||
|
require.Len(t, div.Ret, 1)
|
||||||
|
assert.Equal(t, api.TInt, div.Ret[0].Type)
|
||||||
|
|
||||||
|
// XorData method
|
||||||
|
xor := ep.Methods[1]
|
||||||
|
assert.Equal(t, "XorData", xor.Name)
|
||||||
|
require.Len(t, xor.Params, 2)
|
||||||
|
assert.Equal(t, api.TBlob, xor.Params[0].Type)
|
||||||
|
assert.Equal(t, api.TBlob, xor.Params[1].Type)
|
||||||
|
require.Len(t, xor.Ret, 1)
|
||||||
|
assert.Equal(t, api.TBlob, xor.Ret[0].Type)
|
||||||
|
}
|
||||||
@ -4,14 +4,13 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
_ "embed"
|
_ "embed"
|
||||||
|
|
||||||
"efprojects.com/kitten-ipc/kitcom/internal/api"
|
"efprojects.com/kitten-ipc/kitcom/internal/api"
|
||||||
"efprojects.com/kitten-ipc/types"
|
"efprojects.com/kitten-ipc/kitcom/internal/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed tsgen.tmpl
|
//go:embed tsgen.tmpl
|
||||||
@ -31,24 +30,24 @@ func (g *TypescriptApiGenerator) Generate(apis *api.Api, destFile string) error
|
|||||||
|
|
||||||
tpl := template.New("tsgen")
|
tpl := template.New("tsgen")
|
||||||
tpl = tpl.Funcs(map[string]any{
|
tpl = tpl.Funcs(map[string]any{
|
||||||
"typedef": func(t types.ValType) (string, error) {
|
"typedef": func(t api.ValType) (string, error) {
|
||||||
td, ok := map[types.ValType]string{
|
td, ok := map[api.ValType]string{
|
||||||
types.TInt: "number",
|
api.TInt: "number",
|
||||||
types.TString: "string",
|
api.TString: "string",
|
||||||
types.TBool: "boolean",
|
api.TBool: "boolean",
|
||||||
types.TBlob: "Buffer",
|
api.TBlob: "Buffer",
|
||||||
}[t]
|
}[t]
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("cannot generate type %v", t)
|
return "", fmt.Errorf("cannot generate type %v", t)
|
||||||
}
|
}
|
||||||
return td, nil
|
return td, nil
|
||||||
},
|
},
|
||||||
"convtype": func(valDef string, t types.ValType) (string, error) {
|
"convtype": func(valDef string, t api.ValType) (string, error) {
|
||||||
td, ok := map[types.ValType]string{
|
td, ok := map[api.ValType]string{
|
||||||
types.TInt: fmt.Sprintf("%s as number", valDef),
|
api.TInt: fmt.Sprintf("%s as number", valDef),
|
||||||
types.TString: fmt.Sprintf("%s as string", valDef),
|
api.TString: fmt.Sprintf("%s as string", valDef),
|
||||||
types.TBool: fmt.Sprintf("%s as boolean", valDef),
|
api.TBool: fmt.Sprintf("%s as boolean", valDef),
|
||||||
types.TBlob: fmt.Sprintf("Buffer.from(%s, 'base64')", valDef),
|
api.TBlob: fmt.Sprintf("Buffer.from(%s, 'base64')", valDef),
|
||||||
}[t]
|
}[t]
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("cannot convert type %v for val %s", t, valDef)
|
return "", fmt.Errorf("cannot convert type %v for val %s", t, valDef)
|
||||||
@ -64,24 +63,10 @@ func (g *TypescriptApiGenerator) Generate(apis *api.Api, destFile string) error
|
|||||||
return fmt.Errorf("execute template: %w", err)
|
return fmt.Errorf("execute template: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.writeDest(destFile, buf.Bytes()); err != nil {
|
if err := common.WriteFile(destFile, buf.Bytes()); err != nil {
|
||||||
return fmt.Errorf("write file: %w", err)
|
return fmt.Errorf("write file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *TypescriptApiGenerator) writeDest(destFile string, bytes []byte) error {
|
|
||||||
f, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("open destination file: %w", err)
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
if _, err := f.Write(bytes); err != nil {
|
|
||||||
return fmt.Errorf("write formatted source: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
prettierCmd := exec.Command("npx", "prettier", destFile, "--write")
|
prettierCmd := exec.Command("npx", "prettier", destFile, "--write")
|
||||||
if out, err := prettierCmd.CombinedOutput(); err != nil {
|
if out, err := prettierCmd.CombinedOutput(); err != nil {
|
||||||
log.Printf("Prettier returned error: %v", err)
|
log.Printf("Prettier returned error: %v", err)
|
||||||
|
|||||||
@ -12,7 +12,6 @@ import (
|
|||||||
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/core"
|
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/core"
|
||||||
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/parser"
|
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/parser"
|
||||||
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/tspath"
|
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/tspath"
|
||||||
"efprojects.com/kitten-ipc/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type TypescriptApiParser struct {
|
type TypescriptApiParser struct {
|
||||||
@ -109,25 +108,25 @@ func (p *TypescriptApiParser) parseFile(sourceFilePath string) ([]api.Endpoint,
|
|||||||
return endpoints, nil
|
return endpoints, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *TypescriptApiParser) fieldToVal(typ *ast.TypeNode) (types.ValType, error) {
|
func (p *TypescriptApiParser) fieldToVal(typ *ast.TypeNode) (api.ValType, error) {
|
||||||
switch typ.Kind {
|
switch typ.Kind {
|
||||||
case ast.KindNumberKeyword:
|
case ast.KindNumberKeyword:
|
||||||
return types.TInt, nil
|
return api.TInt, nil
|
||||||
case ast.KindStringKeyword:
|
case ast.KindStringKeyword:
|
||||||
return types.TString, nil
|
return api.TString, nil
|
||||||
case ast.KindBooleanKeyword:
|
case ast.KindBooleanKeyword:
|
||||||
return types.TBool, nil
|
return api.TBool, nil
|
||||||
case ast.KindTypeReference:
|
case ast.KindTypeReference:
|
||||||
refNode := typ.AsTypeReferenceNode()
|
refNode := typ.AsTypeReferenceNode()
|
||||||
ident := refNode.TypeName.AsIdentifier()
|
ident := refNode.TypeName.AsIdentifier()
|
||||||
switch ident.Text {
|
switch ident.Text {
|
||||||
case "Buffer":
|
case "Buffer":
|
||||||
return types.TBlob, nil
|
return api.TBlob, nil
|
||||||
default:
|
default:
|
||||||
return types.TNoType, fmt.Errorf("reference type %s is not supported yet", ident.Text)
|
return api.TNoType, fmt.Errorf("reference type %s is not supported yet", ident.Text)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return types.TNoType, fmt.Errorf("type %s is not supported yet", typ.Kind)
|
return api.TNoType, fmt.Errorf("type %s is not supported yet", typ.Kind)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
46
kitcom/internal/ts/tsparser_test.go
Normal file
46
kitcom/internal/ts/tsparser_test.go
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
package ts
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"efprojects.com/kitten-ipc/kitcom/internal/api"
|
||||||
|
"efprojects.com/kitten-ipc/kitcom/internal/common"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTsParser(t *testing.T) {
|
||||||
|
parser := &TypescriptApiParser{Parser: &common.Parser{}}
|
||||||
|
absPath, err := filepath.Abs("../../../example/ts/src/index.ts")
|
||||||
|
require.NoError(t, err)
|
||||||
|
parser.AddFile(absPath)
|
||||||
|
|
||||||
|
result, err := parser.Parse()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, result.Endpoints, 1)
|
||||||
|
|
||||||
|
ep := result.Endpoints[0]
|
||||||
|
assert.Equal(t, "TsIpcApi", ep.Name)
|
||||||
|
require.Len(t, ep.Methods, 2)
|
||||||
|
|
||||||
|
// Div method
|
||||||
|
div := ep.Methods[0]
|
||||||
|
assert.Equal(t, "Div", div.Name)
|
||||||
|
require.Len(t, div.Params, 2)
|
||||||
|
assert.Equal(t, api.TInt, div.Params[0].Type)
|
||||||
|
assert.Equal(t, "a", div.Params[0].Name)
|
||||||
|
assert.Equal(t, api.TInt, div.Params[1].Type)
|
||||||
|
assert.Equal(t, "b", div.Params[1].Name)
|
||||||
|
require.Len(t, div.Ret, 1)
|
||||||
|
assert.Equal(t, api.TInt, div.Ret[0].Type)
|
||||||
|
|
||||||
|
// XorData method
|
||||||
|
xor := ep.Methods[1]
|
||||||
|
assert.Equal(t, "XorData", xor.Name)
|
||||||
|
require.Len(t, xor.Params, 2)
|
||||||
|
assert.Equal(t, api.TBlob, xor.Params[0].Type)
|
||||||
|
assert.Equal(t, api.TBlob, xor.Params[1].Type)
|
||||||
|
require.Len(t, xor.Ret, 1)
|
||||||
|
assert.Equal(t, api.TBlob, xor.Ret[0].Type)
|
||||||
|
}
|
||||||
60
lib/golang/child.go
Normal file
60
lib/golang/child.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
package golang
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChildIPC struct {
|
||||||
|
*ipcCommon
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChild(localApis ...any) (*ChildIPC, error) {
|
||||||
|
c := ChildIPC{
|
||||||
|
ipcCommon: &ipcCommon{
|
||||||
|
localApis: mapTypeNames(localApis),
|
||||||
|
pendingCalls: make(map[int64]*pendingCall),
|
||||||
|
errCh: make(chan error, 1),
|
||||||
|
ctx: context.Background(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
socketPath := socketPathFromArgs()
|
||||||
|
if socketPath == "" {
|
||||||
|
return nil, fmt.Errorf("ipc socket path is missing")
|
||||||
|
}
|
||||||
|
c.socketPath = socketPath
|
||||||
|
|
||||||
|
return &c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChildIPC) Start() error {
|
||||||
|
conn, err := net.Dial("unix", c.socketPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("connect to parent socket: %w", err)
|
||||||
|
}
|
||||||
|
c.conn = conn
|
||||||
|
go c.readConn()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChildIPC) Wait() error {
|
||||||
|
err := <-c.errCh
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ipc error: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// socketPathFromArgs parses --ipc-socket from os.Args without calling flag.Parse(),
|
||||||
|
// which would interfere with the host application's flag handling.
|
||||||
|
func socketPathFromArgs() string {
|
||||||
|
for i, arg := range os.Args {
|
||||||
|
if arg == ipcSocketArg && i+1 < len(os.Args) {
|
||||||
|
return os.Args[i+1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
261
lib/golang/common.go
Normal file
261
lib/golang/common.go
Normal file
@ -0,0 +1,261 @@
|
|||||||
|
package golang
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IpcCommon interface {
|
||||||
|
Call(method string, params ...any) (Vals, error)
|
||||||
|
ConvType(needType, gotType reflect.Type, arg any) any
|
||||||
|
}
|
||||||
|
|
||||||
|
type callResult struct {
|
||||||
|
vals Vals
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type pendingCall struct {
|
||||||
|
resultChan chan callResult
|
||||||
|
}
|
||||||
|
|
||||||
|
type ipcCommon struct {
|
||||||
|
localApis map[string]any
|
||||||
|
socketPath string
|
||||||
|
conn net.Conn
|
||||||
|
errCh chan error
|
||||||
|
nextId int64
|
||||||
|
pendingCalls map[int64]*pendingCall
|
||||||
|
processingCalls atomic.Int64
|
||||||
|
stopRequested atomic.Bool
|
||||||
|
mu sync.Mutex
|
||||||
|
writeMu sync.Mutex
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ipc *ipcCommon) readConn() {
|
||||||
|
scn := bufio.NewScanner(ipc.conn)
|
||||||
|
scn.Buffer(nil, maxMessageLength)
|
||||||
|
for scn.Scan() {
|
||||||
|
var msg Message
|
||||||
|
msgBytes := scn.Bytes()
|
||||||
|
if err := json.Unmarshal(msgBytes, &msg); err != nil {
|
||||||
|
ipc.raiseErr(fmt.Errorf("unmarshal message: %w", err))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
ipc.processMsg(msg)
|
||||||
|
}
|
||||||
|
if err := scn.Err(); err != nil {
|
||||||
|
ipc.raiseErr(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ipc *ipcCommon) processMsg(msg Message) {
|
||||||
|
switch msg.Type {
|
||||||
|
case MsgCall:
|
||||||
|
go ipc.handleCall(msg)
|
||||||
|
case MsgResponse:
|
||||||
|
ipc.handleResponse(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ipc *ipcCommon) sendMsg(msg Message) error {
|
||||||
|
data, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal message: %w", err)
|
||||||
|
}
|
||||||
|
data = append(data, '\n')
|
||||||
|
|
||||||
|
ipc.writeMu.Lock()
|
||||||
|
_, writeErr := ipc.conn.Write(data)
|
||||||
|
ipc.writeMu.Unlock()
|
||||||
|
if writeErr != nil {
|
||||||
|
return fmt.Errorf("write message: %w", writeErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ipc *ipcCommon) handleCall(msg Message) {
|
||||||
|
if ipc.stopRequested.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ipc.processingCalls.Add(1)
|
||||||
|
defer ipc.processingCalls.Add(-1)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
ipc.sendResponse(msg.Id, nil, fmt.Errorf("handle call panicked: %s", err))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
method, err := ipc.findMethod(msg.Method)
|
||||||
|
if err != nil {
|
||||||
|
ipc.sendResponse(msg.Id, nil, fmt.Errorf("find method: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
argsCount := method.Type().NumIn()
|
||||||
|
if len(msg.Args) != argsCount {
|
||||||
|
ipc.sendResponse(msg.Id, nil, fmt.Errorf("args count mismatch: expected %d, got %d", argsCount, len(msg.Args)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var args []reflect.Value
|
||||||
|
for i, arg := range msg.Args {
|
||||||
|
paramType := method.Type().In(i)
|
||||||
|
argType := reflect.TypeOf(arg)
|
||||||
|
arg = ipc.ConvType(paramType, argType, arg)
|
||||||
|
args = append(args, reflect.ValueOf(arg))
|
||||||
|
}
|
||||||
|
|
||||||
|
allResultVals := method.Call(args)
|
||||||
|
retResultVals := allResultVals[0 : len(allResultVals)-1]
|
||||||
|
errResultVals := allResultVals[len(allResultVals)-1]
|
||||||
|
|
||||||
|
var results []any
|
||||||
|
for _, resVal := range retResultVals {
|
||||||
|
results = append(results, resVal.Interface())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resErr error
|
||||||
|
if !errResultVals.IsNil() {
|
||||||
|
resErr = errResultVals.Interface().(error)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipc.sendResponse(msg.Id, results, resErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ipc *ipcCommon) findMethod(methodName string) (reflect.Value, error) {
|
||||||
|
parts := strings.Split(methodName, ".")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return reflect.Value{}, fmt.Errorf("invalid method: %s", methodName)
|
||||||
|
}
|
||||||
|
|
||||||
|
endpointName, methodName := parts[0], parts[1]
|
||||||
|
|
||||||
|
localApi, ok := ipc.localApis[endpointName]
|
||||||
|
if !ok {
|
||||||
|
return reflect.Value{}, fmt.Errorf("endpoint not found: %s", endpointName)
|
||||||
|
}
|
||||||
|
|
||||||
|
method := reflect.ValueOf(localApi).MethodByName(methodName)
|
||||||
|
if !method.IsValid() {
|
||||||
|
return reflect.Value{}, fmt.Errorf("method not found: %s", methodName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return method, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ipc *ipcCommon) sendResponse(id int64, result []any, err error) {
|
||||||
|
msg := Message{
|
||||||
|
Type: MsgResponse,
|
||||||
|
Id: id,
|
||||||
|
Result: result,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
msg.Error = err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipc.sendMsg(msg); err != nil {
|
||||||
|
ipc.raiseErr(fmt.Errorf("send response for id=%d: %w", id, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ipc *ipcCommon) handleResponse(msg Message) {
|
||||||
|
ipc.mu.Lock()
|
||||||
|
call, ok := ipc.pendingCalls[msg.Id]
|
||||||
|
if ok {
|
||||||
|
delete(ipc.pendingCalls, msg.Id)
|
||||||
|
}
|
||||||
|
ipc.mu.Unlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
ipc.raiseErr(fmt.Errorf("received response for unknown call id: %d", msg.Id))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var res callResult
|
||||||
|
if msg.Error == "" {
|
||||||
|
res = callResult{vals: msg.Result}
|
||||||
|
} else {
|
||||||
|
res = callResult{err: fmt.Errorf("remote error: %s", msg.Error)}
|
||||||
|
}
|
||||||
|
call.resultChan <- res
|
||||||
|
close(call.resultChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ipc *ipcCommon) Call(method string, params ...any) (Vals, error) {
|
||||||
|
if ipc.conn == nil {
|
||||||
|
return nil, fmt.Errorf("ipc is not connected to remote process socket")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ipc.stopRequested.Load() {
|
||||||
|
return nil, fmt.Errorf("ipc is stopping")
|
||||||
|
}
|
||||||
|
|
||||||
|
ipc.mu.Lock()
|
||||||
|
id := ipc.nextId
|
||||||
|
ipc.nextId++
|
||||||
|
call := &pendingCall{
|
||||||
|
resultChan: make(chan callResult, 1),
|
||||||
|
}
|
||||||
|
ipc.pendingCalls[id] = call
|
||||||
|
ipc.mu.Unlock()
|
||||||
|
|
||||||
|
for i := range params {
|
||||||
|
params[i] = ipc.serialize(params[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := Message{
|
||||||
|
Type: MsgCall,
|
||||||
|
Id: id,
|
||||||
|
Method: method,
|
||||||
|
Args: params,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipc.sendMsg(msg); err != nil {
|
||||||
|
ipc.mu.Lock()
|
||||||
|
delete(ipc.pendingCalls, id)
|
||||||
|
ipc.mu.Unlock()
|
||||||
|
return nil, fmt.Errorf("send call: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case result := <-call.resultChan:
|
||||||
|
return result.vals, result.err
|
||||||
|
case <-ipc.ctx.Done():
|
||||||
|
ipc.mu.Lock()
|
||||||
|
delete(ipc.pendingCalls, id)
|
||||||
|
ipc.mu.Unlock()
|
||||||
|
return nil, ipc.ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ipc *ipcCommon) raiseErr(err error) {
|
||||||
|
select {
|
||||||
|
case ipc.errCh <- err:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ipc *ipcCommon) closeConn() {
|
||||||
|
_ = ipc.conn.Close()
|
||||||
|
ipc.mu.Lock()
|
||||||
|
pending := ipc.pendingCalls
|
||||||
|
ipc.pendingCalls = make(map[int64]*pendingCall)
|
||||||
|
ipc.mu.Unlock()
|
||||||
|
for _, call := range pending {
|
||||||
|
call.resultChan <- callResult{err: fmt.Errorf("call cancelled due to ipc termination")}
|
||||||
|
close(call.resultChan)
|
||||||
|
}
|
||||||
|
}
|
||||||
43
lib/golang/common_test.go
Normal file
43
lib/golang/common_test.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package golang
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testEndpoint struct{}
|
||||||
|
|
||||||
|
func (e *testEndpoint) Hello(name string) (string, error) {
|
||||||
|
return "hello " + name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindMethod(t *testing.T) {
|
||||||
|
ipc := &ipcCommon{
|
||||||
|
localApis: mapTypeNames([]any{&testEndpoint{}}),
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("valid method", func(t *testing.T) {
|
||||||
|
method, err := ipc.findMethod("testEndpoint.Hello")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, method.IsValid())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid format - no dot", func(t *testing.T) {
|
||||||
|
_, err := ipc.findMethod("Hello")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid method")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown endpoint", func(t *testing.T) {
|
||||||
|
_, err := ipc.findMethod("Unknown.Hello")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "endpoint not found")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown method", func(t *testing.T) {
|
||||||
|
_, err := ipc.findMethod("testEndpoint.Unknown")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "method not found")
|
||||||
|
})
|
||||||
|
}
|
||||||
@ -2,4 +2,10 @@ module efprojects.com/kitten-ipc
|
|||||||
|
|
||||||
go 1.25.1
|
go 1.25.1
|
||||||
|
|
||||||
require github.com/samber/mo v1.16.0
|
require github.com/stretchr/testify v1.11.1
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
)
|
||||||
|
|||||||
@ -1,2 +1,10 @@
|
|||||||
github.com/samber/mo v1.16.0 h1:qpEPCI63ou6wXlsNDMLE0IIN8A+devbGX/K1xdgr4b4=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/samber/mo v1.16.0/go.mod h1:DlgzJ4SYhOh41nP1L9kh9rDNERuf8IqWSAs+gj2Vxag=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|||||||
@ -1,550 +0,0 @@
|
|||||||
package golang
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"context"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"flag"
|
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path/filepath"
|
|
||||||
"reflect"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"efprojects.com/kitten-ipc/types"
|
|
||||||
"github.com/samber/mo"
|
|
||||||
)
|
|
||||||
|
|
||||||
const ipcSocketArg = "--ipc-socket"
|
|
||||||
const maxMessageLength = 1 * 1024 * 1024 * 1024 // 1 gigabyte
|
|
||||||
|
|
||||||
type StdioMode int
|
|
||||||
|
|
||||||
type MsgType int
|
|
||||||
|
|
||||||
type Vals []any
|
|
||||||
|
|
||||||
const (
|
|
||||||
MsgCall MsgType = 1
|
|
||||||
MsgResponse MsgType = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
Type MsgType `json:"type"`
|
|
||||||
Id int64 `json:"id"`
|
|
||||||
Method string `json:"method"`
|
|
||||||
Args Vals `json:"args"`
|
|
||||||
Result Vals `json:"result"`
|
|
||||||
Error string `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type IpcCommon interface {
|
|
||||||
Call(method string, params ...any) (Vals, error)
|
|
||||||
ConvType(needType reflect.Type, gotType reflect.Type, arg any) any
|
|
||||||
}
|
|
||||||
|
|
||||||
type pendingCall struct {
|
|
||||||
resultChan chan mo.Result[Vals]
|
|
||||||
resultType reflect.Type
|
|
||||||
}
|
|
||||||
|
|
||||||
type ipcCommon struct {
|
|
||||||
localApis map[string]any
|
|
||||||
socketPath string
|
|
||||||
conn net.Conn
|
|
||||||
errCh chan error
|
|
||||||
nextId int64
|
|
||||||
pendingCalls map[int64]*pendingCall
|
|
||||||
processingCalls atomic.Int64
|
|
||||||
stopRequested atomic.Bool
|
|
||||||
mu sync.Mutex
|
|
||||||
writeMu sync.Mutex
|
|
||||||
ctx context.Context
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ipc *ipcCommon) readConn() {
|
|
||||||
scn := bufio.NewScanner(ipc.conn)
|
|
||||||
scn.Buffer(nil, maxMessageLength)
|
|
||||||
for scn.Scan() {
|
|
||||||
var msg Message
|
|
||||||
msgBytes := scn.Bytes()
|
|
||||||
//log.Println(string(msgBytes))
|
|
||||||
if err := json.Unmarshal(msgBytes, &msg); err != nil {
|
|
||||||
ipc.raiseErr(fmt.Errorf("unmarshal message: %w", err))
|
|
||||||
break
|
|
||||||
}
|
|
||||||
ipc.processMsg(msg)
|
|
||||||
}
|
|
||||||
if err := scn.Err(); err != nil {
|
|
||||||
ipc.raiseErr(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ipc *ipcCommon) processMsg(msg Message) {
|
|
||||||
switch msg.Type {
|
|
||||||
case MsgCall:
|
|
||||||
go ipc.handleCall(msg)
|
|
||||||
case MsgResponse:
|
|
||||||
ipc.handleResponse(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ipc *ipcCommon) sendMsg(msg Message) error {
|
|
||||||
|
|
||||||
data, err := json.Marshal(msg)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("marshal message: %w", err)
|
|
||||||
}
|
|
||||||
data = append(data, '\n')
|
|
||||||
|
|
||||||
ipc.writeMu.Lock()
|
|
||||||
_, writeErr := ipc.conn.Write(data)
|
|
||||||
ipc.writeMu.Unlock()
|
|
||||||
if writeErr != nil {
|
|
||||||
return fmt.Errorf("write message: %w", writeErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ipc *ipcCommon) handleCall(msg Message) {
|
|
||||||
if ipc.stopRequested.Load() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ipc.processingCalls.Add(1)
|
|
||||||
defer ipc.processingCalls.Add(-1)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if err := recover(); err != nil {
|
|
||||||
ipc.sendResponse(msg.Id, nil, fmt.Errorf("handle call panicked: %s", err))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
method, err := ipc.findMethod(msg.Method)
|
|
||||||
if err != nil {
|
|
||||||
ipc.sendResponse(msg.Id, nil, fmt.Errorf("find method: %w", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
argsCount := method.Type().NumIn()
|
|
||||||
if len(msg.Args) != argsCount {
|
|
||||||
ipc.sendResponse(msg.Id, nil, fmt.Errorf("args count mismatch: expected %d, got %d", argsCount, len(msg.Args)))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var args []reflect.Value
|
|
||||||
for i, arg := range msg.Args {
|
|
||||||
paramType := method.Type().In(i)
|
|
||||||
argType := reflect.TypeOf(arg)
|
|
||||||
arg = ipc.ConvType(paramType, argType, arg)
|
|
||||||
args = append(args, reflect.ValueOf(arg))
|
|
||||||
}
|
|
||||||
|
|
||||||
allResultVals := method.Call(args)
|
|
||||||
retResultVals := allResultVals[0 : len(allResultVals)-1]
|
|
||||||
errResultVals := allResultVals[len(allResultVals)-1]
|
|
||||||
|
|
||||||
var results []any
|
|
||||||
for _, resVal := range retResultVals {
|
|
||||||
results = append(results, resVal.Interface())
|
|
||||||
}
|
|
||||||
|
|
||||||
var resErr error
|
|
||||||
if !errResultVals.IsNil() {
|
|
||||||
resErr = errResultVals.Interface().(error)
|
|
||||||
}
|
|
||||||
|
|
||||||
ipc.sendResponse(msg.Id, results, resErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ipc *ipcCommon) findMethod(methodName string) (reflect.Value, error) {
|
|
||||||
parts := strings.Split(methodName, ".")
|
|
||||||
if len(parts) != 2 {
|
|
||||||
return reflect.Value{}, fmt.Errorf("invalid method: %s", methodName)
|
|
||||||
}
|
|
||||||
|
|
||||||
endpointName, methodName := parts[0], parts[1]
|
|
||||||
|
|
||||||
localApi, ok := ipc.localApis[endpointName]
|
|
||||||
if !ok {
|
|
||||||
return reflect.Value{}, fmt.Errorf("endpoint not found: %s", endpointName)
|
|
||||||
}
|
|
||||||
|
|
||||||
method := reflect.ValueOf(localApi).MethodByName(methodName)
|
|
||||||
if !method.IsValid() {
|
|
||||||
return reflect.Value{}, fmt.Errorf("method not found: %s", methodName)
|
|
||||||
}
|
|
||||||
|
|
||||||
return method, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ipc *ipcCommon) sendResponse(id int64, result []any, err error) {
|
|
||||||
msg := Message{
|
|
||||||
Type: MsgResponse,
|
|
||||||
Id: id,
|
|
||||||
Result: result,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
msg.Error = err.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipc.sendMsg(msg); err != nil {
|
|
||||||
ipc.raiseErr(fmt.Errorf("send response for id=%d: %w", id, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ipc *ipcCommon) handleResponse(msg Message) {
|
|
||||||
ipc.mu.Lock()
|
|
||||||
call, ok := ipc.pendingCalls[msg.Id]
|
|
||||||
if ok {
|
|
||||||
delete(ipc.pendingCalls, msg.Id)
|
|
||||||
}
|
|
||||||
ipc.mu.Unlock()
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
ipc.raiseErr(fmt.Errorf("received response for unknown call id: %d", msg.Id))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var res mo.Result[Vals]
|
|
||||||
if msg.Error == "" {
|
|
||||||
res = mo.Ok[Vals](msg.Result)
|
|
||||||
} else {
|
|
||||||
res = mo.Err[Vals](fmt.Errorf("remote error: %s", msg.Error))
|
|
||||||
}
|
|
||||||
call.resultChan <- res
|
|
||||||
close(call.resultChan)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ipc *ipcCommon) Call(method string, params ...any) (Vals, error) {
|
|
||||||
if ipc.conn == nil {
|
|
||||||
return nil, fmt.Errorf("ipc is not connected to remote process socket")
|
|
||||||
}
|
|
||||||
|
|
||||||
if ipc.stopRequested.Load() {
|
|
||||||
return nil, fmt.Errorf("ipc is stopping")
|
|
||||||
}
|
|
||||||
|
|
||||||
ipc.mu.Lock()
|
|
||||||
id := ipc.nextId
|
|
||||||
ipc.nextId++
|
|
||||||
call := &pendingCall{
|
|
||||||
resultChan: make(chan mo.Result[Vals], 1),
|
|
||||||
}
|
|
||||||
ipc.pendingCalls[id] = call
|
|
||||||
ipc.mu.Unlock()
|
|
||||||
|
|
||||||
for i := range params {
|
|
||||||
params[i] = ipc.serialize(params[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := Message{
|
|
||||||
Type: MsgCall,
|
|
||||||
Id: id,
|
|
||||||
Method: method,
|
|
||||||
Args: params,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipc.sendMsg(msg); err != nil {
|
|
||||||
ipc.mu.Lock()
|
|
||||||
delete(ipc.pendingCalls, id)
|
|
||||||
ipc.mu.Unlock()
|
|
||||||
return nil, fmt.Errorf("send call: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case result := <-call.resultChan:
|
|
||||||
return result.Get()
|
|
||||||
case <-ipc.ctx.Done():
|
|
||||||
ipc.mu.Lock()
|
|
||||||
delete(ipc.pendingCalls, id)
|
|
||||||
ipc.mu.Unlock()
|
|
||||||
return nil, ipc.ctx.Err()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ipc *ipcCommon) raiseErr(err error) {
|
|
||||||
select {
|
|
||||||
case ipc.errCh <- err:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ipc *ipcCommon) closeConn() {
|
|
||||||
_ = ipc.conn.Close()
|
|
||||||
ipc.mu.Lock()
|
|
||||||
pending := ipc.pendingCalls
|
|
||||||
ipc.pendingCalls = make(map[int64]*pendingCall)
|
|
||||||
ipc.mu.Unlock()
|
|
||||||
for _, call := range pending {
|
|
||||||
call.resultChan <- mo.Err[Vals](fmt.Errorf("call cancelled due to ipc termination"))
|
|
||||||
close(call.resultChan)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ipc *ipcCommon) ConvType(needType reflect.Type, gotType reflect.Type, arg any) any {
|
|
||||||
switch needType.Kind() {
|
|
||||||
case reflect.Int:
|
|
||||||
// JSON decodes any number to float64. If we need int, we should check and convert
|
|
||||||
if gotType.Kind() == reflect.Float64 {
|
|
||||||
floatArg := arg.(float64)
|
|
||||||
if float64(int64(floatArg)) == floatArg && !needType.OverflowInt(int64(floatArg)) {
|
|
||||||
arg = int(floatArg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return arg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ipc *ipcCommon) serialize(arg any) any {
|
|
||||||
t := reflect.TypeOf(arg)
|
|
||||||
switch t.Kind() {
|
|
||||||
case reflect.Slice:
|
|
||||||
switch t.Elem().Name() {
|
|
||||||
case "uint8":
|
|
||||||
return map[string]any{
|
|
||||||
"t": types.TBlob,
|
|
||||||
"d": base64.StdEncoding.EncodeToString(arg.([]byte)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return arg
|
|
||||||
}
|
|
||||||
|
|
||||||
type ParentIPC struct {
|
|
||||||
*ipcCommon
|
|
||||||
cmd *exec.Cmd
|
|
||||||
listener net.Listener
|
|
||||||
cmdDone chan struct{}
|
|
||||||
cmdErr error
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewParent(cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) {
|
|
||||||
return NewParentWithContext(context.Background(), cmd, localApis...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewParentWithContext(ctx context.Context, cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) {
|
|
||||||
p := ParentIPC{
|
|
||||||
ipcCommon: &ipcCommon{
|
|
||||||
localApis: mapTypeNames(localApis),
|
|
||||||
pendingCalls: make(map[int64]*pendingCall),
|
|
||||||
errCh: make(chan error, 1),
|
|
||||||
socketPath: filepath.Join(os.TempDir(), fmt.Sprintf("kitten-ipc-%d-%d.sock", os.Getpid(), rand.Int63())),
|
|
||||||
ctx: ctx,
|
|
||||||
},
|
|
||||||
cmd: cmd,
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.Stdout = os.Stdout
|
|
||||||
cmd.Stderr = os.Stderr
|
|
||||||
|
|
||||||
if slices.Contains(cmd.Args, ipcSocketArg) {
|
|
||||||
return nil, fmt.Errorf("you should not use `%s` argument in your command", ipcSocketArg)
|
|
||||||
}
|
|
||||||
cmd.Args = append(cmd.Args, ipcSocketArg, p.socketPath)
|
|
||||||
|
|
||||||
p.errCh = make(chan error, 1)
|
|
||||||
p.cmdDone = make(chan struct{})
|
|
||||||
|
|
||||||
return &p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ParentIPC) Start() error {
|
|
||||||
_ = os.Remove(p.socketPath)
|
|
||||||
listener, err := net.Listen("unix", p.socketPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("listen unix socket: %w", err)
|
|
||||||
}
|
|
||||||
p.listener = listener
|
|
||||||
defer p.listener.Close()
|
|
||||||
|
|
||||||
err = p.cmd.Start()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cmd start: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
p.cmdErr = p.cmd.Wait()
|
|
||||||
close(p.cmdDone)
|
|
||||||
}()
|
|
||||||
|
|
||||||
return p.acceptConn()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ParentIPC) acceptConn() error {
|
|
||||||
const acceptTimeout = time.Second * 10
|
|
||||||
|
|
||||||
res := make(chan mo.Result[net.Conn], 1)
|
|
||||||
go func() {
|
|
||||||
conn, err := p.listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
res <- mo.Err[net.Conn](err)
|
|
||||||
} else {
|
|
||||||
res <- mo.Ok[net.Conn](conn)
|
|
||||||
}
|
|
||||||
close(res)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-time.After(acceptTimeout):
|
|
||||||
_ = p.cmd.Process.Kill()
|
|
||||||
return fmt.Errorf("accept timeout")
|
|
||||||
case <-p.cmdDone:
|
|
||||||
return fmt.Errorf("cmd exited before accepting connection: %w", p.cmdErr)
|
|
||||||
case res := <-res:
|
|
||||||
if res.IsError() {
|
|
||||||
_ = p.cmd.Process.Kill()
|
|
||||||
return fmt.Errorf("accept: %w", res.Error())
|
|
||||||
}
|
|
||||||
p.conn = res.MustGet()
|
|
||||||
go p.readConn()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ParentIPC) Stop() error {
|
|
||||||
p.mu.Lock()
|
|
||||||
hasPending := len(p.pendingCalls) > 0
|
|
||||||
p.mu.Unlock()
|
|
||||||
if hasPending {
|
|
||||||
return fmt.Errorf("there are calls pending")
|
|
||||||
}
|
|
||||||
if p.processingCalls.Load() > 0 {
|
|
||||||
return fmt.Errorf("there are calls processing")
|
|
||||||
}
|
|
||||||
p.stopRequested.Store(true)
|
|
||||||
if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil {
|
|
||||||
return fmt.Errorf("send SIGINT: %w", err)
|
|
||||||
}
|
|
||||||
return p.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ParentIPC) Wait(timeout ...time.Duration) (retErr error) {
|
|
||||||
const defaultTimeout = time.Duration(1<<63 - 1) // max duration in go
|
|
||||||
_timeout := variadicToOption(timeout).OrElse(defaultTimeout)
|
|
||||||
|
|
||||||
loop:
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case err := <-p.errCh:
|
|
||||||
retErr = mergeErr(retErr, fmt.Errorf("ipc internal error: %w", err))
|
|
||||||
break loop
|
|
||||||
case <-p.cmdDone:
|
|
||||||
err := p.cmdErr
|
|
||||||
if err != nil {
|
|
||||||
var exitErr *exec.ExitError
|
|
||||||
if ok := errors.As(err, &exitErr); ok {
|
|
||||||
if !exitErr.Success() {
|
|
||||||
ws, ok := exitErr.Sys().(syscall.WaitStatus)
|
|
||||||
if !(ok && ws.Signaled() && ws.Signal() == syscall.SIGINT && p.stopRequested.Load()) {
|
|
||||||
retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break loop
|
|
||||||
case <-time.After(_timeout):
|
|
||||||
p.stopRequested.Store(true)
|
|
||||||
if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil {
|
|
||||||
retErr = mergeErr(retErr, fmt.Errorf("send SIGINT: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
p.closeConn()
|
|
||||||
_ = os.Remove(p.socketPath)
|
|
||||||
|
|
||||||
return retErr
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChildIPC struct {
|
|
||||||
*ipcCommon
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewChild(localApis ...any) (*ChildIPC, error) {
|
|
||||||
c := ChildIPC{
|
|
||||||
ipcCommon: &ipcCommon{
|
|
||||||
localApis: mapTypeNames(localApis),
|
|
||||||
pendingCalls: make(map[int64]*pendingCall),
|
|
||||||
errCh: make(chan error, 1),
|
|
||||||
ctx: context.Background(), // todo: NewChildWithContext
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
socketPath := flag.String("ipc-socket", "", "Path to IPC socket")
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
if *socketPath == "" {
|
|
||||||
return nil, fmt.Errorf("ipc socket path is missing")
|
|
||||||
}
|
|
||||||
c.socketPath = *socketPath
|
|
||||||
|
|
||||||
return &c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ChildIPC) Start() error {
|
|
||||||
conn, err := net.Dial("unix", c.socketPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("connect to parent socket: %w", err)
|
|
||||||
}
|
|
||||||
c.conn = conn
|
|
||||||
go c.readConn()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ChildIPC) Wait() error {
|
|
||||||
err := <-c.errCh
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("ipc error: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func mergeErr(errs ...error) (ret error) {
|
|
||||||
for _, err := range errs {
|
|
||||||
if err != nil {
|
|
||||||
if ret == nil {
|
|
||||||
ret = err
|
|
||||||
} else {
|
|
||||||
ret = fmt.Errorf("%w; %w", ret, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func mapTypeNames(types []any) map[string]any {
|
|
||||||
result := make(map[string]any)
|
|
||||||
for _, t := range types {
|
|
||||||
if reflect.TypeOf(t).Kind() != reflect.Pointer {
|
|
||||||
panic(fmt.Sprintf("LocalAPI argument must be pointer"))
|
|
||||||
}
|
|
||||||
typeName := reflect.TypeOf(t).Elem().Name()
|
|
||||||
result[typeName] = t
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func variadicToOption[T any](variadic []T) mo.Option[T] {
|
|
||||||
if len(variadic) >= 2 {
|
|
||||||
panic("variadic param count must be 0 or 1")
|
|
||||||
}
|
|
||||||
if len(variadic) == 0 {
|
|
||||||
return mo.None[T]()
|
|
||||||
}
|
|
||||||
return mo.Some(variadic[0])
|
|
||||||
}
|
|
||||||
165
lib/golang/parent.go
Normal file
165
lib/golang/parent.go
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
package golang
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ParentIPC struct {
|
||||||
|
*ipcCommon
|
||||||
|
cmd *exec.Cmd
|
||||||
|
listener net.Listener
|
||||||
|
cmdDone chan struct{}
|
||||||
|
cmdErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewParent(cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) {
|
||||||
|
return NewParentWithContext(context.Background(), cmd, localApis...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewParentWithContext(ctx context.Context, cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) {
|
||||||
|
p := ParentIPC{
|
||||||
|
ipcCommon: &ipcCommon{
|
||||||
|
localApis: mapTypeNames(localApis),
|
||||||
|
pendingCalls: make(map[int64]*pendingCall),
|
||||||
|
errCh: make(chan error, 1),
|
||||||
|
socketPath: filepath.Join(os.TempDir(), fmt.Sprintf("kitten-ipc-%d-%d.sock", os.Getpid(), rand.Int63())),
|
||||||
|
ctx: ctx,
|
||||||
|
},
|
||||||
|
cmd: cmd,
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
if slices.Contains(cmd.Args, ipcSocketArg) {
|
||||||
|
return nil, fmt.Errorf("you should not use `%s` argument in your command", ipcSocketArg)
|
||||||
|
}
|
||||||
|
cmd.Args = append(cmd.Args, ipcSocketArg, p.socketPath)
|
||||||
|
|
||||||
|
p.errCh = make(chan error, 1)
|
||||||
|
p.cmdDone = make(chan struct{})
|
||||||
|
|
||||||
|
return &p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ParentIPC) Start() error {
|
||||||
|
_ = os.Remove(p.socketPath)
|
||||||
|
listener, err := net.Listen("unix", p.socketPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("listen unix socket: %w", err)
|
||||||
|
}
|
||||||
|
p.listener = listener
|
||||||
|
defer p.listener.Close()
|
||||||
|
|
||||||
|
err = p.cmd.Start()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cmd start: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
p.cmdErr = p.cmd.Wait()
|
||||||
|
close(p.cmdDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return p.acceptConn()
|
||||||
|
}
|
||||||
|
|
||||||
|
type connResult struct {
|
||||||
|
conn net.Conn
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ParentIPC) acceptConn() error {
|
||||||
|
res := make(chan connResult, 1)
|
||||||
|
go func() {
|
||||||
|
conn, err := p.listener.Accept()
|
||||||
|
res <- connResult{conn: conn, err: err}
|
||||||
|
close(res)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Duration(defaultAcceptTimeout) * time.Second):
|
||||||
|
_ = p.cmd.Process.Kill()
|
||||||
|
return fmt.Errorf("accept timeout")
|
||||||
|
case <-p.cmdDone:
|
||||||
|
return fmt.Errorf("cmd exited before accepting connection: %w", p.cmdErr)
|
||||||
|
case r := <-res:
|
||||||
|
if r.err != nil {
|
||||||
|
_ = p.cmd.Process.Kill()
|
||||||
|
return fmt.Errorf("accept: %w", r.err)
|
||||||
|
}
|
||||||
|
p.conn = r.conn
|
||||||
|
go p.readConn()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ParentIPC) Stop() error {
|
||||||
|
p.mu.Lock()
|
||||||
|
hasPending := len(p.pendingCalls) > 0
|
||||||
|
p.mu.Unlock()
|
||||||
|
if hasPending {
|
||||||
|
return fmt.Errorf("there are calls pending")
|
||||||
|
}
|
||||||
|
if p.processingCalls.Load() > 0 {
|
||||||
|
return fmt.Errorf("there are calls processing")
|
||||||
|
}
|
||||||
|
p.stopRequested.Store(true)
|
||||||
|
if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil {
|
||||||
|
return fmt.Errorf("send SIGINT: %w", err)
|
||||||
|
}
|
||||||
|
return p.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ParentIPC) Wait(timeout ...time.Duration) (retErr error) {
|
||||||
|
const maxDuration = time.Duration(1<<63 - 1)
|
||||||
|
_timeout := maxDuration
|
||||||
|
if len(timeout) > 0 {
|
||||||
|
_timeout = timeout[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case err := <-p.errCh:
|
||||||
|
retErr = mergeErr(retErr, fmt.Errorf("ipc internal error: %w", err))
|
||||||
|
break loop
|
||||||
|
case <-p.cmdDone:
|
||||||
|
err := p.cmdErr
|
||||||
|
if err != nil {
|
||||||
|
var exitErr *exec.ExitError
|
||||||
|
if ok := errors.As(err, &exitErr); ok {
|
||||||
|
if !exitErr.Success() {
|
||||||
|
ws, ok := exitErr.Sys().(syscall.WaitStatus)
|
||||||
|
if !(ok && ws.Signaled() && ws.Signal() == syscall.SIGINT && p.stopRequested.Load()) {
|
||||||
|
retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break loop
|
||||||
|
case <-time.After(_timeout):
|
||||||
|
p.stopRequested.Store(true)
|
||||||
|
if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil {
|
||||||
|
retErr = mergeErr(retErr, fmt.Errorf("send SIGINT: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.closeConn()
|
||||||
|
_ = os.Remove(p.socketPath)
|
||||||
|
|
||||||
|
return retErr
|
||||||
|
}
|
||||||
23
lib/golang/protocol.go
Normal file
23
lib/golang/protocol.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package golang
|
||||||
|
|
||||||
|
const ipcSocketArg = "--ipc-socket"
|
||||||
|
const maxMessageLength = 1 << 30 // 1 GB
|
||||||
|
const defaultAcceptTimeout = 10 // seconds
|
||||||
|
|
||||||
|
type MsgType int
|
||||||
|
|
||||||
|
type Vals []any
|
||||||
|
|
||||||
|
const (
|
||||||
|
MsgCall MsgType = 1
|
||||||
|
MsgResponse MsgType = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Type MsgType `json:"type"`
|
||||||
|
Id int64 `json:"id"`
|
||||||
|
Method string `json:"method"`
|
||||||
|
Args Vals `json:"args"`
|
||||||
|
Result Vals `json:"result"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
35
lib/golang/serialize.go
Normal file
35
lib/golang/serialize.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package golang
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (ipc *ipcCommon) serialize(arg any) any {
|
||||||
|
t := reflect.TypeOf(arg)
|
||||||
|
switch t.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
switch t.Elem().Name() {
|
||||||
|
case "uint8":
|
||||||
|
return map[string]any{
|
||||||
|
"t": "blob",
|
||||||
|
"d": base64.StdEncoding.EncodeToString(arg.([]byte)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return arg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ipc *ipcCommon) ConvType(needType reflect.Type, gotType reflect.Type, arg any) any {
|
||||||
|
switch needType.Kind() {
|
||||||
|
case reflect.Int:
|
||||||
|
// JSON decodes any number to float64. If we need int, we should check and convert
|
||||||
|
if gotType.Kind() == reflect.Float64 {
|
||||||
|
floatArg := arg.(float64)
|
||||||
|
if float64(int64(floatArg)) == floatArg && !needType.OverflowInt(int64(floatArg)) {
|
||||||
|
arg = int(floatArg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return arg
|
||||||
|
}
|
||||||
61
lib/golang/serialize_test.go
Normal file
61
lib/golang/serialize_test.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package golang
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSerialize(t *testing.T) {
|
||||||
|
ipc := &ipcCommon{}
|
||||||
|
|
||||||
|
t.Run("primitives pass through", func(t *testing.T) {
|
||||||
|
assert.Equal(t, 42, ipc.serialize(42))
|
||||||
|
assert.Equal(t, "hello", ipc.serialize("hello"))
|
||||||
|
assert.Equal(t, true, ipc.serialize(true))
|
||||||
|
assert.Equal(t, 3.14, ipc.serialize(3.14))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("byte slice serializes to blob", func(t *testing.T) {
|
||||||
|
data := []byte{0x01, 0x02, 0x03}
|
||||||
|
result := ipc.serialize(data)
|
||||||
|
m, ok := result.(map[string]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "blob", m["t"])
|
||||||
|
assert.Equal(t, "AQID", m["d"]) // base64 of {1,2,3}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty byte slice serializes to blob", func(t *testing.T) {
|
||||||
|
data := []byte{}
|
||||||
|
result := ipc.serialize(data)
|
||||||
|
m, ok := result.(map[string]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "blob", m["t"])
|
||||||
|
assert.Equal(t, "", m["d"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvType(t *testing.T) {
|
||||||
|
ipc := &ipcCommon{}
|
||||||
|
|
||||||
|
t.Run("float64 to int", func(t *testing.T) {
|
||||||
|
result := ipc.ConvType(reflect.TypeOf(0), reflect.TypeOf(0.0), float64(42))
|
||||||
|
assert.Equal(t, 42, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("float64 with fractional part stays float", func(t *testing.T) {
|
||||||
|
result := ipc.ConvType(reflect.TypeOf(0), reflect.TypeOf(0.0), float64(42.5))
|
||||||
|
assert.Equal(t, float64(42.5), result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("string passes through", func(t *testing.T) {
|
||||||
|
result := ipc.ConvType(reflect.TypeOf(""), reflect.TypeOf(""), "hello")
|
||||||
|
assert.Equal(t, "hello", result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bool passes through", func(t *testing.T) {
|
||||||
|
result := ipc.ConvType(reflect.TypeOf(true), reflect.TypeOf(true), true)
|
||||||
|
assert.Equal(t, true, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
@ -1,12 +0,0 @@
|
|||||||
package types
|
|
||||||
|
|
||||||
type ValType string
|
|
||||||
|
|
||||||
const (
|
|
||||||
TNoType ValType = "" // zero value constant for ValType (please don't use just "" as zero value!)
|
|
||||||
TInt ValType = "int"
|
|
||||||
TString ValType = "string"
|
|
||||||
TBool ValType = "bool"
|
|
||||||
TBlob ValType = "blob"
|
|
||||||
TArray ValType = "array"
|
|
||||||
)
|
|
||||||
31
lib/golang/util.go
Normal file
31
lib/golang/util.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
package golang
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mergeErr(errs ...error) (ret error) {
|
||||||
|
for _, err := range errs {
|
||||||
|
if err != nil {
|
||||||
|
if ret == nil {
|
||||||
|
ret = err
|
||||||
|
} else {
|
||||||
|
ret = fmt.Errorf("%w; %w", ret, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapTypeNames(types []any) map[string]any {
|
||||||
|
result := make(map[string]any)
|
||||||
|
for _, t := range types {
|
||||||
|
if reflect.TypeOf(t).Kind() != reflect.Pointer {
|
||||||
|
panic(fmt.Sprintf("LocalAPI argument must be pointer"))
|
||||||
|
}
|
||||||
|
typeName := reflect.TypeOf(t).Elem().Name()
|
||||||
|
result[typeName] = t
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
46
lib/golang/util_test.go
Normal file
46
lib/golang/util_test.go
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
package golang
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMergeErr(t *testing.T) {
|
||||||
|
t.Run("all nil returns nil", func(t *testing.T) {
|
||||||
|
assert.NoError(t, mergeErr(nil, nil, nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("single error returns it", func(t *testing.T) {
|
||||||
|
err := fmt.Errorf("one")
|
||||||
|
assert.EqualError(t, mergeErr(nil, err, nil), "one")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple errors merged", func(t *testing.T) {
|
||||||
|
err1 := fmt.Errorf("one")
|
||||||
|
err2 := fmt.Errorf("two")
|
||||||
|
result := mergeErr(err1, err2)
|
||||||
|
assert.ErrorContains(t, result, "one")
|
||||||
|
assert.ErrorContains(t, result, "two")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapTypeNames(t *testing.T) {
|
||||||
|
type Foo struct{}
|
||||||
|
type Bar struct{}
|
||||||
|
|
||||||
|
t.Run("maps pointer types by name", func(t *testing.T) {
|
||||||
|
foo := &Foo{}
|
||||||
|
bar := &Bar{}
|
||||||
|
result := mapTypeNames([]any{foo, bar})
|
||||||
|
assert.Equal(t, foo, result["Foo"])
|
||||||
|
assert.Equal(t, bar, result["Bar"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("panics on non-pointer", func(t *testing.T) {
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
mapTypeNames([]any{Foo{}})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
27
lib/ts/src/asyncqueue.test.ts
Normal file
27
lib/ts/src/asyncqueue.test.ts
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import {test} from 'vitest';
|
||||||
|
import {AsyncQueue} from './asyncqueue.js';
|
||||||
|
|
||||||
|
test('put then collect returns items', async ({expect}) => {
|
||||||
|
const q = new AsyncQueue<number>();
|
||||||
|
q.put(1);
|
||||||
|
q.put(2);
|
||||||
|
const items = await q.collect();
|
||||||
|
expect(items).toEqual([1, 2]);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('collect then put resolves on first put', async ({expect}) => {
|
||||||
|
const q = new AsyncQueue<number>();
|
||||||
|
const collectPromise = q.collect();
|
||||||
|
q.put(42);
|
||||||
|
const items = await collectPromise;
|
||||||
|
expect(items).toEqual([42]);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('multiple puts before collect', async ({expect}) => {
|
||||||
|
const q = new AsyncQueue<string>();
|
||||||
|
q.put('a');
|
||||||
|
q.put('b');
|
||||||
|
q.put('c');
|
||||||
|
const items = await q.collect();
|
||||||
|
expect(items).toEqual(['a', 'b', 'c']);
|
||||||
|
});
|
||||||
44
lib/ts/src/child.ts
Normal file
44
lib/ts/src/child.ts
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import * as net from 'node:net';
|
||||||
|
import {IPCCommon} from './common.js';
|
||||||
|
import {socketPathFromArgs} from './util.js';
|
||||||
|
|
||||||
|
export class ChildIPC extends IPCCommon {
|
||||||
|
constructor(...localApis: object[]) {
|
||||||
|
super(localApis, socketPathFromArgs());
|
||||||
|
}
|
||||||
|
|
||||||
|
async start(): Promise<void> {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
this.conn = net.createConnection(this.socketPath, () => {
|
||||||
|
this.readConn();
|
||||||
|
resolve();
|
||||||
|
});
|
||||||
|
this.conn.on('error', reject);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async wait(): Promise<void> {
|
||||||
|
const closePromise = new Promise<void>((resolve) => {
|
||||||
|
this.onClose = () => {
|
||||||
|
if (this.processingCalls === 0) {
|
||||||
|
this.conn?.destroy();
|
||||||
|
resolve();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if (this.stopRequested && this.processingCalls === 0) {
|
||||||
|
this.conn?.destroy();
|
||||||
|
resolve();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const errorPromise = this.errorQueue.collect().then((errors) => {
|
||||||
|
if (errors.length === 1) {
|
||||||
|
throw errors[0];
|
||||||
|
} else if (errors.length > 1) {
|
||||||
|
throw new Error(errors.map(e => e.toString()).join(', '));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
await Promise.race([closePromise, errorPromise]);
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,46 +1,10 @@
|
|||||||
import * as net from 'node:net';
|
import * as net from 'node:net';
|
||||||
import * as readline from 'node:readline';
|
import * as readline from 'node:readline';
|
||||||
import {type ChildProcess, spawn} from 'node:child_process';
|
|
||||||
import * as os from 'node:os';
|
|
||||||
import * as path from 'node:path';
|
|
||||||
import * as fs from 'node:fs';
|
|
||||||
import * as util from 'node:util';
|
|
||||||
import * as crypto from 'node:crypto';
|
|
||||||
import {AsyncQueue} from './asyncqueue.js';
|
import {AsyncQueue} from './asyncqueue.js';
|
||||||
|
import type {CallMessage, CallResult, Message, ResponseMessage, Vals} from './protocol.js';
|
||||||
|
import {MsgType} from './protocol.js';
|
||||||
|
|
||||||
const IPC_SOCKET_ARG = 'ipc-socket';
|
export abstract class IPCCommon {
|
||||||
|
|
||||||
type JSONSerializable = string | number | boolean;
|
|
||||||
|
|
||||||
enum MsgType {
|
|
||||||
Call = 1,
|
|
||||||
Response = 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
type Vals = any[];
|
|
||||||
|
|
||||||
interface CallMessage {
|
|
||||||
type: MsgType.Call,
|
|
||||||
id: number,
|
|
||||||
method: string;
|
|
||||||
args: Vals;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ResponseMessage {
|
|
||||||
type: MsgType.Response,
|
|
||||||
id: number,
|
|
||||||
result?: Vals;
|
|
||||||
error?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
type Message = CallMessage | ResponseMessage;
|
|
||||||
|
|
||||||
interface CallResult {
|
|
||||||
result: Vals;
|
|
||||||
error: Error | null;
|
|
||||||
}
|
|
||||||
|
|
||||||
abstract class IPCCommon {
|
|
||||||
protected localApis: Record<string, any>;
|
protected localApis: Record<string, any>;
|
||||||
protected socketPath: string;
|
protected socketPath: string;
|
||||||
protected conn: net.Socket | null = null;
|
protected conn: net.Socket | null = null;
|
||||||
@ -75,6 +39,7 @@ abstract class IPCCommon {
|
|||||||
});
|
});
|
||||||
|
|
||||||
this.conn.on('close', (hadError: boolean) => {
|
this.conn.on('close', (hadError: boolean) => {
|
||||||
|
this.rejectPendingCalls(new Error('connection closed'));
|
||||||
if (hadError) {
|
if (hadError) {
|
||||||
this.raiseErr(new Error('connection closed due to error'));
|
this.raiseErr(new Error('connection closed due to error'));
|
||||||
}
|
}
|
||||||
@ -194,7 +159,6 @@ abstract class IPCCommon {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public serialize(arg: any): any {
|
public serialize(arg: any): any {
|
||||||
// noinspection FallThroughInSwitchStatementJS
|
|
||||||
switch (typeof arg) {
|
switch (typeof arg) {
|
||||||
case 'string':
|
case 'string':
|
||||||
case 'boolean':
|
case 'boolean':
|
||||||
@ -212,7 +176,6 @@ abstract class IPCCommon {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public deserialize(arg: any): any {
|
public deserialize(arg: any): any {
|
||||||
// noinspection FallThroughInSwitchStatementJS
|
|
||||||
switch (typeof arg) {
|
switch (typeof arg) {
|
||||||
case 'string':
|
case 'string':
|
||||||
case 'boolean':
|
case 'boolean':
|
||||||
@ -248,201 +211,15 @@ abstract class IPCCommon {
|
|||||||
if (this.onClose) this.onClose();
|
if (this.onClose) this.onClose();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected rejectPendingCalls(err: Error): void {
|
||||||
|
const pending = this.pendingCalls;
|
||||||
|
this.pendingCalls = {};
|
||||||
|
for (const callback of Object.values(pending)) {
|
||||||
|
callback({result: [], error: err});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
protected raiseErr(err: Error): void {
|
protected raiseErr(err: Error): void {
|
||||||
this.errorQueue.put(err);
|
this.errorQueue.put(err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
export class ParentIPC extends IPCCommon {
|
|
||||||
private readonly cmdPath: string;
|
|
||||||
private readonly cmdArgs: string[];
|
|
||||||
private cmd: ChildProcess | null = null;
|
|
||||||
private readonly listener: net.Server;
|
|
||||||
private cmdExitResult: { code: number | null, signal: string | null } | null = null;
|
|
||||||
private cmdExitCallbacks: ((result: { code: number | null, signal: string | null }) => void)[] = [];
|
|
||||||
|
|
||||||
constructor(cmdPath: string, cmdArgs: string[], ...localApis: object[]) {
|
|
||||||
const socketPath = path.join(os.tmpdir(), `kitten-ipc-${ process.pid }-${ crypto.randomInt(2**48 - 1) }.sock`);
|
|
||||||
super(localApis, socketPath);
|
|
||||||
|
|
||||||
this.cmdPath = cmdPath;
|
|
||||||
if (cmdArgs.includes(`--${ IPC_SOCKET_ARG }`)) {
|
|
||||||
throw new Error(`you should not use '--${ IPC_SOCKET_ARG }' argument in your command`);
|
|
||||||
}
|
|
||||||
this.cmdArgs = cmdArgs;
|
|
||||||
|
|
||||||
this.listener = net.createServer();
|
|
||||||
}
|
|
||||||
|
|
||||||
async start(): Promise<void> {
|
|
||||||
try {
|
|
||||||
fs.unlinkSync(this.socketPath);
|
|
||||||
} catch {
|
|
||||||
}
|
|
||||||
|
|
||||||
await new Promise<void>((resolve, reject) => {
|
|
||||||
this.listener.listen(this.socketPath, () => {
|
|
||||||
resolve();
|
|
||||||
});
|
|
||||||
this.listener.on('error', reject);
|
|
||||||
});
|
|
||||||
|
|
||||||
const cmdArgs = [...this.cmdArgs, `--${ IPC_SOCKET_ARG }`, this.socketPath];
|
|
||||||
this.cmd = spawn(this.cmdPath, cmdArgs, {stdio: 'inherit'});
|
|
||||||
|
|
||||||
this.cmd.on('error', (err) => {
|
|
||||||
this.raiseErr(err);
|
|
||||||
});
|
|
||||||
|
|
||||||
this.cmd.on('close', (code, signal) => {
|
|
||||||
const result = { code, signal };
|
|
||||||
this.cmdExitResult = result;
|
|
||||||
for (const cb of this.cmdExitCallbacks) cb(result);
|
|
||||||
this.cmdExitCallbacks = [];
|
|
||||||
});
|
|
||||||
|
|
||||||
await this.acceptConn();
|
|
||||||
}
|
|
||||||
|
|
||||||
private async acceptConn(): Promise<void> {
|
|
||||||
const acceptTimeout = 10000;
|
|
||||||
|
|
||||||
const acceptPromise = new Promise<net.Socket>((resolve, reject) => {
|
|
||||||
this.listener.once('connection', (conn) => {
|
|
||||||
resolve(conn);
|
|
||||||
});
|
|
||||||
this.listener.once('error', reject);
|
|
||||||
});
|
|
||||||
|
|
||||||
const exitPromise = new Promise<net.Socket>((_, reject) => {
|
|
||||||
if (this.cmdExitResult) {
|
|
||||||
reject(new Error(`command exited before connection established`));
|
|
||||||
} else {
|
|
||||||
this.cmdExitCallbacks.push(() => {
|
|
||||||
reject(new Error(`command exited before connection established`));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
try {
|
|
||||||
this.conn = await timeout(Promise.race([acceptPromise, exitPromise]), acceptTimeout);
|
|
||||||
this.readConn();
|
|
||||||
} catch (e) {
|
|
||||||
if (this.cmd) this.cmd.kill();
|
|
||||||
throw e;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async wait(): Promise<void> {
|
|
||||||
if (!this.cmd) {
|
|
||||||
throw new Error('Command is not started yet');
|
|
||||||
}
|
|
||||||
|
|
||||||
const exitPromise = new Promise<{ code: number | null, signal: string | null }>((resolve) => {
|
|
||||||
if (this.cmdExitResult) {
|
|
||||||
resolve(this.cmdExitResult);
|
|
||||||
} else {
|
|
||||||
this.cmdExitCallbacks.push(resolve);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
try {
|
|
||||||
await Promise.race([
|
|
||||||
exitPromise.then(({ code, signal }) => {
|
|
||||||
if (signal || code) {
|
|
||||||
if (signal) throw new Error(`Process exited with signal ${ signal }`);
|
|
||||||
else throw new Error(`Process exited with code ${ code }`);
|
|
||||||
} else if (!this.ready) {
|
|
||||||
throw new Error('command exited before connection established');
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
this.errorQueue.collect().then((errors) => {
|
|
||||||
if (errors.length === 1) {
|
|
||||||
throw errors[0];
|
|
||||||
} else if (errors.length > 1) {
|
|
||||||
throw new Error(errors.map(e => e.toString()).join(', '));
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
]);
|
|
||||||
} finally {
|
|
||||||
try { fs.unlinkSync(this.socketPath); } catch {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
export class ChildIPC extends IPCCommon {
|
|
||||||
constructor(...localApis: object[]) {
|
|
||||||
super(localApis, socketPathFromArgs());
|
|
||||||
}
|
|
||||||
|
|
||||||
async start(): Promise<void> {
|
|
||||||
return new Promise((resolve, reject) => {
|
|
||||||
this.conn = net.createConnection(this.socketPath, () => {
|
|
||||||
this.readConn();
|
|
||||||
resolve();
|
|
||||||
});
|
|
||||||
this.conn.on('error', reject);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
async wait(): Promise<void> {
|
|
||||||
const closePromise = new Promise<void>((resolve) => {
|
|
||||||
this.onClose = () => {
|
|
||||||
if (this.processingCalls === 0) {
|
|
||||||
this.conn?.destroy();
|
|
||||||
resolve();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if (this.stopRequested && this.processingCalls === 0) {
|
|
||||||
this.conn?.destroy();
|
|
||||||
resolve();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
const errorPromise = this.errorQueue.collect().then((errors) => {
|
|
||||||
if (errors.length === 1) {
|
|
||||||
throw errors[0];
|
|
||||||
} else if (errors.length > 1) {
|
|
||||||
throw new Error(errors.map(e => e.toString()).join(', '));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
await Promise.race([closePromise, errorPromise]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
function socketPathFromArgs(): string {
|
|
||||||
const {values} = util.parseArgs({
|
|
||||||
options: {
|
|
||||||
[IPC_SOCKET_ARG]: {
|
|
||||||
type: 'string',
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!values[IPC_SOCKET_ARG]) {
|
|
||||||
throw new Error('ipc socket path is missing');
|
|
||||||
}
|
|
||||||
|
|
||||||
return values[IPC_SOCKET_ARG];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
function sleep(ms: number): Promise<void> {
|
|
||||||
return new Promise(resolve => setTimeout(resolve, ms));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// throws on timeout
|
|
||||||
function timeout<T>(prom: Promise<T>, ms: number): Promise<T> {
|
|
||||||
return Promise.race(
|
|
||||||
[
|
|
||||||
prom,
|
|
||||||
new Promise((res, reject) => {
|
|
||||||
setTimeout(() => {reject(new Error('timed out'))}, ms)
|
|
||||||
})]
|
|
||||||
) as Promise<T>;
|
|
||||||
}
|
|
||||||
@ -1 +1,2 @@
|
|||||||
export {ParentIPC, ChildIPC} from './lib.js';
|
export {ParentIPC} from './parent.js';
|
||||||
|
export {ChildIPC} from './child.js';
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import {test} from 'vitest';
|
import {test} from 'vitest';
|
||||||
import {ParentIPC} from './lib.js';
|
import {ParentIPC} from './parent.js';
|
||||||
|
|
||||||
test('test connection timeout', async ({expect}) => {
|
test('test connection timeout', async ({expect}) => {
|
||||||
const parentIpc = new ParentIPC('../testdata/sleep15.sh', []);
|
const parentIpc = new ParentIPC('../testdata/sleep15.sh', []);
|
||||||
|
|||||||
126
lib/ts/src/parent.ts
Normal file
126
lib/ts/src/parent.ts
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
import * as net from 'node:net';
|
||||||
|
import * as os from 'node:os';
|
||||||
|
import * as path from 'node:path';
|
||||||
|
import * as fs from 'node:fs';
|
||||||
|
import * as crypto from 'node:crypto';
|
||||||
|
import {type ChildProcess, spawn} from 'node:child_process';
|
||||||
|
import {IPCCommon} from './common.js';
|
||||||
|
import {timeout} from './util.js';
|
||||||
|
|
||||||
|
const IPC_SOCKET_ARG = 'ipc-socket';
|
||||||
|
const ACCEPT_TIMEOUT_MS = 10000;
|
||||||
|
|
||||||
|
export class ParentIPC extends IPCCommon {
|
||||||
|
private readonly cmdPath: string;
|
||||||
|
private readonly cmdArgs: string[];
|
||||||
|
private cmd: ChildProcess | null = null;
|
||||||
|
private readonly listener: net.Server;
|
||||||
|
private cmdExitResult: { code: number | null, signal: string | null } | null = null;
|
||||||
|
private cmdExitCallbacks: ((result: { code: number | null, signal: string | null }) => void)[] = [];
|
||||||
|
|
||||||
|
constructor(cmdPath: string, cmdArgs: string[], ...localApis: object[]) {
|
||||||
|
const socketPath = path.join(os.tmpdir(), `kitten-ipc-${ process.pid }-${ crypto.randomInt(2**48 - 1) }.sock`);
|
||||||
|
super(localApis, socketPath);
|
||||||
|
|
||||||
|
this.cmdPath = cmdPath;
|
||||||
|
if (cmdArgs.includes(`--${ IPC_SOCKET_ARG }`)) {
|
||||||
|
throw new Error(`you should not use '--${ IPC_SOCKET_ARG }' argument in your command`);
|
||||||
|
}
|
||||||
|
this.cmdArgs = cmdArgs;
|
||||||
|
|
||||||
|
this.listener = net.createServer();
|
||||||
|
}
|
||||||
|
|
||||||
|
async start(): Promise<void> {
|
||||||
|
try {
|
||||||
|
fs.unlinkSync(this.socketPath);
|
||||||
|
} catch {
|
||||||
|
}
|
||||||
|
|
||||||
|
await new Promise<void>((resolve, reject) => {
|
||||||
|
this.listener.listen(this.socketPath, () => {
|
||||||
|
resolve();
|
||||||
|
});
|
||||||
|
this.listener.on('error', reject);
|
||||||
|
});
|
||||||
|
|
||||||
|
const cmdArgs = [...this.cmdArgs, `--${ IPC_SOCKET_ARG }`, this.socketPath];
|
||||||
|
this.cmd = spawn(this.cmdPath, cmdArgs, {stdio: 'inherit'});
|
||||||
|
|
||||||
|
this.cmd.on('error', (err) => {
|
||||||
|
this.raiseErr(err);
|
||||||
|
});
|
||||||
|
|
||||||
|
this.cmd.on('close', (code, signal) => {
|
||||||
|
const result = { code, signal };
|
||||||
|
this.cmdExitResult = result;
|
||||||
|
for (const cb of this.cmdExitCallbacks) cb(result);
|
||||||
|
this.cmdExitCallbacks = [];
|
||||||
|
});
|
||||||
|
|
||||||
|
await this.acceptConn();
|
||||||
|
}
|
||||||
|
|
||||||
|
private async acceptConn(): Promise<void> {
|
||||||
|
const acceptPromise = new Promise<net.Socket>((resolve, reject) => {
|
||||||
|
this.listener.once('connection', (conn) => {
|
||||||
|
resolve(conn);
|
||||||
|
});
|
||||||
|
this.listener.once('error', reject);
|
||||||
|
});
|
||||||
|
|
||||||
|
const exitPromise = new Promise<net.Socket>((_, reject) => {
|
||||||
|
if (this.cmdExitResult) {
|
||||||
|
reject(new Error(`command exited before connection established`));
|
||||||
|
} else {
|
||||||
|
this.cmdExitCallbacks.push(() => {
|
||||||
|
reject(new Error(`command exited before connection established`));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
this.conn = await timeout(Promise.race([acceptPromise, exitPromise]), ACCEPT_TIMEOUT_MS);
|
||||||
|
this.readConn();
|
||||||
|
} catch (e) {
|
||||||
|
if (this.cmd) this.cmd.kill();
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async wait(): Promise<void> {
|
||||||
|
if (!this.cmd) {
|
||||||
|
throw new Error('Command is not started yet');
|
||||||
|
}
|
||||||
|
|
||||||
|
const exitPromise = new Promise<{ code: number | null, signal: string | null }>((resolve) => {
|
||||||
|
if (this.cmdExitResult) {
|
||||||
|
resolve(this.cmdExitResult);
|
||||||
|
} else {
|
||||||
|
this.cmdExitCallbacks.push(resolve);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
await Promise.race([
|
||||||
|
exitPromise.then(({ code, signal }) => {
|
||||||
|
if (signal || code) {
|
||||||
|
if (signal) throw new Error(`Process exited with signal ${ signal }`);
|
||||||
|
else throw new Error(`Process exited with code ${ code }`);
|
||||||
|
} else if (!this.ready) {
|
||||||
|
throw new Error('command exited before connection established');
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
this.errorQueue.collect().then((errors) => {
|
||||||
|
if (errors.length === 1) {
|
||||||
|
throw errors[0];
|
||||||
|
} else if (errors.length > 1) {
|
||||||
|
throw new Error(errors.map(e => e.toString()).join(', '));
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
]);
|
||||||
|
} finally {
|
||||||
|
try { fs.unlinkSync(this.socketPath); } catch {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
27
lib/ts/src/protocol.ts
Normal file
27
lib/ts/src/protocol.ts
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
export enum MsgType {
|
||||||
|
Call = 1,
|
||||||
|
Response = 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
export type Vals = any[];
|
||||||
|
|
||||||
|
export interface CallMessage {
|
||||||
|
type: MsgType.Call,
|
||||||
|
id: number,
|
||||||
|
method: string;
|
||||||
|
args: Vals;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ResponseMessage {
|
||||||
|
type: MsgType.Response,
|
||||||
|
id: number,
|
||||||
|
result?: Vals;
|
||||||
|
error?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type Message = CallMessage | ResponseMessage;
|
||||||
|
|
||||||
|
export interface CallResult {
|
||||||
|
result: Vals;
|
||||||
|
error: Error | null;
|
||||||
|
}
|
||||||
68
lib/ts/src/serialize.test.ts
Normal file
68
lib/ts/src/serialize.test.ts
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import {test} from 'vitest';
|
||||||
|
import {ChildIPC} from './child.js';
|
||||||
|
|
||||||
|
// Access serialize/deserialize through a ChildIPC instance's public methods
|
||||||
|
// We create a minimal wrapper to test them
|
||||||
|
class TestableIPC {
|
||||||
|
private ipc: any;
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
// Access the prototype methods directly
|
||||||
|
this.ipc = Object.create(ChildIPC.prototype);
|
||||||
|
}
|
||||||
|
|
||||||
|
serialize(arg: any): any {
|
||||||
|
return this.ipc.serialize(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
deserialize(arg: any): any {
|
||||||
|
return this.ipc.deserialize(arg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test('serialize primitives', ({expect}) => {
|
||||||
|
const t = new TestableIPC();
|
||||||
|
expect(t.serialize(42)).toBe(42);
|
||||||
|
expect(t.serialize('hello')).toBe('hello');
|
||||||
|
expect(t.serialize(true)).toBe(true);
|
||||||
|
expect(t.serialize(3.14)).toBe(3.14);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('serialize buffer to base64', ({expect}) => {
|
||||||
|
const t = new TestableIPC();
|
||||||
|
const buf = Buffer.from([1, 2, 3]);
|
||||||
|
expect(t.serialize(buf)).toBe('AQID');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('deserialize primitives', ({expect}) => {
|
||||||
|
const t = new TestableIPC();
|
||||||
|
expect(t.deserialize(42)).toBe(42);
|
||||||
|
expect(t.deserialize('hello')).toBe('hello');
|
||||||
|
expect(t.deserialize(true)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
test('deserialize blob to buffer', ({expect}) => {
|
||||||
|
const t = new TestableIPC();
|
||||||
|
const result = t.deserialize({t: 'blob', d: 'AQID'});
|
||||||
|
expect(Buffer.isBuffer(result)).toBe(true);
|
||||||
|
expect(result).toEqual(Buffer.from([1, 2, 3]));
|
||||||
|
});
|
||||||
|
|
||||||
|
test('serialize then deserialize blob round-trip', ({expect}) => {
|
||||||
|
const t = new TestableIPC();
|
||||||
|
const original = Buffer.from([0xDE, 0xAD, 0xBE, 0xEF]);
|
||||||
|
// Go serializes as {t: 'blob', d: base64}, TS serializes as just base64 string
|
||||||
|
// The TS serialize returns base64 string directly for Buffer
|
||||||
|
const serialized = t.serialize(original);
|
||||||
|
expect(typeof serialized).toBe('string');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('deserialize unknown object throws', ({expect}) => {
|
||||||
|
const t = new TestableIPC();
|
||||||
|
expect(() => t.deserialize({foo: 'bar'})).toThrow('cannot deserialize');
|
||||||
|
});
|
||||||
|
|
||||||
|
test('serialize unsupported object throws', ({expect}) => {
|
||||||
|
const t = new TestableIPC();
|
||||||
|
expect(() => t.serialize({foo: 'bar'})).toThrow('cannot serialize');
|
||||||
|
});
|
||||||
29
lib/ts/src/util.ts
Normal file
29
lib/ts/src/util.ts
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import * as util from 'node:util';
|
||||||
|
|
||||||
|
const IPC_SOCKET_ARG = 'ipc-socket';
|
||||||
|
|
||||||
|
export function socketPathFromArgs(): string {
|
||||||
|
const {values} = util.parseArgs({
|
||||||
|
options: {
|
||||||
|
[IPC_SOCKET_ARG]: {
|
||||||
|
type: 'string',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!values[IPC_SOCKET_ARG]) {
|
||||||
|
throw new Error('ipc socket path is missing');
|
||||||
|
}
|
||||||
|
|
||||||
|
return values[IPC_SOCKET_ARG];
|
||||||
|
}
|
||||||
|
|
||||||
|
export function timeout<T>(prom: Promise<T>, ms: number): Promise<T> {
|
||||||
|
return Promise.race(
|
||||||
|
[
|
||||||
|
prom,
|
||||||
|
new Promise((res, reject) => {
|
||||||
|
setTimeout(() => {reject(new Error('timed out'))}, ms)
|
||||||
|
})]
|
||||||
|
) as Promise<T>;
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user