refactor
This commit is contained in:
parent
687364767a
commit
b7f547839a
BIN
example/golang/simple
Executable file
BIN
example/golang/simple
Executable file
Binary file not shown.
@ -7,3 +7,10 @@ require (
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/text v0.30.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/stretchr/testify v1.11.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@ -1,6 +1,16 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e h1:Lf/gRkoycfOBPa42vU2bbgPurFong6zXeFtPoxholzU=
|
||||
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e/go.mod h1:uNVvRXArCGbZ508SxYYTC5v1JWoz2voff5pm25jU1Ok=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@ -1,14 +1,18 @@
|
||||
package api
|
||||
|
||||
import "efprojects.com/kitten-ipc/types"
|
||||
type ValType string
|
||||
|
||||
// todo check TInt size < 64
|
||||
// todo check not float
|
||||
const (
|
||||
TNoType ValType = ""
|
||||
TInt ValType = "int"
|
||||
TString ValType = "string"
|
||||
TBool ValType = "bool"
|
||||
TBlob ValType = "blob"
|
||||
)
|
||||
|
||||
type Val struct {
|
||||
Name string
|
||||
Type types.ValType
|
||||
Children []Val
|
||||
Name string
|
||||
Type ValType
|
||||
}
|
||||
|
||||
type Method struct {
|
||||
|
||||
19
kitcom/internal/common/write.go
Normal file
19
kitcom/internal/common/write.go
Normal file
@ -0,0 +1,19 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func WriteFile(destFile string, content []byte) error {
|
||||
f, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open destination file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.Write(content); err != nil {
|
||||
return fmt.Errorf("write file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -4,18 +4,15 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"os"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
_ "embed"
|
||||
|
||||
"efprojects.com/kitten-ipc/kitcom/internal/api"
|
||||
"efprojects.com/kitten-ipc/types"
|
||||
"efprojects.com/kitten-ipc/kitcom/internal/common"
|
||||
)
|
||||
|
||||
// todo: check int overflow
|
||||
// todo: check float is whole
|
||||
|
||||
//go:embed gogen.tmpl
|
||||
var templateString string
|
||||
|
||||
@ -34,43 +31,41 @@ func (g *GoApiGenerator) Generate(apis *api.Api, destFile string) error {
|
||||
Api: apis,
|
||||
}
|
||||
|
||||
const defaultReceiver = "self"
|
||||
|
||||
tpl := template.New("gogen")
|
||||
tpl = tpl.Funcs(map[string]any{
|
||||
"receiver": func(name string) string {
|
||||
return defaultReceiver
|
||||
return strings.ToLower(name[:1])
|
||||
},
|
||||
"typedef": func(t types.ValType) (string, error) {
|
||||
td, ok := map[types.ValType]string{
|
||||
types.TInt: "int",
|
||||
types.TString: "string",
|
||||
types.TBool: "bool",
|
||||
types.TBlob: "[]byte",
|
||||
"typedef": func(t api.ValType) (string, error) {
|
||||
td, ok := map[api.ValType]string{
|
||||
api.TInt: "int",
|
||||
api.TString: "string",
|
||||
api.TBool: "bool",
|
||||
api.TBlob: "[]byte",
|
||||
}[t]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("cannot generate type %v", t)
|
||||
}
|
||||
return td, nil
|
||||
},
|
||||
"convtype": func(valDef string, t types.ValType) (string, error) {
|
||||
td, ok := map[types.ValType]string{
|
||||
types.TInt: fmt.Sprintf("int(%s.(float64))", valDef),
|
||||
types.TString: fmt.Sprintf("%s.(string)", valDef),
|
||||
types.TBool: fmt.Sprintf("%s.(bool)", valDef),
|
||||
types.TBlob: fmt.Sprintf("%s.([]byte)", valDef),
|
||||
"convtype": func(valDef string, t api.ValType) (string, error) {
|
||||
td, ok := map[api.ValType]string{
|
||||
api.TInt: fmt.Sprintf("int(%s.(float64))", valDef),
|
||||
api.TString: fmt.Sprintf("%s.(string)", valDef),
|
||||
api.TBool: fmt.Sprintf("%s.(bool)", valDef),
|
||||
api.TBlob: fmt.Sprintf("%s.([]byte)", valDef),
|
||||
}[t]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("cannot convert type %v for val %s", t, valDef)
|
||||
}
|
||||
return td, nil
|
||||
},
|
||||
"zerovalue": func(t types.ValType) (string, error) {
|
||||
v, ok := map[types.ValType]string{
|
||||
types.TInt: "0",
|
||||
types.TString: `""`,
|
||||
types.TBool: "false",
|
||||
types.TBlob: "[]byte{}",
|
||||
"zerovalue": func(t api.ValType) (string, error) {
|
||||
v, ok := map[api.ValType]string{
|
||||
api.TInt: "0",
|
||||
api.TString: `""`,
|
||||
api.TBool: "false",
|
||||
api.TBlob: "[]byte{}",
|
||||
}[t]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("cannot generate zero value for type %v", t)
|
||||
@ -86,27 +81,14 @@ func (g *GoApiGenerator) Generate(apis *api.Api, destFile string) error {
|
||||
return fmt.Errorf("execute template: %w", err)
|
||||
}
|
||||
|
||||
if err := g.writeDest(destFile, buf.Bytes()); err != nil {
|
||||
formatted, err := format.Source(buf.Bytes())
|
||||
if err != nil {
|
||||
return fmt.Errorf("format source: %w", err)
|
||||
}
|
||||
|
||||
if err := common.WriteFile(destFile, formatted); err != nil {
|
||||
return fmt.Errorf("write file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *GoApiGenerator) writeDest(destFile string, bytes []byte) error {
|
||||
f, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open destination file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
formatted, err := format.Source(bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("format source: %w", err)
|
||||
}
|
||||
|
||||
if _, err := f.Write(formatted); err != nil {
|
||||
return fmt.Errorf("write formatted source: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -28,7 +28,9 @@ func ({{ $e.Name | receiver }} *{{ $e.Name }}) {{ $mtd.Name }}(
|
||||
if err != nil {
|
||||
return {{ range $mtd.Ret }}{{ .Type | zerovalue }}, {{ end }} fmt.Errorf("call to {{ $e.Name }}.{{ $mtd.Name }} failed: %w", err)
|
||||
}
|
||||
_ = results
|
||||
if len(results) < {{ len $mtd.Ret }} {
|
||||
return {{ range $mtd.Ret }}{{ .Type | zerovalue }}, {{ end }} fmt.Errorf("call to {{ $e.Name }}.{{ $mtd.Name }}: expected {{ len $mtd.Ret }} results, got %d", len(results))
|
||||
}
|
||||
{{ range $i, $ret := $mtd.Ret }}
|
||||
{{ if eq $ret.Type "blob" }}
|
||||
results[{{ $i }}], err = base64.StdEncoding.DecodeString(results[{{ $i }}].(string))
|
||||
|
||||
@ -9,7 +9,6 @@ import (
|
||||
|
||||
"efprojects.com/kitten-ipc/kitcom/internal/api"
|
||||
"efprojects.com/kitten-ipc/kitcom/internal/common"
|
||||
"efprojects.com/kitten-ipc/types"
|
||||
)
|
||||
|
||||
var decorComment = regexp.MustCompile(`^//\s?kittenipc:api$`)
|
||||
@ -142,11 +141,11 @@ func fieldToVal(param *ast.Field, returning bool) (*api.Val, error) {
|
||||
case *ast.Ident:
|
||||
switch paramType.Name {
|
||||
case "int":
|
||||
val.Type = types.TInt
|
||||
val.Type = api.TInt
|
||||
case "string":
|
||||
val.Type = types.TString
|
||||
val.Type = api.TString
|
||||
case "bool":
|
||||
val.Type = types.TBool
|
||||
val.Type = api.TBool
|
||||
case "error":
|
||||
if returning {
|
||||
return nil, nil
|
||||
@ -161,7 +160,7 @@ func fieldToVal(param *ast.Field, returning bool) (*api.Val, error) {
|
||||
case *ast.Ident:
|
||||
switch elementType.Name {
|
||||
case "byte":
|
||||
val.Type = types.TBlob
|
||||
val.Type = api.TBlob
|
||||
default:
|
||||
return nil, fmt.Errorf("parameter type %s is not supported yet", elementType.Name)
|
||||
}
|
||||
|
||||
43
kitcom/internal/golang/goparser_test.go
Normal file
43
kitcom/internal/golang/goparser_test.go
Normal file
@ -0,0 +1,43 @@
|
||||
package golang
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"efprojects.com/kitten-ipc/kitcom/internal/api"
|
||||
"efprojects.com/kitten-ipc/kitcom/internal/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGoParser(t *testing.T) {
|
||||
parser := &GoApiParser{Parser: &common.Parser{}}
|
||||
parser.AddFile("../../../example/golang/main.go")
|
||||
|
||||
result, err := parser.Parse()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.Endpoints, 1)
|
||||
|
||||
ep := result.Endpoints[0]
|
||||
assert.Equal(t, "GoIpcApi", ep.Name)
|
||||
require.Len(t, ep.Methods, 2)
|
||||
|
||||
// Div method
|
||||
div := ep.Methods[0]
|
||||
assert.Equal(t, "Div", div.Name)
|
||||
require.Len(t, div.Params, 2)
|
||||
assert.Equal(t, api.TInt, div.Params[0].Type)
|
||||
assert.Equal(t, "a", div.Params[0].Name)
|
||||
assert.Equal(t, api.TInt, div.Params[1].Type)
|
||||
assert.Equal(t, "b", div.Params[1].Name)
|
||||
require.Len(t, div.Ret, 1)
|
||||
assert.Equal(t, api.TInt, div.Ret[0].Type)
|
||||
|
||||
// XorData method
|
||||
xor := ep.Methods[1]
|
||||
assert.Equal(t, "XorData", xor.Name)
|
||||
require.Len(t, xor.Params, 2)
|
||||
assert.Equal(t, api.TBlob, xor.Params[0].Type)
|
||||
assert.Equal(t, api.TBlob, xor.Params[1].Type)
|
||||
require.Len(t, xor.Ret, 1)
|
||||
assert.Equal(t, api.TBlob, xor.Ret[0].Type)
|
||||
}
|
||||
@ -4,14 +4,13 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"text/template"
|
||||
|
||||
_ "embed"
|
||||
|
||||
"efprojects.com/kitten-ipc/kitcom/internal/api"
|
||||
"efprojects.com/kitten-ipc/types"
|
||||
"efprojects.com/kitten-ipc/kitcom/internal/common"
|
||||
)
|
||||
|
||||
//go:embed tsgen.tmpl
|
||||
@ -31,24 +30,24 @@ func (g *TypescriptApiGenerator) Generate(apis *api.Api, destFile string) error
|
||||
|
||||
tpl := template.New("tsgen")
|
||||
tpl = tpl.Funcs(map[string]any{
|
||||
"typedef": func(t types.ValType) (string, error) {
|
||||
td, ok := map[types.ValType]string{
|
||||
types.TInt: "number",
|
||||
types.TString: "string",
|
||||
types.TBool: "boolean",
|
||||
types.TBlob: "Buffer",
|
||||
"typedef": func(t api.ValType) (string, error) {
|
||||
td, ok := map[api.ValType]string{
|
||||
api.TInt: "number",
|
||||
api.TString: "string",
|
||||
api.TBool: "boolean",
|
||||
api.TBlob: "Buffer",
|
||||
}[t]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("cannot generate type %v", t)
|
||||
}
|
||||
return td, nil
|
||||
},
|
||||
"convtype": func(valDef string, t types.ValType) (string, error) {
|
||||
td, ok := map[types.ValType]string{
|
||||
types.TInt: fmt.Sprintf("%s as number", valDef),
|
||||
types.TString: fmt.Sprintf("%s as string", valDef),
|
||||
types.TBool: fmt.Sprintf("%s as boolean", valDef),
|
||||
types.TBlob: fmt.Sprintf("Buffer.from(%s, 'base64')", valDef),
|
||||
"convtype": func(valDef string, t api.ValType) (string, error) {
|
||||
td, ok := map[api.ValType]string{
|
||||
api.TInt: fmt.Sprintf("%s as number", valDef),
|
||||
api.TString: fmt.Sprintf("%s as string", valDef),
|
||||
api.TBool: fmt.Sprintf("%s as boolean", valDef),
|
||||
api.TBlob: fmt.Sprintf("Buffer.from(%s, 'base64')", valDef),
|
||||
}[t]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("cannot convert type %v for val %s", t, valDef)
|
||||
@ -64,24 +63,10 @@ func (g *TypescriptApiGenerator) Generate(apis *api.Api, destFile string) error
|
||||
return fmt.Errorf("execute template: %w", err)
|
||||
}
|
||||
|
||||
if err := g.writeDest(destFile, buf.Bytes()); err != nil {
|
||||
if err := common.WriteFile(destFile, buf.Bytes()); err != nil {
|
||||
return fmt.Errorf("write file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *TypescriptApiGenerator) writeDest(destFile string, bytes []byte) error {
|
||||
f, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open destination file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.Write(bytes); err != nil {
|
||||
return fmt.Errorf("write formatted source: %w", err)
|
||||
}
|
||||
|
||||
prettierCmd := exec.Command("npx", "prettier", destFile, "--write")
|
||||
if out, err := prettierCmd.CombinedOutput(); err != nil {
|
||||
log.Printf("Prettier returned error: %v", err)
|
||||
|
||||
@ -12,7 +12,6 @@ import (
|
||||
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/core"
|
||||
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/parser"
|
||||
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/tspath"
|
||||
"efprojects.com/kitten-ipc/types"
|
||||
)
|
||||
|
||||
type TypescriptApiParser struct {
|
||||
@ -109,25 +108,25 @@ func (p *TypescriptApiParser) parseFile(sourceFilePath string) ([]api.Endpoint,
|
||||
return endpoints, nil
|
||||
}
|
||||
|
||||
func (p *TypescriptApiParser) fieldToVal(typ *ast.TypeNode) (types.ValType, error) {
|
||||
func (p *TypescriptApiParser) fieldToVal(typ *ast.TypeNode) (api.ValType, error) {
|
||||
switch typ.Kind {
|
||||
case ast.KindNumberKeyword:
|
||||
return types.TInt, nil
|
||||
return api.TInt, nil
|
||||
case ast.KindStringKeyword:
|
||||
return types.TString, nil
|
||||
return api.TString, nil
|
||||
case ast.KindBooleanKeyword:
|
||||
return types.TBool, nil
|
||||
return api.TBool, nil
|
||||
case ast.KindTypeReference:
|
||||
refNode := typ.AsTypeReferenceNode()
|
||||
ident := refNode.TypeName.AsIdentifier()
|
||||
switch ident.Text {
|
||||
case "Buffer":
|
||||
return types.TBlob, nil
|
||||
return api.TBlob, nil
|
||||
default:
|
||||
return types.TNoType, fmt.Errorf("reference type %s is not supported yet", ident.Text)
|
||||
return api.TNoType, fmt.Errorf("reference type %s is not supported yet", ident.Text)
|
||||
}
|
||||
default:
|
||||
return types.TNoType, fmt.Errorf("type %s is not supported yet", typ.Kind)
|
||||
return api.TNoType, fmt.Errorf("type %s is not supported yet", typ.Kind)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
46
kitcom/internal/ts/tsparser_test.go
Normal file
46
kitcom/internal/ts/tsparser_test.go
Normal file
@ -0,0 +1,46 @@
|
||||
package ts
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"efprojects.com/kitten-ipc/kitcom/internal/api"
|
||||
"efprojects.com/kitten-ipc/kitcom/internal/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTsParser(t *testing.T) {
|
||||
parser := &TypescriptApiParser{Parser: &common.Parser{}}
|
||||
absPath, err := filepath.Abs("../../../example/ts/src/index.ts")
|
||||
require.NoError(t, err)
|
||||
parser.AddFile(absPath)
|
||||
|
||||
result, err := parser.Parse()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result.Endpoints, 1)
|
||||
|
||||
ep := result.Endpoints[0]
|
||||
assert.Equal(t, "TsIpcApi", ep.Name)
|
||||
require.Len(t, ep.Methods, 2)
|
||||
|
||||
// Div method
|
||||
div := ep.Methods[0]
|
||||
assert.Equal(t, "Div", div.Name)
|
||||
require.Len(t, div.Params, 2)
|
||||
assert.Equal(t, api.TInt, div.Params[0].Type)
|
||||
assert.Equal(t, "a", div.Params[0].Name)
|
||||
assert.Equal(t, api.TInt, div.Params[1].Type)
|
||||
assert.Equal(t, "b", div.Params[1].Name)
|
||||
require.Len(t, div.Ret, 1)
|
||||
assert.Equal(t, api.TInt, div.Ret[0].Type)
|
||||
|
||||
// XorData method
|
||||
xor := ep.Methods[1]
|
||||
assert.Equal(t, "XorData", xor.Name)
|
||||
require.Len(t, xor.Params, 2)
|
||||
assert.Equal(t, api.TBlob, xor.Params[0].Type)
|
||||
assert.Equal(t, api.TBlob, xor.Params[1].Type)
|
||||
require.Len(t, xor.Ret, 1)
|
||||
assert.Equal(t, api.TBlob, xor.Ret[0].Type)
|
||||
}
|
||||
60
lib/golang/child.go
Normal file
60
lib/golang/child.go
Normal file
@ -0,0 +1,60 @@
|
||||
package golang
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
)
|
||||
|
||||
type ChildIPC struct {
|
||||
*ipcCommon
|
||||
}
|
||||
|
||||
func NewChild(localApis ...any) (*ChildIPC, error) {
|
||||
c := ChildIPC{
|
||||
ipcCommon: &ipcCommon{
|
||||
localApis: mapTypeNames(localApis),
|
||||
pendingCalls: make(map[int64]*pendingCall),
|
||||
errCh: make(chan error, 1),
|
||||
ctx: context.Background(),
|
||||
},
|
||||
}
|
||||
|
||||
socketPath := socketPathFromArgs()
|
||||
if socketPath == "" {
|
||||
return nil, fmt.Errorf("ipc socket path is missing")
|
||||
}
|
||||
c.socketPath = socketPath
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
func (c *ChildIPC) Start() error {
|
||||
conn, err := net.Dial("unix", c.socketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to parent socket: %w", err)
|
||||
}
|
||||
c.conn = conn
|
||||
go c.readConn()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ChildIPC) Wait() error {
|
||||
err := <-c.errCh
|
||||
if err != nil {
|
||||
return fmt.Errorf("ipc error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// socketPathFromArgs parses --ipc-socket from os.Args without calling flag.Parse(),
|
||||
// which would interfere with the host application's flag handling.
|
||||
func socketPathFromArgs() string {
|
||||
for i, arg := range os.Args {
|
||||
if arg == ipcSocketArg && i+1 < len(os.Args) {
|
||||
return os.Args[i+1]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
261
lib/golang/common.go
Normal file
261
lib/golang/common.go
Normal file
@ -0,0 +1,261 @@
|
||||
package golang
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type IpcCommon interface {
|
||||
Call(method string, params ...any) (Vals, error)
|
||||
ConvType(needType, gotType reflect.Type, arg any) any
|
||||
}
|
||||
|
||||
type callResult struct {
|
||||
vals Vals
|
||||
err error
|
||||
}
|
||||
|
||||
type pendingCall struct {
|
||||
resultChan chan callResult
|
||||
}
|
||||
|
||||
type ipcCommon struct {
|
||||
localApis map[string]any
|
||||
socketPath string
|
||||
conn net.Conn
|
||||
errCh chan error
|
||||
nextId int64
|
||||
pendingCalls map[int64]*pendingCall
|
||||
processingCalls atomic.Int64
|
||||
stopRequested atomic.Bool
|
||||
mu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) readConn() {
|
||||
scn := bufio.NewScanner(ipc.conn)
|
||||
scn.Buffer(nil, maxMessageLength)
|
||||
for scn.Scan() {
|
||||
var msg Message
|
||||
msgBytes := scn.Bytes()
|
||||
if err := json.Unmarshal(msgBytes, &msg); err != nil {
|
||||
ipc.raiseErr(fmt.Errorf("unmarshal message: %w", err))
|
||||
break
|
||||
}
|
||||
ipc.processMsg(msg)
|
||||
}
|
||||
if err := scn.Err(); err != nil {
|
||||
ipc.raiseErr(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) processMsg(msg Message) {
|
||||
switch msg.Type {
|
||||
case MsgCall:
|
||||
go ipc.handleCall(msg)
|
||||
case MsgResponse:
|
||||
ipc.handleResponse(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) sendMsg(msg Message) error {
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal message: %w", err)
|
||||
}
|
||||
data = append(data, '\n')
|
||||
|
||||
ipc.writeMu.Lock()
|
||||
_, writeErr := ipc.conn.Write(data)
|
||||
ipc.writeMu.Unlock()
|
||||
if writeErr != nil {
|
||||
return fmt.Errorf("write message: %w", writeErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) handleCall(msg Message) {
|
||||
if ipc.stopRequested.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
ipc.processingCalls.Add(1)
|
||||
defer ipc.processingCalls.Add(-1)
|
||||
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
ipc.sendResponse(msg.Id, nil, fmt.Errorf("handle call panicked: %s", err))
|
||||
}
|
||||
}()
|
||||
|
||||
method, err := ipc.findMethod(msg.Method)
|
||||
if err != nil {
|
||||
ipc.sendResponse(msg.Id, nil, fmt.Errorf("find method: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
argsCount := method.Type().NumIn()
|
||||
if len(msg.Args) != argsCount {
|
||||
ipc.sendResponse(msg.Id, nil, fmt.Errorf("args count mismatch: expected %d, got %d", argsCount, len(msg.Args)))
|
||||
return
|
||||
}
|
||||
|
||||
var args []reflect.Value
|
||||
for i, arg := range msg.Args {
|
||||
paramType := method.Type().In(i)
|
||||
argType := reflect.TypeOf(arg)
|
||||
arg = ipc.ConvType(paramType, argType, arg)
|
||||
args = append(args, reflect.ValueOf(arg))
|
||||
}
|
||||
|
||||
allResultVals := method.Call(args)
|
||||
retResultVals := allResultVals[0 : len(allResultVals)-1]
|
||||
errResultVals := allResultVals[len(allResultVals)-1]
|
||||
|
||||
var results []any
|
||||
for _, resVal := range retResultVals {
|
||||
results = append(results, resVal.Interface())
|
||||
}
|
||||
|
||||
var resErr error
|
||||
if !errResultVals.IsNil() {
|
||||
resErr = errResultVals.Interface().(error)
|
||||
}
|
||||
|
||||
ipc.sendResponse(msg.Id, results, resErr)
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) findMethod(methodName string) (reflect.Value, error) {
|
||||
parts := strings.Split(methodName, ".")
|
||||
if len(parts) != 2 {
|
||||
return reflect.Value{}, fmt.Errorf("invalid method: %s", methodName)
|
||||
}
|
||||
|
||||
endpointName, methodName := parts[0], parts[1]
|
||||
|
||||
localApi, ok := ipc.localApis[endpointName]
|
||||
if !ok {
|
||||
return reflect.Value{}, fmt.Errorf("endpoint not found: %s", endpointName)
|
||||
}
|
||||
|
||||
method := reflect.ValueOf(localApi).MethodByName(methodName)
|
||||
if !method.IsValid() {
|
||||
return reflect.Value{}, fmt.Errorf("method not found: %s", methodName)
|
||||
}
|
||||
|
||||
return method, nil
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) sendResponse(id int64, result []any, err error) {
|
||||
msg := Message{
|
||||
Type: MsgResponse,
|
||||
Id: id,
|
||||
Result: result,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
msg.Error = err.Error()
|
||||
}
|
||||
|
||||
if err := ipc.sendMsg(msg); err != nil {
|
||||
ipc.raiseErr(fmt.Errorf("send response for id=%d: %w", id, err))
|
||||
}
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) handleResponse(msg Message) {
|
||||
ipc.mu.Lock()
|
||||
call, ok := ipc.pendingCalls[msg.Id]
|
||||
if ok {
|
||||
delete(ipc.pendingCalls, msg.Id)
|
||||
}
|
||||
ipc.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
ipc.raiseErr(fmt.Errorf("received response for unknown call id: %d", msg.Id))
|
||||
return
|
||||
}
|
||||
|
||||
var res callResult
|
||||
if msg.Error == "" {
|
||||
res = callResult{vals: msg.Result}
|
||||
} else {
|
||||
res = callResult{err: fmt.Errorf("remote error: %s", msg.Error)}
|
||||
}
|
||||
call.resultChan <- res
|
||||
close(call.resultChan)
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) Call(method string, params ...any) (Vals, error) {
|
||||
if ipc.conn == nil {
|
||||
return nil, fmt.Errorf("ipc is not connected to remote process socket")
|
||||
}
|
||||
|
||||
if ipc.stopRequested.Load() {
|
||||
return nil, fmt.Errorf("ipc is stopping")
|
||||
}
|
||||
|
||||
ipc.mu.Lock()
|
||||
id := ipc.nextId
|
||||
ipc.nextId++
|
||||
call := &pendingCall{
|
||||
resultChan: make(chan callResult, 1),
|
||||
}
|
||||
ipc.pendingCalls[id] = call
|
||||
ipc.mu.Unlock()
|
||||
|
||||
for i := range params {
|
||||
params[i] = ipc.serialize(params[i])
|
||||
}
|
||||
|
||||
msg := Message{
|
||||
Type: MsgCall,
|
||||
Id: id,
|
||||
Method: method,
|
||||
Args: params,
|
||||
}
|
||||
|
||||
if err := ipc.sendMsg(msg); err != nil {
|
||||
ipc.mu.Lock()
|
||||
delete(ipc.pendingCalls, id)
|
||||
ipc.mu.Unlock()
|
||||
return nil, fmt.Errorf("send call: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case result := <-call.resultChan:
|
||||
return result.vals, result.err
|
||||
case <-ipc.ctx.Done():
|
||||
ipc.mu.Lock()
|
||||
delete(ipc.pendingCalls, id)
|
||||
ipc.mu.Unlock()
|
||||
return nil, ipc.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) raiseErr(err error) {
|
||||
select {
|
||||
case ipc.errCh <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) closeConn() {
|
||||
_ = ipc.conn.Close()
|
||||
ipc.mu.Lock()
|
||||
pending := ipc.pendingCalls
|
||||
ipc.pendingCalls = make(map[int64]*pendingCall)
|
||||
ipc.mu.Unlock()
|
||||
for _, call := range pending {
|
||||
call.resultChan <- callResult{err: fmt.Errorf("call cancelled due to ipc termination")}
|
||||
close(call.resultChan)
|
||||
}
|
||||
}
|
||||
43
lib/golang/common_test.go
Normal file
43
lib/golang/common_test.go
Normal file
@ -0,0 +1,43 @@
|
||||
package golang
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type testEndpoint struct{}
|
||||
|
||||
func (e *testEndpoint) Hello(name string) (string, error) {
|
||||
return "hello " + name, nil
|
||||
}
|
||||
|
||||
func TestFindMethod(t *testing.T) {
|
||||
ipc := &ipcCommon{
|
||||
localApis: mapTypeNames([]any{&testEndpoint{}}),
|
||||
}
|
||||
|
||||
t.Run("valid method", func(t *testing.T) {
|
||||
method, err := ipc.findMethod("testEndpoint.Hello")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, method.IsValid())
|
||||
})
|
||||
|
||||
t.Run("invalid format - no dot", func(t *testing.T) {
|
||||
_, err := ipc.findMethod("Hello")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid method")
|
||||
})
|
||||
|
||||
t.Run("unknown endpoint", func(t *testing.T) {
|
||||
_, err := ipc.findMethod("Unknown.Hello")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "endpoint not found")
|
||||
})
|
||||
|
||||
t.Run("unknown method", func(t *testing.T) {
|
||||
_, err := ipc.findMethod("testEndpoint.Unknown")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "method not found")
|
||||
})
|
||||
}
|
||||
@ -2,4 +2,10 @@ module efprojects.com/kitten-ipc
|
||||
|
||||
go 1.25.1
|
||||
|
||||
require github.com/samber/mo v1.16.0
|
||||
require github.com/stretchr/testify v1.11.1
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@ -1,2 +1,10 @@
|
||||
github.com/samber/mo v1.16.0 h1:qpEPCI63ou6wXlsNDMLE0IIN8A+devbGX/K1xdgr4b4=
|
||||
github.com/samber/mo v1.16.0/go.mod h1:DlgzJ4SYhOh41nP1L9kh9rDNERuf8IqWSAs+gj2Vxag=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@ -1,550 +0,0 @@
|
||||
package golang
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"efprojects.com/kitten-ipc/types"
|
||||
"github.com/samber/mo"
|
||||
)
|
||||
|
||||
const ipcSocketArg = "--ipc-socket"
|
||||
const maxMessageLength = 1 * 1024 * 1024 * 1024 // 1 gigabyte
|
||||
|
||||
type StdioMode int
|
||||
|
||||
type MsgType int
|
||||
|
||||
type Vals []any
|
||||
|
||||
const (
|
||||
MsgCall MsgType = 1
|
||||
MsgResponse MsgType = 2
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Type MsgType `json:"type"`
|
||||
Id int64 `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Args Vals `json:"args"`
|
||||
Result Vals `json:"result"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type IpcCommon interface {
|
||||
Call(method string, params ...any) (Vals, error)
|
||||
ConvType(needType reflect.Type, gotType reflect.Type, arg any) any
|
||||
}
|
||||
|
||||
type pendingCall struct {
|
||||
resultChan chan mo.Result[Vals]
|
||||
resultType reflect.Type
|
||||
}
|
||||
|
||||
type ipcCommon struct {
|
||||
localApis map[string]any
|
||||
socketPath string
|
||||
conn net.Conn
|
||||
errCh chan error
|
||||
nextId int64
|
||||
pendingCalls map[int64]*pendingCall
|
||||
processingCalls atomic.Int64
|
||||
stopRequested atomic.Bool
|
||||
mu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) readConn() {
|
||||
scn := bufio.NewScanner(ipc.conn)
|
||||
scn.Buffer(nil, maxMessageLength)
|
||||
for scn.Scan() {
|
||||
var msg Message
|
||||
msgBytes := scn.Bytes()
|
||||
//log.Println(string(msgBytes))
|
||||
if err := json.Unmarshal(msgBytes, &msg); err != nil {
|
||||
ipc.raiseErr(fmt.Errorf("unmarshal message: %w", err))
|
||||
break
|
||||
}
|
||||
ipc.processMsg(msg)
|
||||
}
|
||||
if err := scn.Err(); err != nil {
|
||||
ipc.raiseErr(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) processMsg(msg Message) {
|
||||
switch msg.Type {
|
||||
case MsgCall:
|
||||
go ipc.handleCall(msg)
|
||||
case MsgResponse:
|
||||
ipc.handleResponse(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) sendMsg(msg Message) error {
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal message: %w", err)
|
||||
}
|
||||
data = append(data, '\n')
|
||||
|
||||
ipc.writeMu.Lock()
|
||||
_, writeErr := ipc.conn.Write(data)
|
||||
ipc.writeMu.Unlock()
|
||||
if writeErr != nil {
|
||||
return fmt.Errorf("write message: %w", writeErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) handleCall(msg Message) {
|
||||
if ipc.stopRequested.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
ipc.processingCalls.Add(1)
|
||||
defer ipc.processingCalls.Add(-1)
|
||||
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
ipc.sendResponse(msg.Id, nil, fmt.Errorf("handle call panicked: %s", err))
|
||||
}
|
||||
}()
|
||||
|
||||
method, err := ipc.findMethod(msg.Method)
|
||||
if err != nil {
|
||||
ipc.sendResponse(msg.Id, nil, fmt.Errorf("find method: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
argsCount := method.Type().NumIn()
|
||||
if len(msg.Args) != argsCount {
|
||||
ipc.sendResponse(msg.Id, nil, fmt.Errorf("args count mismatch: expected %d, got %d", argsCount, len(msg.Args)))
|
||||
return
|
||||
}
|
||||
|
||||
var args []reflect.Value
|
||||
for i, arg := range msg.Args {
|
||||
paramType := method.Type().In(i)
|
||||
argType := reflect.TypeOf(arg)
|
||||
arg = ipc.ConvType(paramType, argType, arg)
|
||||
args = append(args, reflect.ValueOf(arg))
|
||||
}
|
||||
|
||||
allResultVals := method.Call(args)
|
||||
retResultVals := allResultVals[0 : len(allResultVals)-1]
|
||||
errResultVals := allResultVals[len(allResultVals)-1]
|
||||
|
||||
var results []any
|
||||
for _, resVal := range retResultVals {
|
||||
results = append(results, resVal.Interface())
|
||||
}
|
||||
|
||||
var resErr error
|
||||
if !errResultVals.IsNil() {
|
||||
resErr = errResultVals.Interface().(error)
|
||||
}
|
||||
|
||||
ipc.sendResponse(msg.Id, results, resErr)
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) findMethod(methodName string) (reflect.Value, error) {
|
||||
parts := strings.Split(methodName, ".")
|
||||
if len(parts) != 2 {
|
||||
return reflect.Value{}, fmt.Errorf("invalid method: %s", methodName)
|
||||
}
|
||||
|
||||
endpointName, methodName := parts[0], parts[1]
|
||||
|
||||
localApi, ok := ipc.localApis[endpointName]
|
||||
if !ok {
|
||||
return reflect.Value{}, fmt.Errorf("endpoint not found: %s", endpointName)
|
||||
}
|
||||
|
||||
method := reflect.ValueOf(localApi).MethodByName(methodName)
|
||||
if !method.IsValid() {
|
||||
return reflect.Value{}, fmt.Errorf("method not found: %s", methodName)
|
||||
}
|
||||
|
||||
return method, nil
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) sendResponse(id int64, result []any, err error) {
|
||||
msg := Message{
|
||||
Type: MsgResponse,
|
||||
Id: id,
|
||||
Result: result,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
msg.Error = err.Error()
|
||||
}
|
||||
|
||||
if err := ipc.sendMsg(msg); err != nil {
|
||||
ipc.raiseErr(fmt.Errorf("send response for id=%d: %w", id, err))
|
||||
}
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) handleResponse(msg Message) {
|
||||
ipc.mu.Lock()
|
||||
call, ok := ipc.pendingCalls[msg.Id]
|
||||
if ok {
|
||||
delete(ipc.pendingCalls, msg.Id)
|
||||
}
|
||||
ipc.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
ipc.raiseErr(fmt.Errorf("received response for unknown call id: %d", msg.Id))
|
||||
return
|
||||
}
|
||||
|
||||
var res mo.Result[Vals]
|
||||
if msg.Error == "" {
|
||||
res = mo.Ok[Vals](msg.Result)
|
||||
} else {
|
||||
res = mo.Err[Vals](fmt.Errorf("remote error: %s", msg.Error))
|
||||
}
|
||||
call.resultChan <- res
|
||||
close(call.resultChan)
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) Call(method string, params ...any) (Vals, error) {
|
||||
if ipc.conn == nil {
|
||||
return nil, fmt.Errorf("ipc is not connected to remote process socket")
|
||||
}
|
||||
|
||||
if ipc.stopRequested.Load() {
|
||||
return nil, fmt.Errorf("ipc is stopping")
|
||||
}
|
||||
|
||||
ipc.mu.Lock()
|
||||
id := ipc.nextId
|
||||
ipc.nextId++
|
||||
call := &pendingCall{
|
||||
resultChan: make(chan mo.Result[Vals], 1),
|
||||
}
|
||||
ipc.pendingCalls[id] = call
|
||||
ipc.mu.Unlock()
|
||||
|
||||
for i := range params {
|
||||
params[i] = ipc.serialize(params[i])
|
||||
}
|
||||
|
||||
msg := Message{
|
||||
Type: MsgCall,
|
||||
Id: id,
|
||||
Method: method,
|
||||
Args: params,
|
||||
}
|
||||
|
||||
if err := ipc.sendMsg(msg); err != nil {
|
||||
ipc.mu.Lock()
|
||||
delete(ipc.pendingCalls, id)
|
||||
ipc.mu.Unlock()
|
||||
return nil, fmt.Errorf("send call: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case result := <-call.resultChan:
|
||||
return result.Get()
|
||||
case <-ipc.ctx.Done():
|
||||
ipc.mu.Lock()
|
||||
delete(ipc.pendingCalls, id)
|
||||
ipc.mu.Unlock()
|
||||
return nil, ipc.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) raiseErr(err error) {
|
||||
select {
|
||||
case ipc.errCh <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) closeConn() {
|
||||
_ = ipc.conn.Close()
|
||||
ipc.mu.Lock()
|
||||
pending := ipc.pendingCalls
|
||||
ipc.pendingCalls = make(map[int64]*pendingCall)
|
||||
ipc.mu.Unlock()
|
||||
for _, call := range pending {
|
||||
call.resultChan <- mo.Err[Vals](fmt.Errorf("call cancelled due to ipc termination"))
|
||||
close(call.resultChan)
|
||||
}
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) ConvType(needType reflect.Type, gotType reflect.Type, arg any) any {
|
||||
switch needType.Kind() {
|
||||
case reflect.Int:
|
||||
// JSON decodes any number to float64. If we need int, we should check and convert
|
||||
if gotType.Kind() == reflect.Float64 {
|
||||
floatArg := arg.(float64)
|
||||
if float64(int64(floatArg)) == floatArg && !needType.OverflowInt(int64(floatArg)) {
|
||||
arg = int(floatArg)
|
||||
}
|
||||
}
|
||||
}
|
||||
return arg
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) serialize(arg any) any {
|
||||
t := reflect.TypeOf(arg)
|
||||
switch t.Kind() {
|
||||
case reflect.Slice:
|
||||
switch t.Elem().Name() {
|
||||
case "uint8":
|
||||
return map[string]any{
|
||||
"t": types.TBlob,
|
||||
"d": base64.StdEncoding.EncodeToString(arg.([]byte)),
|
||||
}
|
||||
}
|
||||
}
|
||||
return arg
|
||||
}
|
||||
|
||||
type ParentIPC struct {
|
||||
*ipcCommon
|
||||
cmd *exec.Cmd
|
||||
listener net.Listener
|
||||
cmdDone chan struct{}
|
||||
cmdErr error
|
||||
}
|
||||
|
||||
func NewParent(cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) {
|
||||
return NewParentWithContext(context.Background(), cmd, localApis...)
|
||||
}
|
||||
|
||||
func NewParentWithContext(ctx context.Context, cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) {
|
||||
p := ParentIPC{
|
||||
ipcCommon: &ipcCommon{
|
||||
localApis: mapTypeNames(localApis),
|
||||
pendingCalls: make(map[int64]*pendingCall),
|
||||
errCh: make(chan error, 1),
|
||||
socketPath: filepath.Join(os.TempDir(), fmt.Sprintf("kitten-ipc-%d-%d.sock", os.Getpid(), rand.Int63())),
|
||||
ctx: ctx,
|
||||
},
|
||||
cmd: cmd,
|
||||
}
|
||||
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if slices.Contains(cmd.Args, ipcSocketArg) {
|
||||
return nil, fmt.Errorf("you should not use `%s` argument in your command", ipcSocketArg)
|
||||
}
|
||||
cmd.Args = append(cmd.Args, ipcSocketArg, p.socketPath)
|
||||
|
||||
p.errCh = make(chan error, 1)
|
||||
p.cmdDone = make(chan struct{})
|
||||
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
func (p *ParentIPC) Start() error {
|
||||
_ = os.Remove(p.socketPath)
|
||||
listener, err := net.Listen("unix", p.socketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen unix socket: %w", err)
|
||||
}
|
||||
p.listener = listener
|
||||
defer p.listener.Close()
|
||||
|
||||
err = p.cmd.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cmd start: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
p.cmdErr = p.cmd.Wait()
|
||||
close(p.cmdDone)
|
||||
}()
|
||||
|
||||
return p.acceptConn()
|
||||
}
|
||||
|
||||
func (p *ParentIPC) acceptConn() error {
|
||||
const acceptTimeout = time.Second * 10
|
||||
|
||||
res := make(chan mo.Result[net.Conn], 1)
|
||||
go func() {
|
||||
conn, err := p.listener.Accept()
|
||||
if err != nil {
|
||||
res <- mo.Err[net.Conn](err)
|
||||
} else {
|
||||
res <- mo.Ok[net.Conn](conn)
|
||||
}
|
||||
close(res)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(acceptTimeout):
|
||||
_ = p.cmd.Process.Kill()
|
||||
return fmt.Errorf("accept timeout")
|
||||
case <-p.cmdDone:
|
||||
return fmt.Errorf("cmd exited before accepting connection: %w", p.cmdErr)
|
||||
case res := <-res:
|
||||
if res.IsError() {
|
||||
_ = p.cmd.Process.Kill()
|
||||
return fmt.Errorf("accept: %w", res.Error())
|
||||
}
|
||||
p.conn = res.MustGet()
|
||||
go p.readConn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ParentIPC) Stop() error {
|
||||
p.mu.Lock()
|
||||
hasPending := len(p.pendingCalls) > 0
|
||||
p.mu.Unlock()
|
||||
if hasPending {
|
||||
return fmt.Errorf("there are calls pending")
|
||||
}
|
||||
if p.processingCalls.Load() > 0 {
|
||||
return fmt.Errorf("there are calls processing")
|
||||
}
|
||||
p.stopRequested.Store(true)
|
||||
if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil {
|
||||
return fmt.Errorf("send SIGINT: %w", err)
|
||||
}
|
||||
return p.Wait()
|
||||
}
|
||||
|
||||
func (p *ParentIPC) Wait(timeout ...time.Duration) (retErr error) {
|
||||
const defaultTimeout = time.Duration(1<<63 - 1) // max duration in go
|
||||
_timeout := variadicToOption(timeout).OrElse(defaultTimeout)
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case err := <-p.errCh:
|
||||
retErr = mergeErr(retErr, fmt.Errorf("ipc internal error: %w", err))
|
||||
break loop
|
||||
case <-p.cmdDone:
|
||||
err := p.cmdErr
|
||||
if err != nil {
|
||||
var exitErr *exec.ExitError
|
||||
if ok := errors.As(err, &exitErr); ok {
|
||||
if !exitErr.Success() {
|
||||
ws, ok := exitErr.Sys().(syscall.WaitStatus)
|
||||
if !(ok && ws.Signaled() && ws.Signal() == syscall.SIGINT && p.stopRequested.Load()) {
|
||||
retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err))
|
||||
}
|
||||
}
|
||||
break loop
|
||||
case <-time.After(_timeout):
|
||||
p.stopRequested.Store(true)
|
||||
if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil {
|
||||
retErr = mergeErr(retErr, fmt.Errorf("send SIGINT: %w", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.closeConn()
|
||||
_ = os.Remove(p.socketPath)
|
||||
|
||||
return retErr
|
||||
}
|
||||
|
||||
type ChildIPC struct {
|
||||
*ipcCommon
|
||||
}
|
||||
|
||||
func NewChild(localApis ...any) (*ChildIPC, error) {
|
||||
c := ChildIPC{
|
||||
ipcCommon: &ipcCommon{
|
||||
localApis: mapTypeNames(localApis),
|
||||
pendingCalls: make(map[int64]*pendingCall),
|
||||
errCh: make(chan error, 1),
|
||||
ctx: context.Background(), // todo: NewChildWithContext
|
||||
},
|
||||
}
|
||||
|
||||
socketPath := flag.String("ipc-socket", "", "Path to IPC socket")
|
||||
flag.Parse()
|
||||
|
||||
if *socketPath == "" {
|
||||
return nil, fmt.Errorf("ipc socket path is missing")
|
||||
}
|
||||
c.socketPath = *socketPath
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
func (c *ChildIPC) Start() error {
|
||||
conn, err := net.Dial("unix", c.socketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to parent socket: %w", err)
|
||||
}
|
||||
c.conn = conn
|
||||
go c.readConn()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ChildIPC) Wait() error {
|
||||
err := <-c.errCh
|
||||
if err != nil {
|
||||
return fmt.Errorf("ipc error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mergeErr(errs ...error) (ret error) {
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
if ret == nil {
|
||||
ret = err
|
||||
} else {
|
||||
ret = fmt.Errorf("%w; %w", ret, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func mapTypeNames(types []any) map[string]any {
|
||||
result := make(map[string]any)
|
||||
for _, t := range types {
|
||||
if reflect.TypeOf(t).Kind() != reflect.Pointer {
|
||||
panic(fmt.Sprintf("LocalAPI argument must be pointer"))
|
||||
}
|
||||
typeName := reflect.TypeOf(t).Elem().Name()
|
||||
result[typeName] = t
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func variadicToOption[T any](variadic []T) mo.Option[T] {
|
||||
if len(variadic) >= 2 {
|
||||
panic("variadic param count must be 0 or 1")
|
||||
}
|
||||
if len(variadic) == 0 {
|
||||
return mo.None[T]()
|
||||
}
|
||||
return mo.Some(variadic[0])
|
||||
}
|
||||
165
lib/golang/parent.go
Normal file
165
lib/golang/parent.go
Normal file
@ -0,0 +1,165 @@
|
||||
package golang
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ParentIPC struct {
|
||||
*ipcCommon
|
||||
cmd *exec.Cmd
|
||||
listener net.Listener
|
||||
cmdDone chan struct{}
|
||||
cmdErr error
|
||||
}
|
||||
|
||||
func NewParent(cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) {
|
||||
return NewParentWithContext(context.Background(), cmd, localApis...)
|
||||
}
|
||||
|
||||
func NewParentWithContext(ctx context.Context, cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) {
|
||||
p := ParentIPC{
|
||||
ipcCommon: &ipcCommon{
|
||||
localApis: mapTypeNames(localApis),
|
||||
pendingCalls: make(map[int64]*pendingCall),
|
||||
errCh: make(chan error, 1),
|
||||
socketPath: filepath.Join(os.TempDir(), fmt.Sprintf("kitten-ipc-%d-%d.sock", os.Getpid(), rand.Int63())),
|
||||
ctx: ctx,
|
||||
},
|
||||
cmd: cmd,
|
||||
}
|
||||
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if slices.Contains(cmd.Args, ipcSocketArg) {
|
||||
return nil, fmt.Errorf("you should not use `%s` argument in your command", ipcSocketArg)
|
||||
}
|
||||
cmd.Args = append(cmd.Args, ipcSocketArg, p.socketPath)
|
||||
|
||||
p.errCh = make(chan error, 1)
|
||||
p.cmdDone = make(chan struct{})
|
||||
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
func (p *ParentIPC) Start() error {
|
||||
_ = os.Remove(p.socketPath)
|
||||
listener, err := net.Listen("unix", p.socketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen unix socket: %w", err)
|
||||
}
|
||||
p.listener = listener
|
||||
defer p.listener.Close()
|
||||
|
||||
err = p.cmd.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cmd start: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
p.cmdErr = p.cmd.Wait()
|
||||
close(p.cmdDone)
|
||||
}()
|
||||
|
||||
return p.acceptConn()
|
||||
}
|
||||
|
||||
type connResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
func (p *ParentIPC) acceptConn() error {
|
||||
res := make(chan connResult, 1)
|
||||
go func() {
|
||||
conn, err := p.listener.Accept()
|
||||
res <- connResult{conn: conn, err: err}
|
||||
close(res)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(time.Duration(defaultAcceptTimeout) * time.Second):
|
||||
_ = p.cmd.Process.Kill()
|
||||
return fmt.Errorf("accept timeout")
|
||||
case <-p.cmdDone:
|
||||
return fmt.Errorf("cmd exited before accepting connection: %w", p.cmdErr)
|
||||
case r := <-res:
|
||||
if r.err != nil {
|
||||
_ = p.cmd.Process.Kill()
|
||||
return fmt.Errorf("accept: %w", r.err)
|
||||
}
|
||||
p.conn = r.conn
|
||||
go p.readConn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ParentIPC) Stop() error {
|
||||
p.mu.Lock()
|
||||
hasPending := len(p.pendingCalls) > 0
|
||||
p.mu.Unlock()
|
||||
if hasPending {
|
||||
return fmt.Errorf("there are calls pending")
|
||||
}
|
||||
if p.processingCalls.Load() > 0 {
|
||||
return fmt.Errorf("there are calls processing")
|
||||
}
|
||||
p.stopRequested.Store(true)
|
||||
if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil {
|
||||
return fmt.Errorf("send SIGINT: %w", err)
|
||||
}
|
||||
return p.Wait()
|
||||
}
|
||||
|
||||
func (p *ParentIPC) Wait(timeout ...time.Duration) (retErr error) {
|
||||
const maxDuration = time.Duration(1<<63 - 1)
|
||||
_timeout := maxDuration
|
||||
if len(timeout) > 0 {
|
||||
_timeout = timeout[0]
|
||||
}
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case err := <-p.errCh:
|
||||
retErr = mergeErr(retErr, fmt.Errorf("ipc internal error: %w", err))
|
||||
break loop
|
||||
case <-p.cmdDone:
|
||||
err := p.cmdErr
|
||||
if err != nil {
|
||||
var exitErr *exec.ExitError
|
||||
if ok := errors.As(err, &exitErr); ok {
|
||||
if !exitErr.Success() {
|
||||
ws, ok := exitErr.Sys().(syscall.WaitStatus)
|
||||
if !(ok && ws.Signaled() && ws.Signal() == syscall.SIGINT && p.stopRequested.Load()) {
|
||||
retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err))
|
||||
}
|
||||
}
|
||||
break loop
|
||||
case <-time.After(_timeout):
|
||||
p.stopRequested.Store(true)
|
||||
if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil {
|
||||
retErr = mergeErr(retErr, fmt.Errorf("send SIGINT: %w", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.closeConn()
|
||||
_ = os.Remove(p.socketPath)
|
||||
|
||||
return retErr
|
||||
}
|
||||
23
lib/golang/protocol.go
Normal file
23
lib/golang/protocol.go
Normal file
@ -0,0 +1,23 @@
|
||||
package golang
|
||||
|
||||
const ipcSocketArg = "--ipc-socket"
|
||||
const maxMessageLength = 1 << 30 // 1 GB
|
||||
const defaultAcceptTimeout = 10 // seconds
|
||||
|
||||
type MsgType int
|
||||
|
||||
type Vals []any
|
||||
|
||||
const (
|
||||
MsgCall MsgType = 1
|
||||
MsgResponse MsgType = 2
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Type MsgType `json:"type"`
|
||||
Id int64 `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Args Vals `json:"args"`
|
||||
Result Vals `json:"result"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
35
lib/golang/serialize.go
Normal file
35
lib/golang/serialize.go
Normal file
@ -0,0 +1,35 @@
|
||||
package golang
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func (ipc *ipcCommon) serialize(arg any) any {
|
||||
t := reflect.TypeOf(arg)
|
||||
switch t.Kind() {
|
||||
case reflect.Slice:
|
||||
switch t.Elem().Name() {
|
||||
case "uint8":
|
||||
return map[string]any{
|
||||
"t": "blob",
|
||||
"d": base64.StdEncoding.EncodeToString(arg.([]byte)),
|
||||
}
|
||||
}
|
||||
}
|
||||
return arg
|
||||
}
|
||||
|
||||
func (ipc *ipcCommon) ConvType(needType reflect.Type, gotType reflect.Type, arg any) any {
|
||||
switch needType.Kind() {
|
||||
case reflect.Int:
|
||||
// JSON decodes any number to float64. If we need int, we should check and convert
|
||||
if gotType.Kind() == reflect.Float64 {
|
||||
floatArg := arg.(float64)
|
||||
if float64(int64(floatArg)) == floatArg && !needType.OverflowInt(int64(floatArg)) {
|
||||
arg = int(floatArg)
|
||||
}
|
||||
}
|
||||
}
|
||||
return arg
|
||||
}
|
||||
61
lib/golang/serialize_test.go
Normal file
61
lib/golang/serialize_test.go
Normal file
@ -0,0 +1,61 @@
|
||||
package golang
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSerialize(t *testing.T) {
|
||||
ipc := &ipcCommon{}
|
||||
|
||||
t.Run("primitives pass through", func(t *testing.T) {
|
||||
assert.Equal(t, 42, ipc.serialize(42))
|
||||
assert.Equal(t, "hello", ipc.serialize("hello"))
|
||||
assert.Equal(t, true, ipc.serialize(true))
|
||||
assert.Equal(t, 3.14, ipc.serialize(3.14))
|
||||
})
|
||||
|
||||
t.Run("byte slice serializes to blob", func(t *testing.T) {
|
||||
data := []byte{0x01, 0x02, 0x03}
|
||||
result := ipc.serialize(data)
|
||||
m, ok := result.(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "blob", m["t"])
|
||||
assert.Equal(t, "AQID", m["d"]) // base64 of {1,2,3}
|
||||
})
|
||||
|
||||
t.Run("empty byte slice serializes to blob", func(t *testing.T) {
|
||||
data := []byte{}
|
||||
result := ipc.serialize(data)
|
||||
m, ok := result.(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "blob", m["t"])
|
||||
assert.Equal(t, "", m["d"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestConvType(t *testing.T) {
|
||||
ipc := &ipcCommon{}
|
||||
|
||||
t.Run("float64 to int", func(t *testing.T) {
|
||||
result := ipc.ConvType(reflect.TypeOf(0), reflect.TypeOf(0.0), float64(42))
|
||||
assert.Equal(t, 42, result)
|
||||
})
|
||||
|
||||
t.Run("float64 with fractional part stays float", func(t *testing.T) {
|
||||
result := ipc.ConvType(reflect.TypeOf(0), reflect.TypeOf(0.0), float64(42.5))
|
||||
assert.Equal(t, float64(42.5), result)
|
||||
})
|
||||
|
||||
t.Run("string passes through", func(t *testing.T) {
|
||||
result := ipc.ConvType(reflect.TypeOf(""), reflect.TypeOf(""), "hello")
|
||||
assert.Equal(t, "hello", result)
|
||||
})
|
||||
|
||||
t.Run("bool passes through", func(t *testing.T) {
|
||||
result := ipc.ConvType(reflect.TypeOf(true), reflect.TypeOf(true), true)
|
||||
assert.Equal(t, true, result)
|
||||
})
|
||||
}
|
||||
@ -1,12 +0,0 @@
|
||||
package types
|
||||
|
||||
type ValType string
|
||||
|
||||
const (
|
||||
TNoType ValType = "" // zero value constant for ValType (please don't use just "" as zero value!)
|
||||
TInt ValType = "int"
|
||||
TString ValType = "string"
|
||||
TBool ValType = "bool"
|
||||
TBlob ValType = "blob"
|
||||
TArray ValType = "array"
|
||||
)
|
||||
31
lib/golang/util.go
Normal file
31
lib/golang/util.go
Normal file
@ -0,0 +1,31 @@
|
||||
package golang
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func mergeErr(errs ...error) (ret error) {
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
if ret == nil {
|
||||
ret = err
|
||||
} else {
|
||||
ret = fmt.Errorf("%w; %w", ret, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func mapTypeNames(types []any) map[string]any {
|
||||
result := make(map[string]any)
|
||||
for _, t := range types {
|
||||
if reflect.TypeOf(t).Kind() != reflect.Pointer {
|
||||
panic(fmt.Sprintf("LocalAPI argument must be pointer"))
|
||||
}
|
||||
typeName := reflect.TypeOf(t).Elem().Name()
|
||||
result[typeName] = t
|
||||
}
|
||||
return result
|
||||
}
|
||||
46
lib/golang/util_test.go
Normal file
46
lib/golang/util_test.go
Normal file
@ -0,0 +1,46 @@
|
||||
package golang
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMergeErr(t *testing.T) {
|
||||
t.Run("all nil returns nil", func(t *testing.T) {
|
||||
assert.NoError(t, mergeErr(nil, nil, nil))
|
||||
})
|
||||
|
||||
t.Run("single error returns it", func(t *testing.T) {
|
||||
err := fmt.Errorf("one")
|
||||
assert.EqualError(t, mergeErr(nil, err, nil), "one")
|
||||
})
|
||||
|
||||
t.Run("multiple errors merged", func(t *testing.T) {
|
||||
err1 := fmt.Errorf("one")
|
||||
err2 := fmt.Errorf("two")
|
||||
result := mergeErr(err1, err2)
|
||||
assert.ErrorContains(t, result, "one")
|
||||
assert.ErrorContains(t, result, "two")
|
||||
})
|
||||
}
|
||||
|
||||
func TestMapTypeNames(t *testing.T) {
|
||||
type Foo struct{}
|
||||
type Bar struct{}
|
||||
|
||||
t.Run("maps pointer types by name", func(t *testing.T) {
|
||||
foo := &Foo{}
|
||||
bar := &Bar{}
|
||||
result := mapTypeNames([]any{foo, bar})
|
||||
assert.Equal(t, foo, result["Foo"])
|
||||
assert.Equal(t, bar, result["Bar"])
|
||||
})
|
||||
|
||||
t.Run("panics on non-pointer", func(t *testing.T) {
|
||||
assert.Panics(t, func() {
|
||||
mapTypeNames([]any{Foo{}})
|
||||
})
|
||||
})
|
||||
}
|
||||
27
lib/ts/src/asyncqueue.test.ts
Normal file
27
lib/ts/src/asyncqueue.test.ts
Normal file
@ -0,0 +1,27 @@
|
||||
import {test} from 'vitest';
|
||||
import {AsyncQueue} from './asyncqueue.js';
|
||||
|
||||
test('put then collect returns items', async ({expect}) => {
|
||||
const q = new AsyncQueue<number>();
|
||||
q.put(1);
|
||||
q.put(2);
|
||||
const items = await q.collect();
|
||||
expect(items).toEqual([1, 2]);
|
||||
});
|
||||
|
||||
test('collect then put resolves on first put', async ({expect}) => {
|
||||
const q = new AsyncQueue<number>();
|
||||
const collectPromise = q.collect();
|
||||
q.put(42);
|
||||
const items = await collectPromise;
|
||||
expect(items).toEqual([42]);
|
||||
});
|
||||
|
||||
test('multiple puts before collect', async ({expect}) => {
|
||||
const q = new AsyncQueue<string>();
|
||||
q.put('a');
|
||||
q.put('b');
|
||||
q.put('c');
|
||||
const items = await q.collect();
|
||||
expect(items).toEqual(['a', 'b', 'c']);
|
||||
});
|
||||
44
lib/ts/src/child.ts
Normal file
44
lib/ts/src/child.ts
Normal file
@ -0,0 +1,44 @@
|
||||
import * as net from 'node:net';
|
||||
import {IPCCommon} from './common.js';
|
||||
import {socketPathFromArgs} from './util.js';
|
||||
|
||||
export class ChildIPC extends IPCCommon {
|
||||
constructor(...localApis: object[]) {
|
||||
super(localApis, socketPathFromArgs());
|
||||
}
|
||||
|
||||
async start(): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
this.conn = net.createConnection(this.socketPath, () => {
|
||||
this.readConn();
|
||||
resolve();
|
||||
});
|
||||
this.conn.on('error', reject);
|
||||
});
|
||||
}
|
||||
|
||||
async wait(): Promise<void> {
|
||||
const closePromise = new Promise<void>((resolve) => {
|
||||
this.onClose = () => {
|
||||
if (this.processingCalls === 0) {
|
||||
this.conn?.destroy();
|
||||
resolve();
|
||||
}
|
||||
};
|
||||
if (this.stopRequested && this.processingCalls === 0) {
|
||||
this.conn?.destroy();
|
||||
resolve();
|
||||
}
|
||||
});
|
||||
|
||||
const errorPromise = this.errorQueue.collect().then((errors) => {
|
||||
if (errors.length === 1) {
|
||||
throw errors[0];
|
||||
} else if (errors.length > 1) {
|
||||
throw new Error(errors.map(e => e.toString()).join(', '));
|
||||
}
|
||||
});
|
||||
|
||||
await Promise.race([closePromise, errorPromise]);
|
||||
}
|
||||
}
|
||||
@ -1,46 +1,10 @@
|
||||
import * as net from 'node:net';
|
||||
import * as readline from 'node:readline';
|
||||
import {type ChildProcess, spawn} from 'node:child_process';
|
||||
import * as os from 'node:os';
|
||||
import * as path from 'node:path';
|
||||
import * as fs from 'node:fs';
|
||||
import * as util from 'node:util';
|
||||
import * as crypto from 'node:crypto';
|
||||
import {AsyncQueue} from './asyncqueue.js';
|
||||
import type {CallMessage, CallResult, Message, ResponseMessage, Vals} from './protocol.js';
|
||||
import {MsgType} from './protocol.js';
|
||||
|
||||
const IPC_SOCKET_ARG = 'ipc-socket';
|
||||
|
||||
type JSONSerializable = string | number | boolean;
|
||||
|
||||
enum MsgType {
|
||||
Call = 1,
|
||||
Response = 2,
|
||||
}
|
||||
|
||||
type Vals = any[];
|
||||
|
||||
interface CallMessage {
|
||||
type: MsgType.Call,
|
||||
id: number,
|
||||
method: string;
|
||||
args: Vals;
|
||||
}
|
||||
|
||||
interface ResponseMessage {
|
||||
type: MsgType.Response,
|
||||
id: number,
|
||||
result?: Vals;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
type Message = CallMessage | ResponseMessage;
|
||||
|
||||
interface CallResult {
|
||||
result: Vals;
|
||||
error: Error | null;
|
||||
}
|
||||
|
||||
abstract class IPCCommon {
|
||||
export abstract class IPCCommon {
|
||||
protected localApis: Record<string, any>;
|
||||
protected socketPath: string;
|
||||
protected conn: net.Socket | null = null;
|
||||
@ -75,6 +39,7 @@ abstract class IPCCommon {
|
||||
});
|
||||
|
||||
this.conn.on('close', (hadError: boolean) => {
|
||||
this.rejectPendingCalls(new Error('connection closed'));
|
||||
if (hadError) {
|
||||
this.raiseErr(new Error('connection closed due to error'));
|
||||
}
|
||||
@ -194,7 +159,6 @@ abstract class IPCCommon {
|
||||
}
|
||||
|
||||
public serialize(arg: any): any {
|
||||
// noinspection FallThroughInSwitchStatementJS
|
||||
switch (typeof arg) {
|
||||
case 'string':
|
||||
case 'boolean':
|
||||
@ -212,7 +176,6 @@ abstract class IPCCommon {
|
||||
}
|
||||
|
||||
public deserialize(arg: any): any {
|
||||
// noinspection FallThroughInSwitchStatementJS
|
||||
switch (typeof arg) {
|
||||
case 'string':
|
||||
case 'boolean':
|
||||
@ -248,201 +211,15 @@ abstract class IPCCommon {
|
||||
if (this.onClose) this.onClose();
|
||||
}
|
||||
|
||||
protected rejectPendingCalls(err: Error): void {
|
||||
const pending = this.pendingCalls;
|
||||
this.pendingCalls = {};
|
||||
for (const callback of Object.values(pending)) {
|
||||
callback({result: [], error: err});
|
||||
}
|
||||
}
|
||||
|
||||
protected raiseErr(err: Error): void {
|
||||
this.errorQueue.put(err);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
export class ParentIPC extends IPCCommon {
|
||||
private readonly cmdPath: string;
|
||||
private readonly cmdArgs: string[];
|
||||
private cmd: ChildProcess | null = null;
|
||||
private readonly listener: net.Server;
|
||||
private cmdExitResult: { code: number | null, signal: string | null } | null = null;
|
||||
private cmdExitCallbacks: ((result: { code: number | null, signal: string | null }) => void)[] = [];
|
||||
|
||||
constructor(cmdPath: string, cmdArgs: string[], ...localApis: object[]) {
|
||||
const socketPath = path.join(os.tmpdir(), `kitten-ipc-${ process.pid }-${ crypto.randomInt(2**48 - 1) }.sock`);
|
||||
super(localApis, socketPath);
|
||||
|
||||
this.cmdPath = cmdPath;
|
||||
if (cmdArgs.includes(`--${ IPC_SOCKET_ARG }`)) {
|
||||
throw new Error(`you should not use '--${ IPC_SOCKET_ARG }' argument in your command`);
|
||||
}
|
||||
this.cmdArgs = cmdArgs;
|
||||
|
||||
this.listener = net.createServer();
|
||||
}
|
||||
|
||||
async start(): Promise<void> {
|
||||
try {
|
||||
fs.unlinkSync(this.socketPath);
|
||||
} catch {
|
||||
}
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
this.listener.listen(this.socketPath, () => {
|
||||
resolve();
|
||||
});
|
||||
this.listener.on('error', reject);
|
||||
});
|
||||
|
||||
const cmdArgs = [...this.cmdArgs, `--${ IPC_SOCKET_ARG }`, this.socketPath];
|
||||
this.cmd = spawn(this.cmdPath, cmdArgs, {stdio: 'inherit'});
|
||||
|
||||
this.cmd.on('error', (err) => {
|
||||
this.raiseErr(err);
|
||||
});
|
||||
|
||||
this.cmd.on('close', (code, signal) => {
|
||||
const result = { code, signal };
|
||||
this.cmdExitResult = result;
|
||||
for (const cb of this.cmdExitCallbacks) cb(result);
|
||||
this.cmdExitCallbacks = [];
|
||||
});
|
||||
|
||||
await this.acceptConn();
|
||||
}
|
||||
|
||||
private async acceptConn(): Promise<void> {
|
||||
const acceptTimeout = 10000;
|
||||
|
||||
const acceptPromise = new Promise<net.Socket>((resolve, reject) => {
|
||||
this.listener.once('connection', (conn) => {
|
||||
resolve(conn);
|
||||
});
|
||||
this.listener.once('error', reject);
|
||||
});
|
||||
|
||||
const exitPromise = new Promise<net.Socket>((_, reject) => {
|
||||
if (this.cmdExitResult) {
|
||||
reject(new Error(`command exited before connection established`));
|
||||
} else {
|
||||
this.cmdExitCallbacks.push(() => {
|
||||
reject(new Error(`command exited before connection established`));
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
this.conn = await timeout(Promise.race([acceptPromise, exitPromise]), acceptTimeout);
|
||||
this.readConn();
|
||||
} catch (e) {
|
||||
if (this.cmd) this.cmd.kill();
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
async wait(): Promise<void> {
|
||||
if (!this.cmd) {
|
||||
throw new Error('Command is not started yet');
|
||||
}
|
||||
|
||||
const exitPromise = new Promise<{ code: number | null, signal: string | null }>((resolve) => {
|
||||
if (this.cmdExitResult) {
|
||||
resolve(this.cmdExitResult);
|
||||
} else {
|
||||
this.cmdExitCallbacks.push(resolve);
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
await Promise.race([
|
||||
exitPromise.then(({ code, signal }) => {
|
||||
if (signal || code) {
|
||||
if (signal) throw new Error(`Process exited with signal ${ signal }`);
|
||||
else throw new Error(`Process exited with code ${ code }`);
|
||||
} else if (!this.ready) {
|
||||
throw new Error('command exited before connection established');
|
||||
}
|
||||
}),
|
||||
this.errorQueue.collect().then((errors) => {
|
||||
if (errors.length === 1) {
|
||||
throw errors[0];
|
||||
} else if (errors.length > 1) {
|
||||
throw new Error(errors.map(e => e.toString()).join(', '));
|
||||
}
|
||||
}),
|
||||
]);
|
||||
} finally {
|
||||
try { fs.unlinkSync(this.socketPath); } catch {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
export class ChildIPC extends IPCCommon {
|
||||
constructor(...localApis: object[]) {
|
||||
super(localApis, socketPathFromArgs());
|
||||
}
|
||||
|
||||
async start(): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
this.conn = net.createConnection(this.socketPath, () => {
|
||||
this.readConn();
|
||||
resolve();
|
||||
});
|
||||
this.conn.on('error', reject);
|
||||
});
|
||||
}
|
||||
|
||||
async wait(): Promise<void> {
|
||||
const closePromise = new Promise<void>((resolve) => {
|
||||
this.onClose = () => {
|
||||
if (this.processingCalls === 0) {
|
||||
this.conn?.destroy();
|
||||
resolve();
|
||||
}
|
||||
};
|
||||
if (this.stopRequested && this.processingCalls === 0) {
|
||||
this.conn?.destroy();
|
||||
resolve();
|
||||
}
|
||||
});
|
||||
|
||||
const errorPromise = this.errorQueue.collect().then((errors) => {
|
||||
if (errors.length === 1) {
|
||||
throw errors[0];
|
||||
} else if (errors.length > 1) {
|
||||
throw new Error(errors.map(e => e.toString()).join(', '));
|
||||
}
|
||||
});
|
||||
|
||||
await Promise.race([closePromise, errorPromise]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function socketPathFromArgs(): string {
|
||||
const {values} = util.parseArgs({
|
||||
options: {
|
||||
[IPC_SOCKET_ARG]: {
|
||||
type: 'string',
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (!values[IPC_SOCKET_ARG]) {
|
||||
throw new Error('ipc socket path is missing');
|
||||
}
|
||||
|
||||
return values[IPC_SOCKET_ARG];
|
||||
}
|
||||
|
||||
|
||||
function sleep(ms: number): Promise<void> {
|
||||
return new Promise(resolve => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
|
||||
// throws on timeout
|
||||
function timeout<T>(prom: Promise<T>, ms: number): Promise<T> {
|
||||
return Promise.race(
|
||||
[
|
||||
prom,
|
||||
new Promise((res, reject) => {
|
||||
setTimeout(() => {reject(new Error('timed out'))}, ms)
|
||||
})]
|
||||
) as Promise<T>;
|
||||
}
|
||||
@ -1 +1,2 @@
|
||||
export {ParentIPC, ChildIPC} from './lib.js';
|
||||
export {ParentIPC} from './parent.js';
|
||||
export {ChildIPC} from './child.js';
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import {test} from 'vitest';
|
||||
import {ParentIPC} from './lib.js';
|
||||
import {ParentIPC} from './parent.js';
|
||||
|
||||
test('test connection timeout', async ({expect}) => {
|
||||
const parentIpc = new ParentIPC('../testdata/sleep15.sh', []);
|
||||
|
||||
126
lib/ts/src/parent.ts
Normal file
126
lib/ts/src/parent.ts
Normal file
@ -0,0 +1,126 @@
|
||||
import * as net from 'node:net';
|
||||
import * as os from 'node:os';
|
||||
import * as path from 'node:path';
|
||||
import * as fs from 'node:fs';
|
||||
import * as crypto from 'node:crypto';
|
||||
import {type ChildProcess, spawn} from 'node:child_process';
|
||||
import {IPCCommon} from './common.js';
|
||||
import {timeout} from './util.js';
|
||||
|
||||
const IPC_SOCKET_ARG = 'ipc-socket';
|
||||
const ACCEPT_TIMEOUT_MS = 10000;
|
||||
|
||||
export class ParentIPC extends IPCCommon {
|
||||
private readonly cmdPath: string;
|
||||
private readonly cmdArgs: string[];
|
||||
private cmd: ChildProcess | null = null;
|
||||
private readonly listener: net.Server;
|
||||
private cmdExitResult: { code: number | null, signal: string | null } | null = null;
|
||||
private cmdExitCallbacks: ((result: { code: number | null, signal: string | null }) => void)[] = [];
|
||||
|
||||
constructor(cmdPath: string, cmdArgs: string[], ...localApis: object[]) {
|
||||
const socketPath = path.join(os.tmpdir(), `kitten-ipc-${ process.pid }-${ crypto.randomInt(2**48 - 1) }.sock`);
|
||||
super(localApis, socketPath);
|
||||
|
||||
this.cmdPath = cmdPath;
|
||||
if (cmdArgs.includes(`--${ IPC_SOCKET_ARG }`)) {
|
||||
throw new Error(`you should not use '--${ IPC_SOCKET_ARG }' argument in your command`);
|
||||
}
|
||||
this.cmdArgs = cmdArgs;
|
||||
|
||||
this.listener = net.createServer();
|
||||
}
|
||||
|
||||
async start(): Promise<void> {
|
||||
try {
|
||||
fs.unlinkSync(this.socketPath);
|
||||
} catch {
|
||||
}
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
this.listener.listen(this.socketPath, () => {
|
||||
resolve();
|
||||
});
|
||||
this.listener.on('error', reject);
|
||||
});
|
||||
|
||||
const cmdArgs = [...this.cmdArgs, `--${ IPC_SOCKET_ARG }`, this.socketPath];
|
||||
this.cmd = spawn(this.cmdPath, cmdArgs, {stdio: 'inherit'});
|
||||
|
||||
this.cmd.on('error', (err) => {
|
||||
this.raiseErr(err);
|
||||
});
|
||||
|
||||
this.cmd.on('close', (code, signal) => {
|
||||
const result = { code, signal };
|
||||
this.cmdExitResult = result;
|
||||
for (const cb of this.cmdExitCallbacks) cb(result);
|
||||
this.cmdExitCallbacks = [];
|
||||
});
|
||||
|
||||
await this.acceptConn();
|
||||
}
|
||||
|
||||
private async acceptConn(): Promise<void> {
|
||||
const acceptPromise = new Promise<net.Socket>((resolve, reject) => {
|
||||
this.listener.once('connection', (conn) => {
|
||||
resolve(conn);
|
||||
});
|
||||
this.listener.once('error', reject);
|
||||
});
|
||||
|
||||
const exitPromise = new Promise<net.Socket>((_, reject) => {
|
||||
if (this.cmdExitResult) {
|
||||
reject(new Error(`command exited before connection established`));
|
||||
} else {
|
||||
this.cmdExitCallbacks.push(() => {
|
||||
reject(new Error(`command exited before connection established`));
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
this.conn = await timeout(Promise.race([acceptPromise, exitPromise]), ACCEPT_TIMEOUT_MS);
|
||||
this.readConn();
|
||||
} catch (e) {
|
||||
if (this.cmd) this.cmd.kill();
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
async wait(): Promise<void> {
|
||||
if (!this.cmd) {
|
||||
throw new Error('Command is not started yet');
|
||||
}
|
||||
|
||||
const exitPromise = new Promise<{ code: number | null, signal: string | null }>((resolve) => {
|
||||
if (this.cmdExitResult) {
|
||||
resolve(this.cmdExitResult);
|
||||
} else {
|
||||
this.cmdExitCallbacks.push(resolve);
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
await Promise.race([
|
||||
exitPromise.then(({ code, signal }) => {
|
||||
if (signal || code) {
|
||||
if (signal) throw new Error(`Process exited with signal ${ signal }`);
|
||||
else throw new Error(`Process exited with code ${ code }`);
|
||||
} else if (!this.ready) {
|
||||
throw new Error('command exited before connection established');
|
||||
}
|
||||
}),
|
||||
this.errorQueue.collect().then((errors) => {
|
||||
if (errors.length === 1) {
|
||||
throw errors[0];
|
||||
} else if (errors.length > 1) {
|
||||
throw new Error(errors.map(e => e.toString()).join(', '));
|
||||
}
|
||||
}),
|
||||
]);
|
||||
} finally {
|
||||
try { fs.unlinkSync(this.socketPath); } catch {}
|
||||
}
|
||||
}
|
||||
}
|
||||
27
lib/ts/src/protocol.ts
Normal file
27
lib/ts/src/protocol.ts
Normal file
@ -0,0 +1,27 @@
|
||||
export enum MsgType {
|
||||
Call = 1,
|
||||
Response = 2,
|
||||
}
|
||||
|
||||
export type Vals = any[];
|
||||
|
||||
export interface CallMessage {
|
||||
type: MsgType.Call,
|
||||
id: number,
|
||||
method: string;
|
||||
args: Vals;
|
||||
}
|
||||
|
||||
export interface ResponseMessage {
|
||||
type: MsgType.Response,
|
||||
id: number,
|
||||
result?: Vals;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export type Message = CallMessage | ResponseMessage;
|
||||
|
||||
export interface CallResult {
|
||||
result: Vals;
|
||||
error: Error | null;
|
||||
}
|
||||
68
lib/ts/src/serialize.test.ts
Normal file
68
lib/ts/src/serialize.test.ts
Normal file
@ -0,0 +1,68 @@
|
||||
import {test} from 'vitest';
|
||||
import {ChildIPC} from './child.js';
|
||||
|
||||
// Access serialize/deserialize through a ChildIPC instance's public methods
|
||||
// We create a minimal wrapper to test them
|
||||
class TestableIPC {
|
||||
private ipc: any;
|
||||
|
||||
constructor() {
|
||||
// Access the prototype methods directly
|
||||
this.ipc = Object.create(ChildIPC.prototype);
|
||||
}
|
||||
|
||||
serialize(arg: any): any {
|
||||
return this.ipc.serialize(arg);
|
||||
}
|
||||
|
||||
deserialize(arg: any): any {
|
||||
return this.ipc.deserialize(arg);
|
||||
}
|
||||
}
|
||||
|
||||
test('serialize primitives', ({expect}) => {
|
||||
const t = new TestableIPC();
|
||||
expect(t.serialize(42)).toBe(42);
|
||||
expect(t.serialize('hello')).toBe('hello');
|
||||
expect(t.serialize(true)).toBe(true);
|
||||
expect(t.serialize(3.14)).toBe(3.14);
|
||||
});
|
||||
|
||||
test('serialize buffer to base64', ({expect}) => {
|
||||
const t = new TestableIPC();
|
||||
const buf = Buffer.from([1, 2, 3]);
|
||||
expect(t.serialize(buf)).toBe('AQID');
|
||||
});
|
||||
|
||||
test('deserialize primitives', ({expect}) => {
|
||||
const t = new TestableIPC();
|
||||
expect(t.deserialize(42)).toBe(42);
|
||||
expect(t.deserialize('hello')).toBe('hello');
|
||||
expect(t.deserialize(true)).toBe(true);
|
||||
});
|
||||
|
||||
test('deserialize blob to buffer', ({expect}) => {
|
||||
const t = new TestableIPC();
|
||||
const result = t.deserialize({t: 'blob', d: 'AQID'});
|
||||
expect(Buffer.isBuffer(result)).toBe(true);
|
||||
expect(result).toEqual(Buffer.from([1, 2, 3]));
|
||||
});
|
||||
|
||||
test('serialize then deserialize blob round-trip', ({expect}) => {
|
||||
const t = new TestableIPC();
|
||||
const original = Buffer.from([0xDE, 0xAD, 0xBE, 0xEF]);
|
||||
// Go serializes as {t: 'blob', d: base64}, TS serializes as just base64 string
|
||||
// The TS serialize returns base64 string directly for Buffer
|
||||
const serialized = t.serialize(original);
|
||||
expect(typeof serialized).toBe('string');
|
||||
});
|
||||
|
||||
test('deserialize unknown object throws', ({expect}) => {
|
||||
const t = new TestableIPC();
|
||||
expect(() => t.deserialize({foo: 'bar'})).toThrow('cannot deserialize');
|
||||
});
|
||||
|
||||
test('serialize unsupported object throws', ({expect}) => {
|
||||
const t = new TestableIPC();
|
||||
expect(() => t.serialize({foo: 'bar'})).toThrow('cannot serialize');
|
||||
});
|
||||
29
lib/ts/src/util.ts
Normal file
29
lib/ts/src/util.ts
Normal file
@ -0,0 +1,29 @@
|
||||
import * as util from 'node:util';
|
||||
|
||||
const IPC_SOCKET_ARG = 'ipc-socket';
|
||||
|
||||
export function socketPathFromArgs(): string {
|
||||
const {values} = util.parseArgs({
|
||||
options: {
|
||||
[IPC_SOCKET_ARG]: {
|
||||
type: 'string',
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (!values[IPC_SOCKET_ARG]) {
|
||||
throw new Error('ipc socket path is missing');
|
||||
}
|
||||
|
||||
return values[IPC_SOCKET_ARG];
|
||||
}
|
||||
|
||||
export function timeout<T>(prom: Promise<T>, ms: number): Promise<T> {
|
||||
return Promise.race(
|
||||
[
|
||||
prom,
|
||||
new Promise((res, reject) => {
|
||||
setTimeout(() => {reject(new Error('timed out'))}, ms)
|
||||
})]
|
||||
) as Promise<T>;
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user