import Replicate, { Prediction } from "replicate";
import {
  ReplicateIncrediblyFastWhisperInput,
  ReplicateWhisperXPredictionOutput,
  ReplicateWhisperXInput,
  Transcription,
  ReplicateIncrediblyFastWhisperPredictionOuput,
  ReplicateWhisperXSpanishInput,
  ReplicateWhisperXSpanishPredictionOutput,
} from "@/api-lib";
import { LanguageCode } from "@/constants";

interface Env {
  REPLICATE_API_KEY: string;
  HUGGINGFACE_ACCESS_TOKEN: string;
}

export class ReplicateClient {
  private client: Replicate;
  private huggingfaceAccessToken: string;
  private webhookUrl: string = "https://1transcribe.com/api/webhook/replicate";
  static readonly whisperXModelVersion =
    "1395a1d7aa48a01094887250475f384d4bae08fd0616f9c405bb81d4174597ea";
  static readonly incrediblyFastWhisperModelVersion =
    "3ab86df6c8f54c11309d4d1f930ac292bad43ace52d10c80d87eb258b3c9f79c";
  static readonly whisperXSpanishModelVersion =
    "f96c64c0ec3ea5e696c6afd9d30df640b548f44744091b1c2f44c6af01b5838f";

  constructor(env: Env) {
    this.client = new Replicate({ auth: env.REPLICATE_API_KEY });
    this.huggingfaceAccessToken = env.HUGGINGFACE_ACCESS_TOKEN;
  }

  private toIncrediblyFastWhisperLanguageCode(languageCode: string) {
    const languageMap: { [key in LanguageCode]: string } = {
      af: "afrikaans",
      sq: "albanian",
      am: "amharic",
      ar: "arabic",
      hy: "armenian",
      as: "assamese",
      az: "azerbaijani",
      ba: "bashkir",
      eu: "basque",
      be: "belarusian",
      bn: "bengali",
      bs: "bosnian",
      br: "breton",
      bg: "bulgarian",
      yue: "cantonese",
      ca: "catalan",
      zh: "chinese",
      hr: "croatian",
      cs: "czech",
      da: "danish",
      nl: "dutch",
      en: "english",
      et: "estonian",
      fo: "faroese",
      fi: "finnish",
      fr: "french",
      gl: "galician",
      ka: "georgian",
      de: "german",
      el: "greek",
      gu: "gujarati",
      ht: "haitian creole",
      ha: "hausa",
      haw: "hawaiian",
      he: "hebrew",
      hi: "hindi",
      hu: "hungarian",
      is: "icelandic",
      id: "indonesian",
      it: "italian",
      ja: "japanese",
      jw: "javanese",
      kn: "kannada",
      kk: "kazakh",
      km: "khmer",
      ko: "korean",
      lo: "lao",
      la: "latin",
      lv: "latvian",
      ln: "lingala",
      lt: "lithuanian",
      lb: "luxembourgish",
      mk: "macedonian",
      mg: "malagasy",
      ms: "malay",
      ml: "malayalam",
      mt: "maltese",
      mi: "maori",
      mr: "marathi",
      mn: "mongolian",
      my: "myanmar",
      ne: "nepali",
      no: "norwegian",
      nn: "nynorsk",
      oc: "occitan",
      ps: "pashto",
      fa: "persian",
      pl: "polish",
      pt: "portuguese",
      pa: "punjabi",
      ro: "romanian",
      ru: "russian",
      sa: "sanskrit",
      sr: "serbian",
      sn: "shona",
      sd: "sindhi",
      si: "sinhala",
      sk: "slovak",
      sl: "slovenian",
      so: "somali",
      es: "spanish",
      su: "sundanese",
      sw: "swahili",
      sv: "swedish",
      tl: "tagalog",
      tg: "tajik",
      ta: "tamil",
      tt: "tatar",
      te: "telugu",
      th: "thai",
      bo: "tibetan",
      tr: "turkish",
      tk: "turkmen",
      uk: "ukrainian",
      ur: "urdu",
      uz: "uzbek",
      vi: "vietnamese",
      cy: "welsh",
      yi: "yiddish",
      yo: "yoruba",
    };

    return languageMap[languageCode];
  }

  static getAudioFileUrl = (prediction: Prediction): string => {
    switch (prediction.version) {
      case ReplicateClient.whisperXModelVersion:
        return (prediction.input as ReplicateWhisperXInput).audio_file;

      case ReplicateClient.incrediblyFastWhisperModelVersion:
        return (prediction.input as ReplicateIncrediblyFastWhisperInput).audio;

      case ReplicateClient.whisperXSpanishModelVersion:
        return (prediction.input as ReplicateWhisperXSpanishInput).audio;
    }

    return "";
  };

  static getOutputSegments = (
    prediction: Prediction
  ): ReplicateWhisperXPredictionOutput["segments"] => {
    switch (prediction.version) {
      case ReplicateClient.whisperXModelVersion:
        return (prediction.output as ReplicateWhisperXPredictionOutput)
          .segments;

      case ReplicateClient.whisperXSpanishModelVersion:
        return prediction.output as ReplicateWhisperXSpanishPredictionOutput;

      case ReplicateClient.incrediblyFastWhisperModelVersion:
        const output =
          prediction.output as ReplicateIncrediblyFastWhisperPredictionOuput;

        return output.chunks.map((chunk) => ({
          start: chunk.timestamp[0],
          end: chunk.timestamp[1],
          text: chunk.text,
        }));
    }

    return [];
  };

  async runWhisperX(
    params: Pick<
      Transcription,
      | "fileUrl"
      | "id"
      | "isRecording"
      | "withSpeakerLabels"
      | "languageCode"
      | "numberOfSpeakers"
    >
  ): Promise<void> {
    if (params.languageCode === "es" && !params.withSpeakerLabels) {
      return await this.runWhisperXSpanish(params);
    }

    const input: ReplicateWhisperXInput = {
      audio_file: params.fileUrl!,
      language: params.languageCode!,
      language_detection_min_prob: 0,
      language_detection_max_tries: 5,
      batch_size: 64,
      temperature: 0,
      vad_onset: 0.5,
      vad_offset: 0.363,
      align_output: false,
      diarization: params.withSpeakerLabels!,
      huggingface_access_token: this.huggingfaceAccessToken,
      debug: true,
    };

    if (params.numberOfSpeakers) {
      input.max_speakers = params.numberOfSpeakers;
    }

    await this.client.predictions.create({
      version: ReplicateClient.whisperXModelVersion,
      input,
      webhook: this.webhookUrl,
      webhook_events_filter: ["completed"],
    });
  }

  async runIncrediblyFastWhisper(
    params: Pick<Transcription, "fileUrl" | "languageCode">
  ) {
    const input: ReplicateIncrediblyFastWhisperInput = {
      task: "transcribe",
      audio: params.fileUrl!,
      language: this.toIncrediblyFastWhisperLanguageCode(params.languageCode!),
      timestamp: "chunk",
      batch_size: 64,
      diarise_audio: false,
    };

    await this.client.predictions.create({
      version: ReplicateClient.incrediblyFastWhisperModelVersion,
      input,
      webhook: this.webhookUrl,
      webhook_events_filter: ["completed"],
    });
  }

  async runWhisperXSpanish(params: Pick<Transcription, "fileUrl">) {
    const input: ReplicateWhisperXSpanishInput = {
      audio: params.fileUrl!,
      debug: true,
      batch_size: 32,
      diarization: false,
    };

    await this.client.predictions.create({
      version: ReplicateClient.whisperXSpanishModelVersion,
      input,
      webhook: this.webhookUrl,
      webhook_events_filter: ["completed"],
    });
  }
}
