diff --git a/example/golang/simple b/example/golang/simple new file mode 100755 index 0000000..0a4f6b2 Binary files /dev/null and b/example/golang/simple differ diff --git a/kitcom/go.mod b/kitcom/go.mod index 31da92f..6d1ad72 100644 --- a/kitcom/go.mod +++ b/kitcom/go.mod @@ -7,3 +7,10 @@ require ( golang.org/x/sync v0.17.0 golang.org/x/text v0.30.0 ) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/testify v1.11.1 + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/kitcom/go.sum b/kitcom/go.sum index 0243394..c6df700 100644 --- a/kitcom/go.sum +++ b/kitcom/go.sum @@ -1,6 +1,16 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e h1:Lf/gRkoycfOBPa42vU2bbgPurFong6zXeFtPoxholzU= github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e/go.mod h1:uNVvRXArCGbZ508SxYYTC5v1JWoz2voff5pm25jU1Ok= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/kitcom/internal/api/api.go b/kitcom/internal/api/api.go index e647286..6a8e1fd 100644 --- a/kitcom/internal/api/api.go +++ b/kitcom/internal/api/api.go @@ -1,14 +1,18 @@ package api -import "efprojects.com/kitten-ipc/types" +type ValType string -// todo check TInt size < 64 -// todo check not float +const ( + TNoType ValType = "" + TInt ValType = "int" + TString ValType = "string" + TBool ValType = "bool" + TBlob ValType = "blob" +) type Val struct { - Name string - Type types.ValType - Children []Val + Name string + Type ValType } type Method struct { diff --git a/kitcom/internal/common/write.go b/kitcom/internal/common/write.go new file mode 100644 index 0000000..f3e2560 --- /dev/null +++ b/kitcom/internal/common/write.go @@ -0,0 +1,19 @@ +package common + +import ( + "fmt" + "os" +) + +func WriteFile(destFile string, content []byte) error { + f, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return fmt.Errorf("open destination file: %w", err) + } + defer f.Close() + + if _, err := f.Write(content); err != nil { + return fmt.Errorf("write file: %w", err) + } + return nil +} diff --git a/kitcom/internal/golang/gogen.go b/kitcom/internal/golang/gogen.go index 2152263..2b4a297 100644 --- a/kitcom/internal/golang/gogen.go +++ b/kitcom/internal/golang/gogen.go @@ -4,18 +4,15 @@ import ( "bytes" "fmt" "go/format" - "os" + "strings" "text/template" _ "embed" "efprojects.com/kitten-ipc/kitcom/internal/api" - "efprojects.com/kitten-ipc/types" + "efprojects.com/kitten-ipc/kitcom/internal/common" ) -// todo: check int overflow -// todo: check float is whole - //go:embed gogen.tmpl var templateString string @@ -34,43 +31,41 @@ func (g *GoApiGenerator) Generate(apis *api.Api, destFile string) error { Api: apis, } - const defaultReceiver = "self" - tpl := template.New("gogen") tpl = tpl.Funcs(map[string]any{ "receiver": func(name string) string { - return defaultReceiver + return strings.ToLower(name[:1]) }, - "typedef": func(t types.ValType) (string, error) { - td, ok := map[types.ValType]string{ - types.TInt: "int", - types.TString: "string", - types.TBool: "bool", - types.TBlob: "[]byte", + "typedef": func(t api.ValType) (string, error) { + td, ok := map[api.ValType]string{ + api.TInt: "int", + api.TString: "string", + api.TBool: "bool", + api.TBlob: "[]byte", }[t] if !ok { return "", fmt.Errorf("cannot generate type %v", t) } return td, nil }, - "convtype": func(valDef string, t types.ValType) (string, error) { - td, ok := map[types.ValType]string{ - types.TInt: fmt.Sprintf("int(%s.(float64))", valDef), - types.TString: fmt.Sprintf("%s.(string)", valDef), - types.TBool: fmt.Sprintf("%s.(bool)", valDef), - types.TBlob: fmt.Sprintf("%s.([]byte)", valDef), + "convtype": func(valDef string, t api.ValType) (string, error) { + td, ok := map[api.ValType]string{ + api.TInt: fmt.Sprintf("int(%s.(float64))", valDef), + api.TString: fmt.Sprintf("%s.(string)", valDef), + api.TBool: fmt.Sprintf("%s.(bool)", valDef), + api.TBlob: fmt.Sprintf("%s.([]byte)", valDef), }[t] if !ok { return "", fmt.Errorf("cannot convert type %v for val %s", t, valDef) } return td, nil }, - "zerovalue": func(t types.ValType) (string, error) { - v, ok := map[types.ValType]string{ - types.TInt: "0", - types.TString: `""`, - types.TBool: "false", - types.TBlob: "[]byte{}", + "zerovalue": func(t api.ValType) (string, error) { + v, ok := map[api.ValType]string{ + api.TInt: "0", + api.TString: `""`, + api.TBool: "false", + api.TBlob: "[]byte{}", }[t] if !ok { return "", fmt.Errorf("cannot generate zero value for type %v", t) @@ -86,27 +81,14 @@ func (g *GoApiGenerator) Generate(apis *api.Api, destFile string) error { return fmt.Errorf("execute template: %w", err) } - if err := g.writeDest(destFile, buf.Bytes()); err != nil { + formatted, err := format.Source(buf.Bytes()) + if err != nil { + return fmt.Errorf("format source: %w", err) + } + + if err := common.WriteFile(destFile, formatted); err != nil { return fmt.Errorf("write file: %w", err) } return nil } - -func (g *GoApiGenerator) writeDest(destFile string, bytes []byte) error { - f, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) - if err != nil { - return fmt.Errorf("open destination file: %w", err) - } - defer f.Close() - - formatted, err := format.Source(bytes) - if err != nil { - return fmt.Errorf("format source: %w", err) - } - - if _, err := f.Write(formatted); err != nil { - return fmt.Errorf("write formatted source: %w", err) - } - return nil -} diff --git a/kitcom/internal/golang/gogen.tmpl b/kitcom/internal/golang/gogen.tmpl index 0d344b4..f03fdcf 100644 --- a/kitcom/internal/golang/gogen.tmpl +++ b/kitcom/internal/golang/gogen.tmpl @@ -28,7 +28,9 @@ func ({{ $e.Name | receiver }} *{{ $e.Name }}) {{ $mtd.Name }}( if err != nil { return {{ range $mtd.Ret }}{{ .Type | zerovalue }}, {{ end }} fmt.Errorf("call to {{ $e.Name }}.{{ $mtd.Name }} failed: %w", err) } - _ = results + if len(results) < {{ len $mtd.Ret }} { + return {{ range $mtd.Ret }}{{ .Type | zerovalue }}, {{ end }} fmt.Errorf("call to {{ $e.Name }}.{{ $mtd.Name }}: expected {{ len $mtd.Ret }} results, got %d", len(results)) + } {{ range $i, $ret := $mtd.Ret }} {{ if eq $ret.Type "blob" }} results[{{ $i }}], err = base64.StdEncoding.DecodeString(results[{{ $i }}].(string)) diff --git a/kitcom/internal/golang/goparser.go b/kitcom/internal/golang/goparser.go index 44d1258..91306db 100644 --- a/kitcom/internal/golang/goparser.go +++ b/kitcom/internal/golang/goparser.go @@ -9,7 +9,6 @@ import ( "efprojects.com/kitten-ipc/kitcom/internal/api" "efprojects.com/kitten-ipc/kitcom/internal/common" - "efprojects.com/kitten-ipc/types" ) var decorComment = regexp.MustCompile(`^//\s?kittenipc:api$`) @@ -142,11 +141,11 @@ func fieldToVal(param *ast.Field, returning bool) (*api.Val, error) { case *ast.Ident: switch paramType.Name { case "int": - val.Type = types.TInt + val.Type = api.TInt case "string": - val.Type = types.TString + val.Type = api.TString case "bool": - val.Type = types.TBool + val.Type = api.TBool case "error": if returning { return nil, nil @@ -161,7 +160,7 @@ func fieldToVal(param *ast.Field, returning bool) (*api.Val, error) { case *ast.Ident: switch elementType.Name { case "byte": - val.Type = types.TBlob + val.Type = api.TBlob default: return nil, fmt.Errorf("parameter type %s is not supported yet", elementType.Name) } diff --git a/kitcom/internal/golang/goparser_test.go b/kitcom/internal/golang/goparser_test.go new file mode 100644 index 0000000..e5429cf --- /dev/null +++ b/kitcom/internal/golang/goparser_test.go @@ -0,0 +1,43 @@ +package golang + +import ( + "testing" + + "efprojects.com/kitten-ipc/kitcom/internal/api" + "efprojects.com/kitten-ipc/kitcom/internal/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGoParser(t *testing.T) { + parser := &GoApiParser{Parser: &common.Parser{}} + parser.AddFile("../../../example/golang/main.go") + + result, err := parser.Parse() + require.NoError(t, err) + require.Len(t, result.Endpoints, 1) + + ep := result.Endpoints[0] + assert.Equal(t, "GoIpcApi", ep.Name) + require.Len(t, ep.Methods, 2) + + // Div method + div := ep.Methods[0] + assert.Equal(t, "Div", div.Name) + require.Len(t, div.Params, 2) + assert.Equal(t, api.TInt, div.Params[0].Type) + assert.Equal(t, "a", div.Params[0].Name) + assert.Equal(t, api.TInt, div.Params[1].Type) + assert.Equal(t, "b", div.Params[1].Name) + require.Len(t, div.Ret, 1) + assert.Equal(t, api.TInt, div.Ret[0].Type) + + // XorData method + xor := ep.Methods[1] + assert.Equal(t, "XorData", xor.Name) + require.Len(t, xor.Params, 2) + assert.Equal(t, api.TBlob, xor.Params[0].Type) + assert.Equal(t, api.TBlob, xor.Params[1].Type) + require.Len(t, xor.Ret, 1) + assert.Equal(t, api.TBlob, xor.Ret[0].Type) +} diff --git a/kitcom/internal/ts/tsgen.go b/kitcom/internal/ts/tsgen.go index 7dbb932..3bea749 100644 --- a/kitcom/internal/ts/tsgen.go +++ b/kitcom/internal/ts/tsgen.go @@ -4,14 +4,13 @@ import ( "bytes" "fmt" "log" - "os" "os/exec" "text/template" _ "embed" "efprojects.com/kitten-ipc/kitcom/internal/api" - "efprojects.com/kitten-ipc/types" + "efprojects.com/kitten-ipc/kitcom/internal/common" ) //go:embed tsgen.tmpl @@ -31,24 +30,24 @@ func (g *TypescriptApiGenerator) Generate(apis *api.Api, destFile string) error tpl := template.New("tsgen") tpl = tpl.Funcs(map[string]any{ - "typedef": func(t types.ValType) (string, error) { - td, ok := map[types.ValType]string{ - types.TInt: "number", - types.TString: "string", - types.TBool: "boolean", - types.TBlob: "Buffer", + "typedef": func(t api.ValType) (string, error) { + td, ok := map[api.ValType]string{ + api.TInt: "number", + api.TString: "string", + api.TBool: "boolean", + api.TBlob: "Buffer", }[t] if !ok { return "", fmt.Errorf("cannot generate type %v", t) } return td, nil }, - "convtype": func(valDef string, t types.ValType) (string, error) { - td, ok := map[types.ValType]string{ - types.TInt: fmt.Sprintf("%s as number", valDef), - types.TString: fmt.Sprintf("%s as string", valDef), - types.TBool: fmt.Sprintf("%s as boolean", valDef), - types.TBlob: fmt.Sprintf("Buffer.from(%s, 'base64')", valDef), + "convtype": func(valDef string, t api.ValType) (string, error) { + td, ok := map[api.ValType]string{ + api.TInt: fmt.Sprintf("%s as number", valDef), + api.TString: fmt.Sprintf("%s as string", valDef), + api.TBool: fmt.Sprintf("%s as boolean", valDef), + api.TBlob: fmt.Sprintf("Buffer.from(%s, 'base64')", valDef), }[t] if !ok { return "", fmt.Errorf("cannot convert type %v for val %s", t, valDef) @@ -64,24 +63,10 @@ func (g *TypescriptApiGenerator) Generate(apis *api.Api, destFile string) error return fmt.Errorf("execute template: %w", err) } - if err := g.writeDest(destFile, buf.Bytes()); err != nil { + if err := common.WriteFile(destFile, buf.Bytes()); err != nil { return fmt.Errorf("write file: %w", err) } - return nil -} - -func (g *TypescriptApiGenerator) writeDest(destFile string, bytes []byte) error { - f, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) - if err != nil { - return fmt.Errorf("open destination file: %w", err) - } - defer f.Close() - - if _, err := f.Write(bytes); err != nil { - return fmt.Errorf("write formatted source: %w", err) - } - prettierCmd := exec.Command("npx", "prettier", destFile, "--write") if out, err := prettierCmd.CombinedOutput(); err != nil { log.Printf("Prettier returned error: %v", err) diff --git a/kitcom/internal/ts/tsparser.go b/kitcom/internal/ts/tsparser.go index 5c94724..871aaa2 100644 --- a/kitcom/internal/ts/tsparser.go +++ b/kitcom/internal/ts/tsparser.go @@ -12,7 +12,6 @@ import ( "efprojects.com/kitten-ipc/kitcom/internal/tsgo/core" "efprojects.com/kitten-ipc/kitcom/internal/tsgo/parser" "efprojects.com/kitten-ipc/kitcom/internal/tsgo/tspath" - "efprojects.com/kitten-ipc/types" ) type TypescriptApiParser struct { @@ -109,25 +108,25 @@ func (p *TypescriptApiParser) parseFile(sourceFilePath string) ([]api.Endpoint, return endpoints, nil } -func (p *TypescriptApiParser) fieldToVal(typ *ast.TypeNode) (types.ValType, error) { +func (p *TypescriptApiParser) fieldToVal(typ *ast.TypeNode) (api.ValType, error) { switch typ.Kind { case ast.KindNumberKeyword: - return types.TInt, nil + return api.TInt, nil case ast.KindStringKeyword: - return types.TString, nil + return api.TString, nil case ast.KindBooleanKeyword: - return types.TBool, nil + return api.TBool, nil case ast.KindTypeReference: refNode := typ.AsTypeReferenceNode() ident := refNode.TypeName.AsIdentifier() switch ident.Text { case "Buffer": - return types.TBlob, nil + return api.TBlob, nil default: - return types.TNoType, fmt.Errorf("reference type %s is not supported yet", ident.Text) + return api.TNoType, fmt.Errorf("reference type %s is not supported yet", ident.Text) } default: - return types.TNoType, fmt.Errorf("type %s is not supported yet", typ.Kind) + return api.TNoType, fmt.Errorf("type %s is not supported yet", typ.Kind) } } diff --git a/kitcom/internal/ts/tsparser_test.go b/kitcom/internal/ts/tsparser_test.go new file mode 100644 index 0000000..20b089f --- /dev/null +++ b/kitcom/internal/ts/tsparser_test.go @@ -0,0 +1,46 @@ +package ts + +import ( + "path/filepath" + "testing" + + "efprojects.com/kitten-ipc/kitcom/internal/api" + "efprojects.com/kitten-ipc/kitcom/internal/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTsParser(t *testing.T) { + parser := &TypescriptApiParser{Parser: &common.Parser{}} + absPath, err := filepath.Abs("../../../example/ts/src/index.ts") + require.NoError(t, err) + parser.AddFile(absPath) + + result, err := parser.Parse() + require.NoError(t, err) + require.Len(t, result.Endpoints, 1) + + ep := result.Endpoints[0] + assert.Equal(t, "TsIpcApi", ep.Name) + require.Len(t, ep.Methods, 2) + + // Div method + div := ep.Methods[0] + assert.Equal(t, "Div", div.Name) + require.Len(t, div.Params, 2) + assert.Equal(t, api.TInt, div.Params[0].Type) + assert.Equal(t, "a", div.Params[0].Name) + assert.Equal(t, api.TInt, div.Params[1].Type) + assert.Equal(t, "b", div.Params[1].Name) + require.Len(t, div.Ret, 1) + assert.Equal(t, api.TInt, div.Ret[0].Type) + + // XorData method + xor := ep.Methods[1] + assert.Equal(t, "XorData", xor.Name) + require.Len(t, xor.Params, 2) + assert.Equal(t, api.TBlob, xor.Params[0].Type) + assert.Equal(t, api.TBlob, xor.Params[1].Type) + require.Len(t, xor.Ret, 1) + assert.Equal(t, api.TBlob, xor.Ret[0].Type) +} diff --git a/lib/golang/child.go b/lib/golang/child.go new file mode 100644 index 0000000..bfbb878 --- /dev/null +++ b/lib/golang/child.go @@ -0,0 +1,60 @@ +package golang + +import ( + "context" + "fmt" + "net" + "os" +) + +type ChildIPC struct { + *ipcCommon +} + +func NewChild(localApis ...any) (*ChildIPC, error) { + c := ChildIPC{ + ipcCommon: &ipcCommon{ + localApis: mapTypeNames(localApis), + pendingCalls: make(map[int64]*pendingCall), + errCh: make(chan error, 1), + ctx: context.Background(), + }, + } + + socketPath := socketPathFromArgs() + if socketPath == "" { + return nil, fmt.Errorf("ipc socket path is missing") + } + c.socketPath = socketPath + + return &c, nil +} + +func (c *ChildIPC) Start() error { + conn, err := net.Dial("unix", c.socketPath) + if err != nil { + return fmt.Errorf("connect to parent socket: %w", err) + } + c.conn = conn + go c.readConn() + return nil +} + +func (c *ChildIPC) Wait() error { + err := <-c.errCh + if err != nil { + return fmt.Errorf("ipc error: %w", err) + } + return nil +} + +// socketPathFromArgs parses --ipc-socket from os.Args without calling flag.Parse(), +// which would interfere with the host application's flag handling. +func socketPathFromArgs() string { + for i, arg := range os.Args { + if arg == ipcSocketArg && i+1 < len(os.Args) { + return os.Args[i+1] + } + } + return "" +} diff --git a/lib/golang/common.go b/lib/golang/common.go new file mode 100644 index 0000000..ca9dd6e --- /dev/null +++ b/lib/golang/common.go @@ -0,0 +1,261 @@ +package golang + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "net" + "reflect" + "strings" + "sync" + "sync/atomic" +) + +type IpcCommon interface { + Call(method string, params ...any) (Vals, error) + ConvType(needType, gotType reflect.Type, arg any) any +} + +type callResult struct { + vals Vals + err error +} + +type pendingCall struct { + resultChan chan callResult +} + +type ipcCommon struct { + localApis map[string]any + socketPath string + conn net.Conn + errCh chan error + nextId int64 + pendingCalls map[int64]*pendingCall + processingCalls atomic.Int64 + stopRequested atomic.Bool + mu sync.Mutex + writeMu sync.Mutex + ctx context.Context +} + +func (ipc *ipcCommon) readConn() { + scn := bufio.NewScanner(ipc.conn) + scn.Buffer(nil, maxMessageLength) + for scn.Scan() { + var msg Message + msgBytes := scn.Bytes() + if err := json.Unmarshal(msgBytes, &msg); err != nil { + ipc.raiseErr(fmt.Errorf("unmarshal message: %w", err)) + break + } + ipc.processMsg(msg) + } + if err := scn.Err(); err != nil { + ipc.raiseErr(err) + } +} + +func (ipc *ipcCommon) processMsg(msg Message) { + switch msg.Type { + case MsgCall: + go ipc.handleCall(msg) + case MsgResponse: + ipc.handleResponse(msg) + } +} + +func (ipc *ipcCommon) sendMsg(msg Message) error { + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("marshal message: %w", err) + } + data = append(data, '\n') + + ipc.writeMu.Lock() + _, writeErr := ipc.conn.Write(data) + ipc.writeMu.Unlock() + if writeErr != nil { + return fmt.Errorf("write message: %w", writeErr) + } + + return nil +} + +func (ipc *ipcCommon) handleCall(msg Message) { + if ipc.stopRequested.Load() { + return + } + + ipc.processingCalls.Add(1) + defer ipc.processingCalls.Add(-1) + + defer func() { + if err := recover(); err != nil { + ipc.sendResponse(msg.Id, nil, fmt.Errorf("handle call panicked: %s", err)) + } + }() + + method, err := ipc.findMethod(msg.Method) + if err != nil { + ipc.sendResponse(msg.Id, nil, fmt.Errorf("find method: %w", err)) + return + } + + argsCount := method.Type().NumIn() + if len(msg.Args) != argsCount { + ipc.sendResponse(msg.Id, nil, fmt.Errorf("args count mismatch: expected %d, got %d", argsCount, len(msg.Args))) + return + } + + var args []reflect.Value + for i, arg := range msg.Args { + paramType := method.Type().In(i) + argType := reflect.TypeOf(arg) + arg = ipc.ConvType(paramType, argType, arg) + args = append(args, reflect.ValueOf(arg)) + } + + allResultVals := method.Call(args) + retResultVals := allResultVals[0 : len(allResultVals)-1] + errResultVals := allResultVals[len(allResultVals)-1] + + var results []any + for _, resVal := range retResultVals { + results = append(results, resVal.Interface()) + } + + var resErr error + if !errResultVals.IsNil() { + resErr = errResultVals.Interface().(error) + } + + ipc.sendResponse(msg.Id, results, resErr) +} + +func (ipc *ipcCommon) findMethod(methodName string) (reflect.Value, error) { + parts := strings.Split(methodName, ".") + if len(parts) != 2 { + return reflect.Value{}, fmt.Errorf("invalid method: %s", methodName) + } + + endpointName, methodName := parts[0], parts[1] + + localApi, ok := ipc.localApis[endpointName] + if !ok { + return reflect.Value{}, fmt.Errorf("endpoint not found: %s", endpointName) + } + + method := reflect.ValueOf(localApi).MethodByName(methodName) + if !method.IsValid() { + return reflect.Value{}, fmt.Errorf("method not found: %s", methodName) + } + + return method, nil +} + +func (ipc *ipcCommon) sendResponse(id int64, result []any, err error) { + msg := Message{ + Type: MsgResponse, + Id: id, + Result: result, + } + + if err != nil { + msg.Error = err.Error() + } + + if err := ipc.sendMsg(msg); err != nil { + ipc.raiseErr(fmt.Errorf("send response for id=%d: %w", id, err)) + } +} + +func (ipc *ipcCommon) handleResponse(msg Message) { + ipc.mu.Lock() + call, ok := ipc.pendingCalls[msg.Id] + if ok { + delete(ipc.pendingCalls, msg.Id) + } + ipc.mu.Unlock() + + if !ok { + ipc.raiseErr(fmt.Errorf("received response for unknown call id: %d", msg.Id)) + return + } + + var res callResult + if msg.Error == "" { + res = callResult{vals: msg.Result} + } else { + res = callResult{err: fmt.Errorf("remote error: %s", msg.Error)} + } + call.resultChan <- res + close(call.resultChan) +} + +func (ipc *ipcCommon) Call(method string, params ...any) (Vals, error) { + if ipc.conn == nil { + return nil, fmt.Errorf("ipc is not connected to remote process socket") + } + + if ipc.stopRequested.Load() { + return nil, fmt.Errorf("ipc is stopping") + } + + ipc.mu.Lock() + id := ipc.nextId + ipc.nextId++ + call := &pendingCall{ + resultChan: make(chan callResult, 1), + } + ipc.pendingCalls[id] = call + ipc.mu.Unlock() + + for i := range params { + params[i] = ipc.serialize(params[i]) + } + + msg := Message{ + Type: MsgCall, + Id: id, + Method: method, + Args: params, + } + + if err := ipc.sendMsg(msg); err != nil { + ipc.mu.Lock() + delete(ipc.pendingCalls, id) + ipc.mu.Unlock() + return nil, fmt.Errorf("send call: %w", err) + } + + select { + case result := <-call.resultChan: + return result.vals, result.err + case <-ipc.ctx.Done(): + ipc.mu.Lock() + delete(ipc.pendingCalls, id) + ipc.mu.Unlock() + return nil, ipc.ctx.Err() + } +} + +func (ipc *ipcCommon) raiseErr(err error) { + select { + case ipc.errCh <- err: + default: + } +} + +func (ipc *ipcCommon) closeConn() { + _ = ipc.conn.Close() + ipc.mu.Lock() + pending := ipc.pendingCalls + ipc.pendingCalls = make(map[int64]*pendingCall) + ipc.mu.Unlock() + for _, call := range pending { + call.resultChan <- callResult{err: fmt.Errorf("call cancelled due to ipc termination")} + close(call.resultChan) + } +} diff --git a/lib/golang/common_test.go b/lib/golang/common_test.go new file mode 100644 index 0000000..0ea21bc --- /dev/null +++ b/lib/golang/common_test.go @@ -0,0 +1,43 @@ +package golang + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type testEndpoint struct{} + +func (e *testEndpoint) Hello(name string) (string, error) { + return "hello " + name, nil +} + +func TestFindMethod(t *testing.T) { + ipc := &ipcCommon{ + localApis: mapTypeNames([]any{&testEndpoint{}}), + } + + t.Run("valid method", func(t *testing.T) { + method, err := ipc.findMethod("testEndpoint.Hello") + assert.NoError(t, err) + assert.True(t, method.IsValid()) + }) + + t.Run("invalid format - no dot", func(t *testing.T) { + _, err := ipc.findMethod("Hello") + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid method") + }) + + t.Run("unknown endpoint", func(t *testing.T) { + _, err := ipc.findMethod("Unknown.Hello") + assert.Error(t, err) + assert.Contains(t, err.Error(), "endpoint not found") + }) + + t.Run("unknown method", func(t *testing.T) { + _, err := ipc.findMethod("testEndpoint.Unknown") + assert.Error(t, err) + assert.Contains(t, err.Error(), "method not found") + }) +} diff --git a/lib/golang/go.mod b/lib/golang/go.mod index e3ee602..fad8339 100644 --- a/lib/golang/go.mod +++ b/lib/golang/go.mod @@ -2,4 +2,10 @@ module efprojects.com/kitten-ipc go 1.25.1 -require github.com/samber/mo v1.16.0 +require github.com/stretchr/testify v1.11.1 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/lib/golang/go.sum b/lib/golang/go.sum index e83beb0..c4c1710 100644 --- a/lib/golang/go.sum +++ b/lib/golang/go.sum @@ -1,2 +1,10 @@ -github.com/samber/mo v1.16.0 h1:qpEPCI63ou6wXlsNDMLE0IIN8A+devbGX/K1xdgr4b4= -github.com/samber/mo v1.16.0/go.mod h1:DlgzJ4SYhOh41nP1L9kh9rDNERuf8IqWSAs+gj2Vxag= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/lib/golang/lib.go b/lib/golang/lib.go deleted file mode 100644 index bf1300d..0000000 --- a/lib/golang/lib.go +++ /dev/null @@ -1,550 +0,0 @@ -package golang - -import ( - "bufio" - "context" - "encoding/base64" - "encoding/json" - "errors" - "flag" - "fmt" - "math/rand" - "net" - "os" - "os/exec" - "path/filepath" - "reflect" - "slices" - "strings" - "sync" - "sync/atomic" - "syscall" - "time" - - "efprojects.com/kitten-ipc/types" - "github.com/samber/mo" -) - -const ipcSocketArg = "--ipc-socket" -const maxMessageLength = 1 * 1024 * 1024 * 1024 // 1 gigabyte - -type StdioMode int - -type MsgType int - -type Vals []any - -const ( - MsgCall MsgType = 1 - MsgResponse MsgType = 2 -) - -type Message struct { - Type MsgType `json:"type"` - Id int64 `json:"id"` - Method string `json:"method"` - Args Vals `json:"args"` - Result Vals `json:"result"` - Error string `json:"error"` -} - -type IpcCommon interface { - Call(method string, params ...any) (Vals, error) - ConvType(needType reflect.Type, gotType reflect.Type, arg any) any -} - -type pendingCall struct { - resultChan chan mo.Result[Vals] - resultType reflect.Type -} - -type ipcCommon struct { - localApis map[string]any - socketPath string - conn net.Conn - errCh chan error - nextId int64 - pendingCalls map[int64]*pendingCall - processingCalls atomic.Int64 - stopRequested atomic.Bool - mu sync.Mutex - writeMu sync.Mutex - ctx context.Context -} - -func (ipc *ipcCommon) readConn() { - scn := bufio.NewScanner(ipc.conn) - scn.Buffer(nil, maxMessageLength) - for scn.Scan() { - var msg Message - msgBytes := scn.Bytes() - //log.Println(string(msgBytes)) - if err := json.Unmarshal(msgBytes, &msg); err != nil { - ipc.raiseErr(fmt.Errorf("unmarshal message: %w", err)) - break - } - ipc.processMsg(msg) - } - if err := scn.Err(); err != nil { - ipc.raiseErr(err) - } -} - -func (ipc *ipcCommon) processMsg(msg Message) { - switch msg.Type { - case MsgCall: - go ipc.handleCall(msg) - case MsgResponse: - ipc.handleResponse(msg) - } -} - -func (ipc *ipcCommon) sendMsg(msg Message) error { - - data, err := json.Marshal(msg) - if err != nil { - return fmt.Errorf("marshal message: %w", err) - } - data = append(data, '\n') - - ipc.writeMu.Lock() - _, writeErr := ipc.conn.Write(data) - ipc.writeMu.Unlock() - if writeErr != nil { - return fmt.Errorf("write message: %w", writeErr) - } - - return nil -} - -func (ipc *ipcCommon) handleCall(msg Message) { - if ipc.stopRequested.Load() { - return - } - - ipc.processingCalls.Add(1) - defer ipc.processingCalls.Add(-1) - - defer func() { - if err := recover(); err != nil { - ipc.sendResponse(msg.Id, nil, fmt.Errorf("handle call panicked: %s", err)) - } - }() - - method, err := ipc.findMethod(msg.Method) - if err != nil { - ipc.sendResponse(msg.Id, nil, fmt.Errorf("find method: %w", err)) - return - } - - argsCount := method.Type().NumIn() - if len(msg.Args) != argsCount { - ipc.sendResponse(msg.Id, nil, fmt.Errorf("args count mismatch: expected %d, got %d", argsCount, len(msg.Args))) - return - } - - var args []reflect.Value - for i, arg := range msg.Args { - paramType := method.Type().In(i) - argType := reflect.TypeOf(arg) - arg = ipc.ConvType(paramType, argType, arg) - args = append(args, reflect.ValueOf(arg)) - } - - allResultVals := method.Call(args) - retResultVals := allResultVals[0 : len(allResultVals)-1] - errResultVals := allResultVals[len(allResultVals)-1] - - var results []any - for _, resVal := range retResultVals { - results = append(results, resVal.Interface()) - } - - var resErr error - if !errResultVals.IsNil() { - resErr = errResultVals.Interface().(error) - } - - ipc.sendResponse(msg.Id, results, resErr) -} - -func (ipc *ipcCommon) findMethod(methodName string) (reflect.Value, error) { - parts := strings.Split(methodName, ".") - if len(parts) != 2 { - return reflect.Value{}, fmt.Errorf("invalid method: %s", methodName) - } - - endpointName, methodName := parts[0], parts[1] - - localApi, ok := ipc.localApis[endpointName] - if !ok { - return reflect.Value{}, fmt.Errorf("endpoint not found: %s", endpointName) - } - - method := reflect.ValueOf(localApi).MethodByName(methodName) - if !method.IsValid() { - return reflect.Value{}, fmt.Errorf("method not found: %s", methodName) - } - - return method, nil -} - -func (ipc *ipcCommon) sendResponse(id int64, result []any, err error) { - msg := Message{ - Type: MsgResponse, - Id: id, - Result: result, - } - - if err != nil { - msg.Error = err.Error() - } - - if err := ipc.sendMsg(msg); err != nil { - ipc.raiseErr(fmt.Errorf("send response for id=%d: %w", id, err)) - } -} - -func (ipc *ipcCommon) handleResponse(msg Message) { - ipc.mu.Lock() - call, ok := ipc.pendingCalls[msg.Id] - if ok { - delete(ipc.pendingCalls, msg.Id) - } - ipc.mu.Unlock() - - if !ok { - ipc.raiseErr(fmt.Errorf("received response for unknown call id: %d", msg.Id)) - return - } - - var res mo.Result[Vals] - if msg.Error == "" { - res = mo.Ok[Vals](msg.Result) - } else { - res = mo.Err[Vals](fmt.Errorf("remote error: %s", msg.Error)) - } - call.resultChan <- res - close(call.resultChan) -} - -func (ipc *ipcCommon) Call(method string, params ...any) (Vals, error) { - if ipc.conn == nil { - return nil, fmt.Errorf("ipc is not connected to remote process socket") - } - - if ipc.stopRequested.Load() { - return nil, fmt.Errorf("ipc is stopping") - } - - ipc.mu.Lock() - id := ipc.nextId - ipc.nextId++ - call := &pendingCall{ - resultChan: make(chan mo.Result[Vals], 1), - } - ipc.pendingCalls[id] = call - ipc.mu.Unlock() - - for i := range params { - params[i] = ipc.serialize(params[i]) - } - - msg := Message{ - Type: MsgCall, - Id: id, - Method: method, - Args: params, - } - - if err := ipc.sendMsg(msg); err != nil { - ipc.mu.Lock() - delete(ipc.pendingCalls, id) - ipc.mu.Unlock() - return nil, fmt.Errorf("send call: %w", err) - } - - select { - case result := <-call.resultChan: - return result.Get() - case <-ipc.ctx.Done(): - ipc.mu.Lock() - delete(ipc.pendingCalls, id) - ipc.mu.Unlock() - return nil, ipc.ctx.Err() - } -} - -func (ipc *ipcCommon) raiseErr(err error) { - select { - case ipc.errCh <- err: - default: - } -} - -func (ipc *ipcCommon) closeConn() { - _ = ipc.conn.Close() - ipc.mu.Lock() - pending := ipc.pendingCalls - ipc.pendingCalls = make(map[int64]*pendingCall) - ipc.mu.Unlock() - for _, call := range pending { - call.resultChan <- mo.Err[Vals](fmt.Errorf("call cancelled due to ipc termination")) - close(call.resultChan) - } -} - -func (ipc *ipcCommon) ConvType(needType reflect.Type, gotType reflect.Type, arg any) any { - switch needType.Kind() { - case reflect.Int: - // JSON decodes any number to float64. If we need int, we should check and convert - if gotType.Kind() == reflect.Float64 { - floatArg := arg.(float64) - if float64(int64(floatArg)) == floatArg && !needType.OverflowInt(int64(floatArg)) { - arg = int(floatArg) - } - } - } - return arg -} - -func (ipc *ipcCommon) serialize(arg any) any { - t := reflect.TypeOf(arg) - switch t.Kind() { - case reflect.Slice: - switch t.Elem().Name() { - case "uint8": - return map[string]any{ - "t": types.TBlob, - "d": base64.StdEncoding.EncodeToString(arg.([]byte)), - } - } - } - return arg -} - -type ParentIPC struct { - *ipcCommon - cmd *exec.Cmd - listener net.Listener - cmdDone chan struct{} - cmdErr error -} - -func NewParent(cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) { - return NewParentWithContext(context.Background(), cmd, localApis...) -} - -func NewParentWithContext(ctx context.Context, cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) { - p := ParentIPC{ - ipcCommon: &ipcCommon{ - localApis: mapTypeNames(localApis), - pendingCalls: make(map[int64]*pendingCall), - errCh: make(chan error, 1), - socketPath: filepath.Join(os.TempDir(), fmt.Sprintf("kitten-ipc-%d-%d.sock", os.Getpid(), rand.Int63())), - ctx: ctx, - }, - cmd: cmd, - } - - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - if slices.Contains(cmd.Args, ipcSocketArg) { - return nil, fmt.Errorf("you should not use `%s` argument in your command", ipcSocketArg) - } - cmd.Args = append(cmd.Args, ipcSocketArg, p.socketPath) - - p.errCh = make(chan error, 1) - p.cmdDone = make(chan struct{}) - - return &p, nil -} - -func (p *ParentIPC) Start() error { - _ = os.Remove(p.socketPath) - listener, err := net.Listen("unix", p.socketPath) - if err != nil { - return fmt.Errorf("listen unix socket: %w", err) - } - p.listener = listener - defer p.listener.Close() - - err = p.cmd.Start() - if err != nil { - return fmt.Errorf("cmd start: %w", err) - } - - go func() { - p.cmdErr = p.cmd.Wait() - close(p.cmdDone) - }() - - return p.acceptConn() -} - -func (p *ParentIPC) acceptConn() error { - const acceptTimeout = time.Second * 10 - - res := make(chan mo.Result[net.Conn], 1) - go func() { - conn, err := p.listener.Accept() - if err != nil { - res <- mo.Err[net.Conn](err) - } else { - res <- mo.Ok[net.Conn](conn) - } - close(res) - }() - - select { - case <-time.After(acceptTimeout): - _ = p.cmd.Process.Kill() - return fmt.Errorf("accept timeout") - case <-p.cmdDone: - return fmt.Errorf("cmd exited before accepting connection: %w", p.cmdErr) - case res := <-res: - if res.IsError() { - _ = p.cmd.Process.Kill() - return fmt.Errorf("accept: %w", res.Error()) - } - p.conn = res.MustGet() - go p.readConn() - } - return nil -} - -func (p *ParentIPC) Stop() error { - p.mu.Lock() - hasPending := len(p.pendingCalls) > 0 - p.mu.Unlock() - if hasPending { - return fmt.Errorf("there are calls pending") - } - if p.processingCalls.Load() > 0 { - return fmt.Errorf("there are calls processing") - } - p.stopRequested.Store(true) - if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil { - return fmt.Errorf("send SIGINT: %w", err) - } - return p.Wait() -} - -func (p *ParentIPC) Wait(timeout ...time.Duration) (retErr error) { - const defaultTimeout = time.Duration(1<<63 - 1) // max duration in go - _timeout := variadicToOption(timeout).OrElse(defaultTimeout) - -loop: - for { - select { - case err := <-p.errCh: - retErr = mergeErr(retErr, fmt.Errorf("ipc internal error: %w", err)) - break loop - case <-p.cmdDone: - err := p.cmdErr - if err != nil { - var exitErr *exec.ExitError - if ok := errors.As(err, &exitErr); ok { - if !exitErr.Success() { - ws, ok := exitErr.Sys().(syscall.WaitStatus) - if !(ok && ws.Signaled() && ws.Signal() == syscall.SIGINT && p.stopRequested.Load()) { - retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err)) - } - } - } else { - retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err)) - } - } - break loop - case <-time.After(_timeout): - p.stopRequested.Store(true) - if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil { - retErr = mergeErr(retErr, fmt.Errorf("send SIGINT: %w", err)) - } - } - } - - p.closeConn() - _ = os.Remove(p.socketPath) - - return retErr -} - -type ChildIPC struct { - *ipcCommon -} - -func NewChild(localApis ...any) (*ChildIPC, error) { - c := ChildIPC{ - ipcCommon: &ipcCommon{ - localApis: mapTypeNames(localApis), - pendingCalls: make(map[int64]*pendingCall), - errCh: make(chan error, 1), - ctx: context.Background(), // todo: NewChildWithContext - }, - } - - socketPath := flag.String("ipc-socket", "", "Path to IPC socket") - flag.Parse() - - if *socketPath == "" { - return nil, fmt.Errorf("ipc socket path is missing") - } - c.socketPath = *socketPath - - return &c, nil -} - -func (c *ChildIPC) Start() error { - conn, err := net.Dial("unix", c.socketPath) - if err != nil { - return fmt.Errorf("connect to parent socket: %w", err) - } - c.conn = conn - go c.readConn() - return nil -} - -func (c *ChildIPC) Wait() error { - err := <-c.errCh - if err != nil { - return fmt.Errorf("ipc error: %w", err) - } - return nil -} - -func mergeErr(errs ...error) (ret error) { - for _, err := range errs { - if err != nil { - if ret == nil { - ret = err - } else { - ret = fmt.Errorf("%w; %w", ret, err) - } - } - } - return -} - -func mapTypeNames(types []any) map[string]any { - result := make(map[string]any) - for _, t := range types { - if reflect.TypeOf(t).Kind() != reflect.Pointer { - panic(fmt.Sprintf("LocalAPI argument must be pointer")) - } - typeName := reflect.TypeOf(t).Elem().Name() - result[typeName] = t - } - return result -} - -func variadicToOption[T any](variadic []T) mo.Option[T] { - if len(variadic) >= 2 { - panic("variadic param count must be 0 or 1") - } - if len(variadic) == 0 { - return mo.None[T]() - } - return mo.Some(variadic[0]) -} diff --git a/lib/golang/parent.go b/lib/golang/parent.go new file mode 100644 index 0000000..e71afe1 --- /dev/null +++ b/lib/golang/parent.go @@ -0,0 +1,165 @@ +package golang + +import ( + "context" + "errors" + "fmt" + "math/rand" + "net" + "os" + "os/exec" + "path/filepath" + "slices" + "syscall" + "time" +) + +type ParentIPC struct { + *ipcCommon + cmd *exec.Cmd + listener net.Listener + cmdDone chan struct{} + cmdErr error +} + +func NewParent(cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) { + return NewParentWithContext(context.Background(), cmd, localApis...) +} + +func NewParentWithContext(ctx context.Context, cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) { + p := ParentIPC{ + ipcCommon: &ipcCommon{ + localApis: mapTypeNames(localApis), + pendingCalls: make(map[int64]*pendingCall), + errCh: make(chan error, 1), + socketPath: filepath.Join(os.TempDir(), fmt.Sprintf("kitten-ipc-%d-%d.sock", os.Getpid(), rand.Int63())), + ctx: ctx, + }, + cmd: cmd, + } + + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if slices.Contains(cmd.Args, ipcSocketArg) { + return nil, fmt.Errorf("you should not use `%s` argument in your command", ipcSocketArg) + } + cmd.Args = append(cmd.Args, ipcSocketArg, p.socketPath) + + p.errCh = make(chan error, 1) + p.cmdDone = make(chan struct{}) + + return &p, nil +} + +func (p *ParentIPC) Start() error { + _ = os.Remove(p.socketPath) + listener, err := net.Listen("unix", p.socketPath) + if err != nil { + return fmt.Errorf("listen unix socket: %w", err) + } + p.listener = listener + defer p.listener.Close() + + err = p.cmd.Start() + if err != nil { + return fmt.Errorf("cmd start: %w", err) + } + + go func() { + p.cmdErr = p.cmd.Wait() + close(p.cmdDone) + }() + + return p.acceptConn() +} + +type connResult struct { + conn net.Conn + err error +} + +func (p *ParentIPC) acceptConn() error { + res := make(chan connResult, 1) + go func() { + conn, err := p.listener.Accept() + res <- connResult{conn: conn, err: err} + close(res) + }() + + select { + case <-time.After(time.Duration(defaultAcceptTimeout) * time.Second): + _ = p.cmd.Process.Kill() + return fmt.Errorf("accept timeout") + case <-p.cmdDone: + return fmt.Errorf("cmd exited before accepting connection: %w", p.cmdErr) + case r := <-res: + if r.err != nil { + _ = p.cmd.Process.Kill() + return fmt.Errorf("accept: %w", r.err) + } + p.conn = r.conn + go p.readConn() + } + return nil +} + +func (p *ParentIPC) Stop() error { + p.mu.Lock() + hasPending := len(p.pendingCalls) > 0 + p.mu.Unlock() + if hasPending { + return fmt.Errorf("there are calls pending") + } + if p.processingCalls.Load() > 0 { + return fmt.Errorf("there are calls processing") + } + p.stopRequested.Store(true) + if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil { + return fmt.Errorf("send SIGINT: %w", err) + } + return p.Wait() +} + +func (p *ParentIPC) Wait(timeout ...time.Duration) (retErr error) { + const maxDuration = time.Duration(1<<63 - 1) + _timeout := maxDuration + if len(timeout) > 0 { + _timeout = timeout[0] + } + +loop: + for { + select { + case err := <-p.errCh: + retErr = mergeErr(retErr, fmt.Errorf("ipc internal error: %w", err)) + break loop + case <-p.cmdDone: + err := p.cmdErr + if err != nil { + var exitErr *exec.ExitError + if ok := errors.As(err, &exitErr); ok { + if !exitErr.Success() { + ws, ok := exitErr.Sys().(syscall.WaitStatus) + if !(ok && ws.Signaled() && ws.Signal() == syscall.SIGINT && p.stopRequested.Load()) { + retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err)) + } + } + } else { + retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err)) + } + } + break loop + case <-time.After(_timeout): + p.stopRequested.Store(true) + if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil { + retErr = mergeErr(retErr, fmt.Errorf("send SIGINT: %w", err)) + } + } + } + + p.closeConn() + _ = os.Remove(p.socketPath) + + return retErr +} diff --git a/lib/golang/protocol.go b/lib/golang/protocol.go new file mode 100644 index 0000000..639369f --- /dev/null +++ b/lib/golang/protocol.go @@ -0,0 +1,23 @@ +package golang + +const ipcSocketArg = "--ipc-socket" +const maxMessageLength = 1 << 30 // 1 GB +const defaultAcceptTimeout = 10 // seconds + +type MsgType int + +type Vals []any + +const ( + MsgCall MsgType = 1 + MsgResponse MsgType = 2 +) + +type Message struct { + Type MsgType `json:"type"` + Id int64 `json:"id"` + Method string `json:"method"` + Args Vals `json:"args"` + Result Vals `json:"result"` + Error string `json:"error"` +} diff --git a/lib/golang/serialize.go b/lib/golang/serialize.go new file mode 100644 index 0000000..4168b17 --- /dev/null +++ b/lib/golang/serialize.go @@ -0,0 +1,35 @@ +package golang + +import ( + "encoding/base64" + "reflect" +) + +func (ipc *ipcCommon) serialize(arg any) any { + t := reflect.TypeOf(arg) + switch t.Kind() { + case reflect.Slice: + switch t.Elem().Name() { + case "uint8": + return map[string]any{ + "t": "blob", + "d": base64.StdEncoding.EncodeToString(arg.([]byte)), + } + } + } + return arg +} + +func (ipc *ipcCommon) ConvType(needType reflect.Type, gotType reflect.Type, arg any) any { + switch needType.Kind() { + case reflect.Int: + // JSON decodes any number to float64. If we need int, we should check and convert + if gotType.Kind() == reflect.Float64 { + floatArg := arg.(float64) + if float64(int64(floatArg)) == floatArg && !needType.OverflowInt(int64(floatArg)) { + arg = int(floatArg) + } + } + } + return arg +} diff --git a/lib/golang/serialize_test.go b/lib/golang/serialize_test.go new file mode 100644 index 0000000..2ff5cb0 --- /dev/null +++ b/lib/golang/serialize_test.go @@ -0,0 +1,61 @@ +package golang + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSerialize(t *testing.T) { + ipc := &ipcCommon{} + + t.Run("primitives pass through", func(t *testing.T) { + assert.Equal(t, 42, ipc.serialize(42)) + assert.Equal(t, "hello", ipc.serialize("hello")) + assert.Equal(t, true, ipc.serialize(true)) + assert.Equal(t, 3.14, ipc.serialize(3.14)) + }) + + t.Run("byte slice serializes to blob", func(t *testing.T) { + data := []byte{0x01, 0x02, 0x03} + result := ipc.serialize(data) + m, ok := result.(map[string]any) + assert.True(t, ok) + assert.Equal(t, "blob", m["t"]) + assert.Equal(t, "AQID", m["d"]) // base64 of {1,2,3} + }) + + t.Run("empty byte slice serializes to blob", func(t *testing.T) { + data := []byte{} + result := ipc.serialize(data) + m, ok := result.(map[string]any) + assert.True(t, ok) + assert.Equal(t, "blob", m["t"]) + assert.Equal(t, "", m["d"]) + }) +} + +func TestConvType(t *testing.T) { + ipc := &ipcCommon{} + + t.Run("float64 to int", func(t *testing.T) { + result := ipc.ConvType(reflect.TypeOf(0), reflect.TypeOf(0.0), float64(42)) + assert.Equal(t, 42, result) + }) + + t.Run("float64 with fractional part stays float", func(t *testing.T) { + result := ipc.ConvType(reflect.TypeOf(0), reflect.TypeOf(0.0), float64(42.5)) + assert.Equal(t, float64(42.5), result) + }) + + t.Run("string passes through", func(t *testing.T) { + result := ipc.ConvType(reflect.TypeOf(""), reflect.TypeOf(""), "hello") + assert.Equal(t, "hello", result) + }) + + t.Run("bool passes through", func(t *testing.T) { + result := ipc.ConvType(reflect.TypeOf(true), reflect.TypeOf(true), true) + assert.Equal(t, true, result) + }) +} diff --git a/lib/golang/types/types.go b/lib/golang/types/types.go deleted file mode 100644 index d456c47..0000000 --- a/lib/golang/types/types.go +++ /dev/null @@ -1,12 +0,0 @@ -package types - -type ValType string - -const ( - TNoType ValType = "" // zero value constant for ValType (please don't use just "" as zero value!) - TInt ValType = "int" - TString ValType = "string" - TBool ValType = "bool" - TBlob ValType = "blob" - TArray ValType = "array" -) diff --git a/lib/golang/util.go b/lib/golang/util.go new file mode 100644 index 0000000..ba491c6 --- /dev/null +++ b/lib/golang/util.go @@ -0,0 +1,31 @@ +package golang + +import ( + "fmt" + "reflect" +) + +func mergeErr(errs ...error) (ret error) { + for _, err := range errs { + if err != nil { + if ret == nil { + ret = err + } else { + ret = fmt.Errorf("%w; %w", ret, err) + } + } + } + return +} + +func mapTypeNames(types []any) map[string]any { + result := make(map[string]any) + for _, t := range types { + if reflect.TypeOf(t).Kind() != reflect.Pointer { + panic(fmt.Sprintf("LocalAPI argument must be pointer")) + } + typeName := reflect.TypeOf(t).Elem().Name() + result[typeName] = t + } + return result +} diff --git a/lib/golang/util_test.go b/lib/golang/util_test.go new file mode 100644 index 0000000..146363f --- /dev/null +++ b/lib/golang/util_test.go @@ -0,0 +1,46 @@ +package golang + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMergeErr(t *testing.T) { + t.Run("all nil returns nil", func(t *testing.T) { + assert.NoError(t, mergeErr(nil, nil, nil)) + }) + + t.Run("single error returns it", func(t *testing.T) { + err := fmt.Errorf("one") + assert.EqualError(t, mergeErr(nil, err, nil), "one") + }) + + t.Run("multiple errors merged", func(t *testing.T) { + err1 := fmt.Errorf("one") + err2 := fmt.Errorf("two") + result := mergeErr(err1, err2) + assert.ErrorContains(t, result, "one") + assert.ErrorContains(t, result, "two") + }) +} + +func TestMapTypeNames(t *testing.T) { + type Foo struct{} + type Bar struct{} + + t.Run("maps pointer types by name", func(t *testing.T) { + foo := &Foo{} + bar := &Bar{} + result := mapTypeNames([]any{foo, bar}) + assert.Equal(t, foo, result["Foo"]) + assert.Equal(t, bar, result["Bar"]) + }) + + t.Run("panics on non-pointer", func(t *testing.T) { + assert.Panics(t, func() { + mapTypeNames([]any{Foo{}}) + }) + }) +} diff --git a/lib/ts/src/asyncqueue.test.ts b/lib/ts/src/asyncqueue.test.ts new file mode 100644 index 0000000..ef76c40 --- /dev/null +++ b/lib/ts/src/asyncqueue.test.ts @@ -0,0 +1,27 @@ +import {test} from 'vitest'; +import {AsyncQueue} from './asyncqueue.js'; + +test('put then collect returns items', async ({expect}) => { + const q = new AsyncQueue(); + q.put(1); + q.put(2); + const items = await q.collect(); + expect(items).toEqual([1, 2]); +}); + +test('collect then put resolves on first put', async ({expect}) => { + const q = new AsyncQueue(); + const collectPromise = q.collect(); + q.put(42); + const items = await collectPromise; + expect(items).toEqual([42]); +}); + +test('multiple puts before collect', async ({expect}) => { + const q = new AsyncQueue(); + q.put('a'); + q.put('b'); + q.put('c'); + const items = await q.collect(); + expect(items).toEqual(['a', 'b', 'c']); +}); diff --git a/lib/ts/src/child.ts b/lib/ts/src/child.ts new file mode 100644 index 0000000..c0bcef3 --- /dev/null +++ b/lib/ts/src/child.ts @@ -0,0 +1,44 @@ +import * as net from 'node:net'; +import {IPCCommon} from './common.js'; +import {socketPathFromArgs} from './util.js'; + +export class ChildIPC extends IPCCommon { + constructor(...localApis: object[]) { + super(localApis, socketPathFromArgs()); + } + + async start(): Promise { + return new Promise((resolve, reject) => { + this.conn = net.createConnection(this.socketPath, () => { + this.readConn(); + resolve(); + }); + this.conn.on('error', reject); + }); + } + + async wait(): Promise { + const closePromise = new Promise((resolve) => { + this.onClose = () => { + if (this.processingCalls === 0) { + this.conn?.destroy(); + resolve(); + } + }; + if (this.stopRequested && this.processingCalls === 0) { + this.conn?.destroy(); + resolve(); + } + }); + + const errorPromise = this.errorQueue.collect().then((errors) => { + if (errors.length === 1) { + throw errors[0]; + } else if (errors.length > 1) { + throw new Error(errors.map(e => e.toString()).join(', ')); + } + }); + + await Promise.race([closePromise, errorPromise]); + } +} diff --git a/lib/ts/src/lib.ts b/lib/ts/src/common.ts similarity index 50% rename from lib/ts/src/lib.ts rename to lib/ts/src/common.ts index 442f030..f82d413 100644 --- a/lib/ts/src/lib.ts +++ b/lib/ts/src/common.ts @@ -1,46 +1,10 @@ import * as net from 'node:net'; import * as readline from 'node:readline'; -import {type ChildProcess, spawn} from 'node:child_process'; -import * as os from 'node:os'; -import * as path from 'node:path'; -import * as fs from 'node:fs'; -import * as util from 'node:util'; -import * as crypto from 'node:crypto'; import {AsyncQueue} from './asyncqueue.js'; +import type {CallMessage, CallResult, Message, ResponseMessage, Vals} from './protocol.js'; +import {MsgType} from './protocol.js'; -const IPC_SOCKET_ARG = 'ipc-socket'; - -type JSONSerializable = string | number | boolean; - -enum MsgType { - Call = 1, - Response = 2, -} - -type Vals = any[]; - -interface CallMessage { - type: MsgType.Call, - id: number, - method: string; - args: Vals; -} - -interface ResponseMessage { - type: MsgType.Response, - id: number, - result?: Vals; - error?: string; -} - -type Message = CallMessage | ResponseMessage; - -interface CallResult { - result: Vals; - error: Error | null; -} - -abstract class IPCCommon { +export abstract class IPCCommon { protected localApis: Record; protected socketPath: string; protected conn: net.Socket | null = null; @@ -75,6 +39,7 @@ abstract class IPCCommon { }); this.conn.on('close', (hadError: boolean) => { + this.rejectPendingCalls(new Error('connection closed')); if (hadError) { this.raiseErr(new Error('connection closed due to error')); } @@ -194,7 +159,6 @@ abstract class IPCCommon { } public serialize(arg: any): any { - // noinspection FallThroughInSwitchStatementJS switch (typeof arg) { case 'string': case 'boolean': @@ -212,7 +176,6 @@ abstract class IPCCommon { } public deserialize(arg: any): any { - // noinspection FallThroughInSwitchStatementJS switch (typeof arg) { case 'string': case 'boolean': @@ -248,201 +211,15 @@ abstract class IPCCommon { if (this.onClose) this.onClose(); } + protected rejectPendingCalls(err: Error): void { + const pending = this.pendingCalls; + this.pendingCalls = {}; + for (const callback of Object.values(pending)) { + callback({result: [], error: err}); + } + } + protected raiseErr(err: Error): void { this.errorQueue.put(err); } } - - -export class ParentIPC extends IPCCommon { - private readonly cmdPath: string; - private readonly cmdArgs: string[]; - private cmd: ChildProcess | null = null; - private readonly listener: net.Server; - private cmdExitResult: { code: number | null, signal: string | null } | null = null; - private cmdExitCallbacks: ((result: { code: number | null, signal: string | null }) => void)[] = []; - - constructor(cmdPath: string, cmdArgs: string[], ...localApis: object[]) { - const socketPath = path.join(os.tmpdir(), `kitten-ipc-${ process.pid }-${ crypto.randomInt(2**48 - 1) }.sock`); - super(localApis, socketPath); - - this.cmdPath = cmdPath; - if (cmdArgs.includes(`--${ IPC_SOCKET_ARG }`)) { - throw new Error(`you should not use '--${ IPC_SOCKET_ARG }' argument in your command`); - } - this.cmdArgs = cmdArgs; - - this.listener = net.createServer(); - } - - async start(): Promise { - try { - fs.unlinkSync(this.socketPath); - } catch { - } - - await new Promise((resolve, reject) => { - this.listener.listen(this.socketPath, () => { - resolve(); - }); - this.listener.on('error', reject); - }); - - const cmdArgs = [...this.cmdArgs, `--${ IPC_SOCKET_ARG }`, this.socketPath]; - this.cmd = spawn(this.cmdPath, cmdArgs, {stdio: 'inherit'}); - - this.cmd.on('error', (err) => { - this.raiseErr(err); - }); - - this.cmd.on('close', (code, signal) => { - const result = { code, signal }; - this.cmdExitResult = result; - for (const cb of this.cmdExitCallbacks) cb(result); - this.cmdExitCallbacks = []; - }); - - await this.acceptConn(); - } - - private async acceptConn(): Promise { - const acceptTimeout = 10000; - - const acceptPromise = new Promise((resolve, reject) => { - this.listener.once('connection', (conn) => { - resolve(conn); - }); - this.listener.once('error', reject); - }); - - const exitPromise = new Promise((_, reject) => { - if (this.cmdExitResult) { - reject(new Error(`command exited before connection established`)); - } else { - this.cmdExitCallbacks.push(() => { - reject(new Error(`command exited before connection established`)); - }); - } - }); - - try { - this.conn = await timeout(Promise.race([acceptPromise, exitPromise]), acceptTimeout); - this.readConn(); - } catch (e) { - if (this.cmd) this.cmd.kill(); - throw e; - } - } - - async wait(): Promise { - if (!this.cmd) { - throw new Error('Command is not started yet'); - } - - const exitPromise = new Promise<{ code: number | null, signal: string | null }>((resolve) => { - if (this.cmdExitResult) { - resolve(this.cmdExitResult); - } else { - this.cmdExitCallbacks.push(resolve); - } - }); - - try { - await Promise.race([ - exitPromise.then(({ code, signal }) => { - if (signal || code) { - if (signal) throw new Error(`Process exited with signal ${ signal }`); - else throw new Error(`Process exited with code ${ code }`); - } else if (!this.ready) { - throw new Error('command exited before connection established'); - } - }), - this.errorQueue.collect().then((errors) => { - if (errors.length === 1) { - throw errors[0]; - } else if (errors.length > 1) { - throw new Error(errors.map(e => e.toString()).join(', ')); - } - }), - ]); - } finally { - try { fs.unlinkSync(this.socketPath); } catch {} - } - } -} - - -export class ChildIPC extends IPCCommon { - constructor(...localApis: object[]) { - super(localApis, socketPathFromArgs()); - } - - async start(): Promise { - return new Promise((resolve, reject) => { - this.conn = net.createConnection(this.socketPath, () => { - this.readConn(); - resolve(); - }); - this.conn.on('error', reject); - }); - } - - async wait(): Promise { - const closePromise = new Promise((resolve) => { - this.onClose = () => { - if (this.processingCalls === 0) { - this.conn?.destroy(); - resolve(); - } - }; - if (this.stopRequested && this.processingCalls === 0) { - this.conn?.destroy(); - resolve(); - } - }); - - const errorPromise = this.errorQueue.collect().then((errors) => { - if (errors.length === 1) { - throw errors[0]; - } else if (errors.length > 1) { - throw new Error(errors.map(e => e.toString()).join(', ')); - } - }); - - await Promise.race([closePromise, errorPromise]); - } -} - - -function socketPathFromArgs(): string { - const {values} = util.parseArgs({ - options: { - [IPC_SOCKET_ARG]: { - type: 'string', - } - } - }); - - if (!values[IPC_SOCKET_ARG]) { - throw new Error('ipc socket path is missing'); - } - - return values[IPC_SOCKET_ARG]; -} - - -function sleep(ms: number): Promise { - return new Promise(resolve => setTimeout(resolve, ms)); -} - - -// throws on timeout -function timeout(prom: Promise, ms: number): Promise { - return Promise.race( - [ - prom, - new Promise((res, reject) => { - setTimeout(() => {reject(new Error('timed out'))}, ms) - })] - ) as Promise; -} diff --git a/lib/ts/src/index.ts b/lib/ts/src/index.ts index 35f7b93..8055b34 100644 --- a/lib/ts/src/index.ts +++ b/lib/ts/src/index.ts @@ -1 +1,2 @@ -export {ParentIPC, ChildIPC} from './lib.js'; +export {ParentIPC} from './parent.js'; +export {ChildIPC} from './child.js'; diff --git a/lib/ts/src/lib.test.ts b/lib/ts/src/lib.test.ts index 100a487..784c28a 100644 --- a/lib/ts/src/lib.test.ts +++ b/lib/ts/src/lib.test.ts @@ -1,5 +1,5 @@ import {test} from 'vitest'; -import {ParentIPC} from './lib.js'; +import {ParentIPC} from './parent.js'; test('test connection timeout', async ({expect}) => { const parentIpc = new ParentIPC('../testdata/sleep15.sh', []); diff --git a/lib/ts/src/parent.ts b/lib/ts/src/parent.ts new file mode 100644 index 0000000..5edd540 --- /dev/null +++ b/lib/ts/src/parent.ts @@ -0,0 +1,126 @@ +import * as net from 'node:net'; +import * as os from 'node:os'; +import * as path from 'node:path'; +import * as fs from 'node:fs'; +import * as crypto from 'node:crypto'; +import {type ChildProcess, spawn} from 'node:child_process'; +import {IPCCommon} from './common.js'; +import {timeout} from './util.js'; + +const IPC_SOCKET_ARG = 'ipc-socket'; +const ACCEPT_TIMEOUT_MS = 10000; + +export class ParentIPC extends IPCCommon { + private readonly cmdPath: string; + private readonly cmdArgs: string[]; + private cmd: ChildProcess | null = null; + private readonly listener: net.Server; + private cmdExitResult: { code: number | null, signal: string | null } | null = null; + private cmdExitCallbacks: ((result: { code: number | null, signal: string | null }) => void)[] = []; + + constructor(cmdPath: string, cmdArgs: string[], ...localApis: object[]) { + const socketPath = path.join(os.tmpdir(), `kitten-ipc-${ process.pid }-${ crypto.randomInt(2**48 - 1) }.sock`); + super(localApis, socketPath); + + this.cmdPath = cmdPath; + if (cmdArgs.includes(`--${ IPC_SOCKET_ARG }`)) { + throw new Error(`you should not use '--${ IPC_SOCKET_ARG }' argument in your command`); + } + this.cmdArgs = cmdArgs; + + this.listener = net.createServer(); + } + + async start(): Promise { + try { + fs.unlinkSync(this.socketPath); + } catch { + } + + await new Promise((resolve, reject) => { + this.listener.listen(this.socketPath, () => { + resolve(); + }); + this.listener.on('error', reject); + }); + + const cmdArgs = [...this.cmdArgs, `--${ IPC_SOCKET_ARG }`, this.socketPath]; + this.cmd = spawn(this.cmdPath, cmdArgs, {stdio: 'inherit'}); + + this.cmd.on('error', (err) => { + this.raiseErr(err); + }); + + this.cmd.on('close', (code, signal) => { + const result = { code, signal }; + this.cmdExitResult = result; + for (const cb of this.cmdExitCallbacks) cb(result); + this.cmdExitCallbacks = []; + }); + + await this.acceptConn(); + } + + private async acceptConn(): Promise { + const acceptPromise = new Promise((resolve, reject) => { + this.listener.once('connection', (conn) => { + resolve(conn); + }); + this.listener.once('error', reject); + }); + + const exitPromise = new Promise((_, reject) => { + if (this.cmdExitResult) { + reject(new Error(`command exited before connection established`)); + } else { + this.cmdExitCallbacks.push(() => { + reject(new Error(`command exited before connection established`)); + }); + } + }); + + try { + this.conn = await timeout(Promise.race([acceptPromise, exitPromise]), ACCEPT_TIMEOUT_MS); + this.readConn(); + } catch (e) { + if (this.cmd) this.cmd.kill(); + throw e; + } + } + + async wait(): Promise { + if (!this.cmd) { + throw new Error('Command is not started yet'); + } + + const exitPromise = new Promise<{ code: number | null, signal: string | null }>((resolve) => { + if (this.cmdExitResult) { + resolve(this.cmdExitResult); + } else { + this.cmdExitCallbacks.push(resolve); + } + }); + + try { + await Promise.race([ + exitPromise.then(({ code, signal }) => { + if (signal || code) { + if (signal) throw new Error(`Process exited with signal ${ signal }`); + else throw new Error(`Process exited with code ${ code }`); + } else if (!this.ready) { + throw new Error('command exited before connection established'); + } + }), + this.errorQueue.collect().then((errors) => { + if (errors.length === 1) { + throw errors[0]; + } else if (errors.length > 1) { + throw new Error(errors.map(e => e.toString()).join(', ')); + } + }), + ]); + } finally { + try { fs.unlinkSync(this.socketPath); } catch {} + } + } +} diff --git a/lib/ts/src/protocol.ts b/lib/ts/src/protocol.ts new file mode 100644 index 0000000..9e29eb8 --- /dev/null +++ b/lib/ts/src/protocol.ts @@ -0,0 +1,27 @@ +export enum MsgType { + Call = 1, + Response = 2, +} + +export type Vals = any[]; + +export interface CallMessage { + type: MsgType.Call, + id: number, + method: string; + args: Vals; +} + +export interface ResponseMessage { + type: MsgType.Response, + id: number, + result?: Vals; + error?: string; +} + +export type Message = CallMessage | ResponseMessage; + +export interface CallResult { + result: Vals; + error: Error | null; +} diff --git a/lib/ts/src/serialize.test.ts b/lib/ts/src/serialize.test.ts new file mode 100644 index 0000000..957c0b3 --- /dev/null +++ b/lib/ts/src/serialize.test.ts @@ -0,0 +1,68 @@ +import {test} from 'vitest'; +import {ChildIPC} from './child.js'; + +// Access serialize/deserialize through a ChildIPC instance's public methods +// We create a minimal wrapper to test them +class TestableIPC { + private ipc: any; + + constructor() { + // Access the prototype methods directly + this.ipc = Object.create(ChildIPC.prototype); + } + + serialize(arg: any): any { + return this.ipc.serialize(arg); + } + + deserialize(arg: any): any { + return this.ipc.deserialize(arg); + } +} + +test('serialize primitives', ({expect}) => { + const t = new TestableIPC(); + expect(t.serialize(42)).toBe(42); + expect(t.serialize('hello')).toBe('hello'); + expect(t.serialize(true)).toBe(true); + expect(t.serialize(3.14)).toBe(3.14); +}); + +test('serialize buffer to base64', ({expect}) => { + const t = new TestableIPC(); + const buf = Buffer.from([1, 2, 3]); + expect(t.serialize(buf)).toBe('AQID'); +}); + +test('deserialize primitives', ({expect}) => { + const t = new TestableIPC(); + expect(t.deserialize(42)).toBe(42); + expect(t.deserialize('hello')).toBe('hello'); + expect(t.deserialize(true)).toBe(true); +}); + +test('deserialize blob to buffer', ({expect}) => { + const t = new TestableIPC(); + const result = t.deserialize({t: 'blob', d: 'AQID'}); + expect(Buffer.isBuffer(result)).toBe(true); + expect(result).toEqual(Buffer.from([1, 2, 3])); +}); + +test('serialize then deserialize blob round-trip', ({expect}) => { + const t = new TestableIPC(); + const original = Buffer.from([0xDE, 0xAD, 0xBE, 0xEF]); + // Go serializes as {t: 'blob', d: base64}, TS serializes as just base64 string + // The TS serialize returns base64 string directly for Buffer + const serialized = t.serialize(original); + expect(typeof serialized).toBe('string'); +}); + +test('deserialize unknown object throws', ({expect}) => { + const t = new TestableIPC(); + expect(() => t.deserialize({foo: 'bar'})).toThrow('cannot deserialize'); +}); + +test('serialize unsupported object throws', ({expect}) => { + const t = new TestableIPC(); + expect(() => t.serialize({foo: 'bar'})).toThrow('cannot serialize'); +}); diff --git a/lib/ts/src/util.ts b/lib/ts/src/util.ts new file mode 100644 index 0000000..a26af37 --- /dev/null +++ b/lib/ts/src/util.ts @@ -0,0 +1,29 @@ +import * as util from 'node:util'; + +const IPC_SOCKET_ARG = 'ipc-socket'; + +export function socketPathFromArgs(): string { + const {values} = util.parseArgs({ + options: { + [IPC_SOCKET_ARG]: { + type: 'string', + } + } + }); + + if (!values[IPC_SOCKET_ARG]) { + throw new Error('ipc socket path is missing'); + } + + return values[IPC_SOCKET_ARG]; +} + +export function timeout(prom: Promise, ms: number): Promise { + return Promise.race( + [ + prom, + new Promise((res, reject) => { + setTimeout(() => {reject(new Error('timed out'))}, ms) + })] + ) as Promise; +}