mod anthropic_client;
mod distill;
mod example;
mod format_prompt;
mod git;
mod headless;
mod load_project;
mod metrics;
mod paths;
mod predict;
mod progress;
mod pull_examples;
mod reorder_patch;
mod retrieve_context;
mod score;
mod split_commit;
mod synthesize;
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
use edit_prediction::EditPredictionStore;
use gpui::Application;
use reqwest_client::ReqwestClient;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use std::{path::PathBuf, sync::Arc};

use crate::distill::run_distill;
use crate::example::{Example, group_examples_by_repo, read_example_files, write_examples};
use crate::format_prompt::run_format_prompt;
use crate::load_project::run_load_project;
use crate::paths::FAILED_EXAMPLES_DIR;
use crate::predict::run_prediction;
use crate::progress::Progress;
use crate::retrieve_context::run_context_retrieval;
use crate::score::run_scoring;
use crate::split_commit::SplitCommitArgs;
use crate::synthesize::{SynthesizeConfig, run_synthesize};

#[derive(Parser, Debug)]
#[command(name = "ep")]
struct EpArgs {
    #[arg(long, default_value_t = false)]
    printenv: bool,
    #[clap(long, default_value_t = 10, global = true)]
    max_parallelism: usize,
    #[clap(long, global = true)]
    limit: Option<usize>,
    /// Filter examples by name
    #[clap(long, global = true)]
    name: Option<String>,
    /// Filter examples by repository
    #[clap(long, global = true)]
    repo: Option<String>,
    #[command(subcommand)]
    command: Option<Command>,
    #[clap(global = true, help = INPUTS_HELP)]
    inputs: Vec<PathBuf>,
    #[arg(long, short, global = true)]
    output: Option<PathBuf>,
    #[arg(long, short, global = true)]
    in_place: bool,
    #[arg(long, short, global = true)]
    failfast: bool,
}

const INPUTS_HELP: &str = r#"
Inputs can be file paths or special specifiers:

  path
      Path to an example(s) file (.md, .json, or .jsonl)

  captured-after:{timestamp}
      Fetch captured examples from Snowflake after the given RFC3339 timestamp.

      You can specify this multiple times and mix it with file inputs.

      Required environment variables to connect to Snowflake:
          EP_SNOWFLAKE_API_KEY
          EP_SNOWFLAKE_BASE_URL

      Optional:
          EP_SNOWFLAKE_ROLE

Examples:

  # Predict from a file
  ep predict examples.jsonl

  # Predict from captured examples after a timestamp
  ep predict captured-after:2025-01-01T00:00:00Z

  # Mix file inputs and captured-after in the same invocation
  ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
"#;

#[derive(Subcommand, Debug, Clone)]
enum Command {
    /// Parse markdown examples and output a combined .jsonl file
    ParseExample,
    /// Create git worktrees for each example and load file contents
    LoadProject,
    /// Retrieve context for input examples.
    Context,
    /// Generate a prompt string for a specific model
    FormatPrompt(FormatPromptArgs),
    /// Runs edit prediction
    Predict(PredictArgs),
    /// Computes a score based on actual and expected patches
    Score(PredictArgs),
    /// Prepares a distillation dataset by copying expected outputs to
    /// predicted outputs and removing actual outputs and prompts.
    Distill,
    /// Print aggregated scores
    Eval(PredictArgs),
    /// Generate eval examples by analyzing git commits from a repository
    Synthesize(SynthesizeArgs),
    /// Remove git repositories and worktrees
    Clean,
    /// Generate an evaluation example by splitting a chronologically-ordered commit
    SplitCommit(SplitCommitArgs),
}

impl Display for Command {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Command::ParseExample => write!(f, "parse-example"),
            Command::LoadProject => write!(f, "load-project"),
            Command::Context => write!(f, "context"),
            Command::FormatPrompt(format_prompt_args) => write!(
                f,
                "format-prompt --prompt-format={}",
                format_prompt_args
                    .prompt_format
                    .to_possible_value()
                    .unwrap()
                    .get_name()
            ),
            Command::Predict(predict_args) => {
                write!(
                    f,
                    "predict --provider={:?}",
                    predict_args
                        .provider
                        .to_possible_value()
                        .unwrap()
                        .get_name()
                )
            }
            Command::Score(predict_args) => {
                write!(
                    f,
                    "score --provider={:?}",
                    predict_args
                        .provider
                        .to_possible_value()
                        .unwrap()
                        .get_name()
                )
            }
            Command::Distill => write!(f, "distill"),
            Command::Eval(predict_args) => write!(
                f,
                "eval --provider={:?}",
                predict_args
                    .provider
                    .to_possible_value()
                    .unwrap()
                    .get_name()
            ),
            Command::Synthesize(args) => {
                write!(f, "synthesize --repo={}", args.repo)
            }
            Command::Clean => write!(f, "clean"),
            Command::SplitCommit(_) => write!(f, "split-commit"),
        }
    }
}

#[derive(Debug, Args, Clone)]
struct FormatPromptArgs {
    #[clap(long)]
    prompt_format: PromptFormat,
}

#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
enum PromptFormat {
    Teacher,
    Zeta2,
}

#[derive(Debug, Args, Clone)]
struct PredictArgs {
    #[clap(long)]
    provider: PredictionProvider,
    #[clap(long, default_value_t = 1)]
    repetitions: usize,
}

#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
enum PredictionProvider {
    Sweep,
    Mercury,
    Zeta1,
    Zeta2,
    Teacher,
    TeacherNonBatching,
}

#[derive(Debug, Args, Clone)]
struct SynthesizeArgs {
    /// Repository URL (git@github.com:owner/repo or https://...)
    #[clap(long)]
    repo: String,

    /// Number of examples to generate
    #[clap(long, default_value_t = 5)]
    count: usize,

    /// Maximum commits to scan before giving up
    #[clap(long, default_value_t = 100)]
    max_commits: usize,

    /// Ignore state file and reprocess all commits
    #[clap(long)]
    fresh: bool,
}

impl EpArgs {
    fn output_path(&self) -> Option<PathBuf> {
        if self.in_place {
            if self.inputs.len() == 1 {
                self.inputs.first().cloned()
            } else {
                panic!("--in-place requires exactly one input file")
            }
        } else {
            self.output.clone()
        }
    }
}

async fn load_examples(
    http_client: Arc<dyn http_client::HttpClient>,
    args: &EpArgs,
) -> anyhow::Result<Vec<Example>> {
    let mut captured_after_timestamps = Vec::new();
    let mut file_inputs = Vec::new();

    for input in &args.inputs {
        let input_string = input.to_string_lossy();
        if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
            captured_after_timestamps.push(timestamp.to_string());
        } else {
            file_inputs.push(input.clone());
        }
    }

    let mut examples = read_example_files(&file_inputs);

    Progress::global().set_total_examples(examples.len());

    let remaining_limit_for_snowflake =
        args.limit.map(|limit| limit.saturating_sub(examples.len()));

    if let Some(0) = remaining_limit_for_snowflake {
        log::info!(
            "skipping captured-after inputs because --limit is already satisfied by example files"
        );
    } else if !captured_after_timestamps.is_empty() {
        captured_after_timestamps.sort();

        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);

        let mut captured_examples = pull_examples::fetch_captured_examples_after(
            http_client,
            &captured_after_timestamps,
            max_rows_per_timestamp,
        )
        .await?;
        examples.append(&mut captured_examples);
    }

    crate::example::sort_examples_by_repo_and_rev(&mut examples);

    if let Some(name_filter) = &args.name {
        examples.retain(|example| example.spec.name.contains(name_filter));
    }
    if let Some(repo_filter) = &args.repo {
        examples.retain(|example| example.spec.repository_url.contains(repo_filter));
    }

    if let Some(limit) = args.limit {
        if examples.len() > limit {
            examples.truncate(limit);
        }
    }

    Progress::global().set_total_examples(examples.len());

    Ok(examples)
}

fn main() {
    let args = EpArgs::parse();

    if args.printenv {
        ::util::shell_env::print_env();
        return;
    }

    let output = args.output_path();
    let command = match &args.command {
        Some(cmd) => cmd.clone(),
        None => {
            EpArgs::command().print_help().unwrap();
            return;
        }
    };

    match &command {
        Command::Clean => {
            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
            return;
        }
        Command::Synthesize(synth_args) => {
            let Some(output_dir) = args.output else {
                panic!("output dir is required");
            };
            let config = SynthesizeConfig {
                repo_url: synth_args.repo.clone(),
                count: synth_args.count,
                max_commits: synth_args.max_commits,
                output_dir,
                fresh: synth_args.fresh,
            };
            smol::block_on(async {
                if let Err(e) = run_synthesize(config).await {
                    eprintln!("Error: {:?}", e);
                    std::process::exit(1);
                }
            });
            return;
        }
        Command::SplitCommit(split_commit_args) => {
            if let Err(error) =
                split_commit::run_split_commit(split_commit_args, &args.inputs, output.as_ref())
            {
                eprintln!("{error:#}");
                std::process::exit(1);
            }
            return;
        }
        _ => {}
    }

    let http_client = Arc::new(ReqwestClient::new());
    let app = Application::headless().with_http_client(http_client);

    app.run(move |cx| {
        let app_state = Arc::new(headless::init(cx));
        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);

        cx.spawn(async move |cx| {
            let result = async {
                let mut examples = load_examples(app_state.client.http_client(), &args).await?;

                if let Command::Predict(args) = &command {
                    predict::sync_batches(&args.provider).await?;
                }

                let failfast_on_single_example = examples.len() == 1;

                let mut grouped_examples = group_examples_by_repo(&mut examples);
                let example_batches = grouped_examples.chunks_mut(args.max_parallelism);

                for example_batch in example_batches {
                    let futures = example_batch.into_iter().map(|repo_examples| async {
                        for example in repo_examples.iter_mut() {
                            let result = async {
                                match &command {
                                    Command::ParseExample => {}
                                    Command::LoadProject => {
                                        run_load_project(example, app_state.clone(), cx.clone())
                                            .await?;
                                    }
                                    Command::Context => {
                                        run_context_retrieval(
                                            example,
                                            app_state.clone(),
                                            cx.clone(),
                                        )
                                        .await?;
                                    }
                                    Command::FormatPrompt(args) => {
                                        run_format_prompt(
                                            example,
                                            args.prompt_format,
                                            app_state.clone(),
                                            cx.clone(),
                                        )
                                        .await?;
                                    }
                                    Command::Predict(args) => {
                                        run_prediction(
                                            example,
                                            Some(args.provider),
                                            args.repetitions,
                                            app_state.clone(),
                                            cx.clone(),
                                        )
                                        .await?;
                                    }
                                    Command::Distill => {
                                        run_distill(example).await?;
                                    }
                                    Command::Score(args) | Command::Eval(args) => {
                                        run_scoring(example, &args, app_state.clone(), cx.clone())
                                            .await?;
                                    }
                                    Command::Clean
                                    | Command::Synthesize(_)
                                    | Command::SplitCommit(_) => {
                                        unreachable!()
                                    }
                                }
                                anyhow::Ok(())
                            }
                            .await;

                            if let Err(error) = result {
                                handle_error(
                                    error,
                                    &args,
                                    &command,
                                    &app_state,
                                    failfast_on_single_example,
                                    example,
                                )
                                .await;
                            }
                        }
                    });
                    futures::future::join_all(futures).await;
                }
                Progress::global().finalize();

                if args.output.is_some() || !matches!(command, Command::Eval(_)) {
                    write_examples(&examples, output.as_ref());
                }

                match &command {
                    Command::Predict(args) => predict::sync_batches(&args.provider).await?,
                    Command::Eval(_) => score::print_report(&examples),
                    _ => (),
                };

                anyhow::Ok(())
            }
            .await;

            if let Err(e) = result {
                panic!("Fatal error: {:?}", e);
            }

            let _ = cx.update(|cx| cx.quit());
        })
        .detach();
    });
}

async fn handle_error(
    error: anyhow::Error,
    args: &EpArgs,
    command: &Command,
    app_state: &Arc<headless::EpAppState>,
    failfast_on_single_example: bool,
    example: &Example,
) {
    Progress::global().increment_failed();
    let example_name = example.spec.filename();
    let failed_example_path = FAILED_EXAMPLES_DIR.join(format!("{}.json", example_name));
    app_state
        .fs
        .write(
            &failed_example_path,
            &serde_json::to_vec_pretty(&example).unwrap(),
        )
        .await
        .unwrap();
    let err_path = FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example_name));
    app_state
        .fs
        .write(&err_path, format!("{error:?}").as_bytes())
        .await
        .unwrap();

    let file_path = example
        .repo_name()
        .unwrap()
        .worktree_path()
        .join(&example.spec.cursor_path);

    let msg = format!(
        indoc::indoc! {"
            While processing \"{}\":

            {:?}

            Written to: \x1b[36m{}\x1b[0m

            Cursor File: \x1b[36m{}\x1b[0m

            Explore this example data with:
            fx \x1b[36m{}\x1b[0m

            Re-run this example with:
            cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
        "},
        example.spec.name,
        error,
        err_path.display(),
        file_path.display(),
        failed_example_path.display(),
        command,
        failed_example_path.display(),
    );
    if args.failfast || failfast_on_single_example {
        Progress::global().finalize();
        panic!("{}", msg);
    } else {
        log::error!("{}", msg);
    }
}
