223 lines
7.9 KiB
Markdown
223 lines
7.9 KiB
Markdown
# Task 5 — `ExerciseChecker` (`&dyn Backend`, 2σ tolerance, statevector check)
|
||
|
||
> **Index:** [README](README.md). **Spec:** [design](../../specs/2026-04-29-quantum-tutor-design.md).
|
||
|
||
## Goal
|
||
|
||
Verify whether a submitted circuit satisfies an `ExerciseCriteria`. Takes `&dyn Backend` so the checker is testable with mocks and remains compatible with V1.5 IBM (CLAUDE.md §3, spec §3). Runs 1024 fixed shots, applies 2σ tolerance, and validates `statevector_check` when present.
|
||
|
||
## Prerequisites
|
||
|
||
- Task 2 merged (CurriculumLoader available).
|
||
- Task 4 merged (CircuitAnalyzer not strictly required here, but typically present at this point).
|
||
|
||
## Files
|
||
|
||
- Modify: `src/tutor.rs`
|
||
|
||
## Steps
|
||
|
||
- [ ] **Step 1: Append failing tests to the existing `#[cfg(test)]` block in `src/tutor.rs`**
|
||
|
||
```rust
|
||
use crate::executor::{Backend, LocalSimulator};
|
||
|
||
const X_CIRCUIT: &str = "OPENQASM 3.0;\ninclude \"stdgates.inc\";\nqubit[1] q;\nbit[1] c;\nx q[0];\nc = measure q;";
|
||
const IDENTITY_CIRCUIT: &str = "OPENQASM 3.0;\ninclude \"stdgates.inc\";\nqubit[1] q;\nbit[1] c;\nc = measure q;";
|
||
const H_CIRCUIT: &str = "OPENQASM 3.0;\ninclude \"stdgates.inc\";\nqubit[1] q;\nbit[1] c;\nh q[0];\nc = measure q;";
|
||
|
||
fn backend() -> LocalSimulator { LocalSimulator::new() }
|
||
|
||
#[test]
|
||
fn x_circuit_passes_exercise_requiring_bitstring_1() {
|
||
let criteria = ExerciseCriteria {
|
||
required_outcomes: vec![RequiredOutcome { bitstring: "1".into(), min_ratio: Some(0.99) }],
|
||
forbidden_outcomes: vec!["0".into()],
|
||
statevector_check: None,
|
||
};
|
||
let result = ExerciseChecker::check_circuit(&backend() as &dyn Backend, X_CIRCUIT, &criteria);
|
||
assert!(result.passed, "counts: {:?}", result.counts);
|
||
}
|
||
|
||
#[test]
|
||
fn identity_circuit_fails_exercise_requiring_bitstring_1() {
|
||
let criteria = ExerciseCriteria {
|
||
required_outcomes: vec![RequiredOutcome { bitstring: "1".into(), min_ratio: Some(0.99) }],
|
||
forbidden_outcomes: vec!["0".into()],
|
||
statevector_check: None,
|
||
};
|
||
let result = ExerciseChecker::check_circuit(&backend() as &dyn Backend, IDENTITY_CIRCUIT, &criteria);
|
||
assert!(!result.passed);
|
||
}
|
||
|
||
#[test]
|
||
fn h_circuit_passes_balanced_outcomes_with_2sigma_tolerance() {
|
||
let criteria = ExerciseCriteria {
|
||
required_outcomes: vec![
|
||
RequiredOutcome { bitstring: "0".into(), min_ratio: Some(0.4) },
|
||
RequiredOutcome { bitstring: "1".into(), min_ratio: Some(0.4) },
|
||
],
|
||
forbidden_outcomes: vec![],
|
||
statevector_check: None,
|
||
};
|
||
let result = ExerciseChecker::check_circuit(&backend() as &dyn Backend, H_CIRCUIT, &criteria);
|
||
assert!(result.passed, "counts: {:?}", result.counts);
|
||
}
|
||
|
||
#[test]
|
||
fn invalid_circuit_returns_diagnostics_not_panic() {
|
||
let criteria = ExerciseCriteria::default();
|
||
let result = ExerciseChecker::check_circuit(&backend() as &dyn Backend, "not valid qasm", &criteria);
|
||
assert!(!result.passed);
|
||
assert!(result.error.is_some());
|
||
}
|
||
|
||
#[test]
|
||
fn statevector_check_validates_bell_state_amplitudes() {
|
||
const BELL: &str = "OPENQASM 3.0;\ninclude \"stdgates.inc\";\nqubit[2] q;\nbit[2] c;\nh q[0];\ncx q[0], q[1];\nc = measure q;";
|
||
let criteria = ExerciseCriteria {
|
||
required_outcomes: vec![],
|
||
forbidden_outcomes: vec![],
|
||
statevector_check: Some(StatevectorCheck {
|
||
non_zero_amplitude_indices: vec![0, 3],
|
||
zero_amplitude_indices: vec![1, 2],
|
||
tolerance: 1e-6,
|
||
}),
|
||
};
|
||
let result = ExerciseChecker::check_circuit(&backend() as &dyn Backend, BELL, &criteria);
|
||
assert!(result.passed, "counts: {:?}", result.counts);
|
||
}
|
||
```
|
||
|
||
- [ ] **Step 2: Implement `ExerciseChecker` (above the `#[cfg(test)]` block)**
|
||
|
||
```rust
|
||
use crate::executor::{Backend, MAX_LOCAL_QUBITS};
|
||
use crate::types::{CircuitSource, ShotCount, ValidationDiagnostic};
|
||
use crate::validator::CircuitValidator;
|
||
|
||
pub struct CheckResult {
|
||
pub passed: bool,
|
||
pub counts: HashMap<String, u64>,
|
||
pub diagnostics: Vec<ValidationDiagnostic>,
|
||
pub error: Option<String>,
|
||
}
|
||
|
||
pub struct ExerciseChecker;
|
||
|
||
impl ExerciseChecker {
|
||
const CHECK_SHOTS: u32 = 1024;
|
||
/// Spec §3 — pass if `count ≥ (min_ratio - 2σ) × N`.
|
||
const SIGMA_MULTIPLIER: f64 = 2.0;
|
||
|
||
pub fn check_circuit(
|
||
backend: &dyn Backend,
|
||
circuit_source: &str,
|
||
criteria: &ExerciseCriteria,
|
||
) -> CheckResult {
|
||
let source = CircuitSource(circuit_source.to_string());
|
||
|
||
let validator = CircuitValidator::new(MAX_LOCAL_QUBITS);
|
||
let validation = match validator.validate(&source) {
|
||
Err(e) => return CheckResult {
|
||
passed: false, counts: HashMap::new(), diagnostics: vec![],
|
||
error: Some(e.to_string()),
|
||
},
|
||
Ok(v) => v,
|
||
};
|
||
if !validation.is_valid {
|
||
let summary = validation
|
||
.diagnostics
|
||
.iter()
|
||
.map(|d| d.message.as_str())
|
||
.collect::<Vec<_>>()
|
||
.join("; ");
|
||
return CheckResult {
|
||
passed: false, counts: HashMap::new(),
|
||
diagnostics: validation.diagnostics, error: Some(summary),
|
||
};
|
||
}
|
||
|
||
let need_sv = criteria.statevector_check.is_some();
|
||
let result = match backend.run(&source, ShotCount(Self::CHECK_SHOTS), need_sv) {
|
||
Err(e) => return CheckResult {
|
||
passed: false, counts: HashMap::new(), diagnostics: vec![],
|
||
error: Some(e.to_string()),
|
||
},
|
||
Ok(r) => r,
|
||
};
|
||
|
||
let total = result.shots as f64;
|
||
let counts_pass = Self::counts_pass(&result.counts, criteria, total);
|
||
let sv_pass = match (&criteria.statevector_check, &result.statevector) {
|
||
(Some(check), Some(sv)) => Self::statevector_pass(sv, check),
|
||
(Some(_), None) => false,
|
||
(None, _) => true,
|
||
};
|
||
|
||
CheckResult {
|
||
passed: counts_pass && sv_pass,
|
||
counts: result.counts,
|
||
diagnostics: vec![],
|
||
error: None,
|
||
}
|
||
}
|
||
|
||
fn counts_pass(
|
||
counts: &HashMap<String, u64>,
|
||
criteria: &ExerciseCriteria,
|
||
total: f64,
|
||
) -> bool {
|
||
for req in &criteria.required_outcomes {
|
||
let count = counts.get(&req.bitstring).copied().unwrap_or(0) as f64;
|
||
match req.min_ratio {
|
||
Some(min_ratio) => {
|
||
let sigma = (min_ratio * (1.0 - min_ratio) / total).sqrt();
|
||
let threshold = (min_ratio - Self::SIGMA_MULTIPLIER * sigma) * total;
|
||
if count < threshold { return false; }
|
||
}
|
||
None => if count == 0.0 { return false; },
|
||
}
|
||
}
|
||
for forbidden in &criteria.forbidden_outcomes {
|
||
if counts.get(forbidden).copied().unwrap_or(0) > 0 { return false; }
|
||
}
|
||
true
|
||
}
|
||
|
||
fn statevector_pass(sv: &[(f64, f64)], check: &StatevectorCheck) -> bool {
|
||
let magnitude = |idx: usize| -> Option<f64> {
|
||
sv.get(idx).map(|(r, i)| (r * r + i * i).sqrt())
|
||
};
|
||
for &idx in &check.non_zero_amplitude_indices {
|
||
match magnitude(idx) {
|
||
Some(m) if m > check.tolerance => {}
|
||
_ => return false,
|
||
}
|
||
}
|
||
for &idx in &check.zero_amplitude_indices {
|
||
match magnitude(idx) {
|
||
Some(m) if m <= check.tolerance => {}
|
||
_ => return false,
|
||
}
|
||
}
|
||
true
|
||
}
|
||
}
|
||
```
|
||
|
||
- [ ] **Step 3: Run tests (Green)**
|
||
|
||
```bash
|
||
cargo test tutor::tests 2>&1 | grep -E "test result|FAILED"
|
||
```
|
||
|
||
Expected: `test result: ok. 11 passed`.
|
||
|
||
- [ ] **Step 4: Commit**
|
||
|
||
```bash
|
||
git add src/tutor.rs
|
||
git commit -m "feat: ExerciseChecker with Backend injection, 2σ tolerance, statevector check"
|
||
```
|