From 36b5f03658adc685b047209bb1c89f5c42e22a73 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 1 Oct 2024 10:31:53 -0700 Subject: [PATCH 1/2] Fix behavior of run when wait is set --- index.js | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/index.js b/index.js index 712bc59..c54b80b 100644 --- a/index.js +++ b/index.js @@ -165,6 +165,15 @@ class Replicate { throw new Error("Invalid model version identifier"); } + // When `wait` is set, the server may respond + // with the prediction output directly. + // If it hasn't finished, the prediction object is returned + // with an `id` property that can be used to poll for completion. + if (wait && !("id" in prediction)) { + const output = prediction; + return output; + } + // Call progress callback with the initial prediction object if (progress) { progress(prediction); From e6b41f885d6173489cf6493c04e101f013b95a60 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 1 Oct 2024 10:34:58 -0700 Subject: [PATCH 2/2] Fix type definitions for predictions.create methods --- index.d.ts | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/index.d.ts b/index.d.ts index eabcc9b..4a8bbbf 100644 --- a/index.d.ts +++ b/index.d.ts @@ -162,7 +162,7 @@ declare module "replicate" { signal?: AbortSignal; }, progress?: (prediction: Prediction) => void - ): Promise; + ): Promise; stream( identifier: `${string}/${string}` | `${string}/${string}:${string}`, @@ -215,9 +215,9 @@ declare module "replicate" { stream?: boolean; webhook?: string; webhook_events_filter?: WebhookEventType[]; - block?: boolean; + wait?: boolean | number | { mode?: "poll"; interval?: number }; } - ): Promise; + ): Promise; }; get( deployment_owner: string, @@ -304,9 +304,9 @@ declare module "replicate" { stream?: boolean; webhook?: string; webhook_events_filter?: WebhookEventType[]; - block?: boolean; + wait?: boolean | number | { mode?: "poll"; interval?: number }; } & ({ version: string } | { model: string }) - ): Promise; + ): Promise; get(prediction_id: string): Promise; cancel(prediction_id: string): Promise; list(): Promise>;