Source: aiservercloudzoo.js

import AIServerModel from './aiservermodel.js';
import DGError from './dgerror.js';

const DEFAULT_CLOUD_SERVER = 'cs.degirum.com';
const DEFAULT_CLOUD_ZOO = '/degirum/public';
const DEFAULT_CLOUD_URL = `https://${DEFAULT_CLOUD_SERVER}${DEFAULT_CLOUD_ZOO}`;
const MODEL_INIT_TIMEOUT_MS = 20000;
const RESCAN_TIMEOUT_MS = 20000;


/**
 * @class
 * @classdesc This class is responsible for managing and loading models from a cloud-based AI model zoo for inference on an AIServer.
 * It handles model discovery, fetching model information, and loading models for use.
 * Use loadModel() to instantiate and configure a model for use.
 * Various utility functions are available to interact with the cloud zoo and AIServer.
 * 
 * @example <caption>Usage:</caption>
 * let zoo = dg.connect('ws://localhost:8779', 'https://cs.degirum.com/degirum/public', secretToken);
 * let model = await zoo.loadModel('someModel', { max_q_len: 5, callback: someFunction });
 */
class AIServerCloudZoo {
  /**
   * Do not call this constructor, instead use the `connect` function of the dg_sdk class to create an AIServerCloudZoo instance.
   * @param {string} serverUrl - The URL of the AI server (WebSocket URL).
   * @param {string} [zooUrl='https://cs.degirum.com/degirum/public'] - The URL of the cloud model zoo.
   * @param {string} accessToken - The access token for the cloud zoo.
   * @param {boolean} [isDummy=false] - Whether this is a dummy instance (if we only need to have model listing functionality, without connecting to an AIServer).
   */
  constructor(serverUrl, zooUrl = DEFAULT_CLOUD_URL, accessToken, isDummy = false) {
    console.log('AIServerCloudZoo constructor: serverUrl:', serverUrl, 'zooUrl:', zooUrl, 'accessToken:', accessToken);
    this.isDummy = isDummy;
    // Cloud Model Zoo URL for model management
    this.zooUrl = zooUrl.endsWith('/') ? zooUrl.slice(0, -1) : zooUrl; // Remove trailing slash if present
    let parsedUrl = new URL(this.zooUrl);
    let path = parsedUrl.pathname;
    // If the zooUrl doesn't have a path, we append the 'DEFAULT_CLOUD_ZOO' to the zooUrl variable
    // We do this because construction of AIServerCloudZoo is never expected to perform local inference
    if (path === '/' || path === '') {
      this.zooUrl += DEFAULT_CLOUD_ZOO;
    }

    this.accessToken = accessToken;
    this.assets = {}; // model assets
    this.modelNames = []; // store model names in the cloud zoo (including ones that aren't supported on the system)

    // AI Server URL for model inference
    this.serverUrl = serverUrl;
    // HTTP URL for AIServer REST API calls, e.g. systemInfo()
    this.httpUrl = serverUrl.replace('ws://', 'http://');

    // Initialize the assets list
    this.isRescanComplete = false;
    this.rescanZoo();
  }

  /**
   * Returns the full model name in the format 'organization/zooname/simpleModelName'.
   * @private
   * @param {string} simpleModelName - The simple name of the model.
   * @returns {string} The full model name.
  */
  getModelFullName(simpleModelName) {
    let parsedUrl = new URL(this.zooUrl); // Guaranteed to not have a trailing slash due to constructor logic
    // Remove the starting slash if present
    let pathname = parsedUrl.pathname.startsWith('/') ? parsedUrl.pathname.substring(1) : parsedUrl.pathname;
    // console.log('getModelFullName: returning ', `${pathname}/${simpleModelName}`);
    return `${pathname}/${simpleModelName}`;
  }

  /**
   * Fetches data from the API using the provided token.
   * @private
   * @param {string} apiUrl - The API endpoint to fetch from.
   * @param {boolean} [isOctetStream=false] - Whether to expect a blob response.
   * @returns {Promise<Object|Blob>} The fetched data.
   */
  async fetchWithToken(apiUrl, isOctetStream = false) {
    // console.log('Entered fetchWithToken, apiUrl:', apiUrl);
    const headers = {
      'Accept': isOctetStream ? 'application/octet-stream' : 'application/json',
      'token': this.accessToken,
    };

    // Parse this.zooUrl as a URL
    const parsedUrl = new URL(this.zooUrl);

    // Extract the first part of the URL
    const firstPart = parsedUrl.origin;
    // Construct the full URL by appending the cloudServerPathSuffix and then the apiUrl
    const cloudServerPathSuffix = '/zoo/v1/public';
    const fullUrl = `${firstPart}${cloudServerPathSuffix}${apiUrl}`;

    // console.log(`fetchWithToken: initiating fetch with retry for ${fullUrl}`);

    for (let attempt = 1; attempt <= 3; attempt++) {
      try {
        const response = await fetch(fullUrl, { headers });
        // console.log('fetchWithToken: Got response:', response);
        if (response.ok) {
          // Check for redirect and update zooUrl accordingly
          if (response.redirected && response.url.startsWith('https://')) {
            this.zooUrl = response.url.substring(0, response.url.indexOf('/zoo/v1/public'));
          }
          return isOctetStream ? response.blob() : response.json();
        } else {
          // We need to throw custom message for a wrong token.
          if (response.status === 401) {
            console.log('Wrong token. Response:', response);
            // TODO: Ensure that wrong token has custom error handling.
            // We can throw a DG error with a custom error code specifically for wrong token.
            const errorDetails = (await response.json()).detail || 'Invalid token value';
            throw new DGError(`Unable to connect to server: ${errorDetails}`, "FETCH_WITH_TOKEN_FAILED", { status: response.status }, "Unable to connect to server.");
          }
          throw new DGError(`HTTP error! status: ${response.status}: ${response.statusText}`, "FETCH_WITH_TOKEN_FAILED", { status: response.status }, "HTTP error occurred.");
        }
      } catch (error) {
        if (attempt === 3) throw new DGError(`All retries failed: ${error}`, "FETCH_WITH_TOKEN_FAILED", {}, "All retries failed.");
        // Exponential backoff
        await new Promise(resolve => setTimeout(resolve, Math.pow(2, attempt) * 100));
      }
    }
  }

  /**
   * Rescans the cloud zoo to update the list of available models.
   * Fetches from https://cs.degirum.com/zoo/v1/public/models/ORGANIZATION/ZOONAME
   * @private
   * @returns {Promise<void>}
   */
  async rescanZoo() {
    try {
      const parsedUrl = new URL(this.zooUrl);
      const path = parsedUrl.pathname;
      const modelsInfo = await this.fetchWithToken('/models' + path);

      let systemDeviceKeys = new Set();
      if (!this.isDummy) {
        try {
          const systemDeviceTypes = await this.systemSupportedDeviceTypes();
          systemDeviceKeys = new Set(systemDeviceTypes);
        } catch (error) {
          console.error('Error fetching system information for AIServer at', this.serverUrl, error);
          // TODO: THIS BREAKS EVERYTHING if it's turned into a DGError.
          this.isRescanComplete = true;
          return;
        }
      }

      if (!modelsInfo.error) {
        this.assets = {};
        this.modelNames = [];

        for (const [modelName, modelDetails] of Object.entries(modelsInfo)) {
          this.modelNames.push(modelName);

          if (this.isDummy) {
            this.assets[modelName] = modelDetails;
            continue;
          }

          // Bypass audio models.
          if (Array.isArray(modelDetails?.PRE_PROCESS) && modelDetails.PRE_PROCESS.length > 0 && modelDetails.PRE_PROCESS[0]?.InputType === 'Audio') {
            console.log('Skipping audio model', modelName);
            continue;
          }

          let modelSupportedTypes;
          try {
            // Get SupportedDeviceTypes if available
            modelSupportedTypes = modelDetails.DEVICE[0].SupportedDeviceTypes;
            if (modelSupportedTypes) {
              modelSupportedTypes = modelSupportedTypes.split(',').map(type => type.trim());
            } else {
              // Fallback to default DEVICE type check
              const modelRuntime = modelDetails.DEVICE[0].RuntimeAgent;
              const modelDevice = modelDetails.DEVICE[0].DeviceType;
              const capabilityKey = `${modelRuntime}/${modelDevice}`;
              if (systemDeviceKeys.has(capabilityKey)) {
                this.assets[modelName] = modelDetails;
              } else {
                console.log('Skipping model', modelName, 'because it requires', capabilityKey, 'which is not available on the system.');
              }
              continue;
            }
          } catch (error) {
            console.error('Error processing SupportedDeviceTypes for model', modelName, error);
            continue;
          }

          // Perform intersection check with SupportedDeviceTypes if available
          const matchingDevices = this.matchSupportedDevices(modelSupportedTypes, Array.from(systemDeviceKeys));
          if (matchingDevices.length > 0) {
            this.assets[modelName] = modelDetails;
          } else {
            console.log('Skipping model', modelName, 'because none of the SupportedDeviceTypes match the system devices.');
          }
        }

        this.isRescanComplete = true; // Mark rescan as complete
        console.log('rescanZoo: Rescan complete. Model list:', this.assets);
      } else {
        this.isRescanComplete = true;
        throw new DGError(modelsInfo.error, "RESCAN_ZOO_FAILED", {}, "Error occurred while rescanning zoo.");
      }
    } catch (error) {
      this.isRescanComplete = true; // Still mark rescan as complete to avoid indefinite waiting
      throw new DGError('Error rescanning zoo: ' + error, "RESCAN_ZOO_FAILED", { error: error.message }, "Error occurred while rescanning zoo.");
    }
  }

  /**
   * Determines if a system device type matches any of the model's supported device types, considering wildcards.
   * @private
   * @param {Array<string>} modelSupportedTypes - An array of strings representing the device types supported by the model. 
   * Example: ["OPENVINO/*", "TENSORRT/*", "ONNX/*"]
   * @param {Array<string>} systemDeviceTypes - An array of strings representing the device types available on the system.
   * Example: ["OPENVINO/CPU", "TENSORRT/GPU", "ONNX/CPU"]
   * 
   * @returns {Array<string>} - An array of strings representing the intersection of modelSupportedTypes and systemDeviceTypes, 
   * with wildcards considered. 
   * Example: If modelSupportedTypes is ["OPENVINO/*", "TENSORRT/*"] and systemDeviceTypes is ["OPENVINO/CPU", "TENSORRT/GPU"], 
   * it returns ["OPENVINO/CPU", "TENSORRT/GPU"].
   */
  matchSupportedDevices(modelSupportedTypes, systemDeviceTypes) {
    const matchesWildcard = (pattern, type) => {
      const [patternRuntime, patternDevice] = pattern.split('/');
      const [typeRuntime, typeDevice] = type.split('/');

      const runtimeMatches = patternRuntime === '*' || patternRuntime === typeRuntime;
      const deviceMatches = patternDevice === '*' || patternDevice === typeDevice;

      return runtimeMatches && deviceMatches;
    };

    return systemDeviceTypes.filter(systemType =>
      modelSupportedTypes.some(modelType => matchesWildcard(modelType, systemType))
    );
  }


  /**
   * Fetches the list of zoos for a given organization.
   * @param {string} organization - The name of the organization.
   * @returns {Promise<Object>} The list of zoos.
   */
  async getZoos(organization) {
    if (!organization) {
      throw new DGError('Organization name is required to get zoos.', "GET_ZOOS_FAILED", {}, "Organization name is required to get zoos.");
    }
    try {
      // Fetch the list of zoos using the organization name
      const zoos = await this.fetchWithToken(`/zoos/${organization}`);
      return zoos;
    } catch (error) {
      // Log and rethrow the error.
      throw new DGError(`Error fetching zoos for organization ${organization}. Error: ${error.message}`, "GET_ZOOS_FAILED", {}, "Error occurred while fetching zoos for organization.");
    }
  }



  /**
   * Fetches the labels for a specific model from the cloud zoo.
   * queries https://cs.degirum.com/zoo/v1/public/models/degirum/public/example_model/dictionary
   * @param {string} modelName - The name of the model.
   * @returns {Promise<Object>} The model's label dictionary.
   */
  async getModelLabels(modelName) {
    // console.log('Entered getModelLabels');

    try {
      // Use getModelFullName to construct the full path for the model dictionary.
      const fullModelName = this.getModelFullName(modelName);
      const dictionaryPath = `/models/${fullModelName}/dictionary`;
      // Fetch the labels using the full path.
      const labels = await this.fetchWithToken(dictionaryPath);
      // console.log('getModelLabels: Got labels:', labels);
      return labels;
    } catch (error) {
      // Log and rethrow the error.
      throw new DGError(`Error fetching labels for model ${modelName}:`, "GET_MODEL_LABELS_FAILED", {}, "Error occurred while fetching labels for model.");
    }
  }

  /**
   * Fetches system information from the AI server.
   * @returns {Promise<Object>} The system information.
   */
  async systemInfo() {
    try {
      const response = await fetch(`${this.httpUrl}/v1/system_info`);
      // return response.json() if it's ok
      if (response.ok) {
        return response.json();
      } else {
        throw new DGError(`HTTP error! status: ${response.status}: ${response.statusText}`, "SYSTEM_INFO_FAILED", { status: response.status }, "HTTP error occurred.");
      }
    } catch (error) {
      throw new DGError(`No AIServer was found at ${this.httpUrl}. Please check the URL.`, "NETWORK_ERROR", { url: this.httpUrl }, "Please check your network connection and the URL.", "Verify the AI Server IP / port and ensure that it is running.");
      // throw new DGError('systemInfo fetch failed: ' + error, "SYSTEM_INFO_FAILED", {}, "Error occurred while fetching system info.");
    }
  }

  /**
   * Downloads a model from the cloud zoo.
   * fetches from https://cs.degirum.com/zoo/v1/public/models/degirum/public/example_model
   * @param {string} modelName - The name of the model to download.
   * @returns {Promise<Blob>} The downloaded model as a Blob.
   */
  async downloadModel(modelName) {
    const fullModelName = this.getModelFullName(modelName);
    const modelDownloadPath = `/models/${fullModelName}`;
    try {
      const modelBlob = await this.fetchWithToken(modelDownloadPath, true); // true to accept blob response
      // await saveBlobToFile(modelBlob, destRootPath);
      // Check if the blob is valid
      if (modelBlob && modelBlob.size > 0 && modelBlob.type) {
        console.log(`Model ${modelName} downloaded successfully. Blob size: ${modelBlob.size} bytes, Type: ${modelBlob.type}`);
        return modelBlob;
      } else {
        throw new DGError(`Downloaded blob for model ${modelName} is invalid.`, "DOWNLOAD_MODEL_FAILED", {}, "Downloaded blob for model is invalid.");
      }
    } catch (error) {
      throw new DGError(`Error downloading model ${modelName}:`, "DOWNLOAD_MODEL_FAILED", {}, "Error occurred while downloading model.");
    }
  }

  /**
   * Lists all available models in the cloud zoo.
   * @example
   * let models = await zoo.listModels();
   * let modelNames = Object.keys(models);
   * @returns {Promise<Object>} An object containing the available models and their params.
   */
  async listModels() {
    try {
      await waitUntil(() => this.isRescanComplete, RESCAN_TIMEOUT_MS);
    } catch (error) {
      throw new DGError('aiservercloudzoo.js: listModels() timed out.', "LIST_MODELS_FAILED", {}, "Timeout occurred while listing models.");
    }
    return this.assets;
  }

  /**
   * Fetches the system supported device types from the AI server.
   * @private
   * @returns {Promise<Array<string>>} A promise that resolves to an array of strings in "RUNTIME/DEVICE" format.
   */
  async systemSupportedDeviceTypes() {
    // private function, to be called by dg_sdk() instead.

    try {
      let systemInfo;
      try {
        systemInfo = await this.systemInfo();
      } catch (error) {
        throw new DGError('Error fetching system information for AIServer at' + this.httpUrl + ' : ' + error, "FETCH_MODELS_FAILED", {}, "Error fetching sys info.");
      }
      if (!systemInfo.Devices) {
        throw new DGError('No devices found in system info.', "FETCH_MODELS_FAILED", {}, "No devices found in system info.");
      }
      return Object.keys(systemInfo.Devices);
    } catch (error) {
      throw new DGError('systemSupportedDeviceTypes failed: ' + error, "SYSTEM_INFO_FAILED", {}, "Error occurred while fetching system info.");
    }
  }

  /**
   * Loads a model from the cloud zoo and prepares it for inference.
   * @example
   * let model = await zoo.loadModel('someModel', { inputCropPercentage: 0.5, callback: someFunction } ); 
   * @param {string} modelName - The name of the model to load.
   * @param {Object} [options={}] - Additional options for loading the model.
   * @param {number} [options.max_q_len=10] - The maximum length of the internal inference queue.
   * @param {Function} [options.callback=null] - A callback function to handle inference results.
   * @returns {Promise<AIServerModel>} The loaded model instance.
   */
  async loadModel(modelName, options = {}) {
    // console.log('AIServerCloudZoo: Entered loadModel()');

    if (this.isDummy) {
      throw new DGError("Model loading is not supported on dummy instances.", "DUMMY_INSTANCE", {}, "Model loading is not supported on dummy instances.");
    }

    let startTime = performance.now();

    // Default values and destructuring of options
    const {
      max_q_len = 10,
      callback = null,
      ...additionalParams
    } = options;

    // Validate max_q_len
    if (!Number.isInteger(max_q_len) || max_q_len <= 0) {
      throw new DGError("Invalid value for max_q_len: It should be a positive integer.", "INVALID_INPUT", { parameter: "max_q_len", value: max_q_len }, "Please ensure max_q_len is a positive integer.");
    }

    // Validate callback
    if (callback !== null && typeof callback !== "function") {
      throw new DGError("Invalid value for callback: It should be a function.", "INVALID_INPUT", { parameter: "callback", value: callback }, "Please ensure callback is a function.");
    }

    try {
      await waitUntil(() => this.isRescanComplete, RESCAN_TIMEOUT_MS);
    } catch (error) {
      throw new DGError('aiservercloudzoo.js: rescan of models timed out within loadModel().', "RESCAN_MODELS_FAILED", {}, "Timeout occurred while waiting for models to be fetched.");
    }

    // Here we must ping the AIServer with a call to systemInfo() to ensure that the server is up and running
    // If the server is not up and running, we should throw an error
    try {
      await this.systemInfo();
    } catch (error) {
      throw new DGError('No response from AIServer at ' + this.serverUrl + '. Please check the server URL and ensure that the AIServer is running.', "NETWORK_ERROR", {}, "AIServer is not responding.");
    }

    // console.log('Using cached assets to find modelParams:', this.assets);
    // Use the cached assets from rescanZoo()
    let modelParams;
    let matchingModelName;

    // Find the model in the assets, which we originally got from calling rescanZoo()
    for (let key in this.assets) {
      if (key === modelName || key.includes(modelName)) {
        modelParams = this.assets[key];
        matchingModelName = key;  // Save the full model name
        break;
      }
    }

    if (!modelParams) {
      // First, we check to see if this.modelNames has the modelName
      if (!this.modelNames.includes(modelName)) {
        throw new DGError(`Model not found in the cloud zoo: ${modelName}. List of supported models: ${Object.keys(this.assets)}`, "MODEL_NOT_FOUND", { modelName: modelName }, "Model not found in the cloud zoo.");
      } else {
        // Model found, but not supported on the system
        throw new DGError(`Model exists in the zoo, but is not supported on the system: ${modelName}`, "MODEL_NOT_SUPPORTED", { modelName: modelName }, "Model exists in the zoo, but is not supported on the system.");
      }
    } else {
      // console.log('Model found:', matchingModelName, modelParams);
    }

    // construct the extended model name using the cloud zoo path and the simple model name...
    const extendedModelName = this.getModelFullName(matchingModelName);

    console.log(`Loading model: ${matchingModelName}, queue length: ${max_q_len}, callback specified:`, !(callback === null));

    // Download Label Dictionary
    let labels = await this.getModelLabels(matchingModelName);

    // make a deep copy of the modelParams first...
    const deepCopiedModelParams = JSON.parse(JSON.stringify(modelParams));
    // We attach the cloudURL and cloudToken to additionalParams, not to model params    
    additionalParams.cloudURL = this.zooUrl;
    additionalParams.cloudToken = this.accessToken;

    let systemDeviceTypes = await this.systemSupportedDeviceTypes(); // array of 'RUNTIME/DEVICE' strings

    // pack all arguments into a single object for AIServerModel
    const modelCreationParams = {
      modelName: extendedModelName,
      serverUrl: this.serverUrl,
      modelParams: deepCopiedModelParams,
      max_q_len: max_q_len,
      callback: callback,
      labels: labels,
      systemDeviceTypes: systemDeviceTypes
    };

    // Create the new AIServerModel with the parameters
    const model = new AIServerModel(modelCreationParams, additionalParams);

    // Wait for model.initialized to be true, then return model
    try {
      await waitUntil(() => model.initialized, MODEL_INIT_TIMEOUT_MS);
      console.log('aiservercloudzoo.js: loadModel() took', performance.now() - startTime, 'ms');
      return model;
    } catch (error) {
      throw new DGError('aiservercloudzoo: Timeout occurred while waiting for the model to initialize!', "LOAD_MODEL_FAILED", {}, "Timeout occurred while waiting for the model to initialize.", "Check the connection to the AIServer / internet.");
    }
  }
}

/**
 * Waits until a specified condition is met or a timeout occurs.
 * @private
 * @param {Function} predicate - A function that returns a boolean indicating whether the condition is met.
 * @param {number} [timeout=10000] - The maximum time to wait for the condition to be met, in milliseconds.
 * @param {number} [interval=10] - The interval at which to check the condition, in milliseconds.
 * @returns {Promise<void>} A promise that resolves when the condition is met or rejects if the timeout is reached.
 */
function waitUntil(predicate, timeout = 10000, interval = 10) {
  const startTime = Date.now();
  return new Promise((resolve, reject) => {
    const checkCondition = () => {
      if (predicate()) {
        resolve();
      } else if (Date.now() - startTime > timeout) {
        reject(new Error('Timed out waiting for condition'));
      } else {
        setTimeout(checkCondition, interval);
      }
    };
    checkCondition();
  });
}

export default AIServerCloudZoo;