// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
import { ModelExecutionError, ModelLoader as ModelLoaderBase, ModelLoadError, } from '../../core/on_device_model/types.js';
import { signal } from '../../core/reactive/signal.js';
import { assertExhaustive, assertExists, assertNotReached, } from '../../core/utils/assert.js';
import { chunkContentByWord, } from '../../core/utils/utils.js';
import { isCannedResponse, isInvalidFormatResponse, parseResponse, trimRepeatedBulletPoints, } from './on_device_model_utils.js';
import { FormatFeature, LoadModelResult as MojoLoadModelResult, ModelStateMonitorReceiver, ModelStateType, OnDeviceModelRemote, SafetyFeature, SessionRemote, StreamingResponderCallbackRouter, } from './types.js';
// The minimum transcript token length for title generation and summarization.
const MIN_TOKEN_LENGTH = 200;
// The maximum content input length for T&S model.
// Based on tests, 1k input tokens is the most efficient.
// According to the Gemini tokenizer documentation
// (https://ai.google.dev/gemini-api/docs/tokens), 100 tokens are roughly
// equivalent to 60-80 English words, chose 700 since it is average.
const MAX_TS_MODEL_INPUT_WORD_LENGTH = 700;
// The config for model repetition judgement.
// We split model response into bullet points and check string length and LCS
// (longest common subsequence) scores by words.
// If conditions are over threshold, show invalid or trim repetition.
// Thresholds are designed based on real model responses.
const MODEL_REPETITION_CONFIG = {
    maxLength: 1000,
    lcsScoreThreshold: 0.9,
};
class OnDeviceModel {
    constructor(remote, pageRemote, modelInfo) {
        this.remote = remote;
        this.pageRemote = pageRemote;
        this.modelInfo = modelInfo;
        // TODO(pihsun): Handle disconnection error
    }
    async execute(content, language) {
        const session = new SessionRemote();
        this.remote.startSession(session.$.bindNewPipeAndPassReceiver(), null);
        const result = await this.executeInRemoteSession(content, language, session);
        session.$.close();
        return result;
    }
    /**
     * Get input token size through the model.
     * Share the session from params without creating new session.
     */
    async getInputTokenSize(text, session) {
        const inputPieces = { pieces: [{ text }] };
        const { size } = await session.getSizeInTokens(inputPieces);
        return size;
    }
    /**
     * Conduct the model execute.
     * Check input token size first and then execute.
     * Share the session from params without creating new session.
     */
    async executeRaw(text, session, expectedBulletPointCount, language) {
        const inputPieces = { pieces: [{ text }] };
        const size = await this.getInputTokenSize(text, session);
        if (size < MIN_TOKEN_LENGTH) {
            console.warn(`Skip GenAI model execution: too small token size: ${size}.`);
            return {
                kind: 'error',
                error: ModelExecutionError.UNSUPPORTED_TRANSCRIPTION_IS_TOO_SHORT,
            };
        }
        if (size > this.modelInfo.inputTokenLimit) {
            console.warn(`Skip GenAI model execution: too large token size: ${size}.`);
            return {
                kind: 'error',
                error: ModelExecutionError.UNSUPPORTED_TRANSCRIPTION_IS_TOO_LONG,
            };
        }
        const responseRouter = new StreamingResponderCallbackRouter();
        // TODO(pihsun): Error handling.
        const { promise, resolve } = Promise.withResolvers();
        const response = [];
        const onResponseId = responseRouter.onResponse.addListener((chunk) => {
            response.push(chunk.text);
        });
        const onCompleteId = responseRouter.onComplete.addListener((_) => {
            responseRouter.removeListener(onResponseId);
            responseRouter.removeListener(onCompleteId);
            responseRouter.$.close();
            resolve(response.join('').trimStart());
        });
        session.append({
            maxTokens: 0,
            input: inputPieces,
        }, null);
        session.generate({
            maxOutputTokens: 0,
            constraint: null,
        }, responseRouter.$.bindNewPipeAndPassRemote());
        const result = await promise;
        // When the model returns the canned response, show the same UI as
        // unsafe content for now.
        if (isCannedResponse(result)) {
            console.warn('Invalid GenAI result: canned response.');
            return { kind: 'error', error: ModelExecutionError.UNSAFE };
        }
        const parsedResult = parseResponse(result);
        // TODO(yuanchieh): retry inference with higher temperature.
        if (isInvalidFormatResponse(parsedResult, expectedBulletPointCount)) {
            console.warn('Invalid GenAI result: invalid format.');
            return { kind: 'error', error: ModelExecutionError.UNSAFE };
        }
        const finalBulletPoints = trimRepeatedBulletPoints(parsedResult, MODEL_REPETITION_CONFIG.maxLength, language, MODEL_REPETITION_CONFIG.lcsScoreThreshold);
        // Show unsafe content if no valid bullet point.
        if (finalBulletPoints.length === 0) {
            console.warn('Invalid GenAI result: no valid bullet point.');
            return { kind: 'error', error: ModelExecutionError.UNSAFE };
        }
        // To align with model response type, concatenated bullet points back to one
        // string.
        const finalResult = finalBulletPoints.join('\n');
        return { kind: 'success', result: finalResult };
    }
    async contentIsUnsafe(content, safetyFeature, language) {
        // Split the content into chunks due to model performance considerations.
        const contentChunks = chunkContentByWord(content, MAX_TS_MODEL_INPUT_WORD_LENGTH, language);
        for (const chunk of contentChunks) {
            const { safetyInfo } = await this.remote.classifyTextSafety(chunk);
            if (safetyInfo === null) {
                continue;
            }
            const { isSafe } = await this.pageRemote.validateSafetyResult(safetyFeature, chunk, safetyInfo);
            if (!isSafe) {
                return true;
            }
        }
        return false;
    }
    close() {
        this.remote.$.close();
    }
    async formatInput(feature, fields) {
        const { result } = await this.pageRemote.formatModelInput(this.modelInfo.modelId, feature, fields);
        return result;
    }
    /**
     * Formats the prompt with the specified `formatFeature`, runs the prompt
     * through the model, and returns the result.
     *
     * The key of the fields of each different model / formatFeature
     * combination can be found in
     * //google3/chromeos/odml_foundations/lib/inference/features/models/.
     */
    async formatAndExecute(formatFeature, requestSafetyFeature, responseSafetyFeature, fields, session, language, expectedBulletPointCount) {
        const prompt = await this.formatInput(formatFeature, fields);
        if (prompt === null) {
            console.error('formatInput returns null, wrong model?');
            return { kind: 'error', error: ModelExecutionError.GENERAL };
        }
        if (await this.contentIsUnsafe(prompt, requestSafetyFeature, language)) {
            console.warn('Unsafe GenAI prompt.');
            return { kind: 'error', error: ModelExecutionError.UNSAFE };
        }
        const response = await this.executeRaw(prompt, session, expectedBulletPointCount, language);
        if (response.kind === 'error') {
            return response;
        }
        if (await this.contentIsUnsafe(response.result, responseSafetyFeature, language)) {
            console.warn('Unsafe GenAI result.');
            return { kind: 'error', error: ModelExecutionError.UNSAFE };
        }
        return { kind: 'success', result: response.result };
    }
}
export class SummaryModel extends OnDeviceModel {
    async executeInRemoteSession(content, language, session) {
        const inputTokenSize = await this.getInputTokenSize(content, session);
        // For large model, we use v2 safety feature. It only affects on response.
        const safetyFeatureOnResponse = this.modelInfo.isLargeModel ?
            SafetyFeature.kAudioSummaryResponseV2 :
            SafetyFeature.kAudioSummaryResponse;
        const expectedBulletPointCount = this.getExpectedBulletPoints(inputTokenSize);
        const bulletPointsRequest = this.formatBulletPointRequest(expectedBulletPointCount);
        const resp = await this.formatAndExecute(FormatFeature.kAudioSummary, SafetyFeature.kAudioSummaryRequest, safetyFeatureOnResponse, {
            transcription: content,
            language,
            /**
             * Param format is requested by model.
             * See
             * http://google3/chromeos/odml_foundations/lib/inference/features/models/audio_summary_v2.cc.
             */
            /* eslint-disable-next-line @typescript-eslint/naming-convention */
            bullet_points_request: bulletPointsRequest,
        }, session, language, expectedBulletPointCount);
        // TODO(pihsun): `Result` monadic helper class?
        if (resp.kind === 'error') {
            return resp;
        }
        return { kind: 'success', result: resp.result };
    }
    /**
     * Get expected bullet points by input token size.
     */
    getExpectedBulletPoints(inputTokenSize) {
        // For Xss model, return fixed 3 bullet points.
        if (!this.modelInfo.isLargeModel) {
            return 3;
        }
        if (inputTokenSize < 250) {
            return 1;
        }
        else if (inputTokenSize < 600) {
            return 2;
        }
        else if (inputTokenSize < 4000) {
            return 3;
        }
        else if (inputTokenSize < 6600) {
            return 4;
        }
        else if (inputTokenSize < 9300) {
            return 5;
        }
        else {
            return 6;
        }
    }
    /**
     * Format bullet point request to fit model prompt format.
     */
    formatBulletPointRequest(request) {
        if (request <= 0) {
            assertNotReached('Got non-positive bullet point request.');
        }
        return `${request} bullet point` + (request > 1 ? 's' : '');
    }
}
export class TitleSuggestionModel extends OnDeviceModel {
    // For title suggestion, model input only needs transcription.
    async executeInRemoteSession(content, language, session) {
        // For large model, we use v2 safety feature. It only affects on response.
        const safetyFeatureOnResponse = this.modelInfo.isLargeModel ?
            SafetyFeature.kAudioTitleResponseV2 :
            SafetyFeature.kAudioTitleResponse;
        const resp = await this.formatAndExecute(FormatFeature.kAudioTitle, SafetyFeature.kAudioTitleRequest, safetyFeatureOnResponse, {
            transcription: content,
            language,
        }, session, language, 3);
        if (resp.kind === 'error') {
            return resp;
        }
        const lines = resp.result.split('\n');
        const titles = [];
        for (const line of lines) {
            // Each line should start with `- ` and the title.
            const lineStart = '- ';
            if (line.startsWith(lineStart)) {
                titles.push(line.substring(lineStart.length));
            }
        }
        return { kind: 'success', result: titles.slice(0, 3) };
    }
}
/**
 * Converts ModelState from mojo to the `ModelState` interface.
 */
export function mojoModelStateToModelState(state) {
    switch (state.type) {
        case ModelStateType.kNotInstalled:
            return { kind: 'notInstalled' };
        case ModelStateType.kInstalling:
            return { kind: 'installing', progress: assertExists(state.progress) };
        case ModelStateType.kInstalled:
            return { kind: 'installed' };
        case ModelStateType.kNeedsReboot:
            return { kind: 'needsReboot' };
        case ModelStateType.kError:
            return { kind: 'error' };
        case ModelStateType.kUnavailable:
            return { kind: 'unavailable' };
        case ModelStateType.MIN_VALUE:
        case ModelStateType.MAX_VALUE:
            return assertNotReached(`Got MIN_VALUE or MAX_VALUE from mojo ModelStateType: ${state.type}`);
        default:
            assertExhaustive(state.type);
    }
}
function mojoLoadModelResultToModelLoadError(result) {
    switch (result) {
        case MojoLoadModelResult.kFailedToLoadLibrary:
        case MojoLoadModelResult.kGpuBlocked:
            return ModelLoadError.LOAD_FAILURE;
        case MojoLoadModelResult.kCrosNeedReboot:
            return ModelLoadError.NEEDS_REBOOT;
        case MojoLoadModelResult.kSuccess:
            return assertNotReached(`Try transforming success load result to error`);
        case MojoLoadModelResult.MIN_VALUE:
        case MojoLoadModelResult.MAX_VALUE:
            return assertNotReached(`Got MIN_VALUE or MAX_VALUE from mojo LoadModelResult: ${result}`);
        default:
            assertExhaustive(result);
    }
}
class ModelLoader extends ModelLoaderBase {
    constructor(remote, platformHandler) {
        super();
        this.remote = remote;
        this.platformHandler = platformHandler;
        this.state = signal({ kind: 'unavailable' });
        this.modelInfoInternal = null;
    }
    get modelInfo() {
        return assertExists(this.modelInfoInternal);
    }
    async init() {
        this.modelInfoInternal =
            (await this.remote.getModelInfo(this.featureType)).modelInfo;
        const update = (state) => {
            this.state.value = mojoModelStateToModelState(state);
        };
        const monitor = new ModelStateMonitorReceiver({ update });
        // This should be relatively quick since in recorder_app_ui.cc we just
        // return the cached state here, but we await here to avoid UI showing
        // temporary unavailable state.
        const { state } = await this.remote.addModelMonitor(this.modelInfo.modelId, monitor.$.bindNewPipeAndPassRemote());
        update(state);
    }
    async load() {
        const newModel = new OnDeviceModelRemote();
        const { result } = await this.remote.loadModel(this.modelInfo.modelId, newModel.$.bindNewPipeAndPassReceiver());
        if (result !== MojoLoadModelResult.kSuccess) {
            console.error('Load model failed:', result);
            return {
                kind: 'error',
                error: mojoLoadModelResultToModelLoadError(result),
            };
        }
        return { kind: 'success', model: this.createModel(newModel) };
    }
    async loadAndExecute(content, language) {
        if (!this.platformHandler.getLangPackInfo(language).isGenAiSupported) {
            console.warn(`Skip GenAI model execution: unsupported language: ${language}.`);
            return { kind: 'error', error: ModelExecutionError.UNSUPPORTED_LANGUAGE };
        }
        const loadResult = await this.load();
        if (loadResult.kind === 'error') {
            return loadResult;
        }
        try {
            return await loadResult.model.execute(content, language);
        }
        finally {
            loadResult.model.close();
        }
    }
}
export class SummaryModelLoader extends ModelLoader {
    constructor() {
        super(...arguments);
        this.featureType = FormatFeature.kAudioSummary;
    }
    createModel(remote) {
        return new SummaryModel(remote, this.remote, this.modelInfo);
    }
}
export class TitleSuggestionModelLoader extends ModelLoader {
    constructor() {
        super(...arguments);
        this.featureType = FormatFeature.kAudioTitle;
    }
    createModel(remote) {
        return new TitleSuggestionModel(remote, this.remote, this.modelInfo);
    }
}
